use anyhow::{anyhow, Result};
use async_trait::async_trait;
use drasi_core::models::{
Element, ElementMetadata, ElementPropertyMap, ElementReference, SourceChange,
};
use log::{debug, error, info, warn};
use std::collections::HashMap;
use std::sync::Arc;
use tokio_postgres::{Client, NoTls, Row, Transaction};
use drasi_lib::bootstrap::{
BootstrapContext, BootstrapProvider, BootstrapRequest, BootstrapResult,
};
use drasi_lib::channels::SourceChangeEvent;
pub use crate::config::{PostgresBootstrapConfig, SslMode, TableKeyConfig};
pub struct PostgresBootstrapProvider {
config: PostgresConfig,
}
impl PostgresBootstrapProvider {
pub fn new(postgres_config: PostgresBootstrapConfig) -> Self {
Self {
config: PostgresConfig::from_bootstrap_config(postgres_config),
}
}
pub fn builder() -> PostgresBootstrapProviderBuilder {
PostgresBootstrapProviderBuilder::new()
}
}
pub struct PostgresBootstrapProviderBuilder {
host: String,
port: u16,
database: String,
user: String,
password: String,
tables: Vec<String>,
slot_name: String,
publication_name: String,
ssl_mode: SslMode,
table_keys: Vec<TableKeyConfig>,
}
impl PostgresBootstrapProviderBuilder {
pub fn new() -> Self {
Self {
host: "localhost".to_string(), port: 5432,
database: String::new(),
user: String::new(),
password: String::new(),
tables: Vec::new(),
slot_name: "drasi_slot".to_string(),
publication_name: "drasi_pub".to_string(),
ssl_mode: SslMode::Disable,
table_keys: Vec::new(),
}
}
pub fn with_host(mut self, host: impl Into<String>) -> Self {
self.host = host.into();
self
}
pub fn with_port(mut self, port: u16) -> Self {
self.port = port;
self
}
pub fn with_database(mut self, database: impl Into<String>) -> Self {
self.database = database.into();
self
}
pub fn with_user(mut self, user: impl Into<String>) -> Self {
self.user = user.into();
self
}
pub fn with_password(mut self, password: impl Into<String>) -> Self {
self.password = password.into();
self
}
pub fn with_tables(mut self, tables: Vec<String>) -> Self {
self.tables = tables;
self
}
pub fn with_table(mut self, table: impl Into<String>) -> Self {
self.tables.push(table.into());
self
}
pub fn with_slot_name(mut self, slot_name: impl Into<String>) -> Self {
self.slot_name = slot_name.into();
self
}
pub fn with_publication_name(mut self, publication_name: impl Into<String>) -> Self {
self.publication_name = publication_name.into();
self
}
pub fn with_ssl_mode(mut self, ssl_mode: SslMode) -> Self {
self.ssl_mode = ssl_mode;
self
}
pub fn with_table_keys(mut self, table_keys: Vec<TableKeyConfig>) -> Self {
self.table_keys = table_keys;
self
}
pub fn with_table_key(mut self, table: impl Into<String>, key_columns: Vec<String>) -> Self {
self.table_keys.push(TableKeyConfig {
table: table.into(),
key_columns,
});
self
}
pub fn build(self) -> PostgresBootstrapProvider {
let config = PostgresBootstrapConfig {
host: self.host,
port: self.port,
database: self.database,
user: self.user,
password: self.password,
tables: self.tables,
slot_name: self.slot_name,
publication_name: self.publication_name,
ssl_mode: self.ssl_mode,
table_keys: self.table_keys,
};
PostgresBootstrapProvider::new(config)
}
}
impl Default for PostgresBootstrapProviderBuilder {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl BootstrapProvider for PostgresBootstrapProvider {
async fn bootstrap(
&self,
request: BootstrapRequest,
context: &BootstrapContext,
event_tx: drasi_lib::channels::BootstrapEventSender,
_settings: Option<&drasi_lib::config::SourceSubscriptionSettings>,
) -> Result<BootstrapResult> {
info!(
"Starting PostgreSQL bootstrap for query '{}' with {} node labels and {} relation labels",
request.query_id,
request.node_labels.len(),
request.relation_labels.len()
);
let mut handler =
PostgresBootstrapHandler::new(self.config.clone(), context.source_id.clone());
let query_id = request.query_id.clone();
let count = handler.execute(request, context, event_tx).await?;
info!("Completed PostgreSQL bootstrap for query {query_id}: sent {count} records");
Ok(BootstrapResult {
event_count: count,
last_sequence: None,
sequences_aligned: false,
source_position: None,
})
}
}
#[derive(Debug, Clone)]
struct PostgresConfig {
pub host: String,
pub port: u16,
pub database: String,
pub user: String,
pub password: String,
#[allow(dead_code)]
pub tables: Vec<String>,
#[allow(dead_code)]
pub slot_name: String,
#[allow(dead_code)]
pub publication_name: String,
#[allow(dead_code)]
pub ssl_mode: SslMode,
pub table_keys: Vec<TableKeyConfig>,
}
impl PostgresConfig {
fn from_bootstrap_config(postgres_config: PostgresBootstrapConfig) -> Self {
PostgresConfig {
host: postgres_config.host.clone(),
port: postgres_config.port,
database: postgres_config.database.clone(),
user: postgres_config.user.clone(),
password: postgres_config.password.clone(),
tables: postgres_config.tables.clone(),
slot_name: postgres_config.slot_name.clone(),
publication_name: postgres_config.publication_name.clone(),
ssl_mode: postgres_config.ssl_mode,
table_keys: postgres_config.table_keys.clone(),
}
}
}
struct PostgresBootstrapHandler {
config: PostgresConfig,
source_id: String,
table_primary_keys: HashMap<String, Vec<String>>,
}
impl PostgresBootstrapHandler {
fn new(config: PostgresConfig, source_id: String) -> Self {
Self {
config,
source_id,
table_primary_keys: HashMap::new(),
}
}
async fn execute(
&mut self,
request: BootstrapRequest,
context: &BootstrapContext,
event_tx: drasi_lib::channels::BootstrapEventSender,
) -> Result<usize> {
info!(
"Bootstrap: Connecting to PostgreSQL at {}:{}",
self.config.host, self.config.port
);
let mut client = self.connect().await?;
self.query_primary_keys(&client).await?;
info!("Bootstrap: Connected, creating snapshot transaction...");
let (transaction, lsn) = self.create_snapshot(&mut client).await?;
info!("Bootstrap snapshot created at LSN: {lsn}");
let tables = self.resolve_tables(&request, &transaction).await?;
info!(
"Resolved {} labels to {} tables",
request.node_labels.len() + request.relation_labels.len(),
tables.len()
);
let mut total_count = 0;
for table in &tables {
let count = self
.bootstrap_table(&transaction, table, context, &event_tx)
.await?;
info!("Bootstrapped {count} rows from table '{table}'");
total_count += count;
}
transaction.commit().await?;
info!("Bootstrap completed: {total_count} total elements sent");
Ok(total_count)
}
async fn connect(&self) -> Result<Client> {
let connection_string = format!(
"host={} port={} user={} password={} dbname={}",
self.config.host,
self.config.port,
self.config.user,
self.config.password,
self.config.database
);
let (client, connection) = tokio_postgres::connect(&connection_string, NoTls).await?;
tokio::spawn(async move {
if let Err(e) = connection.await {
error!("PostgreSQL connection error: {e}");
}
});
Ok(client)
}
async fn create_snapshot<'a>(
&self,
client: &'a mut Client,
) -> Result<(Transaction<'a>, String)> {
let transaction = client
.build_transaction()
.isolation_level(tokio_postgres::IsolationLevel::RepeatableRead)
.start()
.await?;
let row = transaction
.query_one("SELECT pg_current_wal_lsn()::text", &[])
.await?;
let lsn: String = row.get(0);
Ok((transaction, lsn))
}
async fn resolve_tables(
&self,
request: &BootstrapRequest,
transaction: &Transaction<'_>,
) -> Result<Vec<String>> {
let mut tables = Vec::new();
let all_labels: Vec<String> = request
.node_labels
.iter()
.chain(request.relation_labels.iter())
.cloned()
.collect();
for label in all_labels {
if self.table_exists(transaction, &label).await? {
tables.push(label);
} else {
warn!("Table '{label}' does not exist, skipping");
}
}
Ok(tables)
}
async fn table_exists(&self, transaction: &Transaction<'_>, table_name: &str) -> Result<bool> {
let row = transaction
.query_one(
"SELECT EXISTS (
SELECT 1 FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = $1
)",
&[&table_name],
)
.await?;
Ok(row.get(0))
}
async fn bootstrap_table(
&self,
transaction: &Transaction<'_>,
table: &str,
context: &BootstrapContext,
event_tx: &drasi_lib::channels::BootstrapEventSender,
) -> Result<usize> {
debug!("Starting bootstrap of table '{table}'");
let columns = self.get_table_columns(transaction, table).await?;
let query = format!("SELECT * FROM \"{}\"", table.replace('"', "\"\""));
let rows = transaction.query(&query, &[]).await?;
let mut count = 0;
let mut batch = Vec::new();
let batch_size = 1000;
for row in rows {
let source_change = self.row_to_source_change(&row, table, &columns).await?;
batch.push(SourceChangeEvent {
source_id: self.source_id.clone(),
change: source_change,
timestamp: chrono::Utc::now(),
});
if batch.len() >= batch_size {
self.send_batch(&mut batch, context, event_tx).await?;
count += batch_size;
}
}
if !batch.is_empty() {
count += batch.len();
self.send_batch(&mut batch, context, event_tx).await?;
}
Ok(count)
}
async fn get_table_columns(
&self,
transaction: &Transaction<'_>,
table_name: &str,
) -> Result<Vec<ColumnInfo>> {
let rows = transaction
.query(
"SELECT column_name,
CASE
WHEN data_type = 'character varying' THEN 1043
WHEN data_type = 'integer' THEN 23
WHEN data_type = 'bigint' THEN 20
WHEN data_type = 'smallint' THEN 21
WHEN data_type = 'text' THEN 25
WHEN data_type = 'boolean' THEN 16
WHEN data_type = 'numeric' THEN 1700
WHEN data_type = 'real' THEN 700
WHEN data_type = 'double precision' THEN 701
WHEN data_type = 'timestamp without time zone' THEN 1114
WHEN data_type = 'timestamp with time zone' THEN 1184
WHEN data_type = 'date' THEN 1082
WHEN data_type = 'uuid' THEN 2950
WHEN data_type = 'json' THEN 114
WHEN data_type = 'jsonb' THEN 3802
ELSE 25 -- Default to text
END as type_oid
FROM information_schema.columns
WHERE table_schema = 'public' AND table_name = $1
ORDER BY ordinal_position",
&[&table_name],
)
.await?;
let mut columns = Vec::new();
for row in rows {
columns.push(ColumnInfo {
name: row.get(0),
type_oid: row.get::<_, i32>(1),
});
}
Ok(columns)
}
async fn query_primary_keys(&mut self, client: &Client) -> Result<()> {
info!("Querying primary key information from PostgreSQL system catalogs");
let query = r#"
SELECT
n.nspname as schema_name,
c.relname as table_name,
a.attname as column_name
FROM pg_constraint con
JOIN pg_class c ON con.conrelid = c.oid
JOIN pg_namespace n ON c.relnamespace = n.oid
JOIN pg_attribute a ON a.attrelid = c.oid
WHERE con.contype = 'p' -- Primary key constraint
AND a.attnum = ANY(con.conkey)
AND n.nspname NOT IN ('pg_catalog', 'information_schema')
ORDER BY n.nspname, c.relname, array_position(con.conkey, a.attnum)
"#;
let rows = client.query(query, &[]).await?;
let mut primary_keys: HashMap<String, Vec<String>> = HashMap::new();
for row in rows {
let schema: &str = row.get(0);
let table: &str = row.get(1);
let column: &str = row.get(2);
let table_key = if schema == "public" {
table.to_string()
} else {
format!("{schema}.{table}")
};
primary_keys
.entry(table_key.clone())
.or_default()
.push(column.to_string());
debug!("Found primary key column '{column}' for table '{table_key}'");
}
for table_key_config in &self.config.table_keys {
let table_name = &table_key_config.table;
let key_columns = &table_key_config.key_columns;
if !key_columns.is_empty() {
info!(
"Using user-configured key columns for table '{table_name}': {key_columns:?}"
);
primary_keys.insert(table_name.clone(), key_columns.clone());
}
}
self.table_primary_keys = primary_keys.clone();
info!("Found primary keys for {} tables", primary_keys.len());
for (table, keys) in &primary_keys {
info!("Table '{table}' primary key columns: {keys:?}");
}
Ok(())
}
async fn row_to_source_change(
&self,
row: &Row,
table: &str,
columns: &[ColumnInfo],
) -> Result<SourceChange> {
let mut properties = ElementPropertyMap::new();
let pk_columns = self.table_primary_keys.get(table);
let mut pk_values = Vec::new();
for (idx, column) in columns.iter().enumerate() {
let is_pk = pk_columns
.map(|pks| pks.contains(&column.name))
.unwrap_or(false);
let element_value = match column.type_oid {
16 => {
if let Ok(Some(val)) = row.try_get::<_, Option<bool>>(idx) {
drasi_core::models::ElementValue::Bool(val)
} else {
drasi_core::models::ElementValue::Null
}
}
21 | 23 | 20 => {
if let Ok(Some(val)) = row.try_get::<_, Option<i64>>(idx) {
drasi_core::models::ElementValue::Integer(val)
} else if let Ok(Some(val)) = row.try_get::<_, Option<i32>>(idx) {
drasi_core::models::ElementValue::Integer(val as i64)
} else if let Ok(Some(val)) = row.try_get::<_, Option<i16>>(idx) {
drasi_core::models::ElementValue::Integer(val as i64)
} else {
drasi_core::models::ElementValue::Null
}
}
700 | 701 => {
if let Ok(Some(val)) = row.try_get::<_, Option<f64>>(idx) {
drasi_core::models::ElementValue::Float(ordered_float::OrderedFloat(val))
} else if let Ok(Some(val)) = row.try_get::<_, Option<f32>>(idx) {
drasi_core::models::ElementValue::Float(ordered_float::OrderedFloat(
val as f64,
))
} else {
drasi_core::models::ElementValue::Null
}
}
1700 => {
if let Ok(Some(val)) = row.try_get::<_, Option<rust_decimal::Decimal>>(idx) {
drasi_core::models::ElementValue::Float(ordered_float::OrderedFloat(
val.to_string().parse::<f64>().unwrap_or(0.0),
))
} else {
drasi_core::models::ElementValue::Null
}
}
25 | 1043 | 19 => {
if let Ok(Some(val)) = row.try_get::<_, Option<String>>(idx) {
drasi_core::models::ElementValue::String(std::sync::Arc::from(val))
} else {
drasi_core::models::ElementValue::Null
}
}
1114 | 1184 => {
if let Ok(Some(val)) = row.try_get::<_, Option<chrono::NaiveDateTime>>(idx) {
drasi_core::models::ElementValue::String(std::sync::Arc::from(
val.to_string(),
))
} else if let Ok(Some(val)) =
row.try_get::<_, Option<chrono::DateTime<chrono::Utc>>>(idx)
{
drasi_core::models::ElementValue::String(std::sync::Arc::from(
val.to_string(),
))
} else {
drasi_core::models::ElementValue::Null
}
}
_ => {
if let Ok(Some(val)) = row.try_get::<_, Option<String>>(idx) {
drasi_core::models::ElementValue::String(std::sync::Arc::from(val))
} else {
drasi_core::models::ElementValue::Null
}
}
};
if is_pk && !matches!(element_value, drasi_core::models::ElementValue::Null) {
let value_str = match &element_value {
drasi_core::models::ElementValue::Integer(i) => i.to_string(),
drasi_core::models::ElementValue::Float(f) => f.to_string(),
drasi_core::models::ElementValue::String(s) => s.to_string(),
drasi_core::models::ElementValue::Bool(b) => b.to_string(),
_ => format!("{element_value:?}"),
};
pk_values.push(value_str);
}
properties.insert(&column.name, element_value);
}
let elem_id = if !pk_values.is_empty() {
format!("{}:{}", table, pk_values.join("_"))
} else if pk_columns.is_none() || pk_columns.map(|pks| pks.is_empty()).unwrap_or(true) {
warn!(
"No primary key found for table '{table}'. Consider adding 'table_keys' configuration."
);
format!("{}:{}", table, uuid::Uuid::new_v4())
} else {
format!("{}:{}", table, uuid::Uuid::new_v4())
};
let metadata = ElementMetadata {
reference: ElementReference::new(&self.source_id, &elem_id),
labels: Arc::from(vec![Arc::from(table)]),
effective_from: chrono::Utc::now().timestamp_millis() as u64,
};
let element = Element::Node {
metadata,
properties,
};
Ok(SourceChange::Insert { element })
}
async fn send_batch(
&self,
batch: &mut Vec<SourceChangeEvent>,
context: &BootstrapContext,
event_tx: &drasi_lib::channels::BootstrapEventSender,
) -> Result<()> {
for event in batch.drain(..) {
let sequence = context.next_sequence();
let bootstrap_event = drasi_lib::channels::BootstrapEvent {
source_id: event.source_id,
change: event.change,
timestamp: event.timestamp,
sequence,
};
event_tx.send(bootstrap_event).await.map_err(|e| {
anyhow!("Failed to send bootstrap event to channel (channel may be closed): {e}")
})?;
}
Ok(())
}
}
#[derive(Debug)]
struct ColumnInfo {
name: String,
type_oid: i32,
}
#[cfg(test)]
mod tests {
use drasi_core::models::validate_effective_from;
#[test]
fn effective_from_uses_milliseconds() {
let effective_from = chrono::Utc::now().timestamp_millis() as u64;
assert!(
validate_effective_from(effective_from).is_ok(),
"Postgres bootstrapper effective_from ({effective_from}) should be in millisecond range"
);
}
#[test]
fn effective_from_rejects_nanoseconds_pattern() {
let bad_effective_from = chrono::Utc::now().timestamp_nanos_opt().unwrap() as u64;
assert!(
validate_effective_from(bad_effective_from).is_err(),
"Nanosecond timestamp ({bad_effective_from}) should be rejected"
);
}
}