1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::time::{Duration, Instant};
6
7use async_trait::async_trait;
8use camel_bridge::channel::connect_channel;
9use camel_bridge::download::{default_cache_dir_for_spec, ensure_binary_for_spec};
10use camel_bridge::health::wait_for_health;
11use camel_bridge::process::{BridgeError, BridgeProcess, BridgeProcessConfig};
12use camel_bridge::reconnect::BridgeReconnectHandler;
13use camel_bridge::spec::XML_BRIDGE;
14use dashmap::DashMap;
15use sha2::{Digest, Sha256};
16use tokio::sync::{Mutex, RwLock, watch};
17use tonic::Code;
18use tonic::transport::Channel;
19use tracing::error;
20
21use crate::error::ValidatorError;
22use crate::proto;
23use crate::proto::{
24 HealthCheckRequest, RegisterSchemaRequest, RegisterSchemaResponse, ValidateResponse,
25 ValidateWithRequest,
26};
27
28pub type SchemaId = String;
29
30#[derive(Debug, Clone)]
31pub enum BridgeState {
32 Starting,
33 Ready { channel: Channel },
34 Degraded(String),
35 Restarting { attempt: u32, next_at: Instant },
36 Stopped,
37}
38
39pub struct XmlBridgeSlot {
40 pub state_rx: watch::Receiver<BridgeState>,
41 pub(crate) state_tx: watch::Sender<BridgeState>,
42 pub process: Arc<Mutex<Option<BridgeProcess>>>,
43}
44
45impl std::fmt::Debug for XmlBridgeSlot {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 f.debug_struct("XmlBridgeSlot").finish()
48 }
49}
50
51type ConnectFuture = Pin<Box<dyn Future<Output = Result<Channel, BridgeError>> + Send>>;
52type ConnectFn = dyn Fn(u16) -> ConnectFuture + Send + Sync;
53
54#[async_trait]
55pub trait XsdBridge: Send + Sync {
56 async fn register(&self, xsd_bytes: Vec<u8>) -> Result<SchemaId, ValidatorError>;
57 async fn validate(&self, schema_id: &str, doc_bytes: Vec<u8>) -> Result<(), ValidatorError>;
58}
59
60#[async_trait]
61trait XsdBridgeRpc: Send + Sync {
62 async fn register_schema(
63 &self,
64 channel: Channel,
65 request: RegisterSchemaRequest,
66 ) -> Result<RegisterSchemaResponse, ValidatorError>;
67
68 async fn validate_with(
69 &self,
70 channel: Channel,
71 request: ValidateWithRequest,
72 ) -> Result<ValidateResponse, ValidatorError>;
73}
74
75#[derive(Debug)]
76struct GrpcXsdBridgeRpc;
77
78#[async_trait]
79impl XsdBridgeRpc for GrpcXsdBridgeRpc {
80 async fn register_schema(
81 &self,
82 channel: Channel,
83 request: RegisterSchemaRequest,
84 ) -> Result<RegisterSchemaResponse, ValidatorError> {
85 let mut client = proto::xsd_validator_client::XsdValidatorClient::new(channel);
86 let response = client.register_schema(request).await.map_err(|e| {
87 ValidatorError::transport_with_source("xml-bridge register_schema RPC failed", e)
88 })?;
89 Ok(response.into_inner())
90 }
91
92 async fn validate_with(
93 &self,
94 channel: Channel,
95 request: ValidateWithRequest,
96 ) -> Result<ValidateResponse, ValidatorError> {
97 let mut client = proto::xsd_validator_client::XsdValidatorClient::new(channel);
98 let response = client.validate_with(request).await.map_err(|e| {
99 ValidatorError::transport_with_source("xml-bridge validate_with RPC failed", e)
100 })?;
101 Ok(response.into_inner())
102 }
103}
104
105#[derive(Clone)]
106pub struct XsdBridgeBackend {
107 channel: Arc<RwLock<Option<Channel>>>,
108 schemas: Arc<DashMap<SchemaId, Vec<u8>>>,
109 schema_cache_max_entries: Arc<AtomicUsize>,
110 slot: Arc<XmlBridgeSlot>,
111 rpc: Arc<dyn XsdBridgeRpc>,
112 connect_fn: Arc<ConnectFn>,
113 start_lock: Arc<Mutex<()>>,
114 bridge_version: String,
115 bridge_cache_dir: std::path::PathBuf,
116 bridge_start_timeout_ms: u64,
117}
118
119impl std::fmt::Debug for XsdBridgeBackend {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 f.debug_struct("XsdBridgeBackend")
122 .field("bridge_version", &self.bridge_version)
123 .field("bridge_cache_dir", &self.bridge_cache_dir)
124 .finish()
125 }
126}
127
128impl XsdBridgeBackend {
129 pub fn new() -> Self {
130 let (state_tx, state_rx) = watch::channel(BridgeState::Stopped);
131 let slot = Arc::new(XmlBridgeSlot {
132 state_rx,
133 state_tx,
134 process: Arc::new(Mutex::new(None)),
135 });
136
137 Self {
138 channel: Arc::new(RwLock::new(None)),
139 schemas: Arc::new(DashMap::new()),
140 schema_cache_max_entries: Arc::new(AtomicUsize::new(
141 crate::config::DEFAULT_SCHEMA_CACHE_MAX_ENTRIES,
142 )),
143 slot,
144 rpc: Arc::new(GrpcXsdBridgeRpc),
145 connect_fn: Arc::new(|port| Box::pin(connect_channel(port))),
146 start_lock: Arc::new(Mutex::new(())),
147 bridge_version: crate::BRIDGE_VERSION.to_string(),
148 bridge_cache_dir: default_cache_dir_for_spec(&XML_BRIDGE),
149 bridge_start_timeout_ms: 30_000,
150 }
151 }
152
153 #[cfg(test)]
154 fn for_test(rpc: Arc<dyn XsdBridgeRpc>, connect_fn: Arc<ConnectFn>, channel: Channel) -> Self {
155 let (state_tx, state_rx) = watch::channel(BridgeState::Ready {
156 channel: channel.clone(),
157 });
158 let slot = Arc::new(XmlBridgeSlot {
159 state_rx,
160 state_tx,
161 process: Arc::new(Mutex::new(None)),
162 });
163 Self {
164 channel: Arc::new(RwLock::new(Some(channel))),
165 schemas: Arc::new(DashMap::new()),
166 schema_cache_max_entries: Arc::new(AtomicUsize::new(
167 crate::config::DEFAULT_SCHEMA_CACHE_MAX_ENTRIES,
168 )),
169 slot,
170 rpc,
171 connect_fn,
172 start_lock: Arc::new(Mutex::new(())),
173 bridge_version: crate::BRIDGE_VERSION.to_string(),
174 bridge_cache_dir: default_cache_dir_for_spec(&XML_BRIDGE),
175 bridge_start_timeout_ms: 30_000,
176 }
177 }
178
179 pub fn schema_id_for(xsd_bytes: &[u8]) -> SchemaId {
180 let mut hasher = Sha256::new();
181 hasher.update(xsd_bytes);
182 format!("xsd-{}", hex::encode(hasher.finalize()))
183 }
184
185 pub fn set_schema_cache_max_entries(&self, max_entries: usize) {
188 if self.schemas.len() > max_entries {
189 self.schemas.clear();
190 }
191 self.schema_cache_max_entries
192 .store(max_entries, Ordering::Relaxed);
193 }
194
195 async fn ensure_bridge_ready(&self) -> Result<Channel, ValidatorError> {
196 if let Some(ch) = self.channel.read().await.clone() {
197 return Ok(ch);
198 }
199
200 let _guard = self.start_lock.lock().await;
201 if let Some(ch) = self.channel.read().await.clone() {
202 return Ok(ch);
203 }
204
205 let _ = self.slot.state_tx.send(BridgeState::Starting);
206 let (process, channel, port) = self.start_bridge_process().await?;
207 {
208 let mut process_guard = self.slot.process.lock().await;
209 *process_guard = Some(process);
210 }
211 {
212 let mut ch_guard = self.channel.write().await;
213 *ch_guard = Some(channel.clone());
214 }
215 let _ = self.slot.state_tx.send(BridgeState::Ready {
216 channel: channel.clone(),
217 });
218 self.on_reconnect(port).map_err(|e| {
219 ValidatorError::transport_with_source("xml-bridge reconnect handler failed", e)
220 })?;
221
222 Ok(channel)
223 }
224
225 async fn restart_bridge(&self) -> Result<Channel, ValidatorError> {
226 let _guard = self.start_lock.lock().await;
227
228 let _ = self.slot.state_tx.send(BridgeState::Restarting {
229 attempt: 0,
230 next_at: Instant::now(),
231 });
232
233 let old_process = {
234 let mut process_guard = self.slot.process.lock().await;
235 process_guard.take()
236 };
237 if let Some(p) = old_process {
238 let _ = p.stop().await;
239 }
240
241 let (process, channel, port) = self.start_bridge_process().await?;
242 {
243 let mut process_guard = self.slot.process.lock().await;
244 *process_guard = Some(process);
245 }
246 {
247 let mut ch_guard = self.channel.write().await;
248 *ch_guard = Some(channel.clone());
249 }
250
251 self.on_reconnect(port).map_err(|e| {
252 ValidatorError::transport_with_source("xml-bridge reconnect handler failed", e)
253 })?;
254 let _ = self.slot.state_tx.send(BridgeState::Ready {
255 channel: channel.clone(),
256 });
257
258 Ok(channel)
259 }
260
261 async fn start_bridge_process(&self) -> Result<(BridgeProcess, Channel, u16), ValidatorError> {
262 let binary_path =
263 ensure_binary_for_spec(&XML_BRIDGE, &self.bridge_version, &self.bridge_cache_dir)
264 .await
265 .map_err(|e| {
266 ValidatorError::endpoint(format!("XML bridge binary unavailable: {e}"))
267 })?;
268
269 let process_config = BridgeProcessConfig::xml(binary_path, self.bridge_start_timeout_ms);
270 let process = BridgeProcess::start(&process_config)
271 .await
272 .map_err(|e| ValidatorError::endpoint(format!("XML bridge start failed: {e}")))?;
273 let port = process.grpc_port();
274 let channel = (self.connect_fn)(port).await.map_err(|e| {
275 ValidatorError::endpoint(format!("XML bridge channel connect failed: {e}"))
276 })?;
277
278 wait_for_health(&channel, Duration::from_secs(10), |ch| {
279 let mut client = proto::health_client::HealthClient::new(ch);
280 async move {
281 let resp = client.check(HealthCheckRequest {}).await?;
282 Ok(resp.into_inner().status == "SERVING")
283 }
284 })
285 .await
286 .map_err(|e| ValidatorError::endpoint(format!("XML bridge health check failed: {e}")))?;
287
288 Ok((process, channel, port))
289 }
290
291 fn is_transport_error(msg: &str) -> bool {
292 msg.contains(&Code::Unavailable.to_string())
293 || msg.contains(&Code::Unknown.to_string())
294 || msg.contains("transport")
295 }
296
297 pub async fn shutdown(&self) {
298 let mut guard = self.slot.process.lock().await;
299 if let Some(p) = guard.take()
300 && let Err(e) = p.stop().await
301 {
302 tracing::warn!("Failed to stop XSD bridge process: {}", e);
303 }
304 }
305
306 async fn register_with_channel(
307 &self,
308 channel: Channel,
309 schema_id: SchemaId,
310 xsd_bytes: Vec<u8>,
311 ) -> Result<SchemaId, ValidatorError> {
312 let response = self
313 .rpc
314 .register_schema(
315 channel,
316 RegisterSchemaRequest {
317 schema_id: schema_id.clone(),
318 schema: xsd_bytes.clone(),
319 },
320 )
321 .await?;
322
323 if let Some(err) = response.error {
324 return Err(ValidatorError::from_bridge_error(&err));
325 }
326
327 let max_entries = self.schema_cache_max_entries.load(Ordering::Relaxed);
331 if self.schemas.len() >= max_entries && !self.schemas.contains_key(&schema_id) {
332 self.schemas.clear();
333 }
334
335 self.schemas.insert(schema_id.clone(), xsd_bytes);
336 Ok(schema_id)
337 }
338}
339
340impl BridgeReconnectHandler for XsdBridgeBackend {
341 fn on_reconnect(&self, _port: u16) -> Result<(), BridgeError> {
342 let this = self.clone();
343 tokio::spawn(async move {
344 let Some(channel) = this.channel.read().await.clone() else {
345 return;
346 };
347
348 let schemas: Vec<(SchemaId, Vec<u8>)> = this
349 .schemas
350 .iter()
351 .map(|entry| (entry.key().clone(), entry.value().clone()))
352 .collect();
353
354 for (schema_id, schema_bytes) in schemas {
355 if let Err(e) = this
356 .rpc
357 .register_schema(
358 channel.clone(),
359 RegisterSchemaRequest {
360 schema_id: schema_id.clone(),
361 schema: schema_bytes,
362 },
363 )
364 .await
365 {
366 error!(schema_id = %schema_id, error = %e, "re-seed schema failed after reconnect");
367 }
368 }
369 });
370 Ok(())
371 }
372}
373
374#[async_trait]
375impl XsdBridge for XsdBridgeBackend {
376 async fn register(&self, xsd_bytes: Vec<u8>) -> Result<SchemaId, ValidatorError> {
377 let schema_id = Self::schema_id_for(&xsd_bytes);
378 if self.schemas.contains_key(&schema_id) {
379 return Ok(schema_id);
380 }
381
382 let channel = self.ensure_bridge_ready().await?;
383 match self
384 .register_with_channel(channel.clone(), schema_id.clone(), xsd_bytes.clone())
385 .await
386 {
387 Ok(id) => Ok(id),
388 Err(e) if Self::is_transport_error(&e.to_string()) => {
389 let restarted = self.restart_bridge().await?;
390 self.register_with_channel(restarted, schema_id, xsd_bytes)
391 .await
392 }
393 Err(e) => Err(e),
394 }
395 }
396
397 async fn validate(&self, schema_id: &str, doc_bytes: Vec<u8>) -> Result<(), ValidatorError> {
398 let channel = self.ensure_bridge_ready().await?;
399 let req = ValidateWithRequest {
400 schema_id: schema_id.to_string(),
401 document: doc_bytes.clone(),
402 };
403
404 let response = match self.rpc.validate_with(channel.clone(), req).await {
405 Ok(resp) => resp,
406 Err(e) if Self::is_transport_error(&e.to_string()) => {
407 let restarted = self.restart_bridge().await?;
408 self.rpc
409 .validate_with(
410 restarted,
411 ValidateWithRequest {
412 schema_id: schema_id.to_string(),
413 document: doc_bytes,
414 },
415 )
416 .await?
417 }
418 Err(e) => return Err(e),
419 };
420
421 if let Some(err) = response.error {
422 return Err(ValidatorError::from_bridge_error(&err));
423 }
424 if response.valid {
425 return Ok(());
426 }
427
428 let details = response
429 .errors
430 .iter()
431 .map(|e| format!("{}:{} {}", e.line, e.column, e.message))
432 .collect::<Vec<_>>()
433 .join("\n");
434 Err(ValidatorError::validation(format!(
435 "XSD validation failed:\n{details}"
436 )))
437 }
438}
439
440impl Default for XsdBridgeBackend {
441 fn default() -> Self {
442 Self::new()
443 }
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449 use std::sync::atomic::{AtomicUsize, Ordering};
450 use tonic::transport::Endpoint;
451
452 #[derive(Debug)]
453 struct MockRpc {
454 register_calls: Arc<AtomicUsize>,
455 validate_ok: bool,
456 }
457
458 #[async_trait]
459 impl XsdBridgeRpc for MockRpc {
460 async fn register_schema(
461 &self,
462 _channel: Channel,
463 request: RegisterSchemaRequest,
464 ) -> Result<RegisterSchemaResponse, ValidatorError> {
465 self.register_calls.fetch_add(1, Ordering::SeqCst);
466 Ok(RegisterSchemaResponse {
467 schema_id: request.schema_id,
468 error: None,
469 })
470 }
471
472 async fn validate_with(
473 &self,
474 _channel: Channel,
475 _request: ValidateWithRequest,
476 ) -> Result<ValidateResponse, ValidatorError> {
477 Ok(ValidateResponse {
478 valid: self.validate_ok,
479 errors: Vec::new(),
480 error: None,
481 })
482 }
483 }
484
485 fn lazy_channel() -> Channel {
486 Endpoint::from_static("http://127.0.0.1:65535").connect_lazy()
487 }
488
489 #[tokio::test]
490 async fn xsd_bridge_reconnect_reseeds() {
491 let calls = Arc::new(AtomicUsize::new(0));
492 let rpc = Arc::new(MockRpc {
493 register_calls: Arc::clone(&calls),
494 validate_ok: true,
495 });
496 let connector =
497 Arc::new(|_port| Box::pin(async move { Ok(lazy_channel()) }) as ConnectFuture);
498 let backend = XsdBridgeBackend::for_test(rpc, connector, lazy_channel());
499
500 let _id_a = backend.register(b"<xsd:a/>".to_vec()).await.unwrap();
501 let _id_b = backend.register(b"<xsd:b/>".to_vec()).await.unwrap();
502
503 backend.on_reconnect(50051).unwrap();
504 tokio::time::sleep(Duration::from_millis(20)).await;
505
506 assert!(calls.load(Ordering::SeqCst) >= 4);
507 }
508}