Skip to main content

courier/sources/
sql.rs

1use std::time::{Duration, Instant};
2
3use anyhow::{Result, anyhow};
4use async_trait::async_trait;
5use serde::Deserialize;
6use serde_json::{Map, Number, Value};
7use sqlx::postgres::PgPoolOptions;
8use sqlx::sqlite::SqlitePoolOptions;
9use sqlx::types::chrono;
10use sqlx::{Column, PgPool, Row, SqlitePool, TypeInfo};
11use tokio::sync::mpsc::Sender;
12use tokio::time::{MissedTickBehavior, interval};
13use tokio_util::sync::CancellationToken;
14
15use crate::config::{parse_config, redact_secret};
16use crate::envelope::Envelope;
17use crate::observability::{NodeCtx, SendStopped, SourceCtx};
18use crate::retry::RetryPolicy;
19use crate::sources::Source;
20use crate::sources::retry::PollScheduler;
21
22/// Polls a SQL query and emits one envelope per returned row.
23///
24/// This first version is stateless: every poll executes the configured query
25/// as-is. Operators who need incremental behavior should express it in SQL
26/// until Courier has durable checkpoint storage.
27///
28/// When `retry` is configured, consecutive query failures schedule the next
29/// attempt sooner than the normal cadence — see `PollScheduler` for the rule.
30pub struct SqlQueryPollSource {
31    id: String,
32    driver: SqlDriver,
33    dsn: String,
34    query: String,
35    poll_interval: Duration,
36    retry: Option<RetryPolicy>,
37    source_ctx: SourceCtx,
38}
39
40impl SqlQueryPollSource {
41    pub fn new(
42        id: impl Into<String>,
43        driver: SqlDriver,
44        dsn: impl Into<String>,
45        query: impl Into<String>,
46        poll_interval: Duration,
47    ) -> Self {
48        let id = id.into();
49        Self {
50            source_ctx: SourceCtx::new(&id),
51            id,
52            driver,
53            dsn: dsn.into(),
54            query: query.into(),
55            poll_interval,
56            retry: None,
57        }
58    }
59
60    pub fn with_retry(mut self, retry: RetryPolicy) -> Self {
61        self.retry = Some(retry);
62        self
63    }
64}
65
66#[async_trait]
67impl Source for SqlQueryPollSource {
68    fn id(&self) -> &str {
69        &self.id
70    }
71
72    fn set_node_ctx(&mut self, ctx: NodeCtx) {
73        self.source_ctx = SourceCtx::from_node_ctx(ctx);
74    }
75
76    async fn run(self: Box<Self>, tx: Sender<Envelope>, cancel: CancellationToken) {
77        let mut scheduler = PollScheduler::new(self.poll_interval, self.retry.clone());
78        let mut ticker = interval(self.poll_interval);
79        ticker.set_missed_tick_behavior(MissedTickBehavior::Delay);
80        ticker.tick().await;
81        let source_ctx = self.source_ctx.clone();
82
83        let db = match SourceDb::connect(self.driver, &self.dsn).await {
84            Ok(db) => db,
85            Err(e) => {
86                log::error!(
87                    "[{}] failed to connect SQL source: {e}",
88                    redact_secret(&self.id)
89                );
90                return;
91            }
92        };
93
94        log::info!(
95            "[{}] starting SQL poll loop every {:?}",
96            redact_secret(&self.id),
97            self.poll_interval
98        );
99
100        loop {
101            let start = Instant::now();
102            let rows = tokio::select! {
103                _ = cancel.cancelled() => return,
104                result = db.fetch_rows(&self.query) => match result {
105                    Ok(rows) => rows,
106                    Err(e) => {
107                        let delay = scheduler.record_failure();
108                        log::error!(
109                            "[{}] SQL query failed (consecutive failures: {}), next attempt in {:?}: {e}",
110                            redact_secret(&self.id),
111                            scheduler.consecutive_failures(),
112                            delay,
113                        );
114                        // Backoff bypasses the ticker — sleep an arbitrary
115                        // duration, then reset() so normal cadence resumes
116                        // `interval` after this point.
117                        tokio::select! {
118                            _ = cancel.cancelled() => return,
119                            _ = tokio::time::sleep(delay) => {}
120                        }
121                        ticker.reset();
122                        continue;
123                    }
124                },
125            };
126
127            scheduler.record_success();
128
129            for payload in rows {
130                let env = Envelope::new(&self.id, payload);
131                match source_ctx.send(&tx, env, &cancel).await {
132                    Ok(()) => {}
133                    Err(SendStopped::Cancelled) => return,
134                    Err(SendStopped::DownstreamClosed) => {
135                        log::info!("[{}] downstream closed, stopping", redact_secret(&self.id));
136                        return;
137                    }
138                }
139            }
140
141            let elapsed = start.elapsed();
142            if elapsed > self.poll_interval {
143                log::warn!(
144                    "[{}] SQL poll took {:?}, exceeding interval {:?}",
145                    redact_secret(&self.id),
146                    elapsed,
147                    self.poll_interval,
148                );
149            }
150
151            tokio::select! {
152                _ = cancel.cancelled() => return,
153                _ = ticker.tick() => {}
154            }
155        }
156    }
157}
158
159enum SourceDb {
160    Postgres(PgPool),
161    Sqlite(SqlitePool),
162}
163
164impl SourceDb {
165    async fn connect(driver: SqlDriver, dsn: &str) -> Result<Self> {
166        match driver {
167            SqlDriver::Postgres => Ok(Self::Postgres(PgPoolOptions::new().connect(dsn).await?)),
168            SqlDriver::Sqlite => Ok(Self::Sqlite(SqlitePoolOptions::new().connect(dsn).await?)),
169        }
170    }
171
172    async fn fetch_rows(&self, query: &str) -> Result<Vec<Value>> {
173        match self {
174            Self::Postgres(pool) => {
175                let rows = sqlx::query(query).fetch_all(pool).await?;
176                rows.into_iter().map(pg_row_to_json).collect()
177            }
178            Self::Sqlite(pool) => {
179                let rows = sqlx::query(query).fetch_all(pool).await?;
180                rows.into_iter().map(sqlite_row_to_json).collect()
181            }
182        }
183    }
184}
185
186fn pg_row_to_json(row: sqlx::postgres::PgRow) -> Result<Value> {
187    let mut object = Map::new();
188    for column in row.columns() {
189        let name = column.name();
190        let type_name = column.type_info().name();
191        let value = match type_name {
192            "BOOL" => row
193                .try_get::<Option<bool>, _>(name)?
194                .map(Value::Bool)
195                .unwrap_or(Value::Null),
196            "INT2" => row
197                .try_get::<Option<i16>, _>(name)?
198                .map(|v| Value::Number(Number::from(v)))
199                .unwrap_or(Value::Null),
200            "INT4" => row
201                .try_get::<Option<i32>, _>(name)?
202                .map(|v| Value::Number(Number::from(v)))
203                .unwrap_or(Value::Null),
204            "INT8" => row
205                .try_get::<Option<i64>, _>(name)?
206                .map(|v| Value::Number(Number::from(v)))
207                .unwrap_or(Value::Null),
208            "FLOAT4" => row
209                .try_get::<Option<f32>, _>(name)?
210                .and_then(|v| Number::from_f64(v as f64))
211                .map(Value::Number)
212                .unwrap_or(Value::Null),
213            "FLOAT8" => row
214                .try_get::<Option<f64>, _>(name)?
215                .and_then(Number::from_f64)
216                .map(Value::Number)
217                .unwrap_or(Value::Null),
218            "JSON" | "JSONB" => row
219                .try_get::<Option<Value>, _>(name)?
220                .unwrap_or(Value::Null),
221            "TIMESTAMP" => row
222                .try_get::<Option<chrono::NaiveDateTime>, _>(name)?
223                .map(|v| Value::String(v.to_string()))
224                .unwrap_or(Value::Null),
225            "TIMESTAMPTZ" => row
226                .try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(name)?
227                .map(|v| Value::String(v.to_rfc3339()))
228                .unwrap_or(Value::Null),
229            _ => row
230                .try_get::<Option<String>, _>(name)?
231                .map(Value::String)
232                .unwrap_or(Value::Null),
233        };
234        object.insert(name.to_string(), value);
235    }
236    Ok(Value::Object(object))
237}
238
239fn sqlite_row_to_json(row: sqlx::sqlite::SqliteRow) -> Result<Value> {
240    let mut object = Map::new();
241    for column in row.columns() {
242        let name = column.name();
243        let type_name = column.type_info().name().to_ascii_uppercase();
244        let value = match type_name.as_str() {
245            "BOOLEAN" | "BOOL" => row
246                .try_get::<Option<bool>, _>(name)?
247                .map(Value::Bool)
248                .unwrap_or(Value::Null),
249            "INTEGER" | "INT" | "BIGINT" => row
250                .try_get::<Option<i64>, _>(name)?
251                .map(|v| Value::Number(Number::from(v)))
252                .unwrap_or(Value::Null),
253            "REAL" | "DOUBLE" | "FLOAT" => row
254                .try_get::<Option<f64>, _>(name)?
255                .and_then(Number::from_f64)
256                .map(Value::Number)
257                .unwrap_or(Value::Null),
258            _ => row
259                .try_get::<Option<String>, _>(name)?
260                .map(Value::String)
261                .unwrap_or(Value::Null),
262        };
263        object.insert(name.to_string(), value);
264    }
265    Ok(Value::Object(object))
266}
267
268#[derive(Debug, Clone, Copy, Deserialize)]
269#[serde(rename_all = "snake_case")]
270pub enum SqlDriver {
271    Postgres,
272    Sqlite,
273}
274
275pub(crate) fn validate_driver_dsn(kind: &str, driver: SqlDriver, dsn: &str) -> Result<()> {
276    let valid = match driver {
277        SqlDriver::Postgres => dsn.starts_with("postgres://") || dsn.starts_with("postgresql://"),
278        SqlDriver::Sqlite => dsn.starts_with("sqlite:"),
279    };
280    if !valid {
281        return Err(anyhow!(
282            "invalid config for component type '{kind}': dsn does not match driver {driver:?}"
283        ));
284    }
285    Ok(())
286}
287
288#[derive(Debug, Deserialize)]
289struct SqlQueryPollSourceConfig {
290    driver: SqlDriver,
291    dsn: String,
292    query: String,
293    poll_interval_secs: u64,
294}
295
296/// Registry factory for [`SqlQueryPollSource`]. Registered by
297/// `courier::registry::register_builtin` under kind `"sql_query_poll"`.
298/// The optional `retry` policy is extracted by the registry and threaded
299/// into the source's `PollScheduler`.
300pub fn sql_query_poll_source_factory(
301    id: &str,
302    config: Value,
303    retry: Option<RetryPolicy>,
304) -> Result<Box<dyn Source>> {
305    let config: SqlQueryPollSourceConfig = parse_config("sql_query_poll", config)?;
306    validate_driver_dsn("sql_query_poll", config.driver, &config.dsn)?;
307    if config.query.trim().is_empty() {
308        return Err(anyhow!(
309            "invalid config for component type 'sql_query_poll': query must not be empty"
310        ));
311    }
312    if config.poll_interval_secs == 0 {
313        return Err(anyhow!(
314            "invalid config for component type 'sql_query_poll': poll_interval_secs must be greater than 0"
315        ));
316    }
317
318    let mut source = SqlQueryPollSource::new(
319        id,
320        config.driver,
321        config.dsn,
322        config.query,
323        Duration::from_secs(config.poll_interval_secs),
324    );
325    if let Some(policy) = retry {
326        source = source.with_retry(policy);
327    }
328    Ok(Box::new(source))
329}