1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use async_trait::async_trait;
7use camel_bridge::channel::connect_channel;
8use camel_bridge::download::{default_cache_dir_for_spec, ensure_binary_for_spec};
9use camel_bridge::health::wait_for_health;
10use camel_bridge::process::{BridgeError, BridgeProcess, BridgeProcessConfig};
11use camel_bridge::reconnect::BridgeReconnectHandler;
12use camel_bridge::spec::XML_BRIDGE;
13use dashmap::DashMap;
14use sha2::{Digest, Sha256};
15use tokio::sync::{Mutex, RwLock, watch};
16use tonic::Code;
17use tonic::transport::Channel;
18use tracing::warn;
19
20use crate::error::ValidatorError;
21use crate::proto;
22use crate::proto::{
23 HealthCheckRequest, RegisterSchemaRequest, RegisterSchemaResponse, ValidateResponse,
24 ValidateWithRequest,
25};
26
27pub type SchemaId = String;
28
29#[derive(Debug, Clone)]
30pub enum BridgeState {
31 Starting,
32 Ready { channel: Channel },
33 Degraded(String),
34 Restarting { attempt: u32, next_at: Instant },
35 Stopped,
36}
37
38pub struct XmlBridgeSlot {
39 pub state_rx: watch::Receiver<BridgeState>,
40 pub(crate) state_tx: watch::Sender<BridgeState>,
41 pub process: Arc<Mutex<Option<BridgeProcess>>>,
42}
43
44impl std::fmt::Debug for XmlBridgeSlot {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 f.debug_struct("XmlBridgeSlot").finish()
47 }
48}
49
50type ConnectFuture = Pin<Box<dyn Future<Output = Result<Channel, BridgeError>> + Send>>;
51type ConnectFn = dyn Fn(u16) -> ConnectFuture + Send + Sync;
52
53#[async_trait]
54pub trait XsdBridge: Send + Sync {
55 async fn register(&self, xsd_bytes: Vec<u8>) -> Result<SchemaId, ValidatorError>;
56 async fn validate(&self, schema_id: &str, doc_bytes: Vec<u8>) -> Result<(), ValidatorError>;
57}
58
59#[async_trait]
60trait XsdBridgeRpc: Send + Sync {
61 async fn register_schema(
62 &self,
63 channel: Channel,
64 request: RegisterSchemaRequest,
65 ) -> Result<RegisterSchemaResponse, ValidatorError>;
66
67 async fn validate_with(
68 &self,
69 channel: Channel,
70 request: ValidateWithRequest,
71 ) -> Result<ValidateResponse, ValidatorError>;
72}
73
74#[derive(Debug)]
75struct GrpcXsdBridgeRpc;
76
77#[async_trait]
78impl XsdBridgeRpc for GrpcXsdBridgeRpc {
79 async fn register_schema(
80 &self,
81 channel: Channel,
82 request: RegisterSchemaRequest,
83 ) -> Result<RegisterSchemaResponse, ValidatorError> {
84 let mut client = proto::xsd_validator_client::XsdValidatorClient::new(channel);
85 let response = client.register_schema(request).await.map_err(|e| {
86 ValidatorError::Transport(format!("xml-bridge register_schema RPC failed: {e}"))
87 })?;
88 Ok(response.into_inner())
89 }
90
91 async fn validate_with(
92 &self,
93 channel: Channel,
94 request: ValidateWithRequest,
95 ) -> Result<ValidateResponse, ValidatorError> {
96 let mut client = proto::xsd_validator_client::XsdValidatorClient::new(channel);
97 let response = client.validate_with(request).await.map_err(|e| {
98 ValidatorError::Transport(format!("xml-bridge validate_with RPC failed: {e}"))
99 })?;
100 Ok(response.into_inner())
101 }
102}
103
104#[derive(Clone)]
105pub struct XsdBridgeBackend {
106 channel: Arc<RwLock<Option<Channel>>>,
107 schemas: Arc<DashMap<SchemaId, Vec<u8>>>,
108 slot: Arc<XmlBridgeSlot>,
109 rpc: Arc<dyn XsdBridgeRpc>,
110 connect_fn: Arc<ConnectFn>,
111 start_lock: Arc<Mutex<()>>,
112 bridge_version: String,
113 bridge_cache_dir: std::path::PathBuf,
114 bridge_start_timeout_ms: u64,
115}
116
117impl std::fmt::Debug for XsdBridgeBackend {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 f.debug_struct("XsdBridgeBackend")
120 .field("bridge_version", &self.bridge_version)
121 .field("bridge_cache_dir", &self.bridge_cache_dir)
122 .finish()
123 }
124}
125
126impl XsdBridgeBackend {
127 pub fn new() -> Self {
128 let (state_tx, state_rx) = watch::channel(BridgeState::Stopped);
129 let slot = Arc::new(XmlBridgeSlot {
130 state_rx,
131 state_tx,
132 process: Arc::new(Mutex::new(None)),
133 });
134
135 Self {
136 channel: Arc::new(RwLock::new(None)),
137 schemas: Arc::new(DashMap::new()),
138 slot,
139 rpc: Arc::new(GrpcXsdBridgeRpc),
140 connect_fn: Arc::new(|port| Box::pin(connect_channel(port))),
141 start_lock: Arc::new(Mutex::new(())),
142 bridge_version: crate::BRIDGE_VERSION.to_string(),
143 bridge_cache_dir: default_cache_dir_for_spec(&XML_BRIDGE),
144 bridge_start_timeout_ms: 30_000,
145 }
146 }
147
148 #[cfg(test)]
149 fn for_test(rpc: Arc<dyn XsdBridgeRpc>, connect_fn: Arc<ConnectFn>, channel: Channel) -> Self {
150 let (state_tx, state_rx) = watch::channel(BridgeState::Ready {
151 channel: channel.clone(),
152 });
153 let slot = Arc::new(XmlBridgeSlot {
154 state_rx,
155 state_tx,
156 process: Arc::new(Mutex::new(None)),
157 });
158 Self {
159 channel: Arc::new(RwLock::new(Some(channel))),
160 schemas: Arc::new(DashMap::new()),
161 slot,
162 rpc,
163 connect_fn,
164 start_lock: Arc::new(Mutex::new(())),
165 bridge_version: crate::BRIDGE_VERSION.to_string(),
166 bridge_cache_dir: default_cache_dir_for_spec(&XML_BRIDGE),
167 bridge_start_timeout_ms: 30_000,
168 }
169 }
170
171 pub fn schema_id_for(xsd_bytes: &[u8]) -> SchemaId {
172 let mut hasher = Sha256::new();
173 hasher.update(xsd_bytes);
174 format!("xsd-{}", hex::encode(hasher.finalize()))
175 }
176
177 async fn ensure_bridge_ready(&self) -> Result<Channel, ValidatorError> {
178 if let Some(ch) = self.channel.read().await.clone() {
179 return Ok(ch);
180 }
181
182 let _guard = self.start_lock.lock().await;
183 if let Some(ch) = self.channel.read().await.clone() {
184 return Ok(ch);
185 }
186
187 let _ = self.slot.state_tx.send(BridgeState::Starting);
188 let (process, channel, port) = self.start_bridge_process().await?;
189 {
190 let mut process_guard = self.slot.process.lock().await;
191 *process_guard = Some(process);
192 }
193 {
194 let mut ch_guard = self.channel.write().await;
195 *ch_guard = Some(channel.clone());
196 }
197 let _ = self.slot.state_tx.send(BridgeState::Ready {
198 channel: channel.clone(),
199 });
200 self.on_reconnect(port).map_err(|e| {
201 ValidatorError::Transport(format!("xml-bridge reconnect handler failed: {e}"))
202 })?;
203
204 Ok(channel)
205 }
206
207 async fn restart_bridge(&self) -> Result<Channel, ValidatorError> {
208 let _guard = self.start_lock.lock().await;
209
210 let _ = self.slot.state_tx.send(BridgeState::Restarting {
211 attempt: 0,
212 next_at: Instant::now(),
213 });
214
215 let old_process = {
216 let mut process_guard = self.slot.process.lock().await;
217 process_guard.take()
218 };
219 if let Some(p) = old_process {
220 let _ = p.stop().await;
221 }
222
223 let (process, channel, port) = self.start_bridge_process().await?;
224 {
225 let mut process_guard = self.slot.process.lock().await;
226 *process_guard = Some(process);
227 }
228 {
229 let mut ch_guard = self.channel.write().await;
230 *ch_guard = Some(channel.clone());
231 }
232
233 self.on_reconnect(port).map_err(|e| {
234 ValidatorError::Transport(format!("xml-bridge reconnect handler failed: {e}"))
235 })?;
236 let _ = self.slot.state_tx.send(BridgeState::Ready {
237 channel: channel.clone(),
238 });
239
240 Ok(channel)
241 }
242
243 async fn start_bridge_process(&self) -> Result<(BridgeProcess, Channel, u16), ValidatorError> {
244 let binary_path =
245 ensure_binary_for_spec(&XML_BRIDGE, &self.bridge_version, &self.bridge_cache_dir)
246 .await
247 .map_err(|e| {
248 ValidatorError::endpoint(format!("XML bridge binary unavailable: {e}"))
249 })?;
250
251 let process_config = BridgeProcessConfig::xml(binary_path, self.bridge_start_timeout_ms);
252 let process = BridgeProcess::start(&process_config)
253 .await
254 .map_err(|e| ValidatorError::endpoint(format!("XML bridge start failed: {e}")))?;
255 let port = process.grpc_port();
256 let channel = (self.connect_fn)(port).await.map_err(|e| {
257 ValidatorError::endpoint(format!("XML bridge channel connect failed: {e}"))
258 })?;
259
260 wait_for_health(&channel, Duration::from_secs(10), |ch| {
261 let mut client = proto::health_client::HealthClient::new(ch);
262 async move {
263 let resp = client.check(HealthCheckRequest {}).await?;
264 Ok(resp.into_inner().status == "SERVING")
265 }
266 })
267 .await
268 .map_err(|e| ValidatorError::endpoint(format!("XML bridge health check failed: {e}")))?;
269
270 Ok((process, channel, port))
271 }
272
273 fn is_transport_error(msg: &str) -> bool {
274 msg.contains(&Code::Unavailable.to_string())
275 || msg.contains(&Code::Unknown.to_string())
276 || msg.contains("transport")
277 }
278
279 async fn register_with_channel(
280 &self,
281 channel: Channel,
282 schema_id: SchemaId,
283 xsd_bytes: Vec<u8>,
284 ) -> Result<SchemaId, ValidatorError> {
285 let response = self
286 .rpc
287 .register_schema(
288 channel,
289 RegisterSchemaRequest {
290 schema_id: schema_id.clone(),
291 schema: xsd_bytes.clone(),
292 },
293 )
294 .await?;
295
296 if let Some(err) = response.error {
297 return Err(ValidatorError::from_bridge_error(&err));
298 }
299
300 self.schemas.insert(schema_id.clone(), xsd_bytes);
301 Ok(schema_id)
302 }
303}
304
305impl BridgeReconnectHandler for XsdBridgeBackend {
306 fn on_reconnect(&self, _port: u16) -> Result<(), BridgeError> {
307 let this = self.clone();
308 tokio::spawn(async move {
309 let Some(channel) = this.channel.read().await.clone() else {
310 return;
311 };
312
313 let schemas: Vec<(SchemaId, Vec<u8>)> = this
314 .schemas
315 .iter()
316 .map(|entry| (entry.key().clone(), entry.value().clone()))
317 .collect();
318
319 for (schema_id, schema_bytes) in schemas {
320 if let Err(e) = this
321 .rpc
322 .register_schema(
323 channel.clone(),
324 RegisterSchemaRequest {
325 schema_id: schema_id.clone(),
326 schema: schema_bytes,
327 },
328 )
329 .await
330 {
331 warn!(schema_id = %schema_id, error = %e, "re-seed schema failed after reconnect");
332 }
333 }
334 });
335 Ok(())
336 }
337}
338
339#[async_trait]
340impl XsdBridge for XsdBridgeBackend {
341 async fn register(&self, xsd_bytes: Vec<u8>) -> Result<SchemaId, ValidatorError> {
342 let schema_id = Self::schema_id_for(&xsd_bytes);
343 if self.schemas.contains_key(&schema_id) {
344 return Ok(schema_id);
345 }
346
347 let channel = self.ensure_bridge_ready().await?;
348 match self
349 .register_with_channel(channel.clone(), schema_id.clone(), xsd_bytes.clone())
350 .await
351 {
352 Ok(id) => Ok(id),
353 Err(e) if Self::is_transport_error(&e.to_string()) => {
354 let restarted = self.restart_bridge().await?;
355 self.register_with_channel(restarted, schema_id, xsd_bytes)
356 .await
357 }
358 Err(e) => Err(e),
359 }
360 }
361
362 async fn validate(&self, schema_id: &str, doc_bytes: Vec<u8>) -> Result<(), ValidatorError> {
363 let channel = self.ensure_bridge_ready().await?;
364 let req = ValidateWithRequest {
365 schema_id: schema_id.to_string(),
366 document: doc_bytes.clone(),
367 };
368
369 let response = match self.rpc.validate_with(channel.clone(), req).await {
370 Ok(resp) => resp,
371 Err(e) if Self::is_transport_error(&e.to_string()) => {
372 let restarted = self.restart_bridge().await?;
373 self.rpc
374 .validate_with(
375 restarted,
376 ValidateWithRequest {
377 schema_id: schema_id.to_string(),
378 document: doc_bytes,
379 },
380 )
381 .await?
382 }
383 Err(e) => return Err(e),
384 };
385
386 if let Some(err) = response.error {
387 return Err(ValidatorError::from_bridge_error(&err));
388 }
389 if response.valid {
390 return Ok(());
391 }
392
393 let details = response
394 .errors
395 .iter()
396 .map(|e| format!("{}:{} {}", e.line, e.column, e.message))
397 .collect::<Vec<_>>()
398 .join("\n");
399 Err(ValidatorError::validation(format!(
400 "XSD validation failed:\n{details}"
401 )))
402 }
403}
404
405impl Default for XsdBridgeBackend {
406 fn default() -> Self {
407 Self::new()
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414 use std::sync::atomic::{AtomicUsize, Ordering};
415 use tonic::transport::Endpoint;
416
417 #[derive(Debug)]
418 struct MockRpc {
419 register_calls: Arc<AtomicUsize>,
420 validate_ok: bool,
421 }
422
423 #[async_trait]
424 impl XsdBridgeRpc for MockRpc {
425 async fn register_schema(
426 &self,
427 _channel: Channel,
428 request: RegisterSchemaRequest,
429 ) -> Result<RegisterSchemaResponse, ValidatorError> {
430 self.register_calls.fetch_add(1, Ordering::SeqCst);
431 Ok(RegisterSchemaResponse {
432 schema_id: request.schema_id,
433 error: None,
434 })
435 }
436
437 async fn validate_with(
438 &self,
439 _channel: Channel,
440 _request: ValidateWithRequest,
441 ) -> Result<ValidateResponse, ValidatorError> {
442 Ok(ValidateResponse {
443 valid: self.validate_ok,
444 errors: Vec::new(),
445 error: None,
446 })
447 }
448 }
449
450 fn lazy_channel() -> Channel {
451 Endpoint::from_static("http://127.0.0.1:65535").connect_lazy()
452 }
453
454 #[tokio::test]
455 async fn xsd_bridge_reconnect_reseeds() {
456 let calls = Arc::new(AtomicUsize::new(0));
457 let rpc = Arc::new(MockRpc {
458 register_calls: Arc::clone(&calls),
459 validate_ok: true,
460 });
461 let connector =
462 Arc::new(|_port| Box::pin(async move { Ok(lazy_channel()) }) as ConnectFuture);
463 let backend = XsdBridgeBackend::for_test(rpc, connector, lazy_channel());
464
465 let _id_a = backend.register(b"<xsd:a/>".to_vec()).await.unwrap();
466 let _id_b = backend.register(b"<xsd:b/>".to_vec()).await.unwrap();
467
468 backend.on_reconnect(50051).unwrap();
469 tokio::time::sleep(Duration::from_millis(20)).await;
470
471 assert!(calls.load(Ordering::SeqCst) >= 4);
472 }
473}