fluvio_spu_schema/produce/
request.rs

1use std::fmt::Debug;
2use std::io::{Error, ErrorKind};
3use std::marker::PhantomData;
4use std::time::Duration;
5use bytes::Buf;
6use bytes::BufMut;
7
8use fluvio_protocol::record::RawRecords;
9use fluvio_protocol::Encoder;
10use fluvio_protocol::Decoder;
11use fluvio_protocol::derive::FluvioDefault;
12use fluvio_protocol::Version;
13use fluvio_protocol::api::Request;
14use fluvio_protocol::record::RecordSet;
15use fluvio_types::PartitionId;
16
17use crate::COMMON_VERSION;
18use crate::isolation::Isolation;
19
20use super::ProduceResponse;
21use crate::server::smartmodule::SmartModuleInvocation;
22
23pub type DefaultProduceRequest = ProduceRequest<RecordSet<RawRecords>>;
24pub type DefaultPartitionRequest = PartitionProduceData<RecordSet<RawRecords>>;
25pub type DefaultTopicRequest = TopicProduceData<RecordSet<RawRecords>>;
26
27#[derive(FluvioDefault, Debug)]
28pub struct ProduceRequest<R> {
29    /// The transactional ID, or null if the producer is not transactional.
30    #[fluvio(min_version = 3)]
31    pub transactional_id: Option<String>,
32
33    /// ReadUncommitted - Just wait for leader to write message (only wait for LEO update).
34    /// ReadCommitted - Wait for messages to be committed (wait for HW).
35    pub isolation: Isolation,
36
37    /// The timeout to await a response.
38    pub timeout: Duration,
39
40    /// Each topic to produce to.
41    pub topics: Vec<TopicProduceData<R>>,
42
43    pub smartmodules: Vec<SmartModuleInvocation>,
44
45    pub data: PhantomData<R>,
46}
47
48impl<R> Request for ProduceRequest<R>
49where
50    R: Debug + Decoder + Encoder,
51{
52    const API_KEY: u16 = 0;
53
54    const MIN_API_VERSION: i16 = 0;
55    const DEFAULT_API_VERSION: i16 = COMMON_VERSION;
56
57    type Response = ProduceResponse;
58}
59
60#[derive(Encoder, Decoder, FluvioDefault, Debug)]
61pub struct TopicProduceData<R> {
62    /// The topic name.
63    pub name: String,
64
65    /// Each partition to produce to.
66    pub partitions: Vec<PartitionProduceData<R>>,
67    pub data: PhantomData<R>,
68}
69
70#[derive(Encoder, Decoder, FluvioDefault, Debug)]
71pub struct PartitionProduceData<R> {
72    /// The partition index.
73    pub partition_index: PartitionId,
74
75    /// The record data to be produced.
76    pub records: R,
77}
78
79impl<R> Encoder for ProduceRequest<R>
80where
81    R: Encoder + Decoder + Default + Debug,
82{
83    fn write_size(&self, version: Version) -> usize {
84        self.transactional_id.write_size(version)
85            + IsolationData(0i16).write_size(version)
86            + TimeoutData(0i32).write_size(version)
87            + self.topics.write_size(version)
88            + self.smartmodules.write_size(version)
89    }
90
91    fn encode<T>(&self, dest: &mut T, version: Version) -> Result<(), Error>
92    where
93        T: BufMut,
94    {
95        self.transactional_id.encode(dest, version)?;
96        IsolationData::from(self.isolation).encode(dest, version)?;
97        TimeoutData::try_from(self.timeout)?.encode(dest, version)?;
98        self.topics.encode(dest, version)?;
99        self.smartmodules.encode(dest, version)?;
100        Ok(())
101    }
102}
103
104impl<R> Decoder for ProduceRequest<R>
105where
106    R: Decoder + Encoder + Default + Debug,
107{
108    fn decode<T>(&mut self, src: &mut T, version: Version) -> Result<(), Error>
109    where
110        T: Buf,
111    {
112        self.transactional_id = Decoder::decode_from(src, version)?;
113        self.isolation = Isolation::from(IsolationData::decode_from(src, version)?);
114        self.timeout = Duration::try_from(TimeoutData::decode_from(src, version)?)?;
115        self.topics = Decoder::decode_from(src, version)?;
116        self.smartmodules.decode(src, version)?;
117        Ok(())
118    }
119}
120
121impl<R: Encoder + Decoder + Default + Debug + Clone> Clone for ProduceRequest<R> {
122    fn clone(&self) -> Self {
123        Self {
124            transactional_id: self.transactional_id.clone(),
125            isolation: self.isolation,
126            timeout: self.timeout,
127            topics: self.topics.clone(),
128            data: self.data,
129            smartmodules: self.smartmodules.clone(),
130        }
131    }
132}
133
134impl<R: Encoder + Decoder + Default + Debug + Clone> Clone for TopicProduceData<R> {
135    fn clone(&self) -> Self {
136        Self {
137            name: self.name.clone(),
138            partitions: self.partitions.clone(),
139            data: self.data,
140        }
141    }
142}
143
144impl<R: Encoder + Decoder + Default + Debug + Clone> Clone for PartitionProduceData<R> {
145    fn clone(&self) -> Self {
146        Self {
147            partition_index: self.partition_index,
148            records: self.records.clone(),
149        }
150    }
151}
152
153/// Isolation is represented in binary format as i16 value (field `acks` in Kafka wire protocol).
154#[derive(Encoder, Decoder, FluvioDefault, Debug)]
155struct IsolationData(i16);
156
157impl From<Isolation> for IsolationData {
158    fn from(isolation: Isolation) -> Self {
159        IsolationData(match isolation {
160            Isolation::ReadUncommitted => 1,
161            Isolation::ReadCommitted => -1,
162        })
163    }
164}
165
166impl From<IsolationData> for Isolation {
167    fn from(data: IsolationData) -> Self {
168        match data.0 {
169            acks if acks < 0 => Isolation::ReadCommitted,
170            _ => Isolation::ReadUncommitted,
171        }
172    }
173}
174
175/// Timeout duration is represented in binary format as i32 value (field `timeout_ms` in Kafka wire protocol).
176#[derive(Encoder, Decoder, FluvioDefault, Debug)]
177struct TimeoutData(i32);
178
179impl TryFrom<Duration> for TimeoutData {
180    type Error = Error;
181
182    fn try_from(value: Duration) -> Result<Self, Self::Error> {
183        value.as_millis().try_into().map(TimeoutData).map_err(|_e| {
184            Error::new(
185                ErrorKind::InvalidInput,
186                "Timeout must fit into 4 bytes integer value",
187            )
188        })
189    }
190}
191
192impl TryFrom<TimeoutData> for Duration {
193    type Error = Error;
194
195    fn try_from(value: TimeoutData) -> Result<Self, Self::Error> {
196        u64::try_from(value.0)
197            .map(Duration::from_millis)
198            .map_err(|_e| {
199                Error::new(
200                    ErrorKind::InvalidInput,
201                    "Timeout must be positive integer value",
202                )
203            })
204    }
205}
206
207#[cfg(feature = "file")]
208pub use file::*;
209
210#[cfg(feature = "file")]
211mod file {
212    use std::io::Error as IoError;
213
214    use tracing::trace;
215    use bytes::BytesMut;
216
217    use fluvio_protocol::Version;
218    use fluvio_protocol::store::FileWrite;
219    use fluvio_protocol::store::StoreValue;
220
221    use crate::file::FileRecordSet;
222
223    use super::*;
224
225    pub type FileProduceRequest = ProduceRequest<FileRecordSet>;
226    pub type FileTopicRequest = TopicProduceData<FileRecordSet>;
227    pub type FilePartitionRequest = PartitionProduceData<FileRecordSet>;
228
229    impl FileWrite for FileProduceRequest {
230        fn file_encode(
231            &self,
232            src: &mut BytesMut,
233            data: &mut Vec<StoreValue>,
234            version: Version,
235        ) -> Result<(), IoError> {
236            trace!("file encoding produce request");
237            self.transactional_id.encode(src, version)?;
238            IsolationData::from(self.isolation).encode(src, version)?;
239            TimeoutData::try_from(self.timeout)?.encode(src, version)?;
240            self.topics.file_encode(src, data, version)?;
241            Ok(())
242        }
243    }
244
245    impl FileWrite for FileTopicRequest {
246        fn file_encode(
247            &self,
248            src: &mut BytesMut,
249            data: &mut Vec<StoreValue>,
250            version: Version,
251        ) -> Result<(), IoError> {
252            trace!("file encoding produce topic request");
253            self.name.encode(src, version)?;
254            self.partitions.file_encode(src, data, version)?;
255            Ok(())
256        }
257    }
258
259    impl FileWrite for FilePartitionRequest {
260        fn file_encode(
261            &self,
262            src: &mut BytesMut,
263            data: &mut Vec<StoreValue>,
264            version: Version,
265        ) -> Result<(), IoError> {
266            trace!("file encoding for partition request");
267            self.partition_index.encode(src, version)?;
268            self.records.file_encode(src, data, version)?;
269            Ok(())
270        }
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use std::io::{Error, ErrorKind};
277    use std::time::Duration;
278
279    use fluvio_protocol::{Decoder, Encoder};
280    use fluvio_protocol::api::Request;
281    use fluvio_protocol::record::Batch;
282    use fluvio_protocol::record::{Record, RecordData, RecordSet};
283    use fluvio_smartmodule::dataplane::smartmodule::{SmartModuleExtraParams, Lookback};
284
285    use crate::produce::DefaultProduceRequest;
286    use crate::produce::TopicProduceData;
287    use crate::produce::PartitionProduceData;
288    use crate::isolation::Isolation;
289    use crate::server::smartmodule::{
290        SmartModuleInvocation, SmartModuleInvocationWasm, SmartModuleKind,
291    };
292
293    #[test]
294    fn test_encode_decode_produce_request_isolation_timeout() -> Result<(), Error> {
295        let request = DefaultProduceRequest {
296            isolation: Isolation::ReadCommitted,
297            timeout: Duration::from_millis(123456),
298            ..Default::default()
299        };
300
301        let version = DefaultProduceRequest::DEFAULT_API_VERSION;
302        let mut bytes = request.as_bytes(version)?;
303
304        let decoded: DefaultProduceRequest = Decoder::decode_from(&mut bytes, version)?;
305
306        assert_eq!(request.isolation, decoded.isolation);
307        assert_eq!(request.timeout, decoded.timeout);
308        Ok(())
309    }
310
311    #[test]
312    fn test_encode_produce_request_timeout_too_big() {
313        let request = DefaultProduceRequest {
314            isolation: Isolation::ReadCommitted,
315            timeout: Duration::from_millis(u64::MAX),
316            ..Default::default()
317        };
318
319        let version = DefaultProduceRequest::DEFAULT_API_VERSION;
320        let result = request.as_bytes(version).expect_err("expected error");
321
322        assert_eq!(result.kind(), ErrorKind::InvalidInput);
323        assert_eq!(
324            result.to_string(),
325            "Timeout must fit into 4 bytes integer value"
326        );
327    }
328
329    #[test]
330    fn test_default_produce_request_clone() {
331        //given
332        let request = DefaultProduceRequest {
333            transactional_id: Some("transaction_id".to_string()),
334            isolation: Default::default(),
335            timeout: Duration::from_millis(100),
336            topics: vec![TopicProduceData {
337                name: "topic".to_string(),
338                partitions: vec![PartitionProduceData {
339                    partition_index: 1,
340                    records: RecordSet {
341                        batches: vec![
342                            Batch::from(vec![Record::new(RecordData::from("some raw data"))])
343                                .try_into()
344                                .expect("compressed batch"),
345                        ],
346                    },
347                }],
348                data: Default::default(),
349            }],
350            data: Default::default(),
351            smartmodules: Default::default(),
352        };
353        let version = DefaultProduceRequest::DEFAULT_API_VERSION;
354
355        //when
356        #[allow(clippy::redundant_clone)]
357        let cloned = request.clone();
358        let bytes = request.as_bytes(version).expect("encoded request");
359        let cloned_bytes = cloned.as_bytes(version).expect("encoded cloned request");
360
361        //then
362        assert_eq!(bytes, cloned_bytes);
363    }
364
365    #[test]
366    fn test_encode_produce_request() {
367        //given
368        let name = "adhoc-test";
369        let mut dest = Vec::new();
370        let params = SmartModuleExtraParams::default();
371        let value = DefaultProduceRequest {
372            transactional_id: Some("t_id".into()),
373            isolation: Isolation::ReadCommitted,
374            timeout: Duration::from_secs(1),
375            topics: vec![],
376            smartmodules: vec![SmartModuleInvocation {
377                wasm: SmartModuleInvocationWasm::AdHoc(vec![0xde, 0xad, 0xbe, 0xef]),
378                kind: SmartModuleKind::Filter,
379                params,
380                name: Some(name.to_string()),
381            }],
382            data: std::marker::PhantomData,
383        };
384        let version = DefaultProduceRequest::DEFAULT_API_VERSION;
385        //when
386        value.encode(&mut dest, version).expect("encode failure");
387
388        //then
389        let expected = vec![
390            0x01, 0x00, 0x04, 0x74, 0x5f, 0x69, 0x64, 0xff, 0xff, 0x00, 0x00, 0x03, 0xe8, 0x00,
391            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00, 0x04, 0xde, 0xad,
392            0xbe, 0xef, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x0a, 0x61, 0x64, 0x68, 0x6f, 0x63,
393            0x2d, 0x74, 0x65, 0x73, 0x74,
394        ];
395        assert_eq!(dest, expected);
396    }
397
398    #[test]
399    fn test_decode_produce_request() {
400        //given
401        let bytes = vec![
402            0x01, 0x00, 0x04, 0x74, 0x5f, 0x69, 0x64, 0xff, 0xff, 0x00, 0x00, 0x03, 0xe8, 0x00,
403            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00, 0x04, 0xde, 0xad,
404            0xbe, 0xef, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x0a, 0x61, 0x64, 0x68, 0x6f, 0x63,
405            0x2d, 0x74, 0x65, 0x73, 0x74,
406        ];
407        let mut value = DefaultProduceRequest::default();
408        let version = DefaultProduceRequest::DEFAULT_API_VERSION;
409        //when
410        value
411            .decode(&mut std::io::Cursor::new(bytes), version)
412            .expect("decode failure");
413
414        //then
415        assert_eq!(value.transactional_id, Some("t_id".into()));
416        assert_eq!(value.isolation, Isolation::ReadCommitted);
417        assert_eq!(value.timeout, Duration::from_secs(1));
418        assert!(value.topics.is_empty());
419        let sm = match value.smartmodules.first() {
420            Some(wasm) => wasm,
421            _ => panic!("should have smartmodule payload"),
422        };
423        assert_eq!(&sm.name, &Some("adhoc-test".to_string()));
424        assert!(sm.params.lookback().is_none());
425        let wasm = match &sm.wasm {
426            SmartModuleInvocationWasm::AdHoc(wasm) => wasm.as_slice(),
427            #[allow(unreachable_patterns)]
428            _ => panic!("should be SmartModuleInvocationWasm::AdHoc"),
429        };
430        assert_eq!(wasm, vec![0xde, 0xad, 0xbe, 0xef]);
431        assert!(matches!(sm.kind, SmartModuleKind::Filter));
432    }
433
434    #[test]
435    fn test_encode_produce_request_no_name_version() {
436        use crate::server::smartmodule::COMMON_VERSION_HAS_SM_NAME;
437        //given
438        let name = "adhoc-tes-debugttttttt";
439        let mut dest = Vec::new();
440        let mut params = SmartModuleExtraParams::default();
441        params.set_lookback(Some(Lookback::last(1)));
442        let value = DefaultProduceRequest {
443            transactional_id: Some("t_id".into()),
444            isolation: Isolation::ReadCommitted,
445            timeout: Duration::from_secs(1),
446            topics: vec![],
447            smartmodules: vec![SmartModuleInvocation {
448                wasm: SmartModuleInvocationWasm::AdHoc(vec![0xde, 0xad, 0xbe, 0xef]),
449                kind: SmartModuleKind::Filter,
450                params,
451                name: Some(name.to_string()),
452            }],
453            data: std::marker::PhantomData,
454        };
455        let version = COMMON_VERSION_HAS_SM_NAME - 1;
456        //when
457        value.encode(&mut dest, version).expect("should encode");
458
459        //then
460        let expected = vec![
461            // pre SM_NAME, smartmodule is not encoded into this message
462            0x01, 0x00, 0x04, 0x74, 0x5f, 0x69, 0x64, 0xff, 0xff, 0x00, 0x00, 0x03, 0xe8, 0x00,
463            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00, 0x04, 0xde, 0xad,
464            0xbe, 0xef, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
465            0x00,
466        ];
467        assert_eq!(dest, expected);
468    }
469
470    #[test]
471    fn test_decode_produce_request_no_name_version() {
472        use crate::server::smartmodule::COMMON_VERSION_HAS_SM_NAME;
473        //given
474        let bytes = vec![
475            // pre SM_NAME, smartmodule is not encoded into this message
476            0x01, 0x00, 0x04, 0x74, 0x5f, 0x69, 0x64, 0xff, 0xff, 0x00, 0x00, 0x03, 0xe8, 0x00,
477            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00, 0x04, 0xde, 0xad,
478            0xbe, 0xef, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
479            0x00,
480        ];
481        let mut value = DefaultProduceRequest::default();
482        let _version = DefaultProduceRequest::MAX_API_VERSION;
483        let version = COMMON_VERSION_HAS_SM_NAME - 1;
484        //when
485        value
486            .decode(&mut std::io::Cursor::new(bytes), version)
487            .unwrap();
488
489        //then
490        assert_eq!(value.transactional_id, Some("t_id".into()));
491        assert_eq!(value.isolation, Isolation::ReadCommitted);
492        assert_eq!(value.timeout, Duration::from_secs(1));
493        assert!(value.topics.is_empty());
494        let sm = match value.smartmodules.first() {
495            Some(wasm) => wasm,
496            _ => panic!("should have smartmodule payload"),
497        };
498        assert_eq!(sm.params.lookback(), Some(&Lookback { last: 1, age: None }));
499        let wasm = match &sm.wasm {
500            SmartModuleInvocationWasm::AdHoc(wasm) => wasm.as_slice(),
501            #[allow(unreachable_patterns)]
502            _ => panic!("should be SmartModuleInvocationWasm::AdHoc"),
503        };
504        assert_eq!(wasm, vec![0xde, 0xad, 0xbe, 0xef]);
505        assert!(matches!(sm.kind, SmartModuleKind::Filter));
506        assert!(sm.name.is_none());
507    }
508
509    #[test]
510    fn test_encode_produce_request_prev_version() {
511        //given
512        let name = "adhoc-test";
513        let mut dest = Vec::new();
514        let mut params = SmartModuleExtraParams::default();
515        params.set_lookback(Some(Lookback::age(Duration::from_secs(20), Some(1))));
516        let value = DefaultProduceRequest {
517            transactional_id: Some("t_id".into()),
518            isolation: Isolation::ReadCommitted,
519            timeout: Duration::from_secs(1),
520            topics: vec![],
521            smartmodules: vec![SmartModuleInvocation {
522                wasm: SmartModuleInvocationWasm::AdHoc(vec![0xde, 0xad, 0xbe, 0xef]),
523                kind: SmartModuleKind::Filter,
524                params,
525                name: Some(name.to_string()),
526            }],
527            data: std::marker::PhantomData,
528        };
529        //when
530        // let version = DefaultProduceRequest::MAX_API_VERSION - 1;
531        let version = 8i16;
532        value.encode(&mut dest, version).expect("should encode");
533
534        //then
535        let expected = vec![
536            0x01, 0x00, 0x04, 0x74, 0x5f, 0x69, 0x64, 0xff, 0xff, 0x00, 0x00, 0x03, 0xe8, 0x00,
537            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00, 0x04, 0xde, 0xad,
538            0xbe, 0xef, 0x00, 0x00, 0x00,
539        ];
540        assert_eq!(dest, expected);
541    }
542
543    #[test]
544    fn test_decode_produce_request_last_version() {
545        //given
546        let bytes = vec![
547            0x01, 0x00, 0x04, 0x74, 0x5f, 0x69, 0x64, 0xff, 0xff, 0x00, 0x00, 0x03, 0xe8, 0x00,
548            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00, 0x04, 0xde, 0xad,
549            0xbe, 0xef, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
550            0x00,
551        ];
552        let mut value = DefaultProduceRequest::default();
553
554        //when
555        let version = 8i16;
556        value
557            .decode(&mut std::io::Cursor::new(bytes), version)
558            .unwrap();
559
560        //then
561        assert_eq!(value.transactional_id, Some("t_id".into()));
562        assert_eq!(value.isolation, Isolation::ReadCommitted);
563        assert_eq!(value.timeout, Duration::from_secs(1));
564        assert!(value.topics.is_empty());
565        let sm = match value.smartmodules.first() {
566            Some(wasm) => wasm,
567            _ => panic!("should have smartmodule payload"),
568        };
569        assert_eq!(sm.params.lookback(), None);
570        let wasm = match &sm.wasm {
571            SmartModuleInvocationWasm::AdHoc(wasm) => wasm.as_slice(),
572            #[allow(unreachable_patterns)]
573            _ => panic!("should be SmartModuleInvocationWasm::AdHoc"),
574        };
575        assert_eq!(wasm, vec![0xde, 0xad, 0xbe, 0xef]);
576        assert!(matches!(sm.kind, SmartModuleKind::Filter));
577    }
578
579    #[test]
580    fn test_decode_produce_request_prev_version() {
581        //given
582        let bytes = vec![
583            0x01, 0x00, 0x04, 0x74, 0x5f, 0x69, 0x64, 0xff, 0xff, 0x00, 0x00, 0x03, 0xe8, 0x00,
584            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00, 0x04, 0xde, 0xad,
585            0xbe, 0xef, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
586            0x00, 0x00,
587        ];
588        let mut value = DefaultProduceRequest::default();
589
590        //when
591        value
592            .decode(
593                &mut std::io::Cursor::new(bytes),
594                DefaultProduceRequest::MAX_API_VERSION - 1,
595            )
596            .expect("decode failure");
597
598        //then
599        assert_eq!(value.transactional_id, Some("t_id".into()));
600        assert_eq!(value.isolation, Isolation::ReadCommitted);
601        assert_eq!(value.timeout, Duration::from_secs(1));
602        assert!(value.topics.is_empty());
603        let sm = match value.smartmodules.first() {
604            Some(wasm) => wasm,
605            _ => panic!("should have smartmodule payload"),
606        };
607        assert_eq!(sm.params.lookback(), Some(&Lookback::last(1)));
608        let wasm = match &sm.wasm {
609            SmartModuleInvocationWasm::AdHoc(wasm) => wasm.as_slice(),
610            #[allow(unreachable_patterns)]
611            _ => panic!("should be SmartModuleInvocationWasm::AdHoc"),
612        };
613        assert_eq!(wasm, vec![0xde, 0xad, 0xbe, 0xef]);
614        assert!(matches!(sm.kind, SmartModuleKind::Filter));
615    }
616}