use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use crate::dialect::SqlDialect;
use crate::error::Result;
use crate::types::{TableMetadata, Value};
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum QueryMode {
#[default]
Bulk,
Incrementing {
column: String,
},
Timestamp {
column: String,
delay: Option<Duration>,
},
TimestampIncrementing {
incrementing_column: String,
timestamp_column: String,
delay: Option<Duration>,
},
Custom {
query: String,
},
}
impl QueryMode {
pub fn incrementing(column: impl Into<String>) -> Self {
Self::Incrementing {
column: column.into(),
}
}
pub fn timestamp(column: impl Into<String>) -> Self {
Self::Timestamp {
column: column.into(),
delay: None,
}
}
pub fn timestamp_with_delay(column: impl Into<String>, delay: Duration) -> Self {
Self::Timestamp {
column: column.into(),
delay: Some(delay),
}
}
pub fn timestamp_incrementing(
timestamp_column: impl Into<String>,
incrementing_column: impl Into<String>,
) -> Self {
Self::TimestampIncrementing {
incrementing_column: incrementing_column.into(),
timestamp_column: timestamp_column.into(),
delay: None,
}
}
pub fn is_incremental(&self) -> bool {
!matches!(self, Self::Bulk)
}
}
#[derive(Debug, Clone, PartialEq, Default)]
pub struct SourceOffset {
pub incrementing: Option<Value>,
pub timestamp: Option<Value>,
pub custom: HashMap<String, Value>,
}
impl SourceOffset {
pub fn with_incrementing(value: impl Into<Value>) -> Self {
Self {
incrementing: Some(value.into()),
..Default::default()
}
}
pub fn with_timestamp(value: impl Into<Value>) -> Self {
Self {
timestamp: Some(value.into()),
..Default::default()
}
}
pub fn set_incrementing(&mut self, value: impl Into<Value>) {
self.incrementing = Some(value.into());
}
pub fn set_timestamp(&mut self, value: impl Into<Value>) {
self.timestamp = Some(value.into());
}
pub fn is_empty(&self) -> bool {
self.incrementing.is_none() && self.timestamp.is_none() && self.custom.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct TableSourceConfig {
pub schema: Option<String>,
pub table: String,
pub mode: QueryMode,
pub columns: Option<Vec<String>>,
pub where_clause: Option<String>,
pub batch_size: u32,
pub poll_interval: Duration,
pub topic: Option<String>,
}
impl TableSourceConfig {
pub fn bulk(table: impl Into<String>) -> Self {
Self {
schema: None,
table: table.into(),
mode: QueryMode::Bulk,
columns: None,
where_clause: None,
batch_size: 1000,
poll_interval: Duration::from_secs(5),
topic: None,
}
}
pub fn incrementing(table: impl Into<String>, column: impl Into<String>) -> Self {
Self {
schema: None,
table: table.into(),
mode: QueryMode::incrementing(column),
columns: None,
where_clause: None,
batch_size: 1000,
poll_interval: Duration::from_secs(1),
topic: None,
}
}
pub fn timestamp(table: impl Into<String>, column: impl Into<String>) -> Self {
Self {
schema: None,
table: table.into(),
mode: QueryMode::timestamp(column),
columns: None,
where_clause: None,
batch_size: 1000,
poll_interval: Duration::from_secs(1),
topic: None,
}
}
pub fn with_schema(mut self, schema: impl Into<String>) -> Self {
self.schema = Some(schema.into());
self
}
pub fn with_columns(mut self, columns: Vec<String>) -> Self {
self.columns = Some(columns);
self
}
pub fn with_where(mut self, clause: impl Into<String>) -> Self {
self.where_clause = Some(clause.into());
self
}
pub fn with_batch_size(mut self, size: u32) -> Self {
self.batch_size = size;
self
}
pub fn with_poll_interval(mut self, interval: Duration) -> Self {
self.poll_interval = interval;
self
}
pub fn with_topic(mut self, topic: impl Into<String>) -> Self {
self.topic = Some(topic.into());
self
}
pub fn topic_name(&self) -> &str {
self.topic.as_deref().unwrap_or(&self.table)
}
}
#[derive(Debug, Clone)]
pub struct SourceRecord {
pub schema: Option<String>,
pub table: String,
pub key: Vec<Value>,
pub values: HashMap<String, Value>,
pub offset: SourceOffset,
pub partition_key: Option<String>,
}
impl SourceRecord {
pub fn new(
schema: Option<String>,
table: impl Into<String>,
key: Vec<Value>,
values: HashMap<String, Value>,
offset: SourceOffset,
) -> Self {
Self {
schema,
table: table.into(),
key,
values,
offset,
partition_key: None,
}
}
pub fn with_partition_key(mut self, key: impl Into<String>) -> Self {
self.partition_key = Some(key.into());
self
}
}
#[derive(Debug)]
pub struct PollResult {
pub records: Vec<SourceRecord>,
pub offset: SourceOffset,
pub has_more: bool,
}
impl PollResult {
pub fn empty(offset: SourceOffset) -> Self {
Self {
records: vec![],
offset,
has_more: false,
}
}
pub fn is_empty(&self) -> bool {
self.records.is_empty()
}
pub fn len(&self) -> usize {
self.records.len()
}
}
#[derive(Debug, Clone, Default)]
pub struct SourceStats {
pub records_polled: u64,
pub polls: u64,
pub empty_polls: u64,
pub total_poll_time_ms: u64,
pub avg_records_per_poll: f64,
}
#[derive(Debug, Default)]
#[allow(missing_docs)]
pub struct AtomicSourceStats {
pub records_polled: AtomicU64,
pub polls: AtomicU64,
pub empty_polls: AtomicU64,
pub total_poll_time_ms: AtomicU64,
}
impl AtomicSourceStats {
pub fn record_poll(&self, records: u64, duration_ms: u64) {
self.records_polled.fetch_add(records, Ordering::Relaxed);
self.polls.fetch_add(1, Ordering::Relaxed);
self.total_poll_time_ms
.fetch_add(duration_ms, Ordering::Relaxed);
if records == 0 {
self.empty_polls.fetch_add(1, Ordering::Relaxed);
}
}
pub fn snapshot(&self) -> SourceStats {
let records = self.records_polled.load(Ordering::Relaxed);
let polls = self.polls.load(Ordering::Relaxed);
let avg = if polls > 0 {
records as f64 / polls as f64
} else {
0.0
};
SourceStats {
records_polled: records,
polls,
empty_polls: self.empty_polls.load(Ordering::Relaxed),
total_poll_time_ms: self.total_poll_time_ms.load(Ordering::Relaxed),
avg_records_per_poll: avg,
}
}
}
#[async_trait]
pub trait TableSource: Send + Sync {
async fn poll(&self, offset: &SourceOffset) -> Result<PollResult>;
async fn table_metadata(&self) -> Result<TableMetadata>;
fn config(&self) -> &TableSourceConfig;
fn stats(&self) -> SourceStats;
}
pub struct SourceQueryBuilder<'a> {
config: &'a TableSourceConfig,
dialect: &'a dyn SqlDialect,
}
impl<'a> SourceQueryBuilder<'a> {
pub fn new(config: &'a TableSourceConfig, dialect: &'a dyn SqlDialect) -> Self {
Self { config, dialect }
}
pub fn build_poll_query(&self, offset: &SourceOffset) -> (String, Vec<Value>) {
let columns = self
.config
.columns
.as_ref()
.map(|c| c.iter().map(String::as_str).collect::<Vec<_>>())
.unwrap_or_default();
let cols_str: Vec<_> = if columns.is_empty() {
vec!["*".to_string()]
} else {
columns
.iter()
.map(|c| self.dialect.quote_identifier(c))
.collect()
};
let table = match &self.config.schema {
Some(s) => format!(
"{}.{}",
self.dialect.quote_identifier(s),
self.dialect.quote_identifier(&self.config.table)
),
None => self.dialect.quote_identifier(&self.config.table),
};
let mut conditions = Vec::new();
let mut params = Vec::new();
match &self.config.mode {
QueryMode::Bulk => {
}
QueryMode::Incrementing { column } => {
if let Some(ref val) = offset.incrementing {
conditions.push(format!(
"{} > {}",
self.dialect.quote_identifier(column),
self.dialect.placeholder(params.len() + 1)
));
params.push(val.clone());
}
}
QueryMode::Timestamp { column, delay } => {
if let Some(ref val) = offset.timestamp {
conditions.push(format!(
"{} > {}",
self.dialect.quote_identifier(column),
self.dialect.placeholder(params.len() + 1)
));
params.push(val.clone());
}
if let Some(_delay) = delay {
conditions.push(format!(
"{} < {} - INTERVAL '{}' SECOND",
self.dialect.quote_identifier(column),
self.dialect.current_timestamp(),
_delay.as_secs()
));
}
}
QueryMode::TimestampIncrementing {
incrementing_column,
timestamp_column,
delay: _,
} => {
if let (Some(ref ts), Some(ref inc)) = (&offset.timestamp, &offset.incrementing) {
conditions.push(format!(
"({} > {} OR ({} = {} AND {} > {}))",
self.dialect.quote_identifier(timestamp_column),
self.dialect.placeholder(params.len() + 1),
self.dialect.quote_identifier(timestamp_column),
self.dialect.placeholder(params.len() + 2),
self.dialect.quote_identifier(incrementing_column),
self.dialect.placeholder(params.len() + 3)
));
params.push(ts.clone());
params.push(ts.clone());
params.push(inc.clone());
}
}
QueryMode::Custom { query } => {
return (query.clone(), params);
}
}
if let Some(ref clause) = self.config.where_clause {
conditions.push(format!("({})", clause));
}
let where_clause = if conditions.is_empty() {
String::new()
} else {
format!(" WHERE {}", conditions.join(" AND "))
};
let order_by = match &self.config.mode {
QueryMode::Bulk => String::new(),
QueryMode::Incrementing { column } => {
format!(" ORDER BY {} ASC", self.dialect.quote_identifier(column))
}
QueryMode::Timestamp { column, .. } => {
format!(" ORDER BY {} ASC", self.dialect.quote_identifier(column))
}
QueryMode::TimestampIncrementing {
incrementing_column,
timestamp_column,
..
} => {
format!(
" ORDER BY {} ASC, {} ASC",
self.dialect.quote_identifier(timestamp_column),
self.dialect.quote_identifier(incrementing_column)
)
}
QueryMode::Custom { .. } => String::new(),
};
let limit = self
.dialect
.limit_offset_sql(Some(u64::from(self.config.batch_size)), None);
let sql = format!(
"SELECT {} FROM {}{}{}{}",
cols_str.join(", "),
table,
where_clause,
order_by,
limit
);
(sql, params)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dialect::PostgresDialect;
#[test]
fn test_query_mode() {
assert!(!QueryMode::Bulk.is_incremental());
assert!(QueryMode::incrementing("id").is_incremental());
assert!(QueryMode::timestamp("updated_at").is_incremental());
}
#[test]
fn test_source_offset() {
let mut offset = SourceOffset::with_incrementing(100_i64);
assert!(!offset.is_empty());
assert_eq!(offset.incrementing, Some(Value::Int64(100)));
offset.set_timestamp("2024-01-01T00:00:00Z");
assert!(offset.timestamp.is_some());
}
#[test]
fn test_table_source_config() {
let config = TableSourceConfig::incrementing("users", "id")
.with_schema("public")
.with_batch_size(500)
.with_poll_interval(Duration::from_secs(2))
.with_topic("user-changes");
assert_eq!(config.table, "users");
assert_eq!(config.schema, Some("public".into()));
assert_eq!(config.batch_size, 500);
assert_eq!(config.topic_name(), "user-changes");
if let QueryMode::Incrementing { column } = &config.mode {
assert_eq!(column, "id");
} else {
panic!("Expected Incrementing mode");
}
}
#[test]
fn test_source_record() {
let mut values = HashMap::new();
values.insert("id".into(), Value::Int32(1));
values.insert("name".into(), Value::String("test".into()));
let record = SourceRecord::new(
Some("public".into()),
"users",
vec![Value::Int32(1)],
values,
SourceOffset::with_incrementing(1_i64),
)
.with_partition_key("user-1");
assert_eq!(record.table, "users");
assert_eq!(record.partition_key, Some("user-1".into()));
}
#[test]
fn test_poll_result() {
let result = PollResult::empty(SourceOffset::default());
assert!(result.is_empty());
assert_eq!(result.len(), 0);
assert!(!result.has_more);
}
#[test]
fn test_atomic_source_stats() {
let stats = AtomicSourceStats::default();
stats.record_poll(100, 50);
stats.record_poll(0, 10); stats.record_poll(50, 30);
let snapshot = stats.snapshot();
assert_eq!(snapshot.records_polled, 150);
assert_eq!(snapshot.polls, 3);
assert_eq!(snapshot.empty_polls, 1);
assert_eq!(snapshot.total_poll_time_ms, 90);
assert!((snapshot.avg_records_per_poll - 50.0).abs() < 0.01);
}
#[test]
fn test_source_query_builder_bulk() {
let config = TableSourceConfig::bulk("users").with_schema("public");
let dialect = PostgresDialect;
let builder = SourceQueryBuilder::new(&config, &dialect);
let (sql, params) = builder.build_poll_query(&SourceOffset::default());
assert!(sql.contains("SELECT *"));
assert!(sql.contains("\"public\".\"users\""));
assert!(sql.contains("LIMIT 1000"));
assert!(params.is_empty());
}
#[test]
fn test_source_query_builder_incrementing() {
let config = TableSourceConfig::incrementing("users", "id");
let dialect = PostgresDialect;
let builder = SourceQueryBuilder::new(&config, &dialect);
let offset = SourceOffset::with_incrementing(100_i64);
let (sql, params) = builder.build_poll_query(&offset);
assert!(sql.contains("WHERE"));
assert!(sql.contains("\"id\" > $1"));
assert!(sql.contains("ORDER BY \"id\" ASC"));
assert_eq!(params.len(), 1);
}
#[test]
fn test_source_query_builder_timestamp() {
let config =
TableSourceConfig::timestamp("events", "created_at").with_where("status = 'active'");
let dialect = PostgresDialect;
let builder = SourceQueryBuilder::new(&config, &dialect);
let offset = SourceOffset::with_timestamp("2024-01-01T00:00:00Z");
let (sql, params) = builder.build_poll_query(&offset);
assert!(sql.contains("\"created_at\" > $1"));
assert!(sql.contains("(status = 'active')"));
assert!(sql.contains("ORDER BY \"created_at\" ASC"));
assert_eq!(params.len(), 1);
}
}