Skip to main content

camel_core/
datasource.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::RwLock;
4
5use camel_api::datasource::{DatasourceCatalog, DatasourceConfig, DatasourceHandle, PoolFactory};
6use camel_api::error::CamelError;
7use camel_api::health::{AsyncHealthCheck, CheckResult};
8use dashmap::DashMap;
9use tokio::sync::OnceCell;
10
11use crate::health_registry::HealthCheckRegistry;
12
13type CacheKey = (String, String);
14
15pub struct RuntimeDatasourceCatalog {
16    configs: HashMap<String, DatasourceConfig>,
17    factories: RwLock<HashMap<String, Arc<dyn PoolFactory>>>,
18    pools: DashMap<CacheKey, Arc<OnceCell<DatasourceHandle>>>,
19    health_registry: Option<Arc<HealthCheckRegistry>>,
20}
21
22impl RuntimeDatasourceCatalog {
23    pub fn new(configs: HashMap<String, DatasourceConfig>) -> Self {
24        Self {
25            configs,
26            factories: RwLock::new(HashMap::new()),
27            pools: DashMap::new(),
28            health_registry: None,
29        }
30    }
31
32    pub fn with_health_registry(mut self, registry: Arc<HealthCheckRegistry>) -> Self {
33        self.health_registry = Some(registry);
34        self
35    }
36
37    fn resolve_factory(
38        &self,
39        config: &DatasourceConfig,
40    ) -> Result<Arc<dyn PoolFactory>, CamelError> {
41        let factories = self.factories.read().expect("factory lock poisoned"); // allow-unwrap
42        if let Some(ref provider) = config.provider {
43            let factory = factories.get(provider).ok_or_else(|| {
44                CamelError::Config(format!("unknown datasource provider '{}'", provider))
45            })?;
46            return Ok(factory.clone());
47        }
48
49        let matches: Vec<_> = factories
50            .values()
51            .filter(|entry| entry.matches(config))
52            .collect();
53
54        match matches.len() {
55            0 => Err(CamelError::Config(format!(
56                "no matching factory for datasource url '{}'",
57                scheme_hint(&config.db_url)
58            ))),
59            1 => Ok(matches[0].clone()),
60            _ => {
61                let names: Vec<_> = matches.iter().map(|m| m.name()).collect();
62                Err(CamelError::Config(format!(
63                    "ambiguous datasource: {} factories match '{}'. Set explicit 'provider' field.",
64                    names.len(),
65                    scheme_hint(&config.db_url)
66                )))
67            }
68        }
69    }
70}
71
72/// Extract the scheme portion of a database URL for safe display
73/// without leaking credentials.
74fn scheme_hint(db_url: &str) -> String {
75    if let Some(scheme_end) = db_url.find("://") {
76        format!("{}://...", &db_url[..scheme_end])
77    } else {
78        "[REDACTED]".to_string()
79    }
80}
81
82impl DatasourceCatalog for RuntimeDatasourceCatalog {
83    fn get_config(&self, name: &str) -> Option<DatasourceConfig> {
84        self.configs.get(name).cloned()
85    }
86
87    fn get_pool<'a>(&'a self, name: &'a str) -> camel_api::datasource::GetPoolFuture<'a> {
88        Box::pin(async move {
89            let config = self.configs.get(name).ok_or_else(|| {
90                CamelError::Config(format!("datasource '{}' not found in catalog", name))
91            })?;
92
93            let factory = self.resolve_factory(config)?;
94            let cache_key: CacheKey = (name.to_string(), factory.name().to_string());
95
96            let cell = self
97                .pools
98                .entry(cache_key)
99                .or_insert_with(|| Arc::new(OnceCell::new()))
100                .clone();
101
102            let handle = cell
103                .get_or_try_init(|| async {
104                    let inner = factory.create(config).await?;
105                    let handle =
106                        DatasourceHandle::new(name.to_string(), factory.name().to_string(), inner);
107
108                    if let Some(ref registry) = self.health_registry {
109                        // Uses pseudo-route IDs "datasource:<name>" for health registration.
110                        // mark_route_started activates the check in check_all() reports.
111                        // This is intentional: datasources are global services, not per-route,
112                        // but the HealthCheckRegistry API is route-scoped.
113                        let factory_ref = factory.clone();
114                        let handle_for_check = handle.clone();
115                        let ds_name = name.to_string();
116                        registry.register_for_route(
117                            &format!("datasource:{}", ds_name),
118                            std::sync::Arc::new(DatasourceHealthCheck {
119                                check_name: format!("datasource:{}", ds_name),
120                                factory: factory_ref,
121                                handle: handle_for_check,
122                            }),
123                        );
124                        registry.mark_route_started(&format!("datasource:{}", ds_name));
125                    }
126
127                    Ok::<DatasourceHandle, CamelError>(handle)
128                })
129                .await?;
130            Ok(handle.clone())
131        })
132    }
133
134    fn register_factory(
135        &self,
136        kind: &str,
137        factory: Arc<dyn PoolFactory>,
138    ) -> Result<(), CamelError> {
139        let mut factories = self.factories.write().expect("factory lock poisoned"); // allow-unwrap
140        if factories.contains_key(kind) {
141            return Err(CamelError::Config(format!(
142                "factory '{}' already registered",
143                kind
144            )));
145        }
146        factories.insert(kind.to_string(), factory);
147        Ok(())
148    }
149}
150
151struct DatasourceHealthCheck {
152    check_name: String,
153    factory: Arc<dyn PoolFactory>,
154    handle: DatasourceHandle,
155}
156
157#[async_trait::async_trait]
158impl AsyncHealthCheck for DatasourceHealthCheck {
159    fn name(&self) -> &str {
160        &self.check_name
161    }
162
163    async fn check(&self) -> CheckResult {
164        let status = self.factory.check(&self.handle).await;
165        CheckResult {
166            name: self.check_name.clone(),
167            status,
168            message: None,
169        }
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use std::any::Any;
177    use std::sync::atomic::{AtomicUsize, Ordering};
178
179    use camel_api::datasource::{CheckFuture, CreatePoolFuture};
180    use camel_api::lifecycle::HealthStatus;
181
182    struct MockFactory {
183        name: &'static str,
184        schemes: &'static [&'static str],
185        create_count: Arc<AtomicUsize>,
186    }
187
188    impl PoolFactory for MockFactory {
189        fn create<'a>(&'a self, _config: &'a DatasourceConfig) -> CreatePoolFuture<'a> {
190            let count = self.create_count.clone();
191            Box::pin(async move {
192                count.fetch_add(1, Ordering::SeqCst);
193                Ok(Arc::new("mock_pool") as Arc<dyn Any + Send + Sync>)
194            })
195        }
196
197        fn check<'a>(&'a self, _handle: &'a DatasourceHandle) -> CheckFuture<'a> {
198            Box::pin(async { HealthStatus::Healthy })
199        }
200
201        fn supported_schemes(&self) -> &[&str] {
202            self.schemes
203        }
204
205        fn name(&self) -> &'static str {
206            self.name
207        }
208    }
209
210    fn make_config(db_url: &str) -> DatasourceConfig {
211        DatasourceConfig {
212            db_url: db_url.to_string(),
213            provider: None,
214            max_connections: None,
215            min_connections: None,
216            idle_timeout_secs: None,
217            max_lifetime_secs: None,
218            ssl_mode: None,
219            ssl_root_cert: None,
220            ssl_cert: None,
221            ssl_key: None,
222            extra: std::collections::HashMap::new(),
223        }
224    }
225
226    #[tokio::test]
227    async fn register_factory_and_get_pool() {
228        let mut configs = HashMap::new();
229        configs.insert(
230            "mydb".to_string(),
231            make_config("postgresql://localhost/mydb"),
232        );
233
234        let catalog = RuntimeDatasourceCatalog::new(configs);
235        let factory = Arc::new(MockFactory {
236            name: "pg",
237            schemes: &["postgresql", "postgres"],
238            create_count: Arc::new(AtomicUsize::new(0)),
239        });
240        catalog.register_factory("postgresql", factory).unwrap();
241
242        let handle = catalog.get_pool("mydb").await.unwrap();
243        assert_eq!(handle.name, "mydb");
244        assert_eq!(handle.provider, "pg");
245    }
246
247    #[tokio::test]
248    async fn shared_pool_for_same_datasource() {
249        let mut configs = HashMap::new();
250        configs.insert(
251            "mydb".to_string(),
252            make_config("postgresql://localhost/mydb"),
253        );
254
255        let count = Arc::new(AtomicUsize::new(0));
256        let catalog = RuntimeDatasourceCatalog::new(configs);
257        let factory = Arc::new(MockFactory {
258            name: "pg",
259            schemes: &["postgresql", "postgres"],
260            create_count: count.clone(),
261        });
262        catalog.register_factory("postgresql", factory).unwrap();
263
264        let h1 = catalog.get_pool("mydb").await.unwrap();
265        let h2 = catalog.get_pool("mydb").await.unwrap();
266
267        assert_eq!(h1.name, h2.name);
268        assert_eq!(h1.provider, h2.provider);
269        assert_eq!(count.load(Ordering::SeqCst), 1);
270    }
271
272    #[tokio::test]
273    async fn unknown_datasource_returns_error() {
274        let configs = HashMap::new();
275        let catalog = RuntimeDatasourceCatalog::new(configs);
276
277        let result = catalog.get_pool("nonexistent").await;
278        assert!(result.is_err());
279        let err = result.unwrap_err();
280        assert!(err.to_string().contains("not found"));
281    }
282
283    #[tokio::test]
284    async fn duplicate_factory_returns_error() {
285        let configs = HashMap::new();
286        let catalog = RuntimeDatasourceCatalog::new(configs);
287        let factory = Arc::new(MockFactory {
288            name: "pg",
289            schemes: &["postgresql"],
290            create_count: Arc::new(AtomicUsize::new(0)),
291        });
292
293        catalog.register_factory("pg", factory.clone()).unwrap();
294        let result = catalog.register_factory("pg", factory);
295        assert!(result.is_err());
296        let err = result.unwrap_err();
297        assert!(err.to_string().contains("already registered"));
298    }
299
300    #[tokio::test]
301    async fn no_matching_factory_returns_error() {
302        let mut configs = HashMap::new();
303        configs.insert("mydb".to_string(), make_config("mongodb://localhost/mydb"));
304
305        let catalog = RuntimeDatasourceCatalog::new(configs);
306        let factory = Arc::new(MockFactory {
307            name: "pg",
308            schemes: &["postgresql"],
309            create_count: Arc::new(AtomicUsize::new(0)),
310        });
311        catalog.register_factory("postgresql", factory).unwrap();
312
313        let result = catalog.get_pool("mydb").await;
314        assert!(result.is_err());
315        let err = result.unwrap_err();
316        assert!(err.to_string().contains("no matching factory"));
317    }
318
319    #[tokio::test]
320    async fn explicit_provider_overrides_scheme() {
321        let mut configs = HashMap::new();
322        configs.insert(
323            "mydb".to_string(),
324            DatasourceConfig {
325                db_url: "postgresql://localhost/mydb".to_string(),
326                provider: Some("mysql_factory".to_string()),
327                max_connections: None,
328                min_connections: None,
329                idle_timeout_secs: None,
330                max_lifetime_secs: None,
331                ssl_mode: None,
332                ssl_root_cert: None,
333                ssl_cert: None,
334                ssl_key: None,
335                extra: std::collections::HashMap::new(),
336            },
337        );
338
339        let pg_count = Arc::new(AtomicUsize::new(0));
340        let mysql_count = Arc::new(AtomicUsize::new(0));
341
342        let catalog = RuntimeDatasourceCatalog::new(configs);
343        let pg_factory = Arc::new(MockFactory {
344            name: "pg",
345            schemes: &["postgresql"],
346            create_count: pg_count.clone(),
347        });
348        let mysql_factory = Arc::new(MockFactory {
349            name: "mysql_factory",
350            schemes: &["mysql"],
351            create_count: mysql_count.clone(),
352        });
353
354        catalog.register_factory("postgresql", pg_factory).unwrap();
355        catalog
356            .register_factory("mysql_factory", mysql_factory)
357            .unwrap();
358
359        let handle = catalog.get_pool("mydb").await.unwrap();
360        assert_eq!(handle.provider, "mysql_factory");
361        assert_eq!(pg_count.load(Ordering::SeqCst), 0);
362        assert_eq!(mysql_count.load(Ordering::SeqCst), 1);
363    }
364
365    #[tokio::test]
366    async fn get_config_returns_clone() {
367        let mut configs = HashMap::new();
368        let original = make_config("postgresql://localhost/mydb");
369        configs.insert("mydb".to_string(), original.clone());
370
371        let catalog = RuntimeDatasourceCatalog::new(configs);
372        let retrieved = catalog.get_config("mydb");
373        assert!(retrieved.is_some());
374        assert_eq!(retrieved.unwrap().db_url, original.db_url);
375    }
376
377    #[tokio::test]
378    async fn get_pool_before_factory_registered_returns_clear_error() {
379        let mut configs = HashMap::new();
380        configs.insert(
381            "mydb".to_string(),
382            make_config("postgresql://localhost/mydb"),
383        );
384
385        let catalog = RuntimeDatasourceCatalog::new(configs);
386
387        let result = catalog.get_pool("mydb").await;
388        assert!(result.is_err());
389        let err = result.unwrap_err();
390        assert!(err.to_string().contains("no matching factory"));
391    }
392
393    #[tokio::test]
394    async fn ambiguous_factory_returns_error() {
395        let mut configs = HashMap::new();
396        configs.insert("orders".into(), make_config("postgres://localhost/test"));
397        let catalog = RuntimeDatasourceCatalog::new(configs);
398        catalog
399            .register_factory(
400                "mock1",
401                Arc::new(MockFactory {
402                    name: "mock1",
403                    schemes: &["postgres"],
404                    create_count: Arc::new(AtomicUsize::new(0)),
405                }),
406            )
407            .unwrap();
408
409        struct MockFactory2;
410        impl PoolFactory for MockFactory2 {
411            fn create<'a>(&'a self, config: &'a DatasourceConfig) -> CreatePoolFuture<'a> {
412                Box::pin(async move {
413                    Ok(Arc::new(config.db_url.clone()) as Arc<dyn Any + Send + Sync>)
414                })
415            }
416            fn check<'a>(&'a self, _handle: &'a DatasourceHandle) -> CheckFuture<'a> {
417                Box::pin(async { HealthStatus::Healthy })
418            }
419            fn supported_schemes(&self) -> &[&str] {
420                &["postgres"]
421            }
422            fn name(&self) -> &'static str {
423                "mock2"
424            }
425        }
426        catalog
427            .register_factory("mock2", Arc::new(MockFactory2))
428            .unwrap();
429
430        let result = catalog.get_pool("orders").await;
431        assert!(result.is_err());
432        let msg = result.unwrap_err().to_string();
433        assert!(
434            msg.contains("ambiguous"),
435            "expected ambiguous error, got: {}",
436            msg
437        );
438    }
439
440    #[tokio::test]
441    async fn bad_downcast_returns_clear_error() {
442        let mut configs = HashMap::new();
443        configs.insert(
444            "mydb".to_string(),
445            make_config("postgresql://localhost/mydb"),
446        );
447
448        let catalog = RuntimeDatasourceCatalog::new(configs);
449        let factory = Arc::new(MockFactory {
450            name: "pg",
451            schemes: &["postgresql"],
452            create_count: Arc::new(AtomicUsize::new(0)),
453        });
454        catalog.register_factory("postgresql", factory).unwrap();
455
456        let handle = catalog.get_pool("mydb").await.unwrap();
457
458        let result: Result<Arc<String>, CamelError> = handle.downcast();
459        assert!(result.is_err());
460        let err = result.unwrap_err();
461        assert!(err.to_string().contains("failed to downcast"));
462        assert!(err.to_string().contains("mydb"));
463        assert!(err.to_string().contains("pg"));
464    }
465
466    #[tokio::test]
467    async fn health_check_registered_after_pool_creation() {
468        let mut configs = HashMap::new();
469        configs.insert(
470            "orders".to_string(),
471            make_config("postgresql://localhost/orders"),
472        );
473
474        let registry = Arc::new(HealthCheckRegistry::new(std::time::Duration::from_secs(5)));
475        let catalog = RuntimeDatasourceCatalog::new(configs).with_health_registry(registry.clone());
476        catalog
477            .register_factory(
478                "postgresql",
479                Arc::new(MockFactory {
480                    name: "pg",
481                    schemes: &["postgresql", "postgres"],
482                    create_count: Arc::new(AtomicUsize::new(0)),
483                }),
484            )
485            .unwrap();
486
487        let _ = catalog.get_pool("orders").await.unwrap();
488        registry.mark_route_started("datasource:orders");
489
490        let report = registry.check_all().await;
491        assert!(
492            report
493                .services
494                .iter()
495                .any(|s| s.name.starts_with("datasource:")),
496            "expected datasource health check in report, got: {:?}",
497            report.services
498        );
499    }
500}