1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::RwLock;
4
5use camel_api::datasource::{DatasourceCatalog, DatasourceConfig, DatasourceHandle, PoolFactory};
6use camel_api::error::CamelError;
7use camel_api::health::{AsyncHealthCheck, CheckResult};
8use dashmap::DashMap;
9use tokio::sync::OnceCell;
10
11use crate::health_registry::HealthCheckRegistry;
12
13type CacheKey = (String, String);
14
15pub struct RuntimeDatasourceCatalog {
16 configs: HashMap<String, DatasourceConfig>,
17 factories: RwLock<HashMap<String, Arc<dyn PoolFactory>>>,
18 pools: DashMap<CacheKey, Arc<OnceCell<DatasourceHandle>>>,
19 health_registry: Option<Arc<HealthCheckRegistry>>,
20}
21
22impl RuntimeDatasourceCatalog {
23 pub fn new(configs: HashMap<String, DatasourceConfig>) -> Self {
24 Self {
25 configs,
26 factories: RwLock::new(HashMap::new()),
27 pools: DashMap::new(),
28 health_registry: None,
29 }
30 }
31
32 pub fn with_health_registry(mut self, registry: Arc<HealthCheckRegistry>) -> Self {
33 self.health_registry = Some(registry);
34 self
35 }
36
37 fn resolve_factory(
38 &self,
39 config: &DatasourceConfig,
40 ) -> Result<Arc<dyn PoolFactory>, CamelError> {
41 let factories = self.factories.read().expect("factory lock poisoned"); if let Some(ref provider) = config.provider {
43 let factory = factories.get(provider).ok_or_else(|| {
44 CamelError::Config(format!("unknown datasource provider '{}'", provider))
45 })?;
46 return Ok(factory.clone());
47 }
48
49 let matches: Vec<_> = factories
50 .values()
51 .filter(|entry| entry.matches(config))
52 .collect();
53
54 match matches.len() {
55 0 => Err(CamelError::Config(format!(
56 "no matching factory for datasource url '{}'",
57 scheme_hint(&config.db_url)
58 ))),
59 1 => Ok(matches[0].clone()),
60 _ => {
61 let names: Vec<_> = matches.iter().map(|m| m.name()).collect();
62 Err(CamelError::Config(format!(
63 "ambiguous datasource: {} factories match '{}'. Set explicit 'provider' field.",
64 names.len(),
65 scheme_hint(&config.db_url)
66 )))
67 }
68 }
69 }
70}
71
72fn scheme_hint(db_url: &str) -> String {
75 if let Some(scheme_end) = db_url.find("://") {
76 format!("{}://...", &db_url[..scheme_end])
77 } else {
78 "[REDACTED]".to_string()
79 }
80}
81
82impl DatasourceCatalog for RuntimeDatasourceCatalog {
83 fn get_config(&self, name: &str) -> Option<DatasourceConfig> {
84 self.configs.get(name).cloned()
85 }
86
87 fn get_pool<'a>(&'a self, name: &'a str) -> camel_api::datasource::GetPoolFuture<'a> {
88 Box::pin(async move {
89 let config = self.configs.get(name).ok_or_else(|| {
90 CamelError::Config(format!("datasource '{}' not found in catalog", name))
91 })?;
92
93 let factory = self.resolve_factory(config)?;
94 let cache_key: CacheKey = (name.to_string(), factory.name().to_string());
95
96 let cell = self
97 .pools
98 .entry(cache_key)
99 .or_insert_with(|| Arc::new(OnceCell::new()))
100 .clone();
101
102 let handle = cell
103 .get_or_try_init(|| async {
104 let inner = factory.create(config).await?;
105 let handle =
106 DatasourceHandle::new(name.to_string(), factory.name().to_string(), inner);
107
108 if let Some(ref registry) = self.health_registry {
109 let factory_ref = factory.clone();
114 let handle_for_check = handle.clone();
115 let ds_name = name.to_string();
116 registry.register_for_route(
117 &format!("datasource:{}", ds_name),
118 std::sync::Arc::new(DatasourceHealthCheck {
119 check_name: format!("datasource:{}", ds_name),
120 factory: factory_ref,
121 handle: handle_for_check,
122 }),
123 );
124 registry.mark_route_started(&format!("datasource:{}", ds_name));
125 }
126
127 Ok::<DatasourceHandle, CamelError>(handle)
128 })
129 .await?;
130 Ok(handle.clone())
131 })
132 }
133
134 fn register_factory(
135 &self,
136 kind: &str,
137 factory: Arc<dyn PoolFactory>,
138 ) -> Result<(), CamelError> {
139 let mut factories = self.factories.write().expect("factory lock poisoned"); if factories.contains_key(kind) {
141 return Err(CamelError::Config(format!(
142 "factory '{}' already registered",
143 kind
144 )));
145 }
146 factories.insert(kind.to_string(), factory);
147 Ok(())
148 }
149}
150
151struct DatasourceHealthCheck {
152 check_name: String,
153 factory: Arc<dyn PoolFactory>,
154 handle: DatasourceHandle,
155}
156
157#[async_trait::async_trait]
158impl AsyncHealthCheck for DatasourceHealthCheck {
159 fn name(&self) -> &str {
160 &self.check_name
161 }
162
163 async fn check(&self) -> CheckResult {
164 let status = self.factory.check(&self.handle).await;
165 CheckResult {
166 name: self.check_name.clone(),
167 status,
168 message: None,
169 }
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use std::any::Any;
177 use std::sync::atomic::{AtomicUsize, Ordering};
178
179 use camel_api::datasource::{CheckFuture, CreatePoolFuture};
180 use camel_api::lifecycle::HealthStatus;
181
182 struct MockFactory {
183 name: &'static str,
184 schemes: &'static [&'static str],
185 create_count: Arc<AtomicUsize>,
186 }
187
188 impl PoolFactory for MockFactory {
189 fn create<'a>(&'a self, _config: &'a DatasourceConfig) -> CreatePoolFuture<'a> {
190 let count = self.create_count.clone();
191 Box::pin(async move {
192 count.fetch_add(1, Ordering::SeqCst);
193 Ok(Arc::new("mock_pool") as Arc<dyn Any + Send + Sync>)
194 })
195 }
196
197 fn check<'a>(&'a self, _handle: &'a DatasourceHandle) -> CheckFuture<'a> {
198 Box::pin(async { HealthStatus::Healthy })
199 }
200
201 fn supported_schemes(&self) -> &[&str] {
202 self.schemes
203 }
204
205 fn name(&self) -> &'static str {
206 self.name
207 }
208 }
209
210 fn make_config(db_url: &str) -> DatasourceConfig {
211 DatasourceConfig {
212 db_url: db_url.to_string(),
213 provider: None,
214 max_connections: None,
215 min_connections: None,
216 idle_timeout_secs: None,
217 max_lifetime_secs: None,
218 ssl_mode: None,
219 ssl_root_cert: None,
220 ssl_cert: None,
221 ssl_key: None,
222 extra: std::collections::HashMap::new(),
223 }
224 }
225
226 #[tokio::test]
227 async fn register_factory_and_get_pool() {
228 let mut configs = HashMap::new();
229 configs.insert(
230 "mydb".to_string(),
231 make_config("postgresql://localhost/mydb"),
232 );
233
234 let catalog = RuntimeDatasourceCatalog::new(configs);
235 let factory = Arc::new(MockFactory {
236 name: "pg",
237 schemes: &["postgresql", "postgres"],
238 create_count: Arc::new(AtomicUsize::new(0)),
239 });
240 catalog.register_factory("postgresql", factory).unwrap();
241
242 let handle = catalog.get_pool("mydb").await.unwrap();
243 assert_eq!(handle.name, "mydb");
244 assert_eq!(handle.provider, "pg");
245 }
246
247 #[tokio::test]
248 async fn shared_pool_for_same_datasource() {
249 let mut configs = HashMap::new();
250 configs.insert(
251 "mydb".to_string(),
252 make_config("postgresql://localhost/mydb"),
253 );
254
255 let count = Arc::new(AtomicUsize::new(0));
256 let catalog = RuntimeDatasourceCatalog::new(configs);
257 let factory = Arc::new(MockFactory {
258 name: "pg",
259 schemes: &["postgresql", "postgres"],
260 create_count: count.clone(),
261 });
262 catalog.register_factory("postgresql", factory).unwrap();
263
264 let h1 = catalog.get_pool("mydb").await.unwrap();
265 let h2 = catalog.get_pool("mydb").await.unwrap();
266
267 assert_eq!(h1.name, h2.name);
268 assert_eq!(h1.provider, h2.provider);
269 assert_eq!(count.load(Ordering::SeqCst), 1);
270 }
271
272 #[tokio::test]
273 async fn unknown_datasource_returns_error() {
274 let configs = HashMap::new();
275 let catalog = RuntimeDatasourceCatalog::new(configs);
276
277 let result = catalog.get_pool("nonexistent").await;
278 assert!(result.is_err());
279 let err = result.unwrap_err();
280 assert!(err.to_string().contains("not found"));
281 }
282
283 #[tokio::test]
284 async fn duplicate_factory_returns_error() {
285 let configs = HashMap::new();
286 let catalog = RuntimeDatasourceCatalog::new(configs);
287 let factory = Arc::new(MockFactory {
288 name: "pg",
289 schemes: &["postgresql"],
290 create_count: Arc::new(AtomicUsize::new(0)),
291 });
292
293 catalog.register_factory("pg", factory.clone()).unwrap();
294 let result = catalog.register_factory("pg", factory);
295 assert!(result.is_err());
296 let err = result.unwrap_err();
297 assert!(err.to_string().contains("already registered"));
298 }
299
300 #[tokio::test]
301 async fn no_matching_factory_returns_error() {
302 let mut configs = HashMap::new();
303 configs.insert("mydb".to_string(), make_config("mongodb://localhost/mydb"));
304
305 let catalog = RuntimeDatasourceCatalog::new(configs);
306 let factory = Arc::new(MockFactory {
307 name: "pg",
308 schemes: &["postgresql"],
309 create_count: Arc::new(AtomicUsize::new(0)),
310 });
311 catalog.register_factory("postgresql", factory).unwrap();
312
313 let result = catalog.get_pool("mydb").await;
314 assert!(result.is_err());
315 let err = result.unwrap_err();
316 assert!(err.to_string().contains("no matching factory"));
317 }
318
319 #[tokio::test]
320 async fn explicit_provider_overrides_scheme() {
321 let mut configs = HashMap::new();
322 configs.insert(
323 "mydb".to_string(),
324 DatasourceConfig {
325 db_url: "postgresql://localhost/mydb".to_string(),
326 provider: Some("mysql_factory".to_string()),
327 max_connections: None,
328 min_connections: None,
329 idle_timeout_secs: None,
330 max_lifetime_secs: None,
331 ssl_mode: None,
332 ssl_root_cert: None,
333 ssl_cert: None,
334 ssl_key: None,
335 extra: std::collections::HashMap::new(),
336 },
337 );
338
339 let pg_count = Arc::new(AtomicUsize::new(0));
340 let mysql_count = Arc::new(AtomicUsize::new(0));
341
342 let catalog = RuntimeDatasourceCatalog::new(configs);
343 let pg_factory = Arc::new(MockFactory {
344 name: "pg",
345 schemes: &["postgresql"],
346 create_count: pg_count.clone(),
347 });
348 let mysql_factory = Arc::new(MockFactory {
349 name: "mysql_factory",
350 schemes: &["mysql"],
351 create_count: mysql_count.clone(),
352 });
353
354 catalog.register_factory("postgresql", pg_factory).unwrap();
355 catalog
356 .register_factory("mysql_factory", mysql_factory)
357 .unwrap();
358
359 let handle = catalog.get_pool("mydb").await.unwrap();
360 assert_eq!(handle.provider, "mysql_factory");
361 assert_eq!(pg_count.load(Ordering::SeqCst), 0);
362 assert_eq!(mysql_count.load(Ordering::SeqCst), 1);
363 }
364
365 #[tokio::test]
366 async fn get_config_returns_clone() {
367 let mut configs = HashMap::new();
368 let original = make_config("postgresql://localhost/mydb");
369 configs.insert("mydb".to_string(), original.clone());
370
371 let catalog = RuntimeDatasourceCatalog::new(configs);
372 let retrieved = catalog.get_config("mydb");
373 assert!(retrieved.is_some());
374 assert_eq!(retrieved.unwrap().db_url, original.db_url);
375 }
376
377 #[tokio::test]
378 async fn get_pool_before_factory_registered_returns_clear_error() {
379 let mut configs = HashMap::new();
380 configs.insert(
381 "mydb".to_string(),
382 make_config("postgresql://localhost/mydb"),
383 );
384
385 let catalog = RuntimeDatasourceCatalog::new(configs);
386
387 let result = catalog.get_pool("mydb").await;
388 assert!(result.is_err());
389 let err = result.unwrap_err();
390 assert!(err.to_string().contains("no matching factory"));
391 }
392
393 #[tokio::test]
394 async fn ambiguous_factory_returns_error() {
395 let mut configs = HashMap::new();
396 configs.insert("orders".into(), make_config("postgres://localhost/test"));
397 let catalog = RuntimeDatasourceCatalog::new(configs);
398 catalog
399 .register_factory(
400 "mock1",
401 Arc::new(MockFactory {
402 name: "mock1",
403 schemes: &["postgres"],
404 create_count: Arc::new(AtomicUsize::new(0)),
405 }),
406 )
407 .unwrap();
408
409 struct MockFactory2;
410 impl PoolFactory for MockFactory2 {
411 fn create<'a>(&'a self, config: &'a DatasourceConfig) -> CreatePoolFuture<'a> {
412 Box::pin(async move {
413 Ok(Arc::new(config.db_url.clone()) as Arc<dyn Any + Send + Sync>)
414 })
415 }
416 fn check<'a>(&'a self, _handle: &'a DatasourceHandle) -> CheckFuture<'a> {
417 Box::pin(async { HealthStatus::Healthy })
418 }
419 fn supported_schemes(&self) -> &[&str] {
420 &["postgres"]
421 }
422 fn name(&self) -> &'static str {
423 "mock2"
424 }
425 }
426 catalog
427 .register_factory("mock2", Arc::new(MockFactory2))
428 .unwrap();
429
430 let result = catalog.get_pool("orders").await;
431 assert!(result.is_err());
432 let msg = result.unwrap_err().to_string();
433 assert!(
434 msg.contains("ambiguous"),
435 "expected ambiguous error, got: {}",
436 msg
437 );
438 }
439
440 #[tokio::test]
441 async fn bad_downcast_returns_clear_error() {
442 let mut configs = HashMap::new();
443 configs.insert(
444 "mydb".to_string(),
445 make_config("postgresql://localhost/mydb"),
446 );
447
448 let catalog = RuntimeDatasourceCatalog::new(configs);
449 let factory = Arc::new(MockFactory {
450 name: "pg",
451 schemes: &["postgresql"],
452 create_count: Arc::new(AtomicUsize::new(0)),
453 });
454 catalog.register_factory("postgresql", factory).unwrap();
455
456 let handle = catalog.get_pool("mydb").await.unwrap();
457
458 let result: Result<Arc<String>, CamelError> = handle.downcast();
459 assert!(result.is_err());
460 let err = result.unwrap_err();
461 assert!(err.to_string().contains("failed to downcast"));
462 assert!(err.to_string().contains("mydb"));
463 assert!(err.to_string().contains("pg"));
464 }
465
466 #[tokio::test]
467 async fn health_check_registered_after_pool_creation() {
468 let mut configs = HashMap::new();
469 configs.insert(
470 "orders".to_string(),
471 make_config("postgresql://localhost/orders"),
472 );
473
474 let registry = Arc::new(HealthCheckRegistry::new(std::time::Duration::from_secs(5)));
475 let catalog = RuntimeDatasourceCatalog::new(configs).with_health_registry(registry.clone());
476 catalog
477 .register_factory(
478 "postgresql",
479 Arc::new(MockFactory {
480 name: "pg",
481 schemes: &["postgresql", "postgres"],
482 create_count: Arc::new(AtomicUsize::new(0)),
483 }),
484 )
485 .unwrap();
486
487 let _ = catalog.get_pool("orders").await.unwrap();
488 registry.mark_route_started("datasource:orders");
489
490 let report = registry.check_all().await;
491 assert!(
492 report
493 .services
494 .iter()
495 .any(|s| s.name.starts_with("datasource:")),
496 "expected datasource health check in report, got: {:?}",
497 report.services
498 );
499 }
500}