fluvio_spu_schema/produce/
request.rs

1use std::fmt::Debug;
2use std::io::{Error, ErrorKind};
3use std::marker::PhantomData;
4use std::time::Duration;
5use bytes::{Buf, BufMut};
6
7use fluvio_protocol::record::RawRecords;
8use fluvio_protocol::Encoder;
9use fluvio_protocol::Decoder;
10use fluvio_protocol::derive::FluvioDefault;
11use fluvio_protocol::Version;
12use fluvio_protocol::api::Request;
13use fluvio_protocol::record::RecordSet;
14use fluvio_types::PartitionId;
15
16use crate::COMMON_VERSION;
17use crate::isolation::Isolation;
18
19use super::ProduceResponse;
20use crate::server::smartmodule::SmartModuleInvocation;
21
22pub type DefaultProduceRequest = ProduceRequest<RecordSet<RawRecords>>;
23pub type DefaultPartitionRequest = PartitionProduceData<RecordSet<RawRecords>>;
24pub type DefaultTopicRequest = TopicProduceData<RecordSet<RawRecords>>;
25
26const PRODUCER_TRANSFORMATION_API_VERSION: i16 = 8;
27
28#[derive(FluvioDefault, Debug)]
29pub struct ProduceRequest<R> {
30    /// The transactional ID, or null if the producer is not transactional.
31    #[fluvio(min_version = 3)]
32    pub transactional_id: Option<String>,
33
34    /// ReadUncommitted - Just wait for leader to write message (only wait for LEO update).
35    /// ReadCommitted - Wait for messages to be committed (wait for HW).
36    pub isolation: Isolation,
37
38    /// The timeout to await a response.
39    pub timeout: Duration,
40
41    /// Each topic to produce to.
42    pub topics: Vec<TopicProduceData<R>>,
43
44    #[fluvio(min_version = PRODUCER_TRANSFORMATION_API)]
45    pub smartmodules: Vec<SmartModuleInvocation>,
46
47    pub data: PhantomData<R>,
48}
49
50impl<R> Request for ProduceRequest<R>
51where
52    R: Debug + Decoder + Encoder,
53{
54    const API_KEY: u16 = 0;
55
56    const MIN_API_VERSION: i16 = 0;
57    const DEFAULT_API_VERSION: i16 = COMMON_VERSION;
58
59    type Response = ProduceResponse;
60}
61
62#[derive(Encoder, Decoder, FluvioDefault, Debug)]
63pub struct TopicProduceData<R> {
64    /// The topic name.
65    pub name: String,
66
67    /// Each partition to produce to.
68    pub partitions: Vec<PartitionProduceData<R>>,
69    pub data: PhantomData<R>,
70}
71
72#[derive(Encoder, Decoder, FluvioDefault, Debug)]
73pub struct PartitionProduceData<R> {
74    /// The partition index.
75    pub partition_index: PartitionId,
76
77    /// The record data to be produced.
78    pub records: R,
79}
80
81impl<R> Encoder for ProduceRequest<R>
82where
83    R: Encoder + Decoder + Default + Debug,
84{
85    fn write_size(&self, version: Version) -> usize {
86        self.transactional_id.write_size(version)
87            + IsolationData(0i16).write_size(version)
88            + TimeoutData(0i32).write_size(version)
89            + self.topics.write_size(version)
90            + if version >= PRODUCER_TRANSFORMATION_API_VERSION {
91                self.smartmodules.write_size(version)
92            } else {
93                0
94            }
95    }
96
97    fn encode<T>(&self, dest: &mut T, version: Version) -> Result<(), Error>
98    where
99        T: BufMut,
100    {
101        self.transactional_id.encode(dest, version)?;
102        IsolationData::from(self.isolation).encode(dest, version)?;
103        TimeoutData::try_from(self.timeout)?.encode(dest, version)?;
104        self.topics.encode(dest, version)?;
105        if version >= PRODUCER_TRANSFORMATION_API_VERSION {
106            self.smartmodules.encode(dest, version)?;
107        }
108        Ok(())
109    }
110}
111
112impl<R> Decoder for ProduceRequest<R>
113where
114    R: Decoder + Encoder + Default + Debug,
115{
116    fn decode<T>(&mut self, src: &mut T, version: Version) -> Result<(), Error>
117    where
118        T: Buf,
119    {
120        self.transactional_id = Decoder::decode_from(src, version)?;
121        self.isolation = Isolation::from(IsolationData::decode_from(src, version)?);
122        self.timeout = Duration::try_from(TimeoutData::decode_from(src, version)?)?;
123        self.topics = Decoder::decode_from(src, version)?;
124        if version >= PRODUCER_TRANSFORMATION_API_VERSION {
125            self.smartmodules.decode(src, version)?;
126        }
127        Ok(())
128    }
129}
130
131impl<R: Encoder + Decoder + Default + Debug + Clone> Clone for ProduceRequest<R> {
132    fn clone(&self) -> Self {
133        Self {
134            transactional_id: self.transactional_id.clone(),
135            isolation: self.isolation,
136            timeout: self.timeout,
137            topics: self.topics.clone(),
138            data: self.data,
139            smartmodules: self.smartmodules.clone(),
140        }
141    }
142}
143
144impl<R: Encoder + Decoder + Default + Debug + Clone> Clone for TopicProduceData<R> {
145    fn clone(&self) -> Self {
146        Self {
147            name: self.name.clone(),
148            partitions: self.partitions.clone(),
149            data: self.data,
150        }
151    }
152}
153
154impl<R: Encoder + Decoder + Default + Debug + Clone> Clone for PartitionProduceData<R> {
155    fn clone(&self) -> Self {
156        Self {
157            partition_index: self.partition_index,
158            records: self.records.clone(),
159        }
160    }
161}
162
163/// Isolation is represented in binary format as i16 value (field `acks` in Kafka wire protocol).
164#[derive(Encoder, Decoder, FluvioDefault, Debug)]
165struct IsolationData(i16);
166
167impl From<Isolation> for IsolationData {
168    fn from(isolation: Isolation) -> Self {
169        IsolationData(match isolation {
170            Isolation::ReadUncommitted => 1,
171            Isolation::ReadCommitted => -1,
172        })
173    }
174}
175
176impl From<IsolationData> for Isolation {
177    fn from(data: IsolationData) -> Self {
178        match data.0 {
179            acks if acks < 0 => Isolation::ReadCommitted,
180            _ => Isolation::ReadUncommitted,
181        }
182    }
183}
184
185/// Timeout duration is represented in binary format as i32 value (field `timeout_ms` in Kafka wire protocol).
186#[derive(Encoder, Decoder, FluvioDefault, Debug)]
187struct TimeoutData(i32);
188
189impl TryFrom<Duration> for TimeoutData {
190    type Error = Error;
191
192    fn try_from(value: Duration) -> Result<Self, Self::Error> {
193        value.as_millis().try_into().map(TimeoutData).map_err(|_e| {
194            Error::new(
195                ErrorKind::InvalidInput,
196                "Timeout must fit into 4 bytes integer value",
197            )
198        })
199    }
200}
201
202impl TryFrom<TimeoutData> for Duration {
203    type Error = Error;
204
205    fn try_from(value: TimeoutData) -> Result<Self, Self::Error> {
206        u64::try_from(value.0)
207            .map(Duration::from_millis)
208            .map_err(|_e| {
209                Error::new(
210                    ErrorKind::InvalidInput,
211                    "Timeout must be positive integer value",
212                )
213            })
214    }
215}
216
217#[cfg(feature = "file")]
218pub use file::*;
219
220#[cfg(feature = "file")]
221mod file {
222    use std::io::Error as IoError;
223
224    use tracing::trace;
225    use bytes::BytesMut;
226
227    use fluvio_protocol::Version;
228    use fluvio_protocol::store::FileWrite;
229    use fluvio_protocol::store::StoreValue;
230
231    use crate::file::FileRecordSet;
232
233    use super::*;
234
235    pub type FileProduceRequest = ProduceRequest<FileRecordSet>;
236    pub type FileTopicRequest = TopicProduceData<FileRecordSet>;
237    pub type FilePartitionRequest = PartitionProduceData<FileRecordSet>;
238
239    impl FileWrite for FileProduceRequest {
240        fn file_encode(
241            &self,
242            src: &mut BytesMut,
243            data: &mut Vec<StoreValue>,
244            version: Version,
245        ) -> Result<(), IoError> {
246            trace!("file encoding produce request");
247            self.transactional_id.encode(src, version)?;
248            IsolationData::from(self.isolation).encode(src, version)?;
249            TimeoutData::try_from(self.timeout)?.encode(src, version)?;
250            self.topics.file_encode(src, data, version)?;
251            Ok(())
252        }
253    }
254
255    impl FileWrite for FileTopicRequest {
256        fn file_encode(
257            &self,
258            src: &mut BytesMut,
259            data: &mut Vec<StoreValue>,
260            version: Version,
261        ) -> Result<(), IoError> {
262            trace!("file encoding produce topic request");
263            self.name.encode(src, version)?;
264            self.partitions.file_encode(src, data, version)?;
265            Ok(())
266        }
267    }
268
269    impl FileWrite for FilePartitionRequest {
270        fn file_encode(
271            &self,
272            src: &mut BytesMut,
273            data: &mut Vec<StoreValue>,
274            version: Version,
275        ) -> Result<(), IoError> {
276            trace!("file encoding for partition request");
277            self.partition_index.encode(src, version)?;
278            self.records.file_encode(src, data, version)?;
279            Ok(())
280        }
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use std::io::{Error, ErrorKind};
287    use std::time::Duration;
288
289    use fluvio_protocol::{Decoder, Encoder};
290    use fluvio_protocol::api::Request;
291    use fluvio_protocol::record::Batch;
292    use fluvio_protocol::record::{Record, RecordData, RecordSet};
293    use fluvio_smartmodule::dataplane::smartmodule::{SmartModuleExtraParams, Lookback};
294
295    use crate::produce::DefaultProduceRequest;
296    use crate::produce::TopicProduceData;
297    use crate::produce::PartitionProduceData;
298    use crate::isolation::Isolation;
299    use crate::produce::request::PRODUCER_TRANSFORMATION_API_VERSION;
300    use crate::server::smartmodule::{
301        SmartModuleInvocation, SmartModuleInvocationWasm, SmartModuleKind,
302    };
303
304    #[test]
305    fn test_encode_decode_produce_request_isolation_timeout() -> Result<(), Error> {
306        let request = DefaultProduceRequest {
307            isolation: Isolation::ReadCommitted,
308            timeout: Duration::from_millis(123456),
309            ..Default::default()
310        };
311
312        let version = DefaultProduceRequest::DEFAULT_API_VERSION;
313        let mut bytes = request.as_bytes(version)?;
314
315        let decoded: DefaultProduceRequest = Decoder::decode_from(&mut bytes, version)?;
316
317        assert_eq!(request.isolation, decoded.isolation);
318        assert_eq!(request.timeout, decoded.timeout);
319        Ok(())
320    }
321
322    #[test]
323    fn test_encode_produce_request_timeout_too_big() {
324        let request = DefaultProduceRequest {
325            isolation: Isolation::ReadCommitted,
326            timeout: Duration::from_millis(u64::MAX),
327            ..Default::default()
328        };
329
330        let version = DefaultProduceRequest::DEFAULT_API_VERSION;
331        let result = request.as_bytes(version).expect_err("expected error");
332
333        assert_eq!(result.kind(), ErrorKind::InvalidInput);
334        assert_eq!(
335            result.to_string(),
336            "Timeout must fit into 4 bytes integer value"
337        );
338    }
339
340    #[test]
341    fn test_default_produce_request_clone() {
342        //given
343        let request = DefaultProduceRequest {
344            transactional_id: Some("transaction_id".to_string()),
345            isolation: Default::default(),
346            timeout: Duration::from_millis(100),
347            topics: vec![TopicProduceData {
348                name: "topic".to_string(),
349                partitions: vec![PartitionProduceData {
350                    partition_index: 1,
351                    records: RecordSet {
352                        batches: vec![Batch::from(vec![Record::new(RecordData::from(
353                            "some raw data",
354                        ))])
355                        .try_into()
356                        .expect("compressed batch")],
357                    },
358                }],
359                data: Default::default(),
360            }],
361            data: Default::default(),
362            smartmodules: Default::default(),
363        };
364        let version = DefaultProduceRequest::DEFAULT_API_VERSION;
365
366        //when
367        #[allow(clippy::redundant_clone)]
368        let cloned = request.clone();
369        let bytes = request.as_bytes(version).expect("encoded request");
370        let cloned_bytes = cloned.as_bytes(version).expect("encoded cloned request");
371
372        //then
373        assert_eq!(bytes, cloned_bytes);
374    }
375
376    #[test]
377    fn test_encode_produce_request() {
378        //given
379        let mut dest = Vec::new();
380        let params = SmartModuleExtraParams::default();
381        let value = DefaultProduceRequest {
382            transactional_id: Some("t_id".into()),
383            isolation: Isolation::ReadCommitted,
384            timeout: Duration::from_secs(1),
385            topics: vec![],
386            smartmodules: vec![SmartModuleInvocation {
387                wasm: SmartModuleInvocationWasm::AdHoc(vec![0xde, 0xad, 0xbe, 0xef]),
388                kind: SmartModuleKind::Filter,
389                params,
390            }],
391            data: std::marker::PhantomData,
392        };
393        //when
394        value
395            .encode(&mut dest, PRODUCER_TRANSFORMATION_API_VERSION)
396            .expect("should encode");
397
398        //then
399        let expected = vec![
400            0x01, 0x00, 0x04, 0x74, 0x5f, 0x69, 0x64, 0xff, 0xff, 0x00, 0x00, 0x03, 0xe8, 0x00,
401            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00, 0x04, 0xde, 0xad,
402            0xbe, 0xef, 0x00, 0x00, 0x00,
403        ];
404        assert_eq!(dest, expected);
405    }
406
407    #[test]
408    fn test_decode_produce_request() {
409        //given
410        let bytes = vec![
411            0x01, 0x00, 0x04, 0x74, 0x5f, 0x69, 0x64, 0xff, 0xff, 0x00, 0x00, 0x03, 0xe8, 0x00,
412            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00, 0x04, 0xde, 0xad,
413            0xbe, 0xef, 0x00, 0x00, 0x00,
414        ];
415        let mut value = DefaultProduceRequest::default();
416
417        //when
418        value
419            .decode(
420                &mut std::io::Cursor::new(bytes),
421                PRODUCER_TRANSFORMATION_API_VERSION,
422            )
423            .unwrap();
424
425        //then
426        assert_eq!(value.transactional_id, Some("t_id".into()));
427        assert_eq!(value.isolation, Isolation::ReadCommitted);
428        assert_eq!(value.timeout, Duration::from_secs(1));
429        assert!(value.topics.is_empty());
430        let sm = match value.smartmodules.first() {
431            Some(wasm) => wasm,
432            _ => panic!("should have smartmodule payload"),
433        };
434        assert!(sm.params.lookback().is_none());
435        let wasm = match &sm.wasm {
436            SmartModuleInvocationWasm::AdHoc(wasm) => wasm.as_slice(),
437            #[allow(unreachable_patterns)]
438            _ => panic!("should be SmartModuleInvocationWasm::AdHoc"),
439        };
440        assert_eq!(wasm, vec![0xde, 0xad, 0xbe, 0xef]);
441        assert!(matches!(sm.kind, SmartModuleKind::Filter));
442    }
443
444    #[test]
445    fn test_encode_produce_request_last_version() {
446        //given
447        let mut dest = Vec::new();
448        let mut params = SmartModuleExtraParams::default();
449        params.set_lookback(Some(Lookback::last(1)));
450        let value = DefaultProduceRequest {
451            transactional_id: Some("t_id".into()),
452            isolation: Isolation::ReadCommitted,
453            timeout: Duration::from_secs(1),
454            topics: vec![],
455            smartmodules: vec![SmartModuleInvocation {
456                wasm: SmartModuleInvocationWasm::AdHoc(vec![0xde, 0xad, 0xbe, 0xef]),
457                kind: SmartModuleKind::Filter,
458                params,
459            }],
460            data: std::marker::PhantomData,
461        };
462        //when
463        value
464            .encode(&mut dest, DefaultProduceRequest::MAX_API_VERSION)
465            .expect("should encode");
466
467        //then
468        let expected = vec![
469            0x01, 0x00, 0x04, 0x74, 0x5f, 0x69, 0x64, 0xff, 0xff, 0x00, 0x00, 0x03, 0xe8, 0x00,
470            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00, 0x04, 0xde, 0xad,
471            0xbe, 0xef, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
472            0x00,
473        ];
474        assert_eq!(dest, expected);
475    }
476
477    #[test]
478    fn test_encode_produce_request_prev_version() {
479        //given
480        let mut dest = Vec::new();
481        let mut params = SmartModuleExtraParams::default();
482        params.set_lookback(Some(Lookback::age(Duration::from_secs(20), Some(1))));
483        let value = DefaultProduceRequest {
484            transactional_id: Some("t_id".into()),
485            isolation: Isolation::ReadCommitted,
486            timeout: Duration::from_secs(1),
487            topics: vec![],
488            smartmodules: vec![SmartModuleInvocation {
489                wasm: SmartModuleInvocationWasm::AdHoc(vec![0xde, 0xad, 0xbe, 0xef]),
490                kind: SmartModuleKind::Filter,
491                params,
492            }],
493            data: std::marker::PhantomData,
494        };
495        //when
496        value
497            .encode(&mut dest, DefaultProduceRequest::MAX_API_VERSION - 1)
498            .expect("should encode");
499
500        //then
501        let expected = vec![
502            0x01, 0x00, 0x04, 0x74, 0x5f, 0x69, 0x64, 0xff, 0xff, 0x00, 0x00, 0x03, 0xe8, 0x00,
503            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00, 0x04, 0xde, 0xad,
504            0xbe, 0xef, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
505            0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00,
506        ];
507        assert_eq!(dest, expected);
508    }
509
510    #[test]
511    fn test_decode_produce_request_last_version() {
512        //given
513        let bytes = vec![
514            0x01, 0x00, 0x04, 0x74, 0x5f, 0x69, 0x64, 0xff, 0xff, 0x00, 0x00, 0x03, 0xe8, 0x00,
515            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00, 0x04, 0xde, 0xad,
516            0xbe, 0xef, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
517            0x00,
518        ];
519        let mut value = DefaultProduceRequest::default();
520
521        //when
522        value
523            .decode(
524                &mut std::io::Cursor::new(bytes),
525                DefaultProduceRequest::MAX_API_VERSION,
526            )
527            .unwrap();
528
529        //then
530        assert_eq!(value.transactional_id, Some("t_id".into()));
531        assert_eq!(value.isolation, Isolation::ReadCommitted);
532        assert_eq!(value.timeout, Duration::from_secs(1));
533        assert!(value.topics.is_empty());
534        let sm = match value.smartmodules.first() {
535            Some(wasm) => wasm,
536            _ => panic!("should have smartmodule payload"),
537        };
538        assert_eq!(sm.params.lookback(), Some(&Lookback::last(1)));
539        let wasm = match &sm.wasm {
540            SmartModuleInvocationWasm::AdHoc(wasm) => wasm.as_slice(),
541            #[allow(unreachable_patterns)]
542            _ => panic!("should be SmartModuleInvocationWasm::AdHoc"),
543        };
544        assert_eq!(wasm, vec![0xde, 0xad, 0xbe, 0xef]);
545        assert!(matches!(sm.kind, SmartModuleKind::Filter));
546    }
547
548    #[test]
549    fn test_decode_produce_request_prev_version() {
550        //given
551        let bytes = vec![
552            0x01, 0x00, 0x04, 0x74, 0x5f, 0x69, 0x64, 0xff, 0xff, 0x00, 0x00, 0x03, 0xe8, 0x00,
553            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00, 0x04, 0xde, 0xad,
554            0xbe, 0xef, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
555            0x00,
556        ];
557        let mut value = DefaultProduceRequest::default();
558
559        //when
560        value
561            .decode(
562                &mut std::io::Cursor::new(bytes),
563                DefaultProduceRequest::MAX_API_VERSION - 1,
564            )
565            .unwrap();
566
567        //then
568        assert_eq!(value.transactional_id, Some("t_id".into()));
569        assert_eq!(value.isolation, Isolation::ReadCommitted);
570        assert_eq!(value.timeout, Duration::from_secs(1));
571        assert!(value.topics.is_empty());
572        let sm = match value.smartmodules.first() {
573            Some(wasm) => wasm,
574            _ => panic!("should have smartmodule payload"),
575        };
576        assert_eq!(sm.params.lookback(), Some(&Lookback::last(1)));
577        let wasm = match &sm.wasm {
578            SmartModuleInvocationWasm::AdHoc(wasm) => wasm.as_slice(),
579            #[allow(unreachable_patterns)]
580            _ => panic!("should be SmartModuleInvocationWasm::AdHoc"),
581        };
582        assert_eq!(wasm, vec![0xde, 0xad, 0xbe, 0xef]);
583        assert!(matches!(sm.kind, SmartModuleKind::Filter));
584    }
585}