Skip to main content

camel_component_sql/
health.rs

1use async_trait::async_trait;
2use camel_api::{AsyncHealthCheck, CheckResult};
3use camel_component_api::CamelError;
4use sqlx::AnyPool;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::OnceCell;
10
11type ProbeFuture = Pin<Box<dyn Future<Output = Result<(), CamelError>> + Send>>;
12
13trait SqlHealthProbe: Send + Sync {
14    fn probe(&self) -> ProbeFuture;
15}
16
17struct SqlPoolProbe {
18    pool: Arc<OnceCell<Arc<AnyPool>>>,
19}
20
21impl SqlPoolProbe {
22    fn new(pool: Arc<OnceCell<Arc<AnyPool>>>) -> Self {
23        Self { pool }
24    }
25}
26
27impl SqlHealthProbe for SqlPoolProbe {
28    fn probe(&self) -> ProbeFuture {
29        let pool = Arc::clone(&self.pool);
30        Box::pin(async move {
31            let pool = pool.get().ok_or_else(|| {
32                CamelError::ProcessorError("SQL connection pool not initialized".to_string())
33            })?;
34
35            sqlx::query("SELECT 1")
36                .execute(pool.as_ref())
37                .await
38                .map_err(|e| {
39                    CamelError::ProcessorError(format!("SQL health check failed: {}", e))
40                })?;
41
42            Ok(())
43        })
44    }
45}
46
47pub struct SqlHealthCheck {
48    probe: Arc<dyn SqlHealthProbe>,
49    timeout: Duration,
50}
51
52impl SqlHealthCheck {
53    pub fn new(pool: Arc<OnceCell<Arc<AnyPool>>>) -> Self {
54        Self {
55            probe: Arc::new(SqlPoolProbe::new(pool)),
56            timeout: Duration::from_secs(2),
57        }
58    }
59
60    #[cfg(test)]
61    fn with_probe_for_tests(probe: Arc<dyn SqlHealthProbe>, timeout: Duration) -> Self {
62        Self { probe, timeout }
63    }
64}
65
66#[async_trait]
67impl AsyncHealthCheck for SqlHealthCheck {
68    fn name(&self) -> &str {
69        "sql"
70    }
71
72    async fn check(&self) -> CheckResult {
73        match tokio::time::timeout(self.timeout, self.probe.probe()).await {
74            Ok(Ok(())) => CheckResult::healthy(self.name()),
75            Ok(Err(err)) => CheckResult::unhealthy(self.name(), &err.to_string()),
76            Err(_) => CheckResult::unhealthy(self.name(), "SELECT 1 timed out"),
77        }
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84    use camel_api::HealthStatus;
85
86    struct MockProbe {
87        responder: Arc<dyn Fn() -> ProbeFuture + Send + Sync>,
88    }
89
90    impl MockProbe {
91        fn new<F>(f: F) -> Self
92        where
93            F: Fn() -> ProbeFuture + Send + Sync + 'static,
94        {
95            Self {
96                responder: Arc::new(f),
97            }
98        }
99    }
100
101    impl SqlHealthProbe for MockProbe {
102        fn probe(&self) -> ProbeFuture {
103            (self.responder)()
104        }
105    }
106
107    #[tokio::test]
108    async fn sql_health_check_healthy_when_probe_succeeds() {
109        let probe = Arc::new(MockProbe::new(|| Box::pin(async { Ok(()) })));
110        let check = SqlHealthCheck::with_probe_for_tests(probe, Duration::from_millis(50));
111
112        let result = check.check().await;
113
114        assert_eq!(result.name, "sql");
115        assert_eq!(result.status, HealthStatus::Healthy);
116        assert!(result.message.is_none());
117    }
118
119    #[tokio::test]
120    async fn sql_health_check_unhealthy_when_probe_fails() {
121        let probe = Arc::new(MockProbe::new(|| {
122            Box::pin(async {
123                Err(CamelError::ProcessorError(
124                    "simulated sql error".to_string(),
125                ))
126            })
127        }));
128        let check = SqlHealthCheck::with_probe_for_tests(probe, Duration::from_millis(50));
129
130        let result = check.check().await;
131
132        assert_eq!(result.name, "sql");
133        assert_eq!(result.status, HealthStatus::Unhealthy);
134        assert!(
135            result
136                .message
137                .as_deref()
138                .is_some_and(|m| m.contains("simulated sql error"))
139        );
140    }
141
142    #[tokio::test]
143    async fn sql_health_check_unhealthy_when_probe_times_out() {
144        let probe = Arc::new(MockProbe::new(|| {
145            Box::pin(async {
146                tokio::time::sleep(Duration::from_millis(50)).await;
147                Ok(())
148            })
149        }));
150        let check = SqlHealthCheck::with_probe_for_tests(probe, Duration::from_millis(5));
151
152        let result = check.check().await;
153
154        assert_eq!(result.name, "sql");
155        assert_eq!(result.status, HealthStatus::Unhealthy);
156        assert_eq!(result.message.as_deref(), Some("SELECT 1 timed out"));
157    }
158}