use std::io::Read;
use std::sync::Arc;
use serde::{Serialize, de::DeserializeOwned};
use super::FormatError;
pub type DeserializeFn = Arc<dyn Fn(&[u8]) -> Result<serde_json::Value, FormatError> + Send + Sync>;
pub type SerializeFn =
Arc<dyn Fn(&serde_json::Value) -> Result<Vec<u8>, FormatError> + Send + Sync>;
pub type StreamDeserializeFn = Arc<
dyn Fn(Box<dyn Read>) -> Box<dyn Iterator<Item = Result<serde_json::Value, FormatError>>>
+ Send
+ Sync,
>;
#[derive(Clone)]
pub struct CustomFormat {
pub name: &'static str,
pub extensions: &'static [&'static str],
pub deserialize_fn: Option<DeserializeFn>,
pub serialize_fn: Option<SerializeFn>,
pub stream_deserialize_fn: Option<StreamDeserializeFn>,
}
impl std::fmt::Debug for CustomFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CustomFormat")
.field("name", &self.name)
.field("extensions", &self.extensions)
.field("has_deserialize", &self.deserialize_fn.is_some())
.field("has_serialize", &self.serialize_fn.is_some())
.finish()
}
}
impl CustomFormat {
pub fn new(name: &'static str, extensions: &'static [&'static str]) -> Self {
Self {
name,
extensions,
deserialize_fn: None,
serialize_fn: None,
stream_deserialize_fn: None,
}
}
pub fn with_deserialize<F>(mut self, f: F) -> Self
where
F: Fn(&[u8]) -> Result<serde_json::Value, FormatError> + Send + Sync + 'static,
{
self.deserialize_fn = Some(Arc::new(f));
self
}
pub fn with_serialize<F>(mut self, f: F) -> Self
where
F: Fn(&serde_json::Value) -> Result<Vec<u8>, FormatError> + Send + Sync + 'static,
{
self.serialize_fn = Some(Arc::new(f));
self
}
pub fn with_stream_deserialize<F>(mut self, f: F) -> Self
where
F: Fn(Box<dyn Read>) -> Box<dyn Iterator<Item = Result<serde_json::Value, FormatError>>>
+ Send
+ Sync
+ 'static,
{
self.stream_deserialize_fn = Some(Arc::new(f));
self
}
pub fn stream_deserialize_values(
&self,
reader: Box<dyn Read>,
) -> Result<Box<dyn Iterator<Item = Result<serde_json::Value, FormatError>>>, FormatError> {
let f = self.stream_deserialize_fn.as_ref().ok_or_else(|| {
FormatError::Other(Box::new(std::io::Error::new(
std::io::ErrorKind::Unsupported,
format!(
"Custom format '{}' does not support streaming deserialization",
self.name
),
)))
})?;
Ok(f(reader))
}
pub fn deserialize<T: DeserializeOwned>(&self, bytes: &[u8]) -> Result<T, FormatError> {
let deserialize_fn = self.deserialize_fn.as_ref().ok_or_else(|| {
FormatError::Other(Box::new(std::io::Error::new(
std::io::ErrorKind::Unsupported,
format!(
"Custom format '{}' does not support deserialization",
self.name
),
)))
})?;
let value = deserialize_fn(bytes)?;
serde_json::from_value(value).map_err(|e| FormatError::Serde(Box::new(e)))
}
pub fn serialize<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, FormatError> {
let serialize_fn = self.serialize_fn.as_ref().ok_or_else(|| {
FormatError::Other(Box::new(std::io::Error::new(
std::io::ErrorKind::Unsupported,
format!(
"Custom format '{}' does not support serialization",
self.name
),
)))
})?;
let json_value =
serde_json::to_value(value).map_err(|e| FormatError::Serde(Box::new(e)))?;
serialize_fn(&json_value)
}
pub fn matches_extension(&self, ext: &str) -> bool {
let ext_lower = ext.to_ascii_lowercase();
self.extensions
.iter()
.any(|e| e.eq_ignore_ascii_case(&ext_lower))
}
}