1use std::io::Error as IoError;
2use std::path::Path;
3use std::fmt;
4use std::fmt::Display;
5
6use log::trace;
7
8use kf_protocol::bytes::Buf;
9use kf_protocol::bytes::BufMut;
10use kf_protocol::Decoder;
11use kf_protocol::Encoder;
12use kf_protocol::Version;
13
14use crate::Request;
15use crate::RequestHeader;
16use crate::ResponseMessage;
17
18#[derive(Debug)]
20pub struct RequestMessage<R> {
21 pub header: RequestHeader,
22 pub request: R,
23}
24
25impl<R> RequestMessage<R> {
26 pub fn get_mut_header(&mut self) -> &mut RequestHeader {
27 &mut self.header
28 }
29}
30
31
32impl <R>fmt::Display for RequestMessage<R> where R: Display{
33 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
34 write!(f,"{} {}",self.header,self.request)
35 }
36}
37
38
39impl<R> Default for RequestMessage<R>
40where
41 R: Request + Default,
42{
43 fn default() -> Self {
44 let mut header = RequestHeader::default();
45 header.set_api_version(R::DEFAULT_API_VERSION);
46
47 Self {
48 header,
49 request: R::default(),
50 }
51 }
52}
53
54impl<R> RequestMessage<R>
55where
56 R: Request,
57{
58 pub fn new(header: RequestHeader, request: R) -> Self {
61 Self { header, request }
62 }
63
64 pub fn new_request(request: R) -> Self {
66 let mut header = RequestHeader::new(R::API_KEY);
67 header.set_api_version(R::DEFAULT_API_VERSION);
68
69 Self { header, request }
70 }
71
72 pub fn get_header_request(self) -> (RequestHeader, R) {
73 (self.header, self.request)
74 }
75
76 pub fn request(&self) -> &R {
77 &self.request
78 }
79
80 pub fn new_response(&self, response: R::Response) -> ResponseMessage<R::Response> {
81 Self::response_with_header(&self.header, response)
82 }
83
84 pub fn response_with_header<H>(header: H, response: R::Response) -> ResponseMessage<R::Response>
85 where
86 H: Into<i32>,
87 {
88 ResponseMessage::new(header.into(), response)
89 }
90
91 pub fn decode_response<T>(
92 &self,
93 src: &mut T,
94 version: Version,
95 ) -> Result<ResponseMessage<R::Response>, IoError>
96 where
97 T: Buf,
98 {
99 ResponseMessage::decode_from(src, version)
100 }
101
102 pub fn decode_response_from_file<H: AsRef<Path>>(
103 &self,
104 file_name: H,
105 version: Version,
106 ) -> Result<ResponseMessage<R::Response>, IoError> {
107 ResponseMessage::decode_from_file(file_name, version)
108 }
109
110 pub fn set_client_id<T>(mut self, client_id: T) -> Self
112 where
113 T: Into<String>,
114 {
115 self.header.set_client_id(client_id);
116 self
117 }
118}
119
120impl<R> Decoder for RequestMessage<R>
121where
122 R: Request,
123{
124 fn decode<T>(&mut self, src: &mut T, version: Version) -> Result<(), IoError>
125 where
126 T: Buf,
127 {
128 self.header.decode(src, version)?;
129 self.request.decode(src, self.header.api_version())?;
130 Ok(())
131 }
132}
133
134impl<R> Encoder for RequestMessage<R>
135where
136 R: Request,
137{
138 fn write_size(&self, version: Version) -> usize {
139 self.header.write_size(version) + self.request.write_size(self.header.api_version())
140 }
141
142 fn encode<T>(&self, out: &mut T, version: Version) -> Result<(), IoError>
143 where
144 T: BufMut,
145 {
146 let len = self.write_size(version) as i32;
147 trace!("encoding kf request: {} version: {}, len: {}", std::any::type_name::<R>(),version,len);
148 len.encode(out, version)?;
149
150 trace!("encoding request header: {:#?}", &self.header);
151 self.header.encode(out, version)?;
152
153 trace!("encoding request: {:#?}", &self.request);
154 self.request.encode(out, self.header.api_version())?;
155 Ok(())
156 }
157}
158
159#[cfg(test)]
160mod test {
161
162 use std::io::Cursor;
163 use std::io::Error as IoError;
164 use std::convert::TryInto;
165 use kf_protocol::bytes::Buf;
166 use kf_protocol::bytes::BufMut;
167 use kf_protocol::Decoder;
168 use kf_protocol::Encoder;
169 use kf_protocol::Version;
170 use kf_protocol_derive::Encode;
171 use kf_protocol_derive::Decode;
172
173 use super::RequestHeader;
174 use super::RequestMessage;
175 use crate::KfRequestMessage;
176
177 use crate::Request;
178 use crate::AllKfApiKey;
179
180 #[derive(Decode, Encode, Debug, Default)]
181 pub struct ApiVersionRequest {}
182
183 impl Request for ApiVersionRequest {
184 const API_KEY: u16 = AllKfApiKey::ApiVersion as u16;
185
186 type Response = ApiVersionResponse;
187 }
188
189 #[derive(Encode, Decode, Default, Debug)]
190 pub struct ApiVersionResponse {
191 pub error_code: i16,
192 pub api_versions: Vec<ApiVersion>,
193 pub throttle_time_ms: i32,
194 }
195
196 #[derive(Encode, Decode, Default, Debug)]
197 pub struct ApiVersion {
198 pub api_key: i16,
199 pub min_version: i16,
200 pub max_version: i16,
201 }
202
203 #[derive(PartialEq, Debug, Encode, Decode, Clone, Copy)]
204 #[repr(u16)]
205 pub enum TestApiEnum {
206 ApiVersion = 18,
207 }
208
209 impl Default for TestApiEnum {
210 fn default() -> TestApiEnum {
211 TestApiEnum::ApiVersion
212 }
213 }
214
215 #[test]
216 fn test_decode_header() -> Result<(), IoError> {
217 let data = [
224 0x00, 0x12, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x0a, 0x63, 0x6f, 0x6e, 0x73,
225 0x75, 0x6d, 0x65, 0x72, 0x2d, 0x31,
226 ];
227
228 let header: RequestHeader = RequestHeader::decode_from(&mut Cursor::new(&data), 0)?;
229
230 assert_eq!(header.api_key(), TestApiEnum::ApiVersion as u16);
231 assert_eq!(header.api_version(), 1);
232 assert_eq!(header.correlation_id(), 1);
233 assert_eq!(header.client_id(), "consumer-1");
234
235 Ok(())
236 }
237
238 #[test]
239 fn test_encode_header() {
240 let req_header = RequestHeader::new_with_client(
241 TestApiEnum::ApiVersion as u16,
242 String::from("consumer-1"),
243 );
244 let expected_result = [
245 0x00, 0x12, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x0a, 0x63, 0x6f, 0x6e, 0x73,
246 0x75, 0x6d, 0x65, 0x72, 0x2d, 0x31,
247 ];
248
249 let mut result = vec![];
250 let req_result = req_header.encode(&mut result, 0);
251
252 assert!(req_result.is_ok());
253 assert_eq!(result, expected_result);
254 }
255
256 pub enum TestApiRequest {
257 ApiVersionRequest(RequestMessage<ApiVersionRequest>),
258 }
259
260 impl Default for TestApiRequest {
261 fn default() -> TestApiRequest {
262 TestApiRequest::ApiVersionRequest(RequestMessage::<ApiVersionRequest>::default())
263 }
264 }
265
266 impl KfRequestMessage for TestApiRequest {
267 type ApiKey = TestApiEnum;
268
269 fn decode_with_header<T>(src: &mut T, header: RequestHeader) -> Result<Self, IoError>
270 where
271 Self: Default + Sized,
272 Self::ApiKey: Sized,
273 T: Buf,
274 {
275 match header.api_key().try_into()? {
276 TestApiEnum::ApiVersion => {
277 let request = ApiVersionRequest::decode_from(src, header.api_version())?;
278 return Ok(TestApiRequest::ApiVersionRequest(RequestMessage::new(
279 header, request,
280 )));
281 }
282 }
283 }
284 }
285
286 impl Encoder for TestApiRequest {
287 fn write_size(&self, version: Version) -> usize {
288 match self {
289 TestApiRequest::ApiVersionRequest(response) => response.write_size(version),
290 }
291 }
292
293 fn encode<T>(&self, src: &mut T, version: Version) -> Result<(), IoError>
294 where
295 T: BufMut,
296 {
297 match self {
298 TestApiRequest::ApiVersionRequest(response) => {
299 response.encode(src, version)?;
300 }
301 }
302 Ok(())
303 }
304 }
305
306 #[test]
307 fn test_encode_message() {
308 let mut message = RequestMessage::new_request(ApiVersionRequest {});
309 message
310 .get_mut_header()
311 .set_client_id("consumer-1".to_owned())
312 .set_correlation_id(5);
313
314 let mut out = vec![];
315 message.encode(&mut out, 0).expect("encode work");
316 let mut encode_bytes = Cursor::new(&out);
317
318 let mut len: i32 = 0;
320 len.decode(&mut encode_bytes, 0).expect("cant decode len");
321 let res_msg_result: Result<RequestMessage<ApiVersionRequest>, IoError> =
322 Decoder::decode_from(&mut encode_bytes, 0);
323
324 match res_msg_result {
325 Ok(msg) => {
326 assert_eq!(msg.header.correlation_id(), 5);
327 }
328 Err(err) => {
329 assert!(false, "error: {}", err);
330 }
331 }
332 }
333
334}