fluvio_protocol_api/
request.rs

1use std::io::Error as IoError;
2use std::path::Path;
3use std::fmt;
4use std::fmt::Display;
5
6use tracing::trace;
7
8use crate::core::bytes::Buf;
9use crate::core::bytes::BufMut;
10use crate::core::Decoder;
11use crate::core::Encoder;
12use crate::core::Version;
13
14use crate::api::Request;
15use crate::api::RequestHeader;
16use crate::response::ResponseMessage;
17
18/// Start of API request
19#[derive(Debug)]
20pub struct RequestMessage<R> {
21    pub header: RequestHeader,
22    pub request: R,
23}
24
25impl<R> RequestMessage<R> {
26    #[allow(unused)]
27    pub fn get_mut_header(&mut self) -> &mut RequestHeader {
28        &mut self.header
29    }
30}
31
32impl<R> fmt::Display for RequestMessage<R>
33where
34    R: Display,
35{
36    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
37        write!(f, "{} {}", self.header, self.request)
38    }
39}
40
41impl<R> Default for RequestMessage<R>
42where
43    R: Request + Default,
44{
45    fn default() -> Self {
46        let mut header = RequestHeader::default();
47        header.set_api_version(R::DEFAULT_API_VERSION);
48
49        Self {
50            header,
51            request: R::default(),
52        }
53    }
54}
55
56impl<R> RequestMessage<R>
57where
58    R: Request,
59{
60    /// create with header, this assume header is constructed from higher request
61    /// no api key check is performed since it is already done
62    #[allow(unused)]
63    pub fn new(header: RequestHeader, request: R) -> Self {
64        Self { header, request }
65    }
66
67    /// create from request, header is implicilty created from key in the request
68    #[allow(unused)]
69    pub fn new_request(request: R) -> Self {
70        let mut header = RequestHeader::new(R::API_KEY);
71        header.set_api_version(R::DEFAULT_API_VERSION);
72
73        Self { header, request }
74    }
75
76    #[allow(unused)]
77    pub fn get_header_request(self) -> (RequestHeader, R) {
78        (self.header, self.request)
79    }
80
81    #[allow(unused)]
82    #[allow(unused)]
83    pub fn request(&self) -> &R {
84        &self.request
85    }
86
87    #[allow(unused)]
88    pub fn new_response(&self, response: R::Response) -> ResponseMessage<R::Response> {
89        Self::response_with_header(&self.header, response)
90    }
91
92    pub fn response_with_header<H>(header: H, response: R::Response) -> ResponseMessage<R::Response>
93    where
94        H: Into<i32>,
95    {
96        ResponseMessage::new(header.into(), response)
97    }
98
99    #[allow(unused)]
100    pub fn decode_response<T>(
101        &self,
102        src: &mut T,
103        version: Version,
104    ) -> Result<ResponseMessage<R::Response>, IoError>
105    where
106        T: Buf,
107    {
108        ResponseMessage::decode_from(src, version)
109    }
110
111    #[allow(unused)]
112    pub fn decode_response_from_file<H: AsRef<Path>>(
113        &self,
114        file_name: H,
115        version: Version,
116    ) -> Result<ResponseMessage<R::Response>, IoError> {
117        ResponseMessage::decode_from_file(file_name, version)
118    }
119
120    /// helper function to set client id
121    #[allow(unused)]
122    pub fn set_client_id<T>(mut self, client_id: T) -> Self
123    where
124        T: Into<String>,
125    {
126        self.header.set_client_id(client_id);
127        self
128    }
129}
130
131impl<R> Decoder for RequestMessage<R>
132where
133    R: Request,
134{
135    fn decode<T>(&mut self, src: &mut T, version: Version) -> Result<(), IoError>
136    where
137        T: Buf,
138    {
139        self.header.decode(src, version)?;
140        self.request.decode(src, self.header.api_version())?;
141        Ok(())
142    }
143}
144
145impl<R> Encoder for RequestMessage<R>
146where
147    R: Request,
148{
149    fn write_size(&self, version: Version) -> usize {
150        self.header.write_size(version) + self.request.write_size(self.header.api_version())
151    }
152
153    fn encode<T>(&self, out: &mut T, version: Version) -> Result<(), IoError>
154    where
155        T: BufMut,
156    {
157        let len = self.write_size(version);
158        trace!(
159            "encoding kf request: {} version: {}, len: {}",
160            std::any::type_name::<R>(),
161            version,
162            len
163        );
164
165        trace!("encoding request header: {:#?}", &self.header);
166        self.header.encode(out, version)?;
167
168        trace!("encoding request: {:#?}", &self.request);
169        self.request.encode(out, self.header.api_version())?;
170        Ok(())
171    }
172}
173
174#[cfg(test)]
175mod test {
176
177    use std::io::Cursor;
178    use std::io::Error as IoError;
179    use std::convert::TryInto;
180    use crate::core::bytes::Buf;
181    use crate::core::bytes::BufMut;
182    use crate::core::Decoder;
183    use crate::core::Encoder;
184    use crate::core::Version;
185    use crate::derive::Encoder;
186    use crate::derive::Decoder;
187
188    use super::RequestHeader;
189    use super::RequestMessage;
190    use crate::ApiMessage;
191
192    use crate::Request;
193
194    #[repr(u16)]
195    #[derive(PartialEq, Debug, Clone, Copy, Encoder, Decoder)]
196    #[fluvio(encode_discriminant)]
197    pub enum TestApiKey {
198        ApiVersion = 0,
199    }
200
201    impl Default for TestApiKey {
202        fn default() -> TestApiKey {
203            TestApiKey::ApiVersion
204        }
205    }
206
207    #[derive(Decoder, Encoder, Debug, Default)]
208    pub struct ApiVersionRequest {}
209
210    impl Request for ApiVersionRequest {
211        const API_KEY: u16 = TestApiKey::ApiVersion as u16;
212
213        type Response = ApiVersionResponse;
214    }
215
216    #[derive(Encoder, Decoder, Default, Debug)]
217    pub struct ApiVersionResponse {
218        pub error_code: i16,
219        pub api_versions: Vec<ApiVersion>,
220        pub throttle_time_ms: i32,
221    }
222
223    #[derive(Encoder, Decoder, Default, Debug)]
224    pub struct ApiVersion {
225        pub api_key: i16,
226        pub min_version: i16,
227        pub max_version: i16,
228    }
229
230    #[repr(u16)]
231    #[derive(PartialEq, Debug, Encoder, Decoder, Clone, Copy)]
232    #[fluvio(encode_discriminant)]
233    pub enum TestApiEnum {
234        ApiVersion = 18,
235    }
236
237    impl Default for TestApiEnum {
238        fn default() -> TestApiEnum {
239            TestApiEnum::ApiVersion
240        }
241    }
242
243    #[test]
244    fn test_decode_header() -> Result<(), IoError> {
245        // API versions request
246        // API key: API Versions (18)
247        // API version: 1
248        // correlation id: 1,
249        // strng length 10
250        // client id: consumer-1
251        let data = [
252            0x00, 0x12, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x0a, 0x63, 0x6f, 0x6e, 0x73,
253            0x75, 0x6d, 0x65, 0x72, 0x2d, 0x31,
254        ];
255
256        let header: RequestHeader = RequestHeader::decode_from(&mut Cursor::new(&data), 0)?;
257
258        assert_eq!(header.api_key(), TestApiEnum::ApiVersion as u16);
259        assert_eq!(header.api_version(), 1);
260        assert_eq!(header.correlation_id(), 1);
261        assert_eq!(header.client_id(), "consumer-1");
262
263        Ok(())
264    }
265
266    #[test]
267    fn test_encode_header() {
268        let req_header = RequestHeader::new_with_client(
269            TestApiEnum::ApiVersion as u16,
270            String::from("consumer-1"),
271        );
272        let expected_result = [
273            0x00, 0x12, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x0a, 0x63, 0x6f, 0x6e, 0x73,
274            0x75, 0x6d, 0x65, 0x72, 0x2d, 0x31,
275        ];
276
277        let mut result = vec![];
278        let req_result = req_header.encode(&mut result, 0);
279
280        assert!(req_result.is_ok());
281        assert_eq!(result, expected_result);
282    }
283
284    pub enum TestApiRequest {
285        ApiVersionRequest(RequestMessage<ApiVersionRequest>),
286    }
287
288    impl Default for TestApiRequest {
289        fn default() -> TestApiRequest {
290            TestApiRequest::ApiVersionRequest(RequestMessage::<ApiVersionRequest>::default())
291        }
292    }
293
294    impl ApiMessage for TestApiRequest {
295        type ApiKey = TestApiEnum;
296
297        fn decode_with_header<T>(src: &mut T, header: RequestHeader) -> Result<Self, IoError>
298        where
299            Self: Default + Sized,
300            Self::ApiKey: Sized,
301            T: Buf,
302        {
303            match header.api_key().try_into()? {
304                TestApiEnum::ApiVersion => {
305                    let request = ApiVersionRequest::decode_from(src, header.api_version())?;
306                    Ok(TestApiRequest::ApiVersionRequest(RequestMessage::new(
307                        header, request,
308                    )))
309                }
310            }
311        }
312    }
313
314    impl Encoder for TestApiRequest {
315        fn write_size(&self, version: Version) -> usize {
316            match self {
317                TestApiRequest::ApiVersionRequest(response) => response.write_size(version),
318            }
319        }
320
321        fn encode<T>(&self, src: &mut T, version: Version) -> Result<(), IoError>
322        where
323            T: BufMut,
324        {
325            match self {
326                TestApiRequest::ApiVersionRequest(response) => {
327                    response.encode(src, version)?;
328                }
329            }
330            Ok(())
331        }
332    }
333
334    #[test]
335    fn test_encode_message() {
336        let mut message = RequestMessage::new_request(ApiVersionRequest {});
337        message
338            .get_mut_header()
339            .set_client_id("consumer-1".to_owned())
340            .set_correlation_id(5);
341
342        let mut out = vec![];
343        message.encode(&mut out, 0).expect("encode work");
344        let mut encode_bytes = Cursor::new(&out);
345
346        let res_msg_result: Result<RequestMessage<ApiVersionRequest>, IoError> =
347            Decoder::decode_from(&mut encode_bytes, 0);
348
349        let msg = res_msg_result.unwrap();
350        assert_eq!(msg.header.correlation_id(), 5);
351    }
352}