use crate::client::AthenaClient;
use crate::parser::query_builder::sanitize_identifier;
use anyhow::{Context, Result};
use csv::ReaderBuilder;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value, json};
use std::{collections::HashMap, fs, path::Path};
use tokio::time::{Duration, sleep};
use tracing::info;
use uuid::Uuid;
#[derive(Debug, Deserialize)]
struct RawSequinEvent {
seq: i64,
source_table_schema: String,
source_table_name: String,
record_pk: String,
record: String,
changes: Option<String>,
action: String,
}
#[derive(Debug, Clone)]
pub struct SequinEvent {
pub seq: i64,
pub source_table_schema: String,
pub source_table_name: String,
pub record: Value,
pub record_pk: Value,
pub action: SequinAction,
pub changes: Value,
}
impl SequinEvent {
pub fn table_key(&self) -> String {
format!("{}.{}", self.source_table_schema, self.source_table_name)
}
fn from_csv(raw: RawSequinEvent) -> Result<Self> {
let record = parse_json_field(&raw.record).context("parsing record field")?;
let record_pk = parse_json_field_allow_plain(&raw.record_pk);
let action = SequinAction::try_from(raw.action.as_str())?;
let changes = parse_json_field_optional(raw.changes.as_deref());
Ok(Self {
seq: raw.seq,
source_table_schema: raw.source_table_schema,
source_table_name: raw.source_table_name,
record,
record_pk,
action,
changes,
})
}
pub fn from_query_row(row: &Value) -> Result<Self> {
let map = row
.as_object()
.context("expected sequin row to be an object")?;
let seq = map
.get("seq")
.and_then(|value| value.as_i64())
.context("missing seq column")?;
let source_table_schema = map
.get("source_table_schema")
.and_then(|value| value.as_str())
.context("missing source_table_schema")?
.to_string();
let source_table_name = map
.get("source_table_name")
.and_then(|value| value.as_str())
.context("missing source_table_name")?
.to_string();
let record_val = map
.get("record")
.map(|value| normalize_db_value(value))
.transpose()
.context("normalizing record column")?
.unwrap_or(Value::Null);
let record_pk_val = map
.get("record_pk")
.map(|value| normalize_db_value(value))
.transpose()
.context("normalizing record_pk column")?
.unwrap_or(Value::Null);
let changes_val = map
.get("changes")
.map(|value| normalize_db_value(value))
.transpose()
.context("normalizing changes column")?
.unwrap_or(Value::Null);
let action = map
.get("action")
.and_then(|value| value.as_str())
.context("missing action column")
.and_then(|text| SequinAction::try_from(text))?;
Ok(Self {
seq,
source_table_schema,
source_table_name,
record: record_val,
record_pk: record_pk_val,
action,
changes: changes_val,
})
}
pub fn new_values(&self) -> Value {
match self.action {
SequinAction::Delete => Value::Null,
_ => self.record.clone(),
}
}
pub fn old_values(&self) -> Value {
match self.action {
SequinAction::Insert => Value::Null,
SequinAction::Update => self.changes.clone(),
SequinAction::Delete => self.record.clone(),
}
}
pub fn action_name(&self) -> &'static str {
match self.action {
SequinAction::Insert => "insert",
SequinAction::Update => "update",
SequinAction::Delete => "delete",
}
}
pub fn record_id(&self) -> Uuid {
parse_uuid_from_value(&self.record_pk)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SequinAction {
Insert,
Update,
Delete,
}
impl TryFrom<&str> for SequinAction {
type Error = anyhow::Error;
fn try_from(value: &str) -> Result<Self> {
match value.to_lowercase().as_str() {
"insert" => Ok(Self::Insert),
"update" => Ok(Self::Update),
"delete" => Ok(Self::Delete),
other => Err(anyhow::anyhow!("unsupported sequin action: {}", other)),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CdcTableConfig {
pub schema: String,
pub table: String,
#[serde(default)]
pub pk_columns: Vec<String>,
}
impl CdcTableConfig {
pub fn key(&self) -> String {
format!("{}.{}", self.schema, self.table)
}
pub fn qualified_name(&self) -> Option<String> {
let schema = sanitize_identifier(&self.schema)?;
let table = sanitize_identifier(&self.table)?;
Some(format!("{}.{}", schema, table))
}
pub fn sanitized_pk(&self) -> Vec<String> {
self.pk_columns
.iter()
.filter_map(|column| sanitize_identifier(column))
.collect::<Vec<_>>()
}
}
#[derive(Debug, Deserialize)]
pub struct CdcConfig {
pub tables: Vec<CdcTableConfig>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CdcState {
pub last_seq: Option<i64>,
}
impl CdcState {
pub fn load(path: &Path) -> Result<Self> {
if !path.is_file() {
return Ok(Self::default());
}
let bytes = fs::read_to_string(path)?;
let state: Self = serde_json::from_str(&bytes)?;
Ok(state)
}
pub fn save(&self, path: &Path) -> Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
let data = serde_json::to_string_pretty(self)?;
fs::write(path, data)?;
Ok(())
}
pub fn mark_seq(&mut self, seq: i64) {
if self.last_seq.map_or(true, |current| seq > current) {
self.last_seq = Some(seq);
}
}
}
pub fn load_table_configs(path: Option<&Path>) -> Result<HashMap<String, CdcTableConfig>> {
let mut configs = HashMap::new();
if let Some(path) = path {
let contents = fs::read_to_string(path)
.with_context(|| format!("reading table config from {}", path.display()))?;
let config: CdcConfig = serde_yaml::from_str(&contents)
.with_context(|| format!("parsing table config at {}", path.display()))?;
for table in config.tables {
configs.insert(table.key(), table);
}
}
Ok(configs)
}
pub async fn backfill_from_csv(
client: &AthenaClient,
csv_path: &Path,
table_configs: &HashMap<String, CdcTableConfig>,
state_path: &Path,
dry_run: bool,
audit_logger: Option<&AuditLogger>,
) -> Result<CdcState> {
let mut rdr = ReaderBuilder::new()
.trim(csv::Trim::All)
.from_path(csv_path)
.context("opening CDC CSV file")?;
let mut raw_events = Vec::new();
for result in rdr.deserialize() {
let raw: RawSequinEvent = result.context("parsing CSV row")?;
raw_events.push(SequinEvent::from_csv(raw)?);
}
raw_events.sort_by_key(|event| event.seq);
let mut state = CdcState::load(state_path)?;
for event in raw_events.into_iter() {
if state
.last_seq
.map_or(false, |last_seq| event.seq <= last_seq)
{
continue;
}
apply_event(client, &event, table_configs, dry_run, audit_logger).await?;
state.mark_seq(event.seq);
state.save(state_path)?;
}
Ok(state)
}
pub async fn stream_events(
client: &AthenaClient,
table_configs: &HashMap<String, CdcTableConfig>,
sequin_table: &str,
batch_size: usize,
poll_interval: Duration,
state_path: &Path,
dry_run: bool,
audit_logger: Option<&AuditLogger>,
) -> Result<()> {
let mut state = CdcState::load(state_path)?;
let qualified_table =
sanitize_table_reference(sequin_table).context("invalid sequin table reference")?;
loop {
let since = state.last_seq.unwrap_or(0);
let batch = fetch_batch(client, &qualified_table, since, batch_size).await?;
if batch.is_empty() {
sleep(poll_interval).await;
continue;
}
for event in batch {
apply_event(client, &event, table_configs, dry_run, audit_logger).await?;
state.mark_seq(event.seq);
state.save(state_path)?;
}
}
}
async fn fetch_batch(
client: &AthenaClient,
table: &str,
since: i64,
limit: usize,
) -> Result<Vec<SequinEvent>> {
let condition = if since > 0 {
format!("WHERE seq > {} ", since)
} else {
String::new()
};
let sql = format!(
"SELECT * FROM {} {}ORDER BY seq ASC LIMIT {}",
table,
condition,
limit.max(1)
);
let result = client
.execute_sql(&sql)
.await
.context("fetching sequin events")?;
let mut events = Vec::new();
for row in result.rows {
if let Ok(event) = SequinEvent::from_query_row(&row) {
events.push(event);
}
}
Ok(events)
}
async fn apply_event(
client: &AthenaClient,
event: &SequinEvent,
configs: &HashMap<String, CdcTableConfig>,
dry_run: bool,
audit_logger: Option<&AuditLogger>,
) -> Result<()> {
let config = resolve_table_config(event, configs);
let sql = match event.action {
SequinAction::Insert => build_insert_sql(&event.record, &config, &event.record_pk)?,
SequinAction::Update => build_update_sql(&event.record, &config, &event.record_pk)?,
SequinAction::Delete => build_delete_sql(&event.record, &config, &event.record_pk)?,
};
info!(
"CDC {} {} seq={} {}",
event.action_name(),
event.table_key(),
event.seq,
if dry_run { "(dry run)" } else { "(executing)" }
);
if !dry_run {
client.execute_sql(&sql).await?;
}
if let Some(logger) = audit_logger {
logger.log(event, dry_run).await?;
}
Ok(())
}
fn resolve_table_config(
event: &SequinEvent,
configs: &HashMap<String, CdcTableConfig>,
) -> CdcTableConfig {
if let Some(config) = configs.get(&event.table_key()) {
return config.clone();
}
let inferred_pk = infer_pk_columns(event.record.as_object(), &event.record_pk);
let mut config = CdcTableConfig {
schema: event.source_table_schema.clone(),
table: event.source_table_name.clone(),
pk_columns: inferred_pk,
};
if config.pk_columns.is_empty() {
config.pk_columns = vec!["id".to_string()];
}
config
}
#[doc(hidden)]
pub fn infer_pk_columns(record: Option<&Map<String, Value>>, pk_hint: &Value) -> Vec<String> {
let mut hints = Vec::new();
match pk_hint {
Value::Array(items) => {
for item in items {
if let Some(text) = item.as_str() {
hints.push(text.to_string());
}
}
}
Value::Object(map) => {
hints.extend(map.keys().cloned());
}
other => {
if let Some(text) = value_to_string(other) {
if let Some(record_map) = record {
for (column, value) in record_map {
if value_to_string(value)
.map(|value_text| value_text == text)
.unwrap_or(false)
{
hints.push(column.clone());
}
}
}
}
}
}
hints.sort();
hints.dedup();
hints
}
#[doc(hidden)]
pub fn build_insert_sql(
record: &Value,
config: &CdcTableConfig,
pk_hint: &Value,
) -> Result<String> {
let columns = record
.as_object()
.context("insert record must be an object")?
.iter()
.filter_map(|(raw, value)| {
sanitize_identifier(raw).map(|sanitized| (raw.clone(), sanitized, value.clone()))
})
.collect::<Vec<_>>();
if columns.is_empty() {
anyhow::bail!("insert record contains no valid columns");
}
let names = columns
.iter()
.map(|(_, sanitized, _)| sanitized.clone())
.collect::<Vec<_>>();
let values = columns
.iter()
.map(|(_, _, value)| value_to_sql_literal(value))
.collect::<Vec<_>>();
let pk_columns = if config.pk_columns.is_empty() {
infer_pk_columns(record.as_object(), pk_hint)
} else {
config.pk_columns.clone()
};
let conflict_clause = build_conflict_clause(&pk_columns);
let table = config
.qualified_name()
.context("invalid target table identifier")?;
Ok(format!(
"INSERT INTO {table} ({columns}) VALUES ({values}){conflict};",
table = table,
columns = names.join(", "),
values = values.join(", "),
conflict = conflict_clause
))
}
fn build_conflict_clause(pk_columns: &[String]) -> String {
let sanitized = pk_columns
.iter()
.filter_map(|col| sanitize_identifier(col))
.collect::<Vec<_>>();
if sanitized.is_empty() {
return String::new();
}
let assignments = sanitized
.iter()
.map(|column| format!("{column} = EXCLUDED.{column}"))
.collect::<Vec<_>>()
.join(", ");
format!(
" ON CONFLICT ({}) DO UPDATE SET {}",
sanitized.join(", "),
assignments
)
}
#[doc(hidden)]
pub fn build_update_sql(
record: &Value,
config: &CdcTableConfig,
pk_hint: &Value,
) -> Result<String> {
let map = record
.as_object()
.context("update record must be an object")?;
let pk_columns = if config.pk_columns.is_empty() {
infer_pk_columns(Some(map), pk_hint)
} else {
config.pk_columns.clone()
};
let assignments = map
.iter()
.filter(|(column, _)| !pk_columns.contains(column))
.filter_map(|(raw, value)| {
sanitize_identifier(raw)
.map(|sanitized| format!("{} = {}", sanitized, value_to_sql_literal(value)))
})
.collect::<Vec<_>>();
let where_clause = build_where_clause(map, &pk_columns)?;
if assignments.is_empty() {
anyhow::bail!("update record contains no columns to update");
}
let table = config
.qualified_name()
.context("invalid target table identifier")?;
Ok(format!(
"UPDATE {table} SET {assignments} WHERE {where};",
table = table,
assignments = assignments.join(", "),
where = where_clause
))
}
#[doc(hidden)]
pub fn build_delete_sql(
record: &Value,
config: &CdcTableConfig,
pk_hint: &Value,
) -> Result<String> {
let map = record
.as_object()
.context("delete record must be an object")?;
let pk_columns = if config.pk_columns.is_empty() {
infer_pk_columns(Some(map), pk_hint)
} else {
config.pk_columns.clone()
};
let where_clause = build_where_clause(map, &pk_columns)?;
let table = config
.qualified_name()
.context("invalid target table identifier")?;
Ok(format!("DELETE FROM {table} WHERE {where};", table = table, where = where_clause))
}
fn build_where_clause(map: &Map<String, Value>, pk_columns: &[String]) -> Result<String> {
let mut parts = Vec::new();
for pk in pk_columns {
if let Some(value) = map.get(pk) {
if let Some(sanitized) = sanitize_identifier(pk) {
parts.push(format!("{} = {}", sanitized, value_to_sql_literal(value)));
}
}
}
if parts.is_empty() {
anyhow::bail!("no primary key values available for WHERE clause");
}
Ok(parts.join(" AND "))
}
#[doc(hidden)]
pub fn value_to_sql_literal(value: &Value) -> String {
match value {
Value::Null => "NULL".to_string(),
Value::Bool(flag) => flag.to_string().to_uppercase(),
Value::Number(num) => num.to_string(),
Value::String(text) => format!("'{}'", escape_string(text)),
Value::Array(_) | Value::Object(_) => {
let json = serde_json::to_string(value).unwrap_or_default();
format!("'{}'::jsonb", escape_string(&json))
}
}
}
fn value_to_string(value: &Value) -> Option<String> {
match value {
Value::String(text) => Some(text.clone()),
Value::Number(num) => Some(num.to_string()),
Value::Bool(flag) => Some(flag.to_string()),
Value::Null => Some(String::new()),
other => serde_json::to_string(other).ok(),
}
}
fn escape_string(text: &str) -> String {
text.replace('\'', "''")
}
fn parse_json_field(content: &str) -> Result<Value> {
if content.trim().is_empty() {
return Ok(Value::Null);
}
serde_json::from_str(content).context("parsing JSON field")
}
fn parse_json_field_optional(raw: Option<&str>) -> Value {
raw.and_then(|text| serde_json::from_str(text).ok())
.unwrap_or(Value::Null)
}
fn parse_json_field_allow_plain(raw: &str) -> Value {
if raw.trim().is_empty() {
return Value::Null;
}
serde_json::from_str(raw).unwrap_or_else(|_| Value::String(raw.to_string()))
}
fn normalize_db_value(value: &Value) -> Result<Value> {
match value {
Value::String(text) => {
if text.trim().is_empty() {
Ok(Value::Null)
} else if text.trim_start().starts_with('{') || text.trim_start().starts_with('[') {
Ok(serde_json::from_str(text).unwrap_or_else(|_| Value::String(text.clone())))
} else {
Ok(Value::String(text.clone()))
}
}
other => Ok(other.clone()),
}
}
fn parse_uuid_from_value(value: &Value) -> Uuid {
if let Some(text) = value.as_str() {
if let Ok(uuid) = Uuid::parse_str(text) {
return uuid;
}
}
if let Ok(text) = serde_json::to_string(value) {
if let Ok(uuid) = Uuid::parse_str(&text) {
return uuid;
}
}
Uuid::new_v4()
}
#[doc(hidden)]
pub fn sanitize_table_reference(reference: &str) -> Result<String> {
let parts: Vec<&str> = reference
.split('.')
.map(str::trim)
.filter(|part| !part.is_empty())
.collect();
match parts.as_slice() {
[table] => sanitize_identifier(table)
.ok_or_else(|| anyhow::anyhow!("invalid table name '{}'", table)),
[schema, table] => {
let schema = sanitize_identifier(schema)
.ok_or_else(|| anyhow::anyhow!("invalid schema '{}'", schema))?;
let table = sanitize_identifier(table)
.ok_or_else(|| anyhow::anyhow!("invalid table '{}'", table))?;
Ok(format!("{}.{}", schema, table))
}
_ => Err(anyhow::anyhow!(
"table reference must be `table` or `schema.table`"
)),
}
}
pub struct AuditLogger {
client: AthenaClient,
source: String,
user: String,
}
impl AuditLogger {
pub fn new(client: AthenaClient, source: impl Into<String>, user: impl Into<String>) -> Self {
Self {
client,
source: source.into(),
user: user.into(),
}
}
pub async fn log(&self, event: &SequinEvent, dry_run: bool) -> Result<()> {
let payload = json!({
"table_name": event.table_key(),
"record_id": event.record_id().to_string(),
"commit_lsn": event.seq,
"action": event.action_name(),
"old_values": event.old_values(),
"new_values": event.new_values(),
"source": self.source,
"username": self.user,
"metadata": {
"dry_run": dry_run
}
});
self.client
.insert("audit_logs")
.payload(payload)
.execute()
.await
.context("writing audit log entry")?;
Ok(())
}
}