brainwires_proxy/convert/
mod.rs1pub mod detect;
4pub mod json_transform;
5
6use crate::error::{ProxyError, ProxyResult};
7use crate::types::FormatId;
8use bytes::Bytes;
9use std::collections::HashMap;
10
11#[async_trait::async_trait]
13pub trait Converter: Send + Sync {
14 fn source(&self) -> &FormatId;
16 fn target(&self) -> &FormatId;
18 async fn convert(&self, body: Bytes) -> ProxyResult<Bytes>;
20}
21
22#[async_trait::async_trait]
24pub trait StreamConverter: Send + Sync {
25 fn source(&self) -> &FormatId;
26 fn target(&self) -> &FormatId;
27 async fn convert_chunk(&self, chunk: Bytes) -> ProxyResult<Vec<Bytes>>;
29 async fn flush(&self) -> ProxyResult<Vec<Bytes>>;
31}
32
33pub trait FormatDetector: Send + Sync {
35 fn detect(&self, body: &[u8], content_type: Option<&str>) -> Option<FormatId>;
37 fn name(&self) -> &str;
39}
40
41pub struct ConversionRegistry {
43 converters: HashMap<(FormatId, FormatId), Box<dyn Converter>>,
44 stream_converters: HashMap<(FormatId, FormatId), Box<dyn StreamConverter>>,
45 detectors: Vec<Box<dyn FormatDetector>>,
46}
47
48impl ConversionRegistry {
49 pub fn new() -> Self {
50 Self {
51 converters: HashMap::new(),
52 stream_converters: HashMap::new(),
53 detectors: Vec::new(),
54 }
55 }
56
57 pub fn register_converter(&mut self, converter: impl Converter + 'static) {
59 let key = (converter.source().clone(), converter.target().clone());
60 self.converters.insert(key, Box::new(converter));
61 }
62
63 pub fn register_stream_converter(&mut self, converter: impl StreamConverter + 'static) {
65 let key = (converter.source().clone(), converter.target().clone());
66 self.stream_converters.insert(key, Box::new(converter));
67 }
68
69 pub fn register_detector(&mut self, detector: impl FormatDetector + 'static) {
71 self.detectors.push(Box::new(detector));
72 }
73
74 pub fn get_converter(&self, source: &FormatId, target: &FormatId) -> Option<&dyn Converter> {
76 self.converters
77 .get(&(source.clone(), target.clone()))
78 .map(|c| c.as_ref())
79 }
80
81 pub fn get_stream_converter(
83 &self,
84 source: &FormatId,
85 target: &FormatId,
86 ) -> Option<&dyn StreamConverter> {
87 self.stream_converters
88 .get(&(source.clone(), target.clone()))
89 .map(|c| c.as_ref())
90 }
91
92 pub fn detect_format(&self, body: &[u8], content_type: Option<&str>) -> Option<FormatId> {
94 for detector in &self.detectors {
95 if let Some(fmt) = detector.detect(body, content_type) {
96 return Some(fmt);
97 }
98 }
99 None
100 }
101
102 pub async fn convert(
104 &self,
105 body: Bytes,
106 source: Option<&FormatId>,
107 target: &FormatId,
108 content_type: Option<&str>,
109 ) -> ProxyResult<Bytes> {
110 let detected;
111 let source = match source {
112 Some(s) => s,
113 None => {
114 detected = self
115 .detect_format(&body, content_type)
116 .ok_or(ProxyError::FormatDetectionFailed)?;
117 &detected
118 }
119 };
120
121 let converter = self.get_converter(source, target).ok_or_else(|| {
122 ProxyError::UnsupportedConversion {
123 src: source.to_string(),
124 dst: target.to_string(),
125 }
126 })?;
127
128 converter.convert(body).await
129 }
130}
131
132impl Default for ConversionRegistry {
133 fn default() -> Self {
134 Self::new()
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 struct UpperCaseConverter {
144 source: FormatId,
145 target: FormatId,
146 }
147
148 impl UpperCaseConverter {
149 fn new() -> Self {
150 Self {
151 source: FormatId::new("text"),
152 target: FormatId::new("upper"),
153 }
154 }
155 }
156
157 #[async_trait::async_trait]
158 impl Converter for UpperCaseConverter {
159 fn source(&self) -> &FormatId {
160 &self.source
161 }
162 fn target(&self) -> &FormatId {
163 &self.target
164 }
165 async fn convert(&self, body: Bytes) -> ProxyResult<Bytes> {
166 let text = String::from_utf8_lossy(&body).to_uppercase();
167 Ok(Bytes::from(text))
168 }
169 }
170
171 struct TextDetector;
173
174 impl FormatDetector for TextDetector {
175 fn detect(&self, _body: &[u8], content_type: Option<&str>) -> Option<FormatId> {
176 if content_type?.contains("text/plain") {
177 Some(FormatId::new("text"))
178 } else {
179 None
180 }
181 }
182 fn name(&self) -> &str {
183 "text_detector"
184 }
185 }
186
187 #[tokio::test]
188 async fn register_and_lookup_converter() {
189 let mut registry = ConversionRegistry::new();
190 registry.register_converter(UpperCaseConverter::new());
191
192 let source = FormatId::new("text");
193 let target = FormatId::new("upper");
194 assert!(registry.get_converter(&source, &target).is_some());
195 assert!(registry.get_converter(&target, &source).is_none());
196 }
197
198 #[tokio::test]
199 async fn convert_body() {
200 let mut registry = ConversionRegistry::new();
201 registry.register_converter(UpperCaseConverter::new());
202
203 let source = FormatId::new("text");
204 let target = FormatId::new("upper");
205 let result = registry
206 .convert(Bytes::from("hello"), Some(&source), &target, None)
207 .await
208 .unwrap();
209 assert_eq!(result.as_ref(), b"HELLO");
210 }
211
212 #[tokio::test]
213 async fn auto_detect_source_format() {
214 let mut registry = ConversionRegistry::new();
215 registry.register_converter(UpperCaseConverter::new());
216 registry.register_detector(TextDetector);
217
218 let target = FormatId::new("upper");
219 let result = registry
220 .convert(Bytes::from("world"), None, &target, Some("text/plain"))
221 .await
222 .unwrap();
223 assert_eq!(result.as_ref(), b"WORLD");
224 }
225
226 #[tokio::test]
227 async fn detection_failure_returns_error() {
228 let registry = ConversionRegistry::new(); let target = FormatId::new("upper");
230 let result = registry
231 .convert(Bytes::from("data"), None, &target, None)
232 .await;
233 assert!(result.is_err());
234 }
235
236 #[tokio::test]
237 async fn unsupported_conversion_returns_error() {
238 let mut registry = ConversionRegistry::new();
239 registry.register_converter(UpperCaseConverter::new());
240
241 let source = FormatId::new("text");
242 let target = FormatId::new("nonexistent");
243 let result = registry
244 .convert(Bytes::from("data"), Some(&source), &target, None)
245 .await;
246 assert!(result.is_err());
247 }
248}