Skip to main content

camel_core/
datasource.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::RwLock;
4
5use camel_api::datasource::{DatasourceCatalog, DatasourceConfig, DatasourceHandle, PoolFactory};
6use camel_api::error::CamelError;
7use camel_api::health::{AsyncHealthCheck, CheckResult};
8use dashmap::DashMap;
9use tokio::sync::OnceCell;
10
11use crate::health_registry::HealthCheckRegistry;
12
13type CacheKey = (String, String);
14
15pub struct RuntimeDatasourceCatalog {
16    configs: HashMap<String, DatasourceConfig>,
17    factories: RwLock<HashMap<String, Arc<dyn PoolFactory>>>,
18    pools: DashMap<CacheKey, Arc<OnceCell<DatasourceHandle>>>,
19    health_registry: Option<Arc<HealthCheckRegistry>>,
20}
21
22impl RuntimeDatasourceCatalog {
23    pub fn new(configs: HashMap<String, DatasourceConfig>) -> Self {
24        Self {
25            configs,
26            factories: RwLock::new(HashMap::new()),
27            pools: DashMap::new(),
28            health_registry: None,
29        }
30    }
31
32    pub fn with_health_registry(mut self, registry: Arc<HealthCheckRegistry>) -> Self {
33        self.health_registry = Some(registry);
34        self
35    }
36
37    fn resolve_factory(
38        &self,
39        config: &DatasourceConfig,
40    ) -> Result<Arc<dyn PoolFactory>, CamelError> {
41        let factories = self.factories.read().expect("factory lock poisoned"); // allow-unwrap
42        if let Some(ref provider) = config.provider {
43            let factory = factories.get(provider).ok_or_else(|| {
44                CamelError::Config(format!("unknown datasource provider '{}'", provider))
45            })?;
46            return Ok(factory.clone());
47        }
48
49        let matches: Vec<_> = factories
50            .values()
51            .filter(|entry| entry.matches(config))
52            .collect();
53
54        match matches.len() {
55            0 => Err(CamelError::Config(format!(
56                "no matching factory for datasource url '{}'",
57                scheme_hint(&config.db_url)
58            ))),
59            1 => Ok(matches[0].clone()),
60            _ => {
61                let names: Vec<_> = matches.iter().map(|m| m.name()).collect();
62                Err(CamelError::Config(format!(
63                    "ambiguous datasource: {} factories match '{}'. Set explicit 'provider' field.",
64                    names.len(),
65                    scheme_hint(&config.db_url)
66                )))
67            }
68        }
69    }
70}
71
72/// Extract the scheme portion of a database URL for safe display
73/// without leaking credentials.
74fn scheme_hint(db_url: &str) -> String {
75    if let Some(scheme_end) = db_url.find("://") {
76        format!("{}://...", &db_url[..scheme_end])
77    } else {
78        "[REDACTED]".to_string()
79    }
80}
81
82impl DatasourceCatalog for RuntimeDatasourceCatalog {
83    fn get_config(&self, name: &str) -> Option<DatasourceConfig> {
84        self.configs.get(name).cloned()
85    }
86
87    fn get_pool<'a>(&'a self, name: &'a str) -> camel_api::datasource::GetPoolFuture<'a> {
88        Box::pin(async move {
89            let config = self.configs.get(name).ok_or_else(|| {
90                CamelError::Config(format!("datasource '{}' not found in catalog", name))
91            })?;
92
93            let factory = self.resolve_factory(config)?;
94            let cache_key: CacheKey = (name.to_string(), factory.name().to_string());
95
96            let cell = self
97                .pools
98                .entry(cache_key)
99                .or_insert_with(|| Arc::new(OnceCell::new()))
100                .clone();
101
102            let handle = cell
103                .get_or_try_init(|| async {
104                    let inner = factory.create(config).await?;
105                    let handle =
106                        DatasourceHandle::new(name.to_string(), factory.name().to_string(), inner);
107
108                    if let Some(ref registry) = self.health_registry {
109                        // Uses pseudo-route IDs "datasource:<name>" for health registration.
110                        // mark_route_started activates the check in check_all() reports.
111                        // This is intentional: datasources are global services, not per-route,
112                        // but the HealthCheckRegistry API is route-scoped.
113                        let factory_ref = factory.clone();
114                        let handle_for_check = handle.clone();
115                        let ds_name = name.to_string();
116                        registry.register_for_route(
117                            &format!("datasource:{}", ds_name),
118                            std::sync::Arc::new(DatasourceHealthCheck {
119                                check_name: format!("datasource:{}", ds_name),
120                                factory: factory_ref,
121                                handle: handle_for_check,
122                            }),
123                        );
124                        registry.mark_route_started(&format!("datasource:{}", ds_name));
125                    }
126
127                    Ok::<DatasourceHandle, CamelError>(handle)
128                })
129                .await?;
130            Ok(handle.clone())
131        })
132    }
133
134    fn register_factory(
135        &self,
136        kind: &str,
137        factory: Arc<dyn PoolFactory>,
138    ) -> Result<(), CamelError> {
139        let mut factories = self.factories.write().expect("factory lock poisoned"); // allow-unwrap
140        if factories.contains_key(kind) {
141            return Err(CamelError::Config(format!(
142                "factory '{}' already registered",
143                kind
144            )));
145        }
146        factories.insert(kind.to_string(), factory);
147        Ok(())
148    }
149}
150
151struct DatasourceHealthCheck {
152    check_name: String,
153    factory: Arc<dyn PoolFactory>,
154    handle: DatasourceHandle,
155}
156
157#[async_trait::async_trait]
158impl AsyncHealthCheck for DatasourceHealthCheck {
159    fn name(&self) -> &str {
160        &self.check_name
161    }
162
163    async fn check(&self) -> CheckResult {
164        let status = self.factory.check(&self.handle).await;
165        CheckResult {
166            name: self.check_name.clone(),
167            status,
168            message: None,
169        }
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use std::any::Any;
177    use std::sync::atomic::{AtomicUsize, Ordering};
178
179    use camel_api::datasource::{CheckFuture, CreatePoolFuture};
180    use camel_api::lifecycle::HealthStatus;
181
182    struct MockFactory {
183        name: &'static str,
184        schemes: &'static [&'static str],
185        create_count: Arc<AtomicUsize>,
186    }
187
188    impl PoolFactory for MockFactory {
189        fn create<'a>(&'a self, _config: &'a DatasourceConfig) -> CreatePoolFuture<'a> {
190            let count = self.create_count.clone();
191            Box::pin(async move {
192                count.fetch_add(1, Ordering::SeqCst);
193                Ok(Arc::new("mock_pool") as Arc<dyn Any + Send + Sync>)
194            })
195        }
196
197        fn check<'a>(&'a self, _handle: &'a DatasourceHandle) -> CheckFuture<'a> {
198            Box::pin(async { HealthStatus::Healthy })
199        }
200
201        fn supported_schemes(&self) -> &[&str] {
202            self.schemes
203        }
204
205        fn name(&self) -> &'static str {
206            self.name
207        }
208    }
209
210    fn make_config(db_url: &str) -> DatasourceConfig {
211        DatasourceConfig {
212            db_url: db_url.to_string(),
213            provider: None,
214            max_connections: None,
215            min_connections: None,
216            idle_timeout_secs: None,
217            max_lifetime_secs: None,
218            ssl_mode: None,
219            ssl_root_cert: None,
220            ssl_cert: None,
221            ssl_key: None,
222        }
223    }
224
225    #[tokio::test]
226    async fn register_factory_and_get_pool() {
227        let mut configs = HashMap::new();
228        configs.insert(
229            "mydb".to_string(),
230            make_config("postgresql://localhost/mydb"),
231        );
232
233        let catalog = RuntimeDatasourceCatalog::new(configs);
234        let factory = Arc::new(MockFactory {
235            name: "pg",
236            schemes: &["postgresql", "postgres"],
237            create_count: Arc::new(AtomicUsize::new(0)),
238        });
239        catalog.register_factory("postgresql", factory).unwrap();
240
241        let handle = catalog.get_pool("mydb").await.unwrap();
242        assert_eq!(handle.name, "mydb");
243        assert_eq!(handle.provider, "pg");
244    }
245
246    #[tokio::test]
247    async fn shared_pool_for_same_datasource() {
248        let mut configs = HashMap::new();
249        configs.insert(
250            "mydb".to_string(),
251            make_config("postgresql://localhost/mydb"),
252        );
253
254        let count = Arc::new(AtomicUsize::new(0));
255        let catalog = RuntimeDatasourceCatalog::new(configs);
256        let factory = Arc::new(MockFactory {
257            name: "pg",
258            schemes: &["postgresql", "postgres"],
259            create_count: count.clone(),
260        });
261        catalog.register_factory("postgresql", factory).unwrap();
262
263        let h1 = catalog.get_pool("mydb").await.unwrap();
264        let h2 = catalog.get_pool("mydb").await.unwrap();
265
266        assert_eq!(h1.name, h2.name);
267        assert_eq!(h1.provider, h2.provider);
268        assert_eq!(count.load(Ordering::SeqCst), 1);
269    }
270
271    #[tokio::test]
272    async fn unknown_datasource_returns_error() {
273        let configs = HashMap::new();
274        let catalog = RuntimeDatasourceCatalog::new(configs);
275
276        let result = catalog.get_pool("nonexistent").await;
277        assert!(result.is_err());
278        let err = result.unwrap_err();
279        assert!(err.to_string().contains("not found"));
280    }
281
282    #[tokio::test]
283    async fn duplicate_factory_returns_error() {
284        let configs = HashMap::new();
285        let catalog = RuntimeDatasourceCatalog::new(configs);
286        let factory = Arc::new(MockFactory {
287            name: "pg",
288            schemes: &["postgresql"],
289            create_count: Arc::new(AtomicUsize::new(0)),
290        });
291
292        catalog.register_factory("pg", factory.clone()).unwrap();
293        let result = catalog.register_factory("pg", factory);
294        assert!(result.is_err());
295        let err = result.unwrap_err();
296        assert!(err.to_string().contains("already registered"));
297    }
298
299    #[tokio::test]
300    async fn no_matching_factory_returns_error() {
301        let mut configs = HashMap::new();
302        configs.insert("mydb".to_string(), make_config("mongodb://localhost/mydb"));
303
304        let catalog = RuntimeDatasourceCatalog::new(configs);
305        let factory = Arc::new(MockFactory {
306            name: "pg",
307            schemes: &["postgresql"],
308            create_count: Arc::new(AtomicUsize::new(0)),
309        });
310        catalog.register_factory("postgresql", factory).unwrap();
311
312        let result = catalog.get_pool("mydb").await;
313        assert!(result.is_err());
314        let err = result.unwrap_err();
315        assert!(err.to_string().contains("no matching factory"));
316    }
317
318    #[tokio::test]
319    async fn explicit_provider_overrides_scheme() {
320        let mut configs = HashMap::new();
321        configs.insert(
322            "mydb".to_string(),
323            DatasourceConfig {
324                db_url: "postgresql://localhost/mydb".to_string(),
325                provider: Some("mysql_factory".to_string()),
326                max_connections: None,
327                min_connections: None,
328                idle_timeout_secs: None,
329                max_lifetime_secs: None,
330                ssl_mode: None,
331                ssl_root_cert: None,
332                ssl_cert: None,
333                ssl_key: None,
334            },
335        );
336
337        let pg_count = Arc::new(AtomicUsize::new(0));
338        let mysql_count = Arc::new(AtomicUsize::new(0));
339
340        let catalog = RuntimeDatasourceCatalog::new(configs);
341        let pg_factory = Arc::new(MockFactory {
342            name: "pg",
343            schemes: &["postgresql"],
344            create_count: pg_count.clone(),
345        });
346        let mysql_factory = Arc::new(MockFactory {
347            name: "mysql_factory",
348            schemes: &["mysql"],
349            create_count: mysql_count.clone(),
350        });
351
352        catalog.register_factory("postgresql", pg_factory).unwrap();
353        catalog
354            .register_factory("mysql_factory", mysql_factory)
355            .unwrap();
356
357        let handle = catalog.get_pool("mydb").await.unwrap();
358        assert_eq!(handle.provider, "mysql_factory");
359        assert_eq!(pg_count.load(Ordering::SeqCst), 0);
360        assert_eq!(mysql_count.load(Ordering::SeqCst), 1);
361    }
362
363    #[tokio::test]
364    async fn get_config_returns_clone() {
365        let mut configs = HashMap::new();
366        let original = make_config("postgresql://localhost/mydb");
367        configs.insert("mydb".to_string(), original.clone());
368
369        let catalog = RuntimeDatasourceCatalog::new(configs);
370        let retrieved = catalog.get_config("mydb");
371        assert!(retrieved.is_some());
372        assert_eq!(retrieved.unwrap().db_url, original.db_url);
373    }
374
375    #[tokio::test]
376    async fn get_pool_before_factory_registered_returns_clear_error() {
377        let mut configs = HashMap::new();
378        configs.insert(
379            "mydb".to_string(),
380            make_config("postgresql://localhost/mydb"),
381        );
382
383        let catalog = RuntimeDatasourceCatalog::new(configs);
384
385        let result = catalog.get_pool("mydb").await;
386        assert!(result.is_err());
387        let err = result.unwrap_err();
388        assert!(err.to_string().contains("no matching factory"));
389    }
390
391    #[tokio::test]
392    async fn ambiguous_factory_returns_error() {
393        let mut configs = HashMap::new();
394        configs.insert("orders".into(), make_config("postgres://localhost/test"));
395        let catalog = RuntimeDatasourceCatalog::new(configs);
396        catalog
397            .register_factory(
398                "mock1",
399                Arc::new(MockFactory {
400                    name: "mock1",
401                    schemes: &["postgres"],
402                    create_count: Arc::new(AtomicUsize::new(0)),
403                }),
404            )
405            .unwrap();
406
407        struct MockFactory2;
408        impl PoolFactory for MockFactory2 {
409            fn create<'a>(&'a self, config: &'a DatasourceConfig) -> CreatePoolFuture<'a> {
410                Box::pin(async move {
411                    Ok(Arc::new(config.db_url.clone()) as Arc<dyn Any + Send + Sync>)
412                })
413            }
414            fn check<'a>(&'a self, _handle: &'a DatasourceHandle) -> CheckFuture<'a> {
415                Box::pin(async { HealthStatus::Healthy })
416            }
417            fn supported_schemes(&self) -> &[&str] {
418                &["postgres"]
419            }
420            fn name(&self) -> &'static str {
421                "mock2"
422            }
423        }
424        catalog
425            .register_factory("mock2", Arc::new(MockFactory2))
426            .unwrap();
427
428        let result = catalog.get_pool("orders").await;
429        assert!(result.is_err());
430        let msg = result.unwrap_err().to_string();
431        assert!(
432            msg.contains("ambiguous"),
433            "expected ambiguous error, got: {}",
434            msg
435        );
436    }
437
438    #[tokio::test]
439    async fn bad_downcast_returns_clear_error() {
440        let mut configs = HashMap::new();
441        configs.insert(
442            "mydb".to_string(),
443            make_config("postgresql://localhost/mydb"),
444        );
445
446        let catalog = RuntimeDatasourceCatalog::new(configs);
447        let factory = Arc::new(MockFactory {
448            name: "pg",
449            schemes: &["postgresql"],
450            create_count: Arc::new(AtomicUsize::new(0)),
451        });
452        catalog.register_factory("postgresql", factory).unwrap();
453
454        let handle = catalog.get_pool("mydb").await.unwrap();
455
456        let result: Result<Arc<String>, CamelError> = handle.downcast();
457        assert!(result.is_err());
458        let err = result.unwrap_err();
459        assert!(err.to_string().contains("failed to downcast"));
460        assert!(err.to_string().contains("mydb"));
461        assert!(err.to_string().contains("pg"));
462    }
463
464    #[tokio::test]
465    async fn health_check_registered_after_pool_creation() {
466        let mut configs = HashMap::new();
467        configs.insert(
468            "orders".to_string(),
469            make_config("postgresql://localhost/orders"),
470        );
471
472        let registry = Arc::new(HealthCheckRegistry::new(std::time::Duration::from_secs(5)));
473        let catalog = RuntimeDatasourceCatalog::new(configs).with_health_registry(registry.clone());
474        catalog
475            .register_factory(
476                "postgresql",
477                Arc::new(MockFactory {
478                    name: "pg",
479                    schemes: &["postgresql", "postgres"],
480                    create_count: Arc::new(AtomicUsize::new(0)),
481                }),
482            )
483            .unwrap();
484
485        let _ = catalog.get_pool("orders").await.unwrap();
486        registry.mark_route_started("datasource:orders");
487
488        let report = registry.check_all().await;
489        assert!(
490            report
491                .services
492                .iter()
493                .any(|s| s.name.starts_with("datasource:")),
494            "expected datasource health check in report, got: {:?}",
495            report.services
496        );
497    }
498}