use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use sqlx::types::Json;
use std::borrow::Cow;
use tokio::sync::Mutex;
use tracing::info;
use crate::common::message::SharedMessage;
use crate::common::sql::SqlPool;
use crate::error::{Error, Result};
use crate::sink::Sink;
use crate::transform::value::ValueSource;
pub use crate::common::sql::SqlDriver;
fn is_ident(s: &str) -> bool {
let mut chars = s.chars();
match chars.next() {
Some(first) if first.is_ascii_alphabetic() || first == '_' => {}
_ => return false,
}
chars.all(|c| c.is_ascii_alphanumeric() || c == '_')
}
fn validate_column_ident(name: &str) -> Result<()> {
if is_ident(name) {
Ok(())
} else {
Err(Error::config(format!(
"Invalid SQL identifier for column '{}'. Identifiers must match [A-Za-z_][A-Za-z0-9_]*",
name
)))
}
}
fn validate_table_ident(name: &str) -> Result<()> {
let parts: Vec<&str> = name.split('.').collect();
if parts.is_empty() || parts.iter().any(|p| p.is_empty()) {
return Err(Error::config(format!(
"Invalid SQL identifier for table '{}'",
name
)));
}
for part in parts {
if !is_ident(part) {
return Err(Error::config(format!(
"Invalid SQL identifier for table '{}'. Identifiers must match [A-Za-z_][A-Za-z0-9_]* (schema-qualified names like 'public.events' are allowed)",
name
)));
}
}
Ok(())
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ColumnMapping {
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub from: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub value: Option<Value>,
#[serde(default)]
pub insert_only: bool,
#[serde(default, rename = "type", skip_serializing_if = "Option::is_none")]
pub sql_type: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SqlUpsertConfig {
#[serde(default = "default_conflict_columns")]
pub conflict_columns: Vec<String>,
}
fn default_conflict_columns() -> Vec<String> {
vec!["id".to_string()]
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SqlSinkConfig {
#[serde(default)]
pub driver: SqlDriver,
pub connection: String,
pub table: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub upsert: Option<SqlUpsertConfig>,
pub columns: Vec<ColumnMapping>,
}
struct CompiledColumn {
name: String,
source: ValueSource,
insert_only: bool,
sql_type: Option<String>,
}
pub struct SqlSink {
id: String,
pool: SqlPool,
table: String,
columns: Vec<CompiledColumn>,
write_sql: String,
insert_prefix: String,
upsert_clause: String,
write_lock: Option<Mutex<()>>,
}
impl SqlSink {
pub async fn new(id: impl Into<String>, config: SqlSinkConfig) -> Result<Self> {
let id = id.into();
if config.columns.is_empty() {
return Err(Error::config(
"SQL sink requires at least one column mapping",
));
}
let columns = Self::compile_columns(&config.columns)?;
validate_table_ident(&config.table)?;
for col in &columns {
validate_column_ident(&col.name)?;
}
let pool = SqlPool::connect(config.driver, &config.connection).await?;
let column_names: Vec<&str> = columns.iter().map(|c| c.name.as_str()).collect();
let placeholders: Vec<String> = match config.driver {
SqlDriver::Sqlite => std::iter::repeat_n("?", columns.len())
.map(|s| s.to_string())
.collect(),
SqlDriver::Postgres => columns
.iter()
.enumerate()
.map(|(i, col)| {
if let Some(t) = &col.sql_type {
format!("${}::{}", i + 1, t)
} else {
format!("${}", i + 1)
}
})
.collect(),
};
let insert_prefix = format!(
"INSERT INTO {} ({}) VALUES ",
config.table,
column_names.join(", ")
);
let insert_sql = format!("{}({})", insert_prefix, placeholders.join(", "));
let upsert_clause = if let Some(upsert) = &config.upsert {
if upsert.conflict_columns.is_empty() {
return Err(Error::config(
"sql.upsert.conflict_columns must not be empty",
));
}
for c in &upsert.conflict_columns {
validate_column_ident(c)?;
}
for c in &upsert.conflict_columns {
if !columns.iter().any(|col| col.name == *c) {
return Err(Error::config(format!(
"Upsert conflict column '{}' is not present in columns list",
c
)));
}
}
let update_columns: Vec<&str> = columns
.iter()
.filter(|col| {
!upsert.conflict_columns.iter().any(|c| c == &col.name) && !col.insert_only
})
.map(|col| col.name.as_str())
.collect();
if update_columns.is_empty() {
format!(
" ON CONFLICT({}) DO NOTHING",
upsert.conflict_columns.join(", ")
)
} else {
let assignments: Vec<String> = update_columns
.iter()
.map(|c| format!("{} = excluded.{}", c, c))
.collect();
format!(
" ON CONFLICT({}) DO UPDATE SET {}",
upsert.conflict_columns.join(", "),
assignments.join(", ")
)
}
} else {
String::new()
};
let write_sql = format!("{}{}", insert_sql, upsert_clause);
let write_lock = if let SqlDriver::Sqlite = config.driver {
Some(Mutex::new(()))
} else {
None
};
info!(
sink_id = %id,
driver = ?config.driver,
table = %config.table,
upsert = config.upsert.is_some(),
columns = columns.len(),
"SQL sink created"
);
Ok(Self {
id,
pool,
table: config.table,
columns,
write_sql,
insert_prefix,
upsert_clause,
write_lock,
})
}
fn compile_columns(mappings: &[ColumnMapping]) -> Result<Vec<CompiledColumn>> {
let mut columns = Vec::with_capacity(mappings.len());
for m in mappings {
let source = ValueSource::compile(m.from.as_deref(), m.value.as_ref())
.map_err(|e| Error::config(format!("Column '{}': {}", m.name, e)))?;
columns.push(CompiledColumn {
name: m.name.clone(),
source,
insert_only: m.insert_only,
sql_type: m.sql_type.clone(),
});
}
Ok(columns)
}
}
#[async_trait]
impl Sink for SqlSink {
fn id(&self) -> &str {
&self.id
}
#[tracing::instrument(skip(self, msg), fields(sink_id = %self.id, table = %self.table))]
async fn process(&self, msg: SharedMessage) -> Result<()> {
let values: Vec<Value> = self
.columns
.iter()
.map(|col| col.source.resolve(&msg))
.collect();
let has_null = values.iter().any(|v| v.is_null());
let (sql, bound_values): (Cow<'_, str>, Vec<&Value>) = if has_null {
let mut bound = Vec::with_capacity(values.len());
let mut values_clause = String::new();
let mut placeholder_index = 1;
for (i, value) in values.iter().enumerate() {
if !values_clause.is_empty() {
values_clause.push_str(", ");
}
if value.is_null() {
values_clause.push_str("NULL");
} else {
match &self.pool {
SqlPool::Sqlite(_) => values_clause.push('?'),
SqlPool::Postgres(_) => {
if let Some(t) = &self.columns[i].sql_type {
values_clause.push_str(&format!("${}::{}", placeholder_index, t));
} else {
values_clause.push_str(&format!("${}", placeholder_index));
}
placeholder_index += 1;
}
}
bound.push(value);
}
}
let sql = format!(
"{}({}){}",
self.insert_prefix, values_clause, self.upsert_clause
);
(Cow::Owned(sql), bound)
} else {
(
Cow::Borrowed(self.write_sql.as_str()),
values.iter().collect(),
)
};
let _lock = if let Some(mutex) = &self.write_lock {
Some(mutex.lock().await)
} else {
None
};
match &self.pool {
SqlPool::Sqlite(pool) => {
let mut query = sqlx::query(sql.as_ref());
for value in &bound_values {
let value = *value;
query = match value {
Value::Null => query.bind(None::<String>),
Value::Bool(b) => query.bind(*b),
Value::Number(n) => {
if let Some(i) = n.as_i64() {
query.bind(i)
} else if let Some(f) = n.as_f64() {
query.bind(f)
} else {
query.bind(n.to_string())
}
}
Value::String(s) => query.bind(s.as_str()),
Value::Array(_) | Value::Object(_) => {
let json_str = serde_json::to_string(value).map_err(|e| {
Error::sink(format!(
"Validation failed: JSON serialization error: {}",
e
))
})?;
query.bind(json_str)
}
};
}
query
.execute(pool)
.await
.map_err(|e| Error::sink(format!("Failed to insert row: {}", e)))?;
}
SqlPool::Postgres(pool) => {
let mut query = sqlx::query(sql.as_ref());
for (idx, value) in bound_values.iter().enumerate() {
let value = *value;
let sql_type = if has_null {
let mut col_idx = 0;
let mut bound_idx = 0;
for (i, v) in values.iter().enumerate() {
if !v.is_null() {
if bound_idx == idx {
col_idx = i;
break;
}
bound_idx += 1;
}
}
self.columns
.get(col_idx)
.and_then(|c| c.sql_type.as_deref())
} else {
self.columns.get(idx).and_then(|c| c.sql_type.as_deref())
};
let is_int_type = sql_type.is_some_and(|t| {
let t = t.to_lowercase();
t == "bigint"
|| t == "int"
|| t == "integer"
|| t == "smallint"
|| t == "int2"
|| t == "int4"
|| t == "int8"
});
query = match value {
Value::Null => query.bind(None::<String>),
Value::Bool(b) => query.bind(*b),
Value::Number(n) => {
if let Some(i) = n.as_i64() {
query.bind(i)
} else if let Some(f) = n.as_f64() {
if is_int_type {
query.bind(f as i64)
} else {
query.bind(f)
}
} else {
query.bind(n.to_string())
}
}
Value::String(s) => query.bind(s.as_str()),
Value::Array(_) | Value::Object(_) => query.bind(Json(value.clone())),
};
}
query
.execute(pool)
.await
.map_err(|e| Error::sink(format!("Failed to insert row: {}", e)))?;
}
}
tracing::debug!(
sink_id = %self.id,
table = %self.table,
"Inserted row"
);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_config_deserialize() {
let yaml = r#"
driver: sqlite
connection: "test.db"
table: events
columns:
- name: id
value: "$UUID"
- name: user_id
from: "$.user.id"
- name: source
value: "pipeflow"
"#;
let config: SqlSinkConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(config.driver, SqlDriver::Sqlite);
assert_eq!(config.table, "events");
assert_eq!(config.columns.len(), 3);
assert!(config.upsert.is_none());
}
#[test]
fn test_config_postgres_driver() {
let yaml = r#"
driver: postgres
connection: "postgres://localhost/test"
table: logs
columns:
- name: msg
from: "$.message"
"#;
let config: SqlSinkConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(config.driver, SqlDriver::Postgres);
}
#[test]
fn test_column_mapping_validation() {
let mappings = vec![ColumnMapping {
name: "bad".to_string(),
from: None,
value: None,
insert_only: false,
sql_type: None,
}];
let result = SqlSink::compile_columns(&mappings);
assert!(result.is_err());
}
#[test]
fn test_column_mapping_static_value() {
let mappings = vec![ColumnMapping {
name: "source".to_string(),
from: None,
value: Some(json!("pipeflow")),
insert_only: false,
sql_type: None,
}];
let columns = SqlSink::compile_columns(&mappings).unwrap();
assert_eq!(columns.len(), 1);
assert_eq!(columns[0].name, "source");
}
#[test]
fn test_column_mapping_builtin_vars() {
let mappings = vec![
ColumnMapping {
name: "id".to_string(),
from: None,
value: Some(json!("$UUID")),
insert_only: false,
sql_type: None,
},
ColumnMapping {
name: "ts".to_string(),
from: None,
value: Some(json!("$NOW")),
insert_only: false,
sql_type: None,
},
ColumnMapping {
name: "epoch".to_string(),
from: None,
value: Some(json!("$TIMESTAMP")),
insert_only: false,
sql_type: None,
},
];
let columns = SqlSink::compile_columns(&mappings).unwrap();
assert_eq!(columns.len(), 3);
}
#[test]
fn test_column_mapping_jsonpath() {
let mappings = vec![ColumnMapping {
name: "user_id".to_string(),
from: Some("$.data.user.id".to_string()),
value: None,
insert_only: false,
sql_type: None,
}];
let columns = SqlSink::compile_columns(&mappings).unwrap();
assert_eq!(columns.len(), 1);
}
#[test]
fn test_column_mapping_template() {
let mappings = vec![ColumnMapping {
name: "message".to_string(),
from: Some("{{ $.type }}: {{ $.content }}".to_string()),
value: None,
insert_only: false,
sql_type: None,
}];
let columns = SqlSink::compile_columns(&mappings).unwrap();
assert_eq!(columns.len(), 1);
}
#[test]
fn test_upsert_config_deserialize() {
let yaml = r#"
driver: sqlite
connection: "test.db"
table: events
upsert:
conflict_columns: ["id"]
columns:
- name: id
from: "$.id"
- name: created_at
value: "$TIMESTAMP"
insert_only: true
- name: value
from: "$.value"
"#;
let config: SqlSinkConfig = serde_yaml::from_str(yaml).unwrap();
assert!(config.upsert.is_some());
assert_eq!(
config.upsert.as_ref().unwrap().conflict_columns,
vec!["id".to_string()]
);
assert_eq!(config.columns.len(), 3);
assert!(config.columns[1].insert_only);
}
#[test]
fn test_validate_table_ident_rejects_invalid() {
assert!(validate_table_ident("my table").is_err());
assert!(validate_table_ident("123table").is_err());
assert!(validate_table_ident("").is_err());
assert!(validate_table_ident("table-name").is_err());
}
#[test]
fn test_validate_table_ident_accepts_valid() {
assert!(validate_table_ident("events").is_ok());
assert!(validate_table_ident("my_table").is_ok());
assert!(validate_table_ident("public.events").is_ok());
assert!(validate_table_ident("_private").is_ok());
}
#[test]
fn test_validate_column_ident_rejects_invalid() {
assert!(validate_column_ident("bad column").is_err());
assert!(validate_column_ident("123col").is_err());
assert!(validate_column_ident("col-name").is_err());
}
#[test]
fn test_validate_column_ident_accepts_valid() {
assert!(validate_column_ident("id").is_ok());
assert!(validate_column_ident("user_id").is_ok());
assert!(validate_column_ident("_hidden").is_ok());
}
}