Skip to main content

courier/sinks/
sql.rs

1use std::collections::BTreeMap;
2
3use anyhow::{Result, anyhow};
4use async_trait::async_trait;
5use serde::Deserialize;
6use serde_json::Value;
7use sqlx::postgres::PgPoolOptions;
8use sqlx::sqlite::SqlitePoolOptions;
9use sqlx::{PgPool, SqlitePool};
10
11use crate::config::{parse_config, redact_secret};
12use crate::envelope::Envelope;
13use crate::pipeline::ErrorPolicy;
14use crate::retry::RetryPolicy;
15use crate::sinks::{ManagedSink, Sink, WriteOne};
16use crate::sources::sql::{SqlDriver, validate_driver_dsn};
17
18/// SQL sink that inserts one row per envelope using explicit column mappings.
19///
20/// Upsert is intentionally not included in the first version because conflict
21/// targets and update column behavior need a more precise cross-driver API.
22pub struct SqlSink {
23    id: String,
24    db: SinkDb,
25    insert_sql: String,
26    columns: Vec<(String, String)>,
27}
28
29impl SqlSink {
30    pub fn new(
31        id: impl Into<String>,
32        driver: SqlDriver,
33        dsn: &str,
34        table: &str,
35        columns: BTreeMap<String, String>,
36    ) -> Result<Self> {
37        validate_identifier(table, "table")?;
38        if columns.is_empty() {
39            return Err(anyhow!(
40                "invalid config for component type 'sql': columns must not be empty"
41            ));
42        }
43        for column in columns.keys() {
44            validate_identifier(column, "column")?;
45        }
46
47        let db_columns: Vec<_> = columns.keys().cloned().collect();
48        let placeholders = placeholders(driver, db_columns.len());
49        let insert_sql = format!(
50            "INSERT INTO {} ({}) VALUES ({})",
51            quote_identifier(table),
52            db_columns
53                .iter()
54                .map(|c| quote_identifier(c))
55                .collect::<Vec<_>>()
56                .join(", "),
57            placeholders.join(", ")
58        );
59
60        let db = match driver {
61            SqlDriver::Postgres => {
62                SinkDb::Postgres(PgPoolOptions::new().connect_lazy(dsn).map_err(|e| {
63                    anyhow!("invalid config for component type 'sql': invalid postgres dsn: {e}")
64                })?)
65            }
66            SqlDriver::Sqlite => {
67                SinkDb::Sqlite(SqlitePoolOptions::new().connect_lazy(dsn).map_err(|e| {
68                    anyhow!("invalid config for component type 'sql': invalid sqlite dsn: {e}")
69                })?)
70            }
71        };
72
73        Ok(Self {
74            id: id.into(),
75            db,
76            insert_sql,
77            columns: columns.into_iter().collect(),
78        })
79    }
80}
81
82#[async_trait]
83impl WriteOne for SqlSink {
84    fn id(&self) -> &str {
85        &self.id
86    }
87
88    async fn write(&self, env: &Envelope) -> Result<()> {
89        let env_value = serde_json::to_value(env)?;
90        match &self.db {
91            SinkDb::Postgres(pool) => {
92                let mut query = sqlx::query(&self.insert_sql);
93                for (_, path) in &self.columns {
94                    query = bind_pg_value(query, extract_path(&env_value, path));
95                }
96                query.execute(pool).await?;
97            }
98            SinkDb::Sqlite(pool) => {
99                let mut query = sqlx::query(&self.insert_sql);
100                for (_, path) in &self.columns {
101                    query = bind_sqlite_value(query, extract_path(&env_value, path));
102                }
103                query.execute(pool).await?;
104            }
105        }
106        Ok(())
107    }
108}
109
110enum SinkDb {
111    Postgres(PgPool),
112    Sqlite(SqlitePool),
113}
114
115fn bind_pg_value<'q>(
116    query: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>,
117    value: Option<&Value>,
118) -> sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments> {
119    match value {
120        None | Some(Value::Null) => query.bind(None::<String>),
121        Some(Value::Bool(v)) => query.bind(*v),
122        Some(Value::Number(n)) => {
123            if let Some(v) = n.as_i64() {
124                query.bind(v)
125            } else if let Some(v) = n.as_f64() {
126                query.bind(v)
127            } else {
128                query.bind(n.to_string())
129            }
130        }
131        Some(Value::String(v)) => query.bind(v.clone()),
132        Some(other) => query.bind(sqlx::types::Json(other.clone())),
133    }
134}
135
136fn bind_sqlite_value<'q>(
137    query: sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>>,
138    value: Option<&Value>,
139) -> sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>> {
140    match value {
141        None | Some(Value::Null) => query.bind(None::<String>),
142        Some(Value::Bool(v)) => query.bind(*v),
143        Some(Value::Number(n)) => {
144            if let Some(v) = n.as_i64() {
145                query.bind(v)
146            } else if let Some(v) = n.as_f64() {
147                query.bind(v)
148            } else {
149                query.bind(n.to_string())
150            }
151        }
152        Some(Value::String(v)) => query.bind(v.clone()),
153        Some(other) => query.bind(other.to_string()),
154    }
155}
156
157fn extract_path<'a>(env: &'a Value, dotted: &str) -> Option<&'a Value> {
158    let mut current = env;
159    for segment in dotted.split('.') {
160        current = current.get(segment)?;
161    }
162    Some(current)
163}
164
165fn placeholders(driver: SqlDriver, count: usize) -> Vec<String> {
166    match driver {
167        SqlDriver::Postgres => (1..=count).map(|i| format!("${i}")).collect(),
168        SqlDriver::Sqlite => (0..count).map(|_| "?".to_string()).collect(),
169    }
170}
171
172fn validate_identifier(identifier: &str, label: &str) -> Result<()> {
173    if identifier.is_empty() {
174        return Err(anyhow!(
175            "invalid config for component type 'sql': {label} must not be empty"
176        ));
177    }
178    if identifier
179        .split('.')
180        .any(|part| part.is_empty() || !part.chars().all(|c| c.is_ascii_alphanumeric() || c == '_'))
181    {
182        return Err(anyhow!(
183            "invalid config for component type 'sql': invalid {label} identifier '{}'",
184            redact_secret(identifier)
185        ));
186    }
187    Ok(())
188}
189
190fn quote_identifier(identifier: &str) -> String {
191    identifier
192        .split('.')
193        .map(|part| format!("\"{}\"", part.replace('"', "\"\"")))
194        .collect::<Vec<_>>()
195        .join(".")
196}
197
198#[derive(Debug, Deserialize)]
199struct SqlSinkConfig {
200    driver: SqlDriver,
201    dsn: String,
202    table: String,
203    #[serde(default)]
204    mode: SqlSinkMode,
205    columns: BTreeMap<String, String>,
206}
207
208#[derive(Debug, Default, Deserialize, Eq, PartialEq)]
209#[serde(rename_all = "snake_case")]
210enum SqlSinkMode {
211    #[default]
212    Insert,
213}
214
215/// Registry factory for [`SqlSink`]. Registered by
216/// `courier::registry::register_builtin` under kind `"sql"`.
217pub fn sql_sink_factory(
218    id: &str,
219    config: Value,
220    on_error: ErrorPolicy,
221    retry: Option<RetryPolicy>,
222) -> Result<Box<dyn Sink>> {
223    let config: SqlSinkConfig = parse_config("sql", config)?;
224    validate_driver_dsn("sql", config.driver, &config.dsn)?;
225    match config.mode {
226        SqlSinkMode::Insert => {}
227    }
228
229    let sql = SqlSink::new(
230        id,
231        config.driver,
232        &config.dsn,
233        &config.table,
234        config.columns,
235    )?;
236    let mut sink = ManagedSink::new(sql).with_error_policy(on_error);
237    if let Some(policy) = retry {
238        sink = sink.with_retry(policy);
239    }
240    Ok(Box::new(sink))
241}