kf_protocol_api/
request.rs

1use std::io::Error as IoError;
2use std::path::Path;
3use std::fmt;
4use std::fmt::Display;
5
6use log::trace;
7
8use kf_protocol::bytes::Buf;
9use kf_protocol::bytes::BufMut;
10use kf_protocol::Decoder;
11use kf_protocol::Encoder;
12use kf_protocol::Version;
13
14use crate::Request;
15use crate::RequestHeader;
16use crate::ResponseMessage;
17
18/// Start of API request
19#[derive(Debug)]
20pub struct RequestMessage<R> {
21    pub header: RequestHeader,
22    pub request: R,
23}
24
25impl<R> RequestMessage<R> {
26    pub fn get_mut_header(&mut self) -> &mut RequestHeader {
27        &mut self.header
28    }
29}
30
31
32impl <R>fmt::Display for RequestMessage<R> where R: Display{
33    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
34        write!(f,"{} {}",self.header,self.request)      
35    }
36}
37
38
39impl<R> Default for RequestMessage<R>
40where
41    R: Request + Default,
42{
43    fn default() -> Self {
44        let mut header = RequestHeader::default();
45        header.set_api_version(R::DEFAULT_API_VERSION);
46
47        Self {
48            header,
49            request: R::default(),
50        }
51    }
52}
53
54impl<R> RequestMessage<R>
55where
56    R: Request,
57{
58    /// create with header, this assume header is constructed from higher request
59    /// no api key check is performed since it is already done
60    pub fn new(header: RequestHeader, request: R) -> Self {
61        Self { header, request }
62    }
63
64    /// create from request, header is implicilty created from key in the request
65    pub fn new_request(request: R) -> Self {
66        let mut header = RequestHeader::new(R::API_KEY);
67        header.set_api_version(R::DEFAULT_API_VERSION);
68
69        Self { header, request }
70    }
71
72    pub fn get_header_request(self) -> (RequestHeader, R) {
73        (self.header, self.request)
74    }
75
76    pub fn request(&self) -> &R {
77        &self.request
78    }
79
80    pub fn new_response(&self, response: R::Response) -> ResponseMessage<R::Response> {
81        Self::response_with_header(&self.header, response)
82    }
83
84    pub fn response_with_header<H>(header: H, response: R::Response) -> ResponseMessage<R::Response>
85    where
86        H: Into<i32>,
87    {
88        ResponseMessage::new(header.into(), response)
89    }
90
91    pub fn decode_response<T>(
92        &self,
93        src: &mut T,
94        version: Version,
95    ) -> Result<ResponseMessage<R::Response>, IoError>
96    where
97        T: Buf,
98    {
99        ResponseMessage::decode_from(src, version)
100    }
101
102    pub fn decode_response_from_file<H: AsRef<Path>>(
103        &self,
104        file_name: H,
105        version: Version,
106    ) -> Result<ResponseMessage<R::Response>, IoError> {
107        ResponseMessage::decode_from_file(file_name, version)
108    }
109
110    /// helper function to set client id
111    pub fn set_client_id<T>(mut self, client_id: T) -> Self
112    where
113        T: Into<String>,
114    {
115        self.header.set_client_id(client_id);
116        self
117    }
118}
119
120impl<R> Decoder for RequestMessage<R>
121where
122    R: Request,
123{
124    fn decode<T>(&mut self, src: &mut T, version: Version) -> Result<(), IoError>
125    where
126        T: Buf,
127    {
128        self.header.decode(src, version)?;
129        self.request.decode(src, self.header.api_version())?;
130        Ok(())
131    }
132}
133
134impl<R> Encoder for RequestMessage<R>
135where
136    R: Request,
137{
138    fn write_size(&self, version: Version) -> usize {
139        self.header.write_size(version) + self.request.write_size(self.header.api_version())
140    }
141
142    fn encode<T>(&self, out: &mut T, version: Version) -> Result<(), IoError>
143    where
144        T: BufMut,
145    {
146        let len = self.write_size(version) as i32;
147        trace!("encoding kf request: {} version: {}, len: {}", std::any::type_name::<R>(),version,len);
148        len.encode(out, version)?;
149
150        trace!("encoding request header: {:#?}", &self.header);
151        self.header.encode(out, version)?;
152
153        trace!("encoding request: {:#?}", &self.request);
154        self.request.encode(out, self.header.api_version())?;
155        Ok(())
156    }
157}
158
159#[cfg(test)]
160mod test {
161
162    use std::io::Cursor;
163    use std::io::Error as IoError;
164    use std::convert::TryInto;
165    use kf_protocol::bytes::Buf;
166    use kf_protocol::bytes::BufMut;
167    use kf_protocol::Decoder;
168    use kf_protocol::Encoder;
169    use kf_protocol::Version;
170    use kf_protocol_derive::Encode;
171    use kf_protocol_derive::Decode;
172
173    use super::RequestHeader;
174    use super::RequestMessage;
175    use crate::KfRequestMessage;
176
177    use crate::Request;
178    use crate::AllKfApiKey;
179
180    #[derive(Decode, Encode, Debug, Default)]
181    pub struct ApiVersionRequest {}
182
183    impl Request for ApiVersionRequest {
184        const API_KEY: u16 = AllKfApiKey::ApiVersion as u16;
185
186        type Response = ApiVersionResponse;
187    }
188
189    #[derive(Encode, Decode, Default, Debug)]
190    pub struct ApiVersionResponse {
191        pub error_code: i16,
192        pub api_versions: Vec<ApiVersion>,
193        pub throttle_time_ms: i32,
194    }
195
196    #[derive(Encode, Decode, Default, Debug)]
197    pub struct ApiVersion {
198        pub api_key: i16,
199        pub min_version: i16,
200        pub max_version: i16,
201    }
202
203    #[derive(PartialEq, Debug, Encode, Decode, Clone, Copy)]
204    #[repr(u16)]
205    pub enum TestApiEnum {
206        ApiVersion = 18,
207    }
208
209    impl Default for TestApiEnum {
210        fn default() -> TestApiEnum {
211            TestApiEnum::ApiVersion
212        }
213    }
214
215    #[test]
216    fn test_decode_header() -> Result<(), IoError> {
217        // API versions request
218        // API key: API Versions (18)
219        // API version: 1
220        // correlation id: 1,
221        // strng length 10
222        // client id: consumer-1
223        let data = [
224            0x00, 0x12, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x0a, 0x63, 0x6f, 0x6e, 0x73,
225            0x75, 0x6d, 0x65, 0x72, 0x2d, 0x31,
226        ];
227
228        let header: RequestHeader = RequestHeader::decode_from(&mut Cursor::new(&data), 0)?;
229
230        assert_eq!(header.api_key(), TestApiEnum::ApiVersion as u16);
231        assert_eq!(header.api_version(), 1);
232        assert_eq!(header.correlation_id(), 1);
233        assert_eq!(header.client_id(), "consumer-1");
234
235        Ok(())
236    }
237
238    #[test]
239    fn test_encode_header() {
240        let req_header = RequestHeader::new_with_client(
241            TestApiEnum::ApiVersion as u16,
242            String::from("consumer-1"),
243        );
244        let expected_result = [
245            0x00, 0x12, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x0a, 0x63, 0x6f, 0x6e, 0x73,
246            0x75, 0x6d, 0x65, 0x72, 0x2d, 0x31,
247        ];
248
249        let mut result = vec![];
250        let req_result = req_header.encode(&mut result, 0);
251
252        assert!(req_result.is_ok());
253        assert_eq!(result, expected_result);
254    }
255
256    pub enum TestApiRequest {
257        ApiVersionRequest(RequestMessage<ApiVersionRequest>),
258    }
259
260    impl Default for TestApiRequest {
261        fn default() -> TestApiRequest {
262            TestApiRequest::ApiVersionRequest(RequestMessage::<ApiVersionRequest>::default())
263        }
264    }
265
266    impl KfRequestMessage for TestApiRequest {
267        type ApiKey = TestApiEnum;
268
269        fn decode_with_header<T>(src: &mut T, header: RequestHeader) -> Result<Self, IoError>
270        where
271            Self: Default + Sized,
272            Self::ApiKey: Sized,
273            T: Buf,
274        {
275            match header.api_key().try_into()? {
276                TestApiEnum::ApiVersion => {
277                    let request = ApiVersionRequest::decode_from(src, header.api_version())?;
278                    return Ok(TestApiRequest::ApiVersionRequest(RequestMessage::new(
279                        header, request,
280                    )));
281                }
282            }
283        }
284    }
285
286    impl Encoder for TestApiRequest {
287        fn write_size(&self, version: Version) -> usize {
288            match self {
289                TestApiRequest::ApiVersionRequest(response) => response.write_size(version),
290            }
291        }
292
293        fn encode<T>(&self, src: &mut T, version: Version) -> Result<(), IoError>
294        where
295            T: BufMut,
296        {
297            match self {
298                TestApiRequest::ApiVersionRequest(response) => {
299                    response.encode(src, version)?;
300                }
301            }
302            Ok(())
303        }
304    }
305
306    #[test]
307    fn test_encode_message() {
308        let mut message = RequestMessage::new_request(ApiVersionRequest {});
309        message
310            .get_mut_header()
311            .set_client_id("consumer-1".to_owned())
312            .set_correlation_id(5);
313
314        let mut out = vec![];
315        message.encode(&mut out, 0).expect("encode work");
316        let mut encode_bytes = Cursor::new(&out);
317
318        // decode back
319        let mut len: i32 = 0;
320        len.decode(&mut encode_bytes, 0).expect("cant decode len");
321        let res_msg_result: Result<RequestMessage<ApiVersionRequest>, IoError> =
322            Decoder::decode_from(&mut encode_bytes, 0);
323
324        match res_msg_result {
325            Ok(msg) => {
326                assert_eq!(msg.header.correlation_id(), 5);
327            }
328            Err(err) => {
329                assert!(false, "error: {}", err);
330            }
331        }
332    }
333
334}