Skip to main content

camel_component_validator/
xsd_bridge.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use async_trait::async_trait;
7use camel_bridge::channel::connect_channel;
8use camel_bridge::download::{default_cache_dir_for_spec, ensure_binary_for_spec};
9use camel_bridge::health::wait_for_health;
10use camel_bridge::process::{BridgeError, BridgeProcess, BridgeProcessConfig};
11use camel_bridge::reconnect::BridgeReconnectHandler;
12use camel_bridge::spec::XML_BRIDGE;
13use dashmap::DashMap;
14use sha2::{Digest, Sha256};
15use tokio::sync::{Mutex, RwLock, watch};
16use tonic::Code;
17use tonic::transport::Channel;
18use tracing::warn;
19
20use crate::error::ValidatorError;
21use crate::proto;
22use crate::proto::{
23    HealthCheckRequest, RegisterSchemaRequest, RegisterSchemaResponse, ValidateResponse,
24    ValidateWithRequest,
25};
26
27pub type SchemaId = String;
28
29#[derive(Debug, Clone)]
30pub enum BridgeState {
31    Starting,
32    Ready { channel: Channel },
33    Degraded(String),
34    Restarting { attempt: u32, next_at: Instant },
35    Stopped,
36}
37
38pub struct XmlBridgeSlot {
39    pub state_rx: watch::Receiver<BridgeState>,
40    pub(crate) state_tx: watch::Sender<BridgeState>,
41    pub process: Arc<Mutex<Option<BridgeProcess>>>,
42}
43
44impl std::fmt::Debug for XmlBridgeSlot {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        f.debug_struct("XmlBridgeSlot").finish()
47    }
48}
49
50type ConnectFuture = Pin<Box<dyn Future<Output = Result<Channel, BridgeError>> + Send>>;
51type ConnectFn = dyn Fn(u16) -> ConnectFuture + Send + Sync;
52
53#[async_trait]
54pub trait XsdBridge: Send + Sync {
55    async fn register(&self, xsd_bytes: Vec<u8>) -> Result<SchemaId, ValidatorError>;
56    async fn validate(&self, schema_id: &str, doc_bytes: Vec<u8>) -> Result<(), ValidatorError>;
57}
58
59#[async_trait]
60trait XsdBridgeRpc: Send + Sync {
61    async fn register_schema(
62        &self,
63        channel: Channel,
64        request: RegisterSchemaRequest,
65    ) -> Result<RegisterSchemaResponse, ValidatorError>;
66
67    async fn validate_with(
68        &self,
69        channel: Channel,
70        request: ValidateWithRequest,
71    ) -> Result<ValidateResponse, ValidatorError>;
72}
73
74#[derive(Debug)]
75struct GrpcXsdBridgeRpc;
76
77#[async_trait]
78impl XsdBridgeRpc for GrpcXsdBridgeRpc {
79    async fn register_schema(
80        &self,
81        channel: Channel,
82        request: RegisterSchemaRequest,
83    ) -> Result<RegisterSchemaResponse, ValidatorError> {
84        let mut client = proto::xsd_validator_client::XsdValidatorClient::new(channel);
85        let response = client.register_schema(request).await.map_err(|e| {
86            ValidatorError::Transport(format!("xml-bridge register_schema RPC failed: {e}"))
87        })?;
88        Ok(response.into_inner())
89    }
90
91    async fn validate_with(
92        &self,
93        channel: Channel,
94        request: ValidateWithRequest,
95    ) -> Result<ValidateResponse, ValidatorError> {
96        let mut client = proto::xsd_validator_client::XsdValidatorClient::new(channel);
97        let response = client.validate_with(request).await.map_err(|e| {
98            ValidatorError::Transport(format!("xml-bridge validate_with RPC failed: {e}"))
99        })?;
100        Ok(response.into_inner())
101    }
102}
103
104#[derive(Clone)]
105pub struct XsdBridgeBackend {
106    channel: Arc<RwLock<Option<Channel>>>,
107    schemas: Arc<DashMap<SchemaId, Vec<u8>>>,
108    slot: Arc<XmlBridgeSlot>,
109    rpc: Arc<dyn XsdBridgeRpc>,
110    connect_fn: Arc<ConnectFn>,
111    start_lock: Arc<Mutex<()>>,
112    bridge_version: String,
113    bridge_cache_dir: std::path::PathBuf,
114    bridge_start_timeout_ms: u64,
115}
116
117impl std::fmt::Debug for XsdBridgeBackend {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        f.debug_struct("XsdBridgeBackend")
120            .field("bridge_version", &self.bridge_version)
121            .field("bridge_cache_dir", &self.bridge_cache_dir)
122            .finish()
123    }
124}
125
126impl XsdBridgeBackend {
127    pub fn new() -> Self {
128        let (state_tx, state_rx) = watch::channel(BridgeState::Stopped);
129        let slot = Arc::new(XmlBridgeSlot {
130            state_rx,
131            state_tx,
132            process: Arc::new(Mutex::new(None)),
133        });
134
135        Self {
136            channel: Arc::new(RwLock::new(None)),
137            schemas: Arc::new(DashMap::new()),
138            slot,
139            rpc: Arc::new(GrpcXsdBridgeRpc),
140            connect_fn: Arc::new(|port| Box::pin(connect_channel(port))),
141            start_lock: Arc::new(Mutex::new(())),
142            bridge_version: crate::BRIDGE_VERSION.to_string(),
143            bridge_cache_dir: default_cache_dir_for_spec(&XML_BRIDGE),
144            bridge_start_timeout_ms: 30_000,
145        }
146    }
147
148    #[cfg(test)]
149    fn for_test(rpc: Arc<dyn XsdBridgeRpc>, connect_fn: Arc<ConnectFn>, channel: Channel) -> Self {
150        let (state_tx, state_rx) = watch::channel(BridgeState::Ready {
151            channel: channel.clone(),
152        });
153        let slot = Arc::new(XmlBridgeSlot {
154            state_rx,
155            state_tx,
156            process: Arc::new(Mutex::new(None)),
157        });
158        Self {
159            channel: Arc::new(RwLock::new(Some(channel))),
160            schemas: Arc::new(DashMap::new()),
161            slot,
162            rpc,
163            connect_fn,
164            start_lock: Arc::new(Mutex::new(())),
165            bridge_version: crate::BRIDGE_VERSION.to_string(),
166            bridge_cache_dir: default_cache_dir_for_spec(&XML_BRIDGE),
167            bridge_start_timeout_ms: 30_000,
168        }
169    }
170
171    pub fn schema_id_for(xsd_bytes: &[u8]) -> SchemaId {
172        let mut hasher = Sha256::new();
173        hasher.update(xsd_bytes);
174        format!("xsd-{}", hex::encode(hasher.finalize()))
175    }
176
177    async fn ensure_bridge_ready(&self) -> Result<Channel, ValidatorError> {
178        if let Some(ch) = self.channel.read().await.clone() {
179            return Ok(ch);
180        }
181
182        let _guard = self.start_lock.lock().await;
183        if let Some(ch) = self.channel.read().await.clone() {
184            return Ok(ch);
185        }
186
187        let _ = self.slot.state_tx.send(BridgeState::Starting);
188        let (process, channel, port) = self.start_bridge_process().await?;
189        {
190            let mut process_guard = self.slot.process.lock().await;
191            *process_guard = Some(process);
192        }
193        {
194            let mut ch_guard = self.channel.write().await;
195            *ch_guard = Some(channel.clone());
196        }
197        let _ = self.slot.state_tx.send(BridgeState::Ready {
198            channel: channel.clone(),
199        });
200        self.on_reconnect(port).map_err(|e| {
201            ValidatorError::Transport(format!("xml-bridge reconnect handler failed: {e}"))
202        })?;
203
204        Ok(channel)
205    }
206
207    async fn restart_bridge(&self) -> Result<Channel, ValidatorError> {
208        let _guard = self.start_lock.lock().await;
209
210        let _ = self.slot.state_tx.send(BridgeState::Restarting {
211            attempt: 0,
212            next_at: Instant::now(),
213        });
214
215        let old_process = {
216            let mut process_guard = self.slot.process.lock().await;
217            process_guard.take()
218        };
219        if let Some(p) = old_process {
220            let _ = p.stop().await;
221        }
222
223        let (process, channel, port) = self.start_bridge_process().await?;
224        {
225            let mut process_guard = self.slot.process.lock().await;
226            *process_guard = Some(process);
227        }
228        {
229            let mut ch_guard = self.channel.write().await;
230            *ch_guard = Some(channel.clone());
231        }
232
233        self.on_reconnect(port).map_err(|e| {
234            ValidatorError::Transport(format!("xml-bridge reconnect handler failed: {e}"))
235        })?;
236        let _ = self.slot.state_tx.send(BridgeState::Ready {
237            channel: channel.clone(),
238        });
239
240        Ok(channel)
241    }
242
243    async fn start_bridge_process(&self) -> Result<(BridgeProcess, Channel, u16), ValidatorError> {
244        let binary_path =
245            ensure_binary_for_spec(&XML_BRIDGE, &self.bridge_version, &self.bridge_cache_dir)
246                .await
247                .map_err(|e| {
248                    ValidatorError::endpoint(format!("XML bridge binary unavailable: {e}"))
249                })?;
250
251        let process_config = BridgeProcessConfig::xml(binary_path, self.bridge_start_timeout_ms);
252        let process = BridgeProcess::start(&process_config)
253            .await
254            .map_err(|e| ValidatorError::endpoint(format!("XML bridge start failed: {e}")))?;
255        let port = process.grpc_port();
256        let channel = (self.connect_fn)(port).await.map_err(|e| {
257            ValidatorError::endpoint(format!("XML bridge channel connect failed: {e}"))
258        })?;
259
260        wait_for_health(&channel, Duration::from_secs(10), |ch| {
261            let mut client = proto::health_client::HealthClient::new(ch);
262            async move {
263                let resp = client.check(HealthCheckRequest {}).await?;
264                Ok(resp.into_inner().status == "SERVING")
265            }
266        })
267        .await
268        .map_err(|e| ValidatorError::endpoint(format!("XML bridge health check failed: {e}")))?;
269
270        Ok((process, channel, port))
271    }
272
273    fn is_transport_error(msg: &str) -> bool {
274        msg.contains(&Code::Unavailable.to_string())
275            || msg.contains(&Code::Unknown.to_string())
276            || msg.contains("transport")
277    }
278
279    async fn register_with_channel(
280        &self,
281        channel: Channel,
282        schema_id: SchemaId,
283        xsd_bytes: Vec<u8>,
284    ) -> Result<SchemaId, ValidatorError> {
285        let response = self
286            .rpc
287            .register_schema(
288                channel,
289                RegisterSchemaRequest {
290                    schema_id: schema_id.clone(),
291                    schema: xsd_bytes.clone(),
292                },
293            )
294            .await?;
295
296        if let Some(err) = response.error {
297            return Err(ValidatorError::from_bridge_error(&err));
298        }
299
300        self.schemas.insert(schema_id.clone(), xsd_bytes);
301        Ok(schema_id)
302    }
303}
304
305impl BridgeReconnectHandler for XsdBridgeBackend {
306    fn on_reconnect(&self, _port: u16) -> Result<(), BridgeError> {
307        let this = self.clone();
308        tokio::spawn(async move {
309            let Some(channel) = this.channel.read().await.clone() else {
310                return;
311            };
312
313            let schemas: Vec<(SchemaId, Vec<u8>)> = this
314                .schemas
315                .iter()
316                .map(|entry| (entry.key().clone(), entry.value().clone()))
317                .collect();
318
319            for (schema_id, schema_bytes) in schemas {
320                if let Err(e) = this
321                    .rpc
322                    .register_schema(
323                        channel.clone(),
324                        RegisterSchemaRequest {
325                            schema_id: schema_id.clone(),
326                            schema: schema_bytes,
327                        },
328                    )
329                    .await
330                {
331                    warn!(schema_id = %schema_id, error = %e, "re-seed schema failed after reconnect");
332                }
333            }
334        });
335        Ok(())
336    }
337}
338
339#[async_trait]
340impl XsdBridge for XsdBridgeBackend {
341    async fn register(&self, xsd_bytes: Vec<u8>) -> Result<SchemaId, ValidatorError> {
342        let schema_id = Self::schema_id_for(&xsd_bytes);
343        if self.schemas.contains_key(&schema_id) {
344            return Ok(schema_id);
345        }
346
347        let channel = self.ensure_bridge_ready().await?;
348        match self
349            .register_with_channel(channel.clone(), schema_id.clone(), xsd_bytes.clone())
350            .await
351        {
352            Ok(id) => Ok(id),
353            Err(e) if Self::is_transport_error(&e.to_string()) => {
354                let restarted = self.restart_bridge().await?;
355                self.register_with_channel(restarted, schema_id, xsd_bytes)
356                    .await
357            }
358            Err(e) => Err(e),
359        }
360    }
361
362    async fn validate(&self, schema_id: &str, doc_bytes: Vec<u8>) -> Result<(), ValidatorError> {
363        let channel = self.ensure_bridge_ready().await?;
364        let req = ValidateWithRequest {
365            schema_id: schema_id.to_string(),
366            document: doc_bytes.clone(),
367        };
368
369        let response = match self.rpc.validate_with(channel.clone(), req).await {
370            Ok(resp) => resp,
371            Err(e) if Self::is_transport_error(&e.to_string()) => {
372                let restarted = self.restart_bridge().await?;
373                self.rpc
374                    .validate_with(
375                        restarted,
376                        ValidateWithRequest {
377                            schema_id: schema_id.to_string(),
378                            document: doc_bytes,
379                        },
380                    )
381                    .await?
382            }
383            Err(e) => return Err(e),
384        };
385
386        if let Some(err) = response.error {
387            return Err(ValidatorError::from_bridge_error(&err));
388        }
389        if response.valid {
390            return Ok(());
391        }
392
393        let details = response
394            .errors
395            .iter()
396            .map(|e| format!("{}:{} {}", e.line, e.column, e.message))
397            .collect::<Vec<_>>()
398            .join("\n");
399        Err(ValidatorError::validation(format!(
400            "XSD validation failed:\n{details}"
401        )))
402    }
403}
404
405impl Default for XsdBridgeBackend {
406    fn default() -> Self {
407        Self::new()
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414    use std::sync::atomic::{AtomicUsize, Ordering};
415    use tonic::transport::Endpoint;
416
417    #[derive(Debug)]
418    struct MockRpc {
419        register_calls: Arc<AtomicUsize>,
420        validate_ok: bool,
421    }
422
423    #[async_trait]
424    impl XsdBridgeRpc for MockRpc {
425        async fn register_schema(
426            &self,
427            _channel: Channel,
428            request: RegisterSchemaRequest,
429        ) -> Result<RegisterSchemaResponse, ValidatorError> {
430            self.register_calls.fetch_add(1, Ordering::SeqCst);
431            Ok(RegisterSchemaResponse {
432                schema_id: request.schema_id,
433                error: None,
434            })
435        }
436
437        async fn validate_with(
438            &self,
439            _channel: Channel,
440            _request: ValidateWithRequest,
441        ) -> Result<ValidateResponse, ValidatorError> {
442            Ok(ValidateResponse {
443                valid: self.validate_ok,
444                errors: Vec::new(),
445                error: None,
446            })
447        }
448    }
449
450    fn lazy_channel() -> Channel {
451        Endpoint::from_static("http://127.0.0.1:65535").connect_lazy()
452    }
453
454    #[tokio::test]
455    async fn xsd_bridge_reconnect_reseeds() {
456        let calls = Arc::new(AtomicUsize::new(0));
457        let rpc = Arc::new(MockRpc {
458            register_calls: Arc::clone(&calls),
459            validate_ok: true,
460        });
461        let connector =
462            Arc::new(|_port| Box::pin(async move { Ok(lazy_channel()) }) as ConnectFuture);
463        let backend = XsdBridgeBackend::for_test(rpc, connector, lazy_channel());
464
465        let _id_a = backend.register(b"<xsd:a/>".to_vec()).await.unwrap();
466        let _id_b = backend.register(b"<xsd:b/>".to_vec()).await.unwrap();
467
468        backend.on_reconnect(50051).unwrap();
469        tokio::time::sleep(Duration::from_millis(20)).await;
470
471        assert!(calls.load(Ordering::SeqCst) >= 4);
472    }
473}