Skip to main content

drasi_middleware/decoder/
mod.rs

1// Copyright 2024 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::common::ErrorHandling;
16use async_trait::async_trait;
17use base64::Engine;
18use drasi_core::{
19    interface::{
20        ElementIndex, MiddlewareError, MiddlewareSetupError, SourceMiddleware,
21        SourceMiddlewareFactory,
22    },
23    models::{Element, ElementValue, SourceChange, SourceMiddlewareConfig},
24};
25use serde::Deserialize;
26use serde_json::Value;
27use std::sync::Arc;
28
29#[cfg(test)]
30mod tests;
31
32/// Specifies the encoding type of the target property's string value.
33#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
34#[serde(rename_all = "snake_case")]
35pub enum EncodingType {
36    Base64,
37    Base64url,
38    Hex,
39    Url,
40    JsonEscape,
41}
42
43/// Mapping of encoding types to their string representations for error messages
44const ENCODING_TYPE_NAMES: &[(&str, EncodingType)] = &[
45    ("base64", EncodingType::Base64),
46    ("base64url", EncodingType::Base64url),
47    ("hex", EncodingType::Hex),
48    ("url", EncodingType::Url),
49    ("json_escape", EncodingType::JsonEscape),
50];
51
52/// Configuration for the Decoder middleware.
53#[derive(Debug, Clone, Deserialize)]
54#[serde(deny_unknown_fields)]
55pub struct DecoderConfig {
56    /// The encoding type to decode from.
57    pub encoding_type: EncodingType,
58    /// The property containing the encoded string.
59    pub target_property: String,
60    /// Optional property name to store the decoded string.
61    /// If None, `target_property` is overwritten.
62    pub output_property: Option<String>,
63    /// If true, surrounding double quotes (`"`) are removed from the target string before decoding.
64    #[serde(default)]
65    pub strip_quotes: bool,
66    /// How to handle errors during decoding or if the target property is missing/invalid.
67    #[serde(default)]
68    pub on_error: ErrorHandling,
69    /// Maximum allowed size for encoded strings to prevent excessive memory usage.
70    #[serde(default = "default_max_size")]
71    pub max_size_bytes: usize,
72}
73
74fn default_max_size() -> usize {
75    1024 * 1024 // 1MB default
76}
77
78/// Middleware that decodes a string property using a specified encoding.
79pub struct Decoder {
80    name: String,
81    config: DecoderConfig,
82}
83
84#[async_trait]
85impl SourceMiddleware for Decoder {
86    async fn process(
87        &self,
88        source_change: SourceChange,
89        _element_index: &dyn ElementIndex,
90    ) -> Result<Vec<SourceChange>, MiddlewareError> {
91        match source_change {
92            SourceChange::Insert { mut element } => match self.decode_property(&mut element) {
93                Ok(_) => Ok(vec![SourceChange::Insert { element }]),
94                Err(e) => Err(e),
95            },
96            SourceChange::Update { mut element } => match self.decode_property(&mut element) {
97                Ok(_) => Ok(vec![SourceChange::Update { element }]),
98                Err(e) => Err(e),
99            },
100            // Pass through other change types
101            SourceChange::Delete { .. } | SourceChange::Future { .. } => Ok(vec![source_change]),
102        }
103    }
104}
105
106impl Decoder {
107    /// Helper method to get the type name of an ElementValue
108    fn get_element_value_type_name(value: &ElementValue) -> &'static str {
109        match value {
110            ElementValue::Null => "Null",
111            ElementValue::Bool(_) => "Bool",
112            ElementValue::Float(_) => "Float",
113            ElementValue::Integer(_) => "Integer",
114            ElementValue::String(_) => "String",
115            ElementValue::List(_) => "List",
116            ElementValue::Object(_) => "Object",
117        }
118    }
119
120    /// Helper method to get the encoding type name for error messages
121    fn get_encoding_type_name(&self) -> &'static str {
122        for (name, encoding_type) in ENCODING_TYPE_NAMES {
123            if *encoding_type == self.config.encoding_type {
124                return name;
125            }
126        }
127        "unknown"
128    }
129
130    /// Decodes the target property within the given element.
131    fn decode_property(&self, element: &mut Element) -> Result<(), MiddlewareError> {
132        let target_prop_name = &self.config.target_property;
133        let output_prop_name = self
134            .config
135            .output_property
136            .as_deref()
137            .unwrap_or(target_prop_name);
138
139        // Get the property value without cloning
140        match element.get_properties().get(target_prop_name) {
141            Some(ElementValue::String(s)) => {
142                let encoded_str = s.to_string();
143
144                // Check size limit
145                if encoded_str.len() > self.config.max_size_bytes {
146                    let msg = format!(
147                        "[{}] Encoded string in property '{}' exceeds size limit ({} > {})",
148                        self.name,
149                        target_prop_name,
150                        encoded_str.len(),
151                        self.config.max_size_bytes
152                    );
153                    log::warn!("{msg}");
154                    return if self.config.on_error == ErrorHandling::Fail {
155                        Err(MiddlewareError::SourceChangeError(msg))
156                    } else {
157                        Ok(())
158                    };
159                }
160
161                // Step 1: Strip quotes if requested
162                let processed_str = if self.config.strip_quotes {
163                    encoded_str.trim_matches('"')
164                } else {
165                    &encoded_str
166                }
167                .to_string();
168
169                // Step 2: Decode the string based on encoding type
170                let decoded_result = match self.config.encoding_type {
171                    EncodingType::Base64 => self.decode_base64(&processed_str),
172                    EncodingType::Base64url => self.decode_base64url(&processed_str),
173                    EncodingType::Hex => self.decode_hex(&processed_str),
174                    EncodingType::Url => self.decode_url(&processed_str),
175                    EncodingType::JsonEscape => self.decode_json_escape(&processed_str),
176                };
177
178                match decoded_result {
179                    Ok(decoded_string) => {
180                        // Get mutable access to properties to check for collision and insert
181                        match element {
182                            Element::Node { properties, .. }
183                            | Element::Relation { properties, .. } => {
184                                // Check for potential overwrite collision
185                                if output_prop_name != target_prop_name
186                                    && properties.get(output_prop_name).is_some()
187                                {
188                                    log::warn!(
189                                        "[{}] Output property '{}' specified in config already exists and will be overwritten.",
190                                        self.name,
191                                        output_prop_name
192                                    );
193                                }
194
195                                // Update the element with the decoded string
196                                properties.insert(
197                                    output_prop_name,
198                                    ElementValue::String(decoded_string.into()),
199                                );
200                            }
201                        }
202                        Ok(())
203                    }
204                    Err(e) => {
205                        let encoding_name = self.get_encoding_type_name();
206                        let msg = format!(
207                            "[{}] Failed to decode property '{}' using {} encoding: {}",
208                            self.name, target_prop_name, encoding_name, e
209                        );
210                        log::warn!("{msg}");
211                        if self.config.on_error == ErrorHandling::Fail {
212                            Err(MiddlewareError::SourceChangeError(msg))
213                        } else {
214                            Ok(())
215                        }
216                    }
217                }
218            }
219            Some(value) => {
220                // Handle non-string property types
221                let type_name = Self::get_element_value_type_name(value);
222                let msg = format!(
223                    "[{}] Target property '{}' is not a string value (Type: {}).",
224                    self.name, target_prop_name, type_name
225                );
226                log::warn!("{msg}");
227                if self.config.on_error == ErrorHandling::Fail {
228                    Err(MiddlewareError::SourceChangeError(msg))
229                } else {
230                    Ok(())
231                }
232            }
233            None => {
234                // Handle missing property
235                let msg = format!(
236                    "[{}] Target property '{}' not found in element.",
237                    self.name, target_prop_name
238                );
239                log::warn!("{msg}");
240                if self.config.on_error == ErrorHandling::Fail {
241                    Err(MiddlewareError::SourceChangeError(msg))
242                } else {
243                    Ok(())
244                }
245            }
246        }
247    }
248
249    /// Decodes a base64 encoded string.
250    pub fn decode_base64(&self, encoded: &str) -> Result<String, String> {
251        base64::engine::general_purpose::STANDARD
252            .decode(encoded.as_bytes())
253            .map_err(|e| format!("Invalid base64 encoding: {e}"))
254            .and_then(|bytes| {
255                String::from_utf8(bytes)
256                    .map_err(|e| format!("Decoded bytes are not valid UTF-8: {e}"))
257            })
258    }
259
260    /// Decodes a base64url encoded string.
261    pub fn decode_base64url(&self, encoded: &str) -> Result<String, String> {
262        base64::engine::general_purpose::URL_SAFE_NO_PAD
263            .decode(encoded.as_bytes())
264            .map_err(|e| format!("Invalid base64url encoding: {e}"))
265            .and_then(|bytes| {
266                String::from_utf8(bytes)
267                    .map_err(|e| format!("Decoded bytes are not valid UTF-8: {e}"))
268            })
269    }
270
271    /// Decodes a hex encoded string.
272    pub fn decode_hex(&self, encoded: &str) -> Result<String, String> {
273        hex::decode(encoded)
274            .map_err(|e| format!("Invalid hex encoding: {e}"))
275            .and_then(|bytes| {
276                String::from_utf8(bytes)
277                    .map_err(|e| format!("Decoded bytes are not valid UTF-8: {e}"))
278            })
279    }
280
281    /// Decodes a URL encoded string.
282    pub fn decode_url(&self, encoded: &str) -> Result<String, String> {
283        urlencoding::decode(encoded)
284            .map(|cow| cow.into_owned())
285            .map_err(|e| format!("Invalid URL encoding: {e}"))
286    }
287
288    /// Decodes a JSON escaped string.
289    pub fn decode_json_escape(&self, encoded: &str) -> Result<String, String> {
290        let json_value_str = format!("\"{encoded}\"");
291        serde_json::from_str::<String>(&json_value_str)
292            .map_err(|e| format!("Invalid JSON escape sequence: {e}"))
293    }
294}
295
296/// Factory for creating Decoder middleware instances.
297pub struct DecoderFactory {}
298
299impl DecoderFactory {
300    pub fn new() -> Self {
301        DecoderFactory {}
302    }
303}
304
305impl Default for DecoderFactory {
306    fn default() -> Self {
307        Self::new()
308    }
309}
310
311impl SourceMiddlewareFactory for DecoderFactory {
312    fn name(&self) -> String {
313        "decoder".to_string()
314    }
315
316    fn create(
317        &self,
318        config: &SourceMiddlewareConfig,
319    ) -> Result<Arc<dyn SourceMiddleware>, MiddlewareSetupError> {
320        let decoder_config: DecoderConfig =
321            match serde_json::from_value(Value::Object(config.config.clone())) {
322                Ok(cfg) => cfg,
323                Err(e) => {
324                    return Err(MiddlewareSetupError::InvalidConfiguration(format!(
325                        "[{}] Invalid decoder configuration: {}",
326                        config.name, e
327                    )));
328                }
329            };
330
331        // Basic validation
332        if decoder_config.target_property.is_empty() {
333            return Err(MiddlewareSetupError::InvalidConfiguration(format!(
334                "[{}] Missing or empty 'target_property' field in decoder configuration",
335                config.name
336            )));
337        }
338
339        if let Some(output_prop) = &decoder_config.output_property {
340            if output_prop.is_empty() {
341                return Err(MiddlewareSetupError::InvalidConfiguration(format!(
342                    "[{}] 'output_property' cannot be empty if provided",
343                    config.name
344                )));
345            }
346        }
347
348        // Validate max_size_bytes is reasonable
349        if decoder_config.max_size_bytes == 0 {
350            return Err(MiddlewareSetupError::InvalidConfiguration(format!(
351                "[{}] 'max_size_bytes' must be greater than zero",
352                config.name
353            )));
354        }
355
356        log::info!(
357            "[{}] Creating Decoder middleware with config: {:?}",
358            config.name,
359            decoder_config
360        );
361
362        Ok(Arc::new(Decoder {
363            name: config.name.to_string(),
364            config: decoder_config,
365        }))
366    }
367}