fluvio_spu_schema/server/
smartmodule.rs

1#![allow(deprecated)]
2
3use std::fmt::{Debug, self};
4use std::io;
5use std::io::Error as IoError;
6use std::io::Read;
7
8use bytes::BufMut;
9use flate2::{
10    Compression,
11    bufread::{GzEncoder, GzDecoder},
12};
13
14use fluvio_protocol::{Encoder, Decoder, Version};
15use fluvio_smartmodule::dataplane::smartmodule::SmartModuleExtraParams;
16
17// The fluvio COMMON_VERSION in fluvio-spu-schema/src/lib.rs
18// that introduced the smartmodule name to SmartModuleInvocations
19pub const COMMON_VERSION_HAS_SM_NAME: Version = 25;
20
21/// The request payload when using a Consumer SmartModule.
22///
23/// This includes the WASM module name as well as the invocation being used.
24/// It also carries any data that is required for specific invocations of SmartModules.
25#[derive(Debug, Default, Clone)]
26pub struct SmartModuleInvocation {
27    pub wasm: SmartModuleInvocationWasm,
28    pub kind: SmartModuleKind,
29    pub params: SmartModuleExtraParams,
30    // only included in PROD_API_HAS_SM_NAME, or later
31    // if decoding a version before this, None will be filled in
32    pub name: Option<String>, // option for backward compatibility
33}
34
35impl Decoder for SmartModuleInvocation {
36    fn decode<T>(&mut self, src: &mut T, version: Version) -> Result<(), IoError>
37    where
38        T: bytes::Buf,
39    {
40        self.wasm.decode(src, version)?;
41        self.kind.decode(src, version)?;
42        self.params.decode(src, version)?;
43        if version < COMMON_VERSION_HAS_SM_NAME {
44            self.name = None;
45        } else {
46            self.name.decode(src, version)?;
47        }
48        Ok(())
49    }
50}
51
52impl Encoder for SmartModuleInvocation {
53    fn write_size(&self, version: Version) -> usize {
54        let mut size = self.wasm.write_size(version);
55        size += self.kind.write_size(version);
56        size += self.params.write_size(version);
57        if version >= COMMON_VERSION_HAS_SM_NAME {
58            size += self.name.write_size(version);
59        }
60        size
61    }
62
63    fn encode<T>(&self, dest: &mut T, version: Version) -> Result<(), IoError>
64    where
65        T: BufMut,
66    {
67        self.wasm.encode(dest, version)?;
68        self.kind.encode(dest, version)?;
69        self.params.encode(dest, version)?;
70        if version >= COMMON_VERSION_HAS_SM_NAME {
71            self.name.encode(dest, version)?;
72        }
73        Ok(())
74    }
75}
76
77#[derive(Clone, Encoder, Decoder)]
78pub enum SmartModuleInvocationWasm {
79    /// Name of SmartModule
80    #[fluvio(tag = 0)]
81    Predefined(String),
82    /// Compressed WASM module payload using Gzip
83    #[fluvio(tag = 1)]
84    AdHoc(Vec<u8>),
85}
86
87impl SmartModuleInvocationWasm {
88    pub fn adhoc_from_bytes(bytes: &[u8]) -> io::Result<Self> {
89        Ok(Self::AdHoc(zip(bytes)?))
90    }
91
92    /// consume and get the raw bytes of the WASM module
93    pub fn into_raw(self) -> io::Result<Vec<u8>> {
94        match self {
95            Self::AdHoc(gzipped) => Ok(unzip(gzipped.as_ref())?),
96            _ => Err(io::Error::new(
97                io::ErrorKind::InvalidData,
98                "unable to represent as raw data",
99            )),
100        }
101    }
102}
103
104impl Default for SmartModuleInvocationWasm {
105    fn default() -> Self {
106        Self::AdHoc(Vec::new())
107    }
108}
109
110impl Debug for SmartModuleInvocationWasm {
111    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
112        match self {
113            Self::Predefined(module) => write!(f, "Predefined{module}"),
114            Self::AdHoc(bytes) => f
115                .debug_tuple("Adhoc")
116                .field(&format!("{} bytes", bytes.len()))
117                .finish(),
118        }
119    }
120}
121
122/// Indicates the type of SmartModule as well as any special data required
123#[derive(Debug, Clone, Encoder, Decoder, Default)]
124pub enum SmartModuleKind {
125    #[default]
126    #[fluvio(tag = 0)]
127    Filter,
128    #[fluvio(tag = 1)]
129    Map,
130    #[fluvio(tag = 2)]
131    #[fluvio(min_version = ARRAY_MAP_WASM_API)]
132    ArrayMap,
133    #[fluvio(tag = 3)]
134    Aggregate { accumulator: Vec<u8> },
135    #[fluvio(tag = 4)]
136    #[fluvio(min_version = ARRAY_MAP_WASM_API)]
137    FilterMap,
138    #[fluvio(tag = 5)]
139    #[fluvio(min_version = SMART_MODULE_API, max_version = CHAIN_SMARTMODULE_API)]
140    Join(String),
141    #[fluvio(tag = 6)]
142    #[fluvio(min_version = SMART_MODULE_API, max_version = CHAIN_SMARTMODULE_API)]
143    JoinStream {
144        topic: String,
145        derivedstream: String,
146    },
147    #[fluvio(tag = 7)]
148    #[fluvio(min_version = GENERIC_SMARTMODULE_API)]
149    Generic(SmartModuleContextData),
150}
151
152impl std::fmt::Display for SmartModuleKind {
153    fn fmt(&self, out: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
154        let name = match self {
155            SmartModuleKind::Filter => "filter",
156            SmartModuleKind::Map => "map",
157            SmartModuleKind::ArrayMap => "array_map",
158            SmartModuleKind::Aggregate { .. } => "aggregate",
159            SmartModuleKind::FilterMap => "filter_map",
160            SmartModuleKind::Join(..) => "join",
161            SmartModuleKind::JoinStream { .. } => "join_stream",
162            SmartModuleKind::Generic(..) => "smartmodule",
163        };
164        out.write_str(name)
165    }
166}
167
168#[derive(Debug, Clone, Encoder, Decoder, Default)]
169pub enum SmartModuleContextData {
170    #[default]
171    #[fluvio(tag = 0)]
172    None,
173    #[fluvio(tag = 1)]
174    Aggregate { accumulator: Vec<u8> },
175    #[fluvio(tag = 2)]
176    Join(String),
177    #[fluvio(tag = 3)]
178    JoinStream {
179        topic: String,
180        derivedstream: String,
181    },
182}
183
184fn zip(raw: &[u8]) -> io::Result<Vec<u8>> {
185    let mut encoder = GzEncoder::new(raw, Compression::default());
186    let mut buffer = Vec::with_capacity(raw.len());
187    encoder.read_to_end(&mut buffer)?;
188    Ok(buffer)
189}
190
191fn unzip(compressed: &[u8]) -> io::Result<Vec<u8>> {
192    let mut decoder = GzDecoder::new(compressed);
193    let mut buffer = Vec::with_capacity(compressed.len());
194    decoder.read_to_end(&mut buffer)?;
195    Ok(buffer)
196}
197
198#[cfg(test)]
199mod tests {
200
201    use super::*;
202
203    #[test]
204    fn test_encode_smartmodulekind() {
205        let mut dest = Vec::new();
206        let value: SmartModuleKind = SmartModuleKind::Filter;
207        value.encode(&mut dest, 0).expect("should encode");
208        assert_eq!(dest.len(), 1);
209        assert_eq!(dest[0], 0x00);
210    }
211
212    #[test]
213    fn test_decode_smartmodulekind() {
214        let bytes = vec![0x01];
215        let mut value: SmartModuleKind = Default::default();
216        value
217            .decode(&mut io::Cursor::new(bytes), 0)
218            .expect("should decode");
219        assert!(matches!(value, SmartModuleKind::Map));
220    }
221
222    #[test]
223    fn test_gzip_smartmoduleinvocationwasm() {
224        let bytes = vec![0xde, 0xad, 0xbe, 0xef];
225        let value: SmartModuleInvocationWasm =
226            SmartModuleInvocationWasm::adhoc_from_bytes(&bytes).expect("should encode");
227        if let SmartModuleInvocationWasm::AdHoc(compressed_bytes) = value {
228            let decompressed_bytes = unzip(&compressed_bytes).expect("should decompress");
229            assert_eq!(decompressed_bytes, bytes);
230        } else {
231            panic!("not adhoc")
232        }
233    }
234}