Skip to main content

brainwires_proxy/convert/
mod.rs

1//! Body conversion system — registry-based format conversion for request/response bodies.
2
3pub mod detect;
4pub mod json_transform;
5
6use crate::error::{ProxyError, ProxyResult};
7use crate::types::FormatId;
8use bytes::Bytes;
9use std::collections::HashMap;
10
11/// Converts a complete body between two formats atomically.
12#[async_trait::async_trait]
13pub trait Converter: Send + Sync {
14    /// Source format this converter reads.
15    fn source(&self) -> &FormatId;
16    /// Target format this converter produces.
17    fn target(&self) -> &FormatId;
18    /// Convert the body bytes.
19    async fn convert(&self, body: Bytes) -> ProxyResult<Bytes>;
20}
21
22/// Converts streaming data chunk-by-chunk (for SSE, WebSocket, etc.).
23#[async_trait::async_trait]
24pub trait StreamConverter: Send + Sync {
25    fn source(&self) -> &FormatId;
26    fn target(&self) -> &FormatId;
27    /// Process one chunk, returning zero or more output chunks.
28    async fn convert_chunk(&self, chunk: Bytes) -> ProxyResult<Vec<Bytes>>;
29    /// Flush any buffered data at end of stream.
30    async fn flush(&self) -> ProxyResult<Vec<Bytes>>;
31}
32
33/// Detects the format of a body payload.
34pub trait FormatDetector: Send + Sync {
35    /// Inspect bytes and return the detected format, or `None` if unknown.
36    fn detect(&self, body: &[u8], content_type: Option<&str>) -> Option<FormatId>;
37    /// Human-readable detector name.
38    fn name(&self) -> &str;
39}
40
41/// Registry mapping `(source, target)` format pairs to converters.
42pub struct ConversionRegistry {
43    converters: HashMap<(FormatId, FormatId), Box<dyn Converter>>,
44    stream_converters: HashMap<(FormatId, FormatId), Box<dyn StreamConverter>>,
45    detectors: Vec<Box<dyn FormatDetector>>,
46}
47
48impl ConversionRegistry {
49    pub fn new() -> Self {
50        Self {
51            converters: HashMap::new(),
52            stream_converters: HashMap::new(),
53            detectors: Vec::new(),
54        }
55    }
56
57    /// Register an atomic converter for a format pair.
58    pub fn register_converter(&mut self, converter: impl Converter + 'static) {
59        let key = (converter.source().clone(), converter.target().clone());
60        self.converters.insert(key, Box::new(converter));
61    }
62
63    /// Register a streaming converter for a format pair.
64    pub fn register_stream_converter(&mut self, converter: impl StreamConverter + 'static) {
65        let key = (converter.source().clone(), converter.target().clone());
66        self.stream_converters.insert(key, Box::new(converter));
67    }
68
69    /// Register a format detector.
70    pub fn register_detector(&mut self, detector: impl FormatDetector + 'static) {
71        self.detectors.push(Box::new(detector));
72    }
73
74    /// Look up an atomic converter.
75    pub fn get_converter(&self, source: &FormatId, target: &FormatId) -> Option<&dyn Converter> {
76        self.converters
77            .get(&(source.clone(), target.clone()))
78            .map(|c| c.as_ref())
79    }
80
81    /// Look up a streaming converter.
82    pub fn get_stream_converter(
83        &self,
84        source: &FormatId,
85        target: &FormatId,
86    ) -> Option<&dyn StreamConverter> {
87        self.stream_converters
88            .get(&(source.clone(), target.clone()))
89            .map(|c| c.as_ref())
90    }
91
92    /// Detect the format of a body payload using registered detectors.
93    pub fn detect_format(&self, body: &[u8], content_type: Option<&str>) -> Option<FormatId> {
94        for detector in &self.detectors {
95            if let Some(fmt) = detector.detect(body, content_type) {
96                return Some(fmt);
97            }
98        }
99        None
100    }
101
102    /// Convert body between formats, auto-detecting source if not provided.
103    pub async fn convert(
104        &self,
105        body: Bytes,
106        source: Option<&FormatId>,
107        target: &FormatId,
108        content_type: Option<&str>,
109    ) -> ProxyResult<Bytes> {
110        let detected;
111        let source = match source {
112            Some(s) => s,
113            None => {
114                detected = self
115                    .detect_format(&body, content_type)
116                    .ok_or(ProxyError::FormatDetectionFailed)?;
117                &detected
118            }
119        };
120
121        let converter = self.get_converter(source, target).ok_or_else(|| {
122            ProxyError::UnsupportedConversion {
123                src: source.to_string(),
124                dst: target.to_string(),
125            }
126        })?;
127
128        converter.convert(body).await
129    }
130}
131
132impl Default for ConversionRegistry {
133    fn default() -> Self {
134        Self::new()
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    /// A simple test converter that upper-cases body text.
143    struct UpperCaseConverter {
144        source: FormatId,
145        target: FormatId,
146    }
147
148    impl UpperCaseConverter {
149        fn new() -> Self {
150            Self {
151                source: FormatId::new("text"),
152                target: FormatId::new("upper"),
153            }
154        }
155    }
156
157    #[async_trait::async_trait]
158    impl Converter for UpperCaseConverter {
159        fn source(&self) -> &FormatId {
160            &self.source
161        }
162        fn target(&self) -> &FormatId {
163            &self.target
164        }
165        async fn convert(&self, body: Bytes) -> ProxyResult<Bytes> {
166            let text = String::from_utf8_lossy(&body).to_uppercase();
167            Ok(Bytes::from(text))
168        }
169    }
170
171    /// A detector that identifies "text" format.
172    struct TextDetector;
173
174    impl FormatDetector for TextDetector {
175        fn detect(&self, _body: &[u8], content_type: Option<&str>) -> Option<FormatId> {
176            if content_type?.contains("text/plain") {
177                Some(FormatId::new("text"))
178            } else {
179                None
180            }
181        }
182        fn name(&self) -> &str {
183            "text_detector"
184        }
185    }
186
187    #[tokio::test]
188    async fn register_and_lookup_converter() {
189        let mut registry = ConversionRegistry::new();
190        registry.register_converter(UpperCaseConverter::new());
191
192        let source = FormatId::new("text");
193        let target = FormatId::new("upper");
194        assert!(registry.get_converter(&source, &target).is_some());
195        assert!(registry.get_converter(&target, &source).is_none());
196    }
197
198    #[tokio::test]
199    async fn convert_body() {
200        let mut registry = ConversionRegistry::new();
201        registry.register_converter(UpperCaseConverter::new());
202
203        let source = FormatId::new("text");
204        let target = FormatId::new("upper");
205        let result = registry
206            .convert(Bytes::from("hello"), Some(&source), &target, None)
207            .await
208            .unwrap();
209        assert_eq!(result.as_ref(), b"HELLO");
210    }
211
212    #[tokio::test]
213    async fn auto_detect_source_format() {
214        let mut registry = ConversionRegistry::new();
215        registry.register_converter(UpperCaseConverter::new());
216        registry.register_detector(TextDetector);
217
218        let target = FormatId::new("upper");
219        let result = registry
220            .convert(Bytes::from("world"), None, &target, Some("text/plain"))
221            .await
222            .unwrap();
223        assert_eq!(result.as_ref(), b"WORLD");
224    }
225
226    #[tokio::test]
227    async fn detection_failure_returns_error() {
228        let registry = ConversionRegistry::new(); // no detectors
229        let target = FormatId::new("upper");
230        let result = registry
231            .convert(Bytes::from("data"), None, &target, None)
232            .await;
233        assert!(result.is_err());
234    }
235
236    #[tokio::test]
237    async fn unsupported_conversion_returns_error() {
238        let mut registry = ConversionRegistry::new();
239        registry.register_converter(UpperCaseConverter::new());
240
241        let source = FormatId::new("text");
242        let target = FormatId::new("nonexistent");
243        let result = registry
244            .convert(Bytes::from("data"), Some(&source), &target, None)
245            .await;
246        assert!(result.is_err());
247    }
248}