use crate::common::ErrorHandling;
use async_trait::async_trait;
use base64::Engine;
use drasi_core::{
interface::{
ElementIndex, MiddlewareError, MiddlewareSetupError, SourceMiddleware,
SourceMiddlewareFactory,
},
models::{Element, ElementValue, SourceChange, SourceMiddlewareConfig},
};
use serde::Deserialize;
use serde_json::Value;
use std::sync::Arc;
#[cfg(test)]
mod tests;
#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum EncodingType {
Base64,
Base64url,
Hex,
Url,
JsonEscape,
}
const ENCODING_TYPE_NAMES: &[(&str, EncodingType)] = &[
("base64", EncodingType::Base64),
("base64url", EncodingType::Base64url),
("hex", EncodingType::Hex),
("url", EncodingType::Url),
("json_escape", EncodingType::JsonEscape),
];
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct DecoderConfig {
pub encoding_type: EncodingType,
pub target_property: String,
pub output_property: Option<String>,
#[serde(default)]
pub strip_quotes: bool,
#[serde(default)]
pub on_error: ErrorHandling,
#[serde(default = "default_max_size")]
pub max_size_bytes: usize,
}
fn default_max_size() -> usize {
1024 * 1024 }
pub struct Decoder {
name: String,
config: DecoderConfig,
}
#[async_trait]
impl SourceMiddleware for Decoder {
async fn process(
&self,
source_change: SourceChange,
_element_index: &dyn ElementIndex,
) -> Result<Vec<SourceChange>, MiddlewareError> {
match source_change {
SourceChange::Insert { mut element } => match self.decode_property(&mut element) {
Ok(_) => Ok(vec![SourceChange::Insert { element }]),
Err(e) => Err(e),
},
SourceChange::Update { mut element } => match self.decode_property(&mut element) {
Ok(_) => Ok(vec![SourceChange::Update { element }]),
Err(e) => Err(e),
},
SourceChange::Delete { .. } | SourceChange::Future { .. } => Ok(vec![source_change]),
}
}
}
impl Decoder {
fn get_element_value_type_name(value: &ElementValue) -> &'static str {
match value {
ElementValue::Null => "Null",
ElementValue::Bool(_) => "Bool",
ElementValue::Float(_) => "Float",
ElementValue::Integer(_) => "Integer",
ElementValue::String(_) => "String",
ElementValue::List(_) => "List",
ElementValue::Object(_) => "Object",
}
}
fn get_encoding_type_name(&self) -> &'static str {
for (name, encoding_type) in ENCODING_TYPE_NAMES {
if *encoding_type == self.config.encoding_type {
return name;
}
}
"unknown"
}
fn decode_property(&self, element: &mut Element) -> Result<(), MiddlewareError> {
let target_prop_name = &self.config.target_property;
let output_prop_name = self
.config
.output_property
.as_deref()
.unwrap_or(target_prop_name);
match element.get_properties().get(target_prop_name) {
Some(ElementValue::String(s)) => {
let encoded_str = s.to_string();
if encoded_str.len() > self.config.max_size_bytes {
let msg = format!(
"[{}] Encoded string in property '{}' exceeds size limit ({} > {})",
self.name,
target_prop_name,
encoded_str.len(),
self.config.max_size_bytes
);
log::warn!("{msg}");
return if self.config.on_error == ErrorHandling::Fail {
Err(MiddlewareError::SourceChangeError(msg))
} else {
Ok(())
};
}
let processed_str = if self.config.strip_quotes {
encoded_str.trim_matches('"')
} else {
&encoded_str
}
.to_string();
let decoded_result = match self.config.encoding_type {
EncodingType::Base64 => self.decode_base64(&processed_str),
EncodingType::Base64url => self.decode_base64url(&processed_str),
EncodingType::Hex => self.decode_hex(&processed_str),
EncodingType::Url => self.decode_url(&processed_str),
EncodingType::JsonEscape => self.decode_json_escape(&processed_str),
};
match decoded_result {
Ok(decoded_string) => {
match element {
Element::Node { properties, .. }
| Element::Relation { properties, .. } => {
if output_prop_name != target_prop_name
&& properties.get(output_prop_name).is_some()
{
log::warn!(
"[{}] Output property '{}' specified in config already exists and will be overwritten.",
self.name,
output_prop_name
);
}
properties.insert(
output_prop_name,
ElementValue::String(decoded_string.into()),
);
}
}
Ok(())
}
Err(e) => {
let encoding_name = self.get_encoding_type_name();
let msg = format!(
"[{}] Failed to decode property '{}' using {} encoding: {}",
self.name, target_prop_name, encoding_name, e
);
log::warn!("{msg}");
if self.config.on_error == ErrorHandling::Fail {
Err(MiddlewareError::SourceChangeError(msg))
} else {
Ok(())
}
}
}
}
Some(value) => {
let type_name = Self::get_element_value_type_name(value);
let msg = format!(
"[{}] Target property '{}' is not a string value (Type: {}).",
self.name, target_prop_name, type_name
);
log::warn!("{msg}");
if self.config.on_error == ErrorHandling::Fail {
Err(MiddlewareError::SourceChangeError(msg))
} else {
Ok(())
}
}
None => {
let msg = format!(
"[{}] Target property '{}' not found in element.",
self.name, target_prop_name
);
log::warn!("{msg}");
if self.config.on_error == ErrorHandling::Fail {
Err(MiddlewareError::SourceChangeError(msg))
} else {
Ok(())
}
}
}
}
pub fn decode_base64(&self, encoded: &str) -> Result<String, String> {
base64::engine::general_purpose::STANDARD
.decode(encoded.as_bytes())
.map_err(|e| format!("Invalid base64 encoding: {e}"))
.and_then(|bytes| {
String::from_utf8(bytes)
.map_err(|e| format!("Decoded bytes are not valid UTF-8: {e}"))
})
}
pub fn decode_base64url(&self, encoded: &str) -> Result<String, String> {
base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(encoded.as_bytes())
.map_err(|e| format!("Invalid base64url encoding: {e}"))
.and_then(|bytes| {
String::from_utf8(bytes)
.map_err(|e| format!("Decoded bytes are not valid UTF-8: {e}"))
})
}
pub fn decode_hex(&self, encoded: &str) -> Result<String, String> {
hex::decode(encoded)
.map_err(|e| format!("Invalid hex encoding: {e}"))
.and_then(|bytes| {
String::from_utf8(bytes)
.map_err(|e| format!("Decoded bytes are not valid UTF-8: {e}"))
})
}
pub fn decode_url(&self, encoded: &str) -> Result<String, String> {
urlencoding::decode(encoded)
.map(|cow| cow.into_owned())
.map_err(|e| format!("Invalid URL encoding: {e}"))
}
pub fn decode_json_escape(&self, encoded: &str) -> Result<String, String> {
let json_value_str = format!("\"{encoded}\"");
serde_json::from_str::<String>(&json_value_str)
.map_err(|e| format!("Invalid JSON escape sequence: {e}"))
}
}
pub struct DecoderFactory {}
impl DecoderFactory {
pub fn new() -> Self {
DecoderFactory {}
}
}
impl Default for DecoderFactory {
fn default() -> Self {
Self::new()
}
}
impl SourceMiddlewareFactory for DecoderFactory {
fn name(&self) -> String {
"decoder".to_string()
}
fn create(
&self,
config: &SourceMiddlewareConfig,
) -> Result<Arc<dyn SourceMiddleware>, MiddlewareSetupError> {
let decoder_config: DecoderConfig =
match serde_json::from_value(Value::Object(config.config.clone())) {
Ok(cfg) => cfg,
Err(e) => {
return Err(MiddlewareSetupError::InvalidConfiguration(format!(
"[{}] Invalid decoder configuration: {}",
config.name, e
)));
}
};
if decoder_config.target_property.is_empty() {
return Err(MiddlewareSetupError::InvalidConfiguration(format!(
"[{}] Missing or empty 'target_property' field in decoder configuration",
config.name
)));
}
if let Some(output_prop) = &decoder_config.output_property {
if output_prop.is_empty() {
return Err(MiddlewareSetupError::InvalidConfiguration(format!(
"[{}] 'output_property' cannot be empty if provided",
config.name
)));
}
}
if decoder_config.max_size_bytes == 0 {
return Err(MiddlewareSetupError::InvalidConfiguration(format!(
"[{}] 'max_size_bytes' must be greater than zero",
config.name
)));
}
log::info!(
"[{}] Creating Decoder middleware with config: {:?}",
config.name,
decoder_config
);
Ok(Arc::new(Decoder {
name: config.name.to_string(),
config: decoder_config,
}))
}
}