1use std::io::Error as IoError;
2use std::path::Path;
3use std::fmt;
4use std::fmt::Display;
5
6use tracing::trace;
7
8use crate::core::bytes::Buf;
9use crate::core::bytes::BufMut;
10use crate::core::Decoder;
11use crate::core::Encoder;
12use crate::core::Version;
13
14use crate::api::Request;
15use crate::api::RequestHeader;
16use crate::response::ResponseMessage;
17
18#[derive(Debug)]
20pub struct RequestMessage<R> {
21 pub header: RequestHeader,
22 pub request: R,
23}
24
25impl<R> RequestMessage<R> {
26 #[allow(unused)]
27 pub fn get_mut_header(&mut self) -> &mut RequestHeader {
28 &mut self.header
29 }
30}
31
32impl<R> fmt::Display for RequestMessage<R>
33where
34 R: Display,
35{
36 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
37 write!(f, "{} {}", self.header, self.request)
38 }
39}
40
41impl<R> Default for RequestMessage<R>
42where
43 R: Request + Default,
44{
45 fn default() -> Self {
46 let mut header = RequestHeader::default();
47 header.set_api_version(R::DEFAULT_API_VERSION);
48
49 Self {
50 header,
51 request: R::default(),
52 }
53 }
54}
55
56impl<R> RequestMessage<R>
57where
58 R: Request,
59{
60 #[allow(unused)]
63 pub fn new(header: RequestHeader, request: R) -> Self {
64 Self { header, request }
65 }
66
67 #[allow(unused)]
69 pub fn new_request(request: R) -> Self {
70 let mut header = RequestHeader::new(R::API_KEY);
71 header.set_api_version(R::DEFAULT_API_VERSION);
72
73 Self { header, request }
74 }
75
76 #[allow(unused)]
77 pub fn get_header_request(self) -> (RequestHeader, R) {
78 (self.header, self.request)
79 }
80
81 #[allow(unused)]
82 #[allow(unused)]
83 pub fn request(&self) -> &R {
84 &self.request
85 }
86
87 #[allow(unused)]
88 pub fn new_response(&self, response: R::Response) -> ResponseMessage<R::Response> {
89 Self::response_with_header(&self.header, response)
90 }
91
92 pub fn response_with_header<H>(header: H, response: R::Response) -> ResponseMessage<R::Response>
93 where
94 H: Into<i32>,
95 {
96 ResponseMessage::new(header.into(), response)
97 }
98
99 #[allow(unused)]
100 pub fn decode_response<T>(
101 &self,
102 src: &mut T,
103 version: Version,
104 ) -> Result<ResponseMessage<R::Response>, IoError>
105 where
106 T: Buf,
107 {
108 ResponseMessage::decode_from(src, version)
109 }
110
111 #[allow(unused)]
112 pub fn decode_response_from_file<H: AsRef<Path>>(
113 &self,
114 file_name: H,
115 version: Version,
116 ) -> Result<ResponseMessage<R::Response>, IoError> {
117 ResponseMessage::decode_from_file(file_name, version)
118 }
119
120 #[allow(unused)]
122 pub fn set_client_id<T>(mut self, client_id: T) -> Self
123 where
124 T: Into<String>,
125 {
126 self.header.set_client_id(client_id);
127 self
128 }
129}
130
131impl<R> Decoder for RequestMessage<R>
132where
133 R: Request,
134{
135 fn decode<T>(&mut self, src: &mut T, version: Version) -> Result<(), IoError>
136 where
137 T: Buf,
138 {
139 self.header.decode(src, version)?;
140 self.request.decode(src, self.header.api_version())?;
141 Ok(())
142 }
143}
144
145impl<R> Encoder for RequestMessage<R>
146where
147 R: Request,
148{
149 fn write_size(&self, version: Version) -> usize {
150 self.header.write_size(version) + self.request.write_size(self.header.api_version())
151 }
152
153 fn encode<T>(&self, out: &mut T, version: Version) -> Result<(), IoError>
154 where
155 T: BufMut,
156 {
157 let len = self.write_size(version);
158 trace!(
159 "encoding kf request: {} version: {}, len: {}",
160 std::any::type_name::<R>(),
161 version,
162 len
163 );
164
165 trace!("encoding request header: {:#?}", &self.header);
166 self.header.encode(out, version)?;
167
168 trace!("encoding request: {:#?}", &self.request);
169 self.request.encode(out, self.header.api_version())?;
170 Ok(())
171 }
172}
173
174#[cfg(test)]
175mod test {
176
177 use std::io::Cursor;
178 use std::io::Error as IoError;
179 use std::convert::TryInto;
180 use crate::core::bytes::Buf;
181 use crate::core::bytes::BufMut;
182 use crate::core::Decoder;
183 use crate::core::Encoder;
184 use crate::core::Version;
185 use crate::derive::Encoder;
186 use crate::derive::Decoder;
187
188 use super::RequestHeader;
189 use super::RequestMessage;
190 use crate::ApiMessage;
191
192 use crate::Request;
193
194 #[repr(u16)]
195 #[derive(PartialEq, Debug, Clone, Copy, Encoder, Decoder)]
196 #[fluvio(encode_discriminant)]
197 pub enum TestApiKey {
198 ApiVersion = 0,
199 }
200
201 impl Default for TestApiKey {
202 fn default() -> TestApiKey {
203 TestApiKey::ApiVersion
204 }
205 }
206
207 #[derive(Decoder, Encoder, Debug, Default)]
208 pub struct ApiVersionRequest {}
209
210 impl Request for ApiVersionRequest {
211 const API_KEY: u16 = TestApiKey::ApiVersion as u16;
212
213 type Response = ApiVersionResponse;
214 }
215
216 #[derive(Encoder, Decoder, Default, Debug)]
217 pub struct ApiVersionResponse {
218 pub error_code: i16,
219 pub api_versions: Vec<ApiVersion>,
220 pub throttle_time_ms: i32,
221 }
222
223 #[derive(Encoder, Decoder, Default, Debug)]
224 pub struct ApiVersion {
225 pub api_key: i16,
226 pub min_version: i16,
227 pub max_version: i16,
228 }
229
230 #[repr(u16)]
231 #[derive(PartialEq, Debug, Encoder, Decoder, Clone, Copy)]
232 #[fluvio(encode_discriminant)]
233 pub enum TestApiEnum {
234 ApiVersion = 18,
235 }
236
237 impl Default for TestApiEnum {
238 fn default() -> TestApiEnum {
239 TestApiEnum::ApiVersion
240 }
241 }
242
243 #[test]
244 fn test_decode_header() -> Result<(), IoError> {
245 let data = [
252 0x00, 0x12, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x0a, 0x63, 0x6f, 0x6e, 0x73,
253 0x75, 0x6d, 0x65, 0x72, 0x2d, 0x31,
254 ];
255
256 let header: RequestHeader = RequestHeader::decode_from(&mut Cursor::new(&data), 0)?;
257
258 assert_eq!(header.api_key(), TestApiEnum::ApiVersion as u16);
259 assert_eq!(header.api_version(), 1);
260 assert_eq!(header.correlation_id(), 1);
261 assert_eq!(header.client_id(), "consumer-1");
262
263 Ok(())
264 }
265
266 #[test]
267 fn test_encode_header() {
268 let req_header = RequestHeader::new_with_client(
269 TestApiEnum::ApiVersion as u16,
270 String::from("consumer-1"),
271 );
272 let expected_result = [
273 0x00, 0x12, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x0a, 0x63, 0x6f, 0x6e, 0x73,
274 0x75, 0x6d, 0x65, 0x72, 0x2d, 0x31,
275 ];
276
277 let mut result = vec![];
278 let req_result = req_header.encode(&mut result, 0);
279
280 assert!(req_result.is_ok());
281 assert_eq!(result, expected_result);
282 }
283
284 pub enum TestApiRequest {
285 ApiVersionRequest(RequestMessage<ApiVersionRequest>),
286 }
287
288 impl Default for TestApiRequest {
289 fn default() -> TestApiRequest {
290 TestApiRequest::ApiVersionRequest(RequestMessage::<ApiVersionRequest>::default())
291 }
292 }
293
294 impl ApiMessage for TestApiRequest {
295 type ApiKey = TestApiEnum;
296
297 fn decode_with_header<T>(src: &mut T, header: RequestHeader) -> Result<Self, IoError>
298 where
299 Self: Default + Sized,
300 Self::ApiKey: Sized,
301 T: Buf,
302 {
303 match header.api_key().try_into()? {
304 TestApiEnum::ApiVersion => {
305 let request = ApiVersionRequest::decode_from(src, header.api_version())?;
306 Ok(TestApiRequest::ApiVersionRequest(RequestMessage::new(
307 header, request,
308 )))
309 }
310 }
311 }
312 }
313
314 impl Encoder for TestApiRequest {
315 fn write_size(&self, version: Version) -> usize {
316 match self {
317 TestApiRequest::ApiVersionRequest(response) => response.write_size(version),
318 }
319 }
320
321 fn encode<T>(&self, src: &mut T, version: Version) -> Result<(), IoError>
322 where
323 T: BufMut,
324 {
325 match self {
326 TestApiRequest::ApiVersionRequest(response) => {
327 response.encode(src, version)?;
328 }
329 }
330 Ok(())
331 }
332 }
333
334 #[test]
335 fn test_encode_message() {
336 let mut message = RequestMessage::new_request(ApiVersionRequest {});
337 message
338 .get_mut_header()
339 .set_client_id("consumer-1".to_owned())
340 .set_correlation_id(5);
341
342 let mut out = vec![];
343 message.encode(&mut out, 0).expect("encode work");
344 let mut encode_bytes = Cursor::new(&out);
345
346 let res_msg_result: Result<RequestMessage<ApiVersionRequest>, IoError> =
347 Decoder::decode_from(&mut encode_bytes, 0);
348
349 let msg = res_msg_result.unwrap();
350 assert_eq!(msg.header.correlation_id(), 5);
351 }
352}