1use anyhow::{anyhow, Result};
18use async_trait::async_trait;
19use drasi_core::models::{
20 Element, ElementMetadata, ElementPropertyMap, ElementReference, SourceChange,
21};
22use log::{debug, error, info, warn};
23use std::collections::HashMap;
24use std::sync::Arc;
25use tokio_postgres::{Client, NoTls, Row, Transaction};
26
27use drasi_lib::bootstrap::{
28 BootstrapContext, BootstrapProvider, BootstrapRequest, BootstrapResult,
29};
30use drasi_lib::channels::SourceChangeEvent;
31
32pub use crate::config::{PostgresBootstrapConfig, SslMode, TableKeyConfig};
33
34pub struct PostgresBootstrapProvider {
39 config: PostgresConfig,
40}
41
42impl PostgresBootstrapProvider {
43 pub fn new(postgres_config: PostgresBootstrapConfig) -> Self {
45 Self {
46 config: PostgresConfig::from_bootstrap_config(postgres_config),
47 }
48 }
49
50 pub fn builder() -> PostgresBootstrapProviderBuilder {
52 PostgresBootstrapProviderBuilder::new()
53 }
54}
55
56pub struct PostgresBootstrapProviderBuilder {
73 host: String,
74 port: u16,
75 database: String,
76 user: String,
77 password: String,
78 tables: Vec<String>,
79 slot_name: String,
80 publication_name: String,
81 ssl_mode: SslMode,
82 table_keys: Vec<TableKeyConfig>,
83}
84
85impl PostgresBootstrapProviderBuilder {
86 pub fn new() -> Self {
88 Self {
89 host: "localhost".to_string(), port: 5432,
91 database: String::new(),
92 user: String::new(),
93 password: String::new(),
94 tables: Vec::new(),
95 slot_name: "drasi_slot".to_string(),
96 publication_name: "drasi_pub".to_string(),
97 ssl_mode: SslMode::Disable,
98 table_keys: Vec::new(),
99 }
100 }
101
102 pub fn with_host(mut self, host: impl Into<String>) -> Self {
104 self.host = host.into();
105 self
106 }
107
108 pub fn with_port(mut self, port: u16) -> Self {
110 self.port = port;
111 self
112 }
113
114 pub fn with_database(mut self, database: impl Into<String>) -> Self {
116 self.database = database.into();
117 self
118 }
119
120 pub fn with_user(mut self, user: impl Into<String>) -> Self {
122 self.user = user.into();
123 self
124 }
125
126 pub fn with_password(mut self, password: impl Into<String>) -> Self {
128 self.password = password.into();
129 self
130 }
131
132 pub fn with_tables(mut self, tables: Vec<String>) -> Self {
134 self.tables = tables;
135 self
136 }
137
138 pub fn with_table(mut self, table: impl Into<String>) -> Self {
140 self.tables.push(table.into());
141 self
142 }
143
144 pub fn with_slot_name(mut self, slot_name: impl Into<String>) -> Self {
146 self.slot_name = slot_name.into();
147 self
148 }
149
150 pub fn with_publication_name(mut self, publication_name: impl Into<String>) -> Self {
152 self.publication_name = publication_name.into();
153 self
154 }
155
156 pub fn with_ssl_mode(mut self, ssl_mode: SslMode) -> Self {
158 self.ssl_mode = ssl_mode;
159 self
160 }
161
162 pub fn with_table_keys(mut self, table_keys: Vec<TableKeyConfig>) -> Self {
164 self.table_keys = table_keys;
165 self
166 }
167
168 pub fn with_table_key(mut self, table: impl Into<String>, key_columns: Vec<String>) -> Self {
170 self.table_keys.push(TableKeyConfig {
171 table: table.into(),
172 key_columns,
173 });
174 self
175 }
176
177 pub fn build(self) -> PostgresBootstrapProvider {
179 let config = PostgresBootstrapConfig {
180 host: self.host,
181 port: self.port,
182 database: self.database,
183 user: self.user,
184 password: self.password,
185 tables: self.tables,
186 slot_name: self.slot_name,
187 publication_name: self.publication_name,
188 ssl_mode: self.ssl_mode,
189 table_keys: self.table_keys,
190 };
191 PostgresBootstrapProvider::new(config)
192 }
193}
194
195impl Default for PostgresBootstrapProviderBuilder {
196 fn default() -> Self {
197 Self::new()
198 }
199}
200
201#[async_trait]
202impl BootstrapProvider for PostgresBootstrapProvider {
203 async fn bootstrap(
204 &self,
205 request: BootstrapRequest,
206 context: &BootstrapContext,
207 event_tx: drasi_lib::channels::BootstrapEventSender,
208 _settings: Option<&drasi_lib::config::SourceSubscriptionSettings>,
209 ) -> Result<BootstrapResult> {
210 info!(
211 "Starting PostgreSQL bootstrap for query '{}' with {} node labels and {} relation labels",
212 request.query_id,
213 request.node_labels.len(),
214 request.relation_labels.len()
215 );
216
217 let mut handler =
219 PostgresBootstrapHandler::new(self.config.clone(), context.source_id.clone());
220
221 let query_id = request.query_id.clone();
223
224 let count = handler.execute(request, context, event_tx).await?;
226
227 info!("Completed PostgreSQL bootstrap for query {query_id}: sent {count} records");
228
229 Ok(BootstrapResult {
235 event_count: count,
236 last_sequence: None,
237 sequences_aligned: false,
238 })
239 }
240}
241
242#[derive(Debug, Clone)]
244struct PostgresConfig {
245 pub host: String,
246 pub port: u16,
247 pub database: String,
248 pub user: String,
249 pub password: String,
250 #[allow(dead_code)]
251 pub tables: Vec<String>,
252 #[allow(dead_code)]
253 pub slot_name: String,
254 #[allow(dead_code)]
255 pub publication_name: String,
256 #[allow(dead_code)]
257 pub ssl_mode: SslMode,
258 pub table_keys: Vec<TableKeyConfig>,
259}
260
261impl PostgresConfig {
262 fn from_bootstrap_config(postgres_config: PostgresBootstrapConfig) -> Self {
263 PostgresConfig {
264 host: postgres_config.host.clone(),
265 port: postgres_config.port,
266 database: postgres_config.database.clone(),
267 user: postgres_config.user.clone(),
268 password: postgres_config.password.clone(),
269 tables: postgres_config.tables.clone(),
270 slot_name: postgres_config.slot_name.clone(),
271 publication_name: postgres_config.publication_name.clone(),
272 ssl_mode: postgres_config.ssl_mode,
273 table_keys: postgres_config.table_keys.clone(),
274 }
275 }
276}
277
278struct PostgresBootstrapHandler {
280 config: PostgresConfig,
281 source_id: String,
282 table_primary_keys: HashMap<String, Vec<String>>,
284}
285
286impl PostgresBootstrapHandler {
287 fn new(config: PostgresConfig, source_id: String) -> Self {
288 Self {
289 config,
290 source_id,
291 table_primary_keys: HashMap::new(),
292 }
293 }
294
295 async fn execute(
297 &mut self,
298 request: BootstrapRequest,
299 context: &BootstrapContext,
300 event_tx: drasi_lib::channels::BootstrapEventSender,
301 ) -> Result<usize> {
302 info!(
303 "Bootstrap: Connecting to PostgreSQL at {}:{}",
304 self.config.host, self.config.port
305 );
306
307 let mut client = self.connect().await?;
309
310 self.query_primary_keys(&client).await?;
312
313 info!("Bootstrap: Connected, creating snapshot transaction...");
314 let (transaction, lsn) = self.create_snapshot(&mut client).await?;
316
317 info!("Bootstrap snapshot created at LSN: {lsn}");
318
319 let tables = self.resolve_tables(&request, &transaction).await?;
321 info!(
322 "Resolved {} labels to {} tables",
323 request.node_labels.len() + request.relation_labels.len(),
324 tables.len()
325 );
326
327 let mut total_count = 0;
329 for table in &tables {
330 let count = self
331 .bootstrap_table(&transaction, table, context, &event_tx)
332 .await?;
333 info!("Bootstrapped {count} rows from table '{table}'");
334 total_count += count;
335 }
336
337 transaction.commit().await?;
339
340 info!("Bootstrap completed: {total_count} total elements sent");
341 Ok(total_count)
342 }
343
344 async fn connect(&self) -> Result<Client> {
346 let connection_string = format!(
347 "host={} port={} user={} password={} dbname={}",
348 self.config.host,
349 self.config.port,
350 self.config.user,
351 self.config.password,
352 self.config.database
353 );
354
355 let (client, connection) = tokio_postgres::connect(&connection_string, NoTls).await?;
356
357 tokio::spawn(async move {
359 if let Err(e) = connection.await {
360 error!("PostgreSQL connection error: {e}");
361 }
362 });
363
364 Ok(client)
365 }
366
367 async fn create_snapshot<'a>(
369 &self,
370 client: &'a mut Client,
371 ) -> Result<(Transaction<'a>, String)> {
372 let transaction = client
374 .build_transaction()
375 .isolation_level(tokio_postgres::IsolationLevel::RepeatableRead)
376 .start()
377 .await?;
378
379 let row = transaction
381 .query_one("SELECT pg_current_wal_lsn()::text", &[])
382 .await?;
383 let lsn: String = row.get(0);
384
385 Ok((transaction, lsn))
386 }
387
388 async fn resolve_tables(
392 &self,
393 request: &BootstrapRequest,
394 transaction: &Transaction<'_>,
395 ) -> Result<Vec<String>> {
396 let mut tables = Vec::new();
397
398 let all_labels: Vec<String> = request
400 .node_labels
401 .iter()
402 .chain(request.relation_labels.iter())
403 .cloned()
404 .collect();
405
406 for label in all_labels {
407 if self.table_exists(transaction, &label).await? {
408 tables.push(label);
409 } else {
410 warn!("Table '{label}' does not exist, skipping");
411 }
412 }
413
414 Ok(tables)
415 }
416
417 async fn table_exists(&self, transaction: &Transaction<'_>, table_name: &str) -> Result<bool> {
419 let row = transaction
420 .query_one(
421 "SELECT EXISTS (
422 SELECT 1 FROM information_schema.tables
423 WHERE table_schema = 'public'
424 AND table_name = $1
425 )",
426 &[&table_name],
427 )
428 .await?;
429
430 Ok(row.get(0))
431 }
432
433 async fn bootstrap_table(
435 &self,
436 transaction: &Transaction<'_>,
437 table: &str,
438 context: &BootstrapContext,
439 event_tx: &drasi_lib::channels::BootstrapEventSender,
440 ) -> Result<usize> {
441 debug!("Starting bootstrap of table '{table}'");
442
443 let columns = self.get_table_columns(transaction, table).await?;
445
446 let query = format!("SELECT * FROM \"{}\"", table.replace('"', "\"\""));
448 let rows = transaction.query(&query, &[]).await?;
449
450 let mut count = 0;
451 let mut batch = Vec::new();
452 let batch_size = 1000;
453
454 for row in rows {
455 let source_change = self.row_to_source_change(&row, table, &columns).await?;
456
457 batch.push(SourceChangeEvent {
458 source_id: self.source_id.clone(),
459 change: source_change,
460 timestamp: chrono::Utc::now(),
461 });
462
463 if batch.len() >= batch_size {
464 self.send_batch(&mut batch, context, event_tx).await?;
465 count += batch_size;
466 }
467 }
468
469 if !batch.is_empty() {
471 count += batch.len();
472 self.send_batch(&mut batch, context, event_tx).await?;
473 }
474
475 Ok(count)
476 }
477
478 async fn get_table_columns(
480 &self,
481 transaction: &Transaction<'_>,
482 table_name: &str,
483 ) -> Result<Vec<ColumnInfo>> {
484 let rows = transaction
485 .query(
486 "SELECT column_name,
487 CASE
488 WHEN data_type = 'character varying' THEN 1043
489 WHEN data_type = 'integer' THEN 23
490 WHEN data_type = 'bigint' THEN 20
491 WHEN data_type = 'smallint' THEN 21
492 WHEN data_type = 'text' THEN 25
493 WHEN data_type = 'boolean' THEN 16
494 WHEN data_type = 'numeric' THEN 1700
495 WHEN data_type = 'real' THEN 700
496 WHEN data_type = 'double precision' THEN 701
497 WHEN data_type = 'timestamp without time zone' THEN 1114
498 WHEN data_type = 'timestamp with time zone' THEN 1184
499 WHEN data_type = 'date' THEN 1082
500 WHEN data_type = 'uuid' THEN 2950
501 WHEN data_type = 'json' THEN 114
502 WHEN data_type = 'jsonb' THEN 3802
503 ELSE 25 -- Default to text
504 END as type_oid
505 FROM information_schema.columns
506 WHERE table_schema = 'public' AND table_name = $1
507 ORDER BY ordinal_position",
508 &[&table_name],
509 )
510 .await?;
511
512 let mut columns = Vec::new();
513 for row in rows {
514 columns.push(ColumnInfo {
515 name: row.get(0),
516 type_oid: row.get::<_, i32>(1),
517 });
518 }
519
520 Ok(columns)
521 }
522
523 async fn query_primary_keys(&mut self, client: &Client) -> Result<()> {
525 info!("Querying primary key information from PostgreSQL system catalogs");
526
527 let query = r#"
528 SELECT
529 n.nspname as schema_name,
530 c.relname as table_name,
531 a.attname as column_name
532 FROM pg_constraint con
533 JOIN pg_class c ON con.conrelid = c.oid
534 JOIN pg_namespace n ON c.relnamespace = n.oid
535 JOIN pg_attribute a ON a.attrelid = c.oid
536 WHERE con.contype = 'p' -- Primary key constraint
537 AND a.attnum = ANY(con.conkey)
538 AND n.nspname NOT IN ('pg_catalog', 'information_schema')
539 ORDER BY n.nspname, c.relname, array_position(con.conkey, a.attnum)
540 "#;
541
542 let rows = client.query(query, &[]).await?;
543
544 let mut primary_keys: HashMap<String, Vec<String>> = HashMap::new();
545
546 for row in rows {
547 let schema: &str = row.get(0);
548 let table: &str = row.get(1);
549 let column: &str = row.get(2);
550
551 let table_key = if schema == "public" {
553 table.to_string()
554 } else {
555 format!("{schema}.{table}")
556 };
557
558 primary_keys
559 .entry(table_key.clone())
560 .or_default()
561 .push(column.to_string());
562
563 debug!("Found primary key column '{column}' for table '{table_key}'");
564 }
565
566 for table_key_config in &self.config.table_keys {
568 let table_name = &table_key_config.table;
569 let key_columns = &table_key_config.key_columns;
570
571 if !key_columns.is_empty() {
572 info!(
573 "Using user-configured key columns for table '{table_name}': {key_columns:?}"
574 );
575 primary_keys.insert(table_name.clone(), key_columns.clone());
576 }
577 }
578
579 self.table_primary_keys = primary_keys.clone();
581
582 info!("Found primary keys for {} tables", primary_keys.len());
583 for (table, keys) in &primary_keys {
584 info!("Table '{table}' primary key columns: {keys:?}");
585 }
586
587 Ok(())
588 }
589
590 async fn row_to_source_change(
592 &self,
593 row: &Row,
594 table: &str,
595 columns: &[ColumnInfo],
596 ) -> Result<SourceChange> {
597 let mut properties = ElementPropertyMap::new();
598
599 let pk_columns = self.table_primary_keys.get(table);
601
602 let mut pk_values = Vec::new();
604
605 for (idx, column) in columns.iter().enumerate() {
606 let is_pk = pk_columns
608 .map(|pks| pks.contains(&column.name))
609 .unwrap_or(false);
610
611 let element_value = match column.type_oid {
613 16 => {
614 if let Ok(Some(val)) = row.try_get::<_, Option<bool>>(idx) {
616 drasi_core::models::ElementValue::Bool(val)
617 } else {
618 drasi_core::models::ElementValue::Null
619 }
620 }
621 21 | 23 | 20 => {
622 if let Ok(Some(val)) = row.try_get::<_, Option<i64>>(idx) {
624 drasi_core::models::ElementValue::Integer(val)
625 } else if let Ok(Some(val)) = row.try_get::<_, Option<i32>>(idx) {
626 drasi_core::models::ElementValue::Integer(val as i64)
627 } else if let Ok(Some(val)) = row.try_get::<_, Option<i16>>(idx) {
628 drasi_core::models::ElementValue::Integer(val as i64)
629 } else {
630 drasi_core::models::ElementValue::Null
631 }
632 }
633 700 | 701 => {
634 if let Ok(Some(val)) = row.try_get::<_, Option<f64>>(idx) {
636 drasi_core::models::ElementValue::Float(ordered_float::OrderedFloat(val))
637 } else if let Ok(Some(val)) = row.try_get::<_, Option<f32>>(idx) {
638 drasi_core::models::ElementValue::Float(ordered_float::OrderedFloat(
639 val as f64,
640 ))
641 } else {
642 drasi_core::models::ElementValue::Null
643 }
644 }
645 1700 => {
646 if let Ok(Some(val)) = row.try_get::<_, Option<rust_decimal::Decimal>>(idx) {
648 drasi_core::models::ElementValue::Float(ordered_float::OrderedFloat(
649 val.to_string().parse::<f64>().unwrap_or(0.0),
650 ))
651 } else {
652 drasi_core::models::ElementValue::Null
653 }
654 }
655 25 | 1043 | 19 => {
656 if let Ok(Some(val)) = row.try_get::<_, Option<String>>(idx) {
658 drasi_core::models::ElementValue::String(std::sync::Arc::from(val))
659 } else {
660 drasi_core::models::ElementValue::Null
661 }
662 }
663 1114 | 1184 => {
664 if let Ok(Some(val)) = row.try_get::<_, Option<chrono::NaiveDateTime>>(idx) {
666 drasi_core::models::ElementValue::String(std::sync::Arc::from(
667 val.to_string(),
668 ))
669 } else if let Ok(Some(val)) =
670 row.try_get::<_, Option<chrono::DateTime<chrono::Utc>>>(idx)
671 {
672 drasi_core::models::ElementValue::String(std::sync::Arc::from(
673 val.to_string(),
674 ))
675 } else {
676 drasi_core::models::ElementValue::Null
677 }
678 }
679 _ => {
680 if let Ok(Some(val)) = row.try_get::<_, Option<String>>(idx) {
682 drasi_core::models::ElementValue::String(std::sync::Arc::from(val))
683 } else {
684 drasi_core::models::ElementValue::Null
685 }
686 }
687 };
688
689 if is_pk && !matches!(element_value, drasi_core::models::ElementValue::Null) {
691 let value_str = match &element_value {
692 drasi_core::models::ElementValue::Integer(i) => i.to_string(),
693 drasi_core::models::ElementValue::Float(f) => f.to_string(),
694 drasi_core::models::ElementValue::String(s) => s.to_string(),
695 drasi_core::models::ElementValue::Bool(b) => b.to_string(),
696 _ => format!("{element_value:?}"),
697 };
698 pk_values.push(value_str);
699 }
700
701 properties.insert(&column.name, element_value);
702 }
703
704 let elem_id = if !pk_values.is_empty() {
707 format!("{}:{}", table, pk_values.join("_"))
709 } else if pk_columns.is_none() || pk_columns.map(|pks| pks.is_empty()).unwrap_or(true) {
710 warn!(
712 "No primary key found for table '{table}'. Consider adding 'table_keys' configuration."
713 );
714 format!("{}:{}", table, uuid::Uuid::new_v4())
716 } else {
717 format!("{}:{}", table, uuid::Uuid::new_v4())
719 };
720
721 let metadata = ElementMetadata {
722 reference: ElementReference::new(&self.source_id, &elem_id),
723 labels: Arc::from(vec![Arc::from(table)]),
724 effective_from: chrono::Utc::now().timestamp_millis() as u64,
725 };
726
727 let element = Element::Node {
728 metadata,
729 properties,
730 };
731
732 Ok(SourceChange::Insert { element })
733 }
734
735 async fn send_batch(
737 &self,
738 batch: &mut Vec<SourceChangeEvent>,
739 context: &BootstrapContext,
740 event_tx: &drasi_lib::channels::BootstrapEventSender,
741 ) -> Result<()> {
742 for event in batch.drain(..) {
743 let sequence = context.next_sequence();
745
746 let bootstrap_event = drasi_lib::channels::BootstrapEvent {
747 source_id: event.source_id,
748 change: event.change,
749 timestamp: event.timestamp,
750 sequence,
751 };
752 event_tx.send(bootstrap_event).await.map_err(|e| {
753 anyhow!("Failed to send bootstrap event to channel (channel may be closed): {e}")
754 })?;
755 }
756 Ok(())
757 }
758}
759
760#[derive(Debug)]
761struct ColumnInfo {
762 name: String,
763 type_oid: i32,
764}
765
766#[cfg(test)]
767mod tests {
768 use drasi_core::models::validate_effective_from;
769
770 #[test]
776 fn effective_from_uses_milliseconds() {
777 let effective_from = chrono::Utc::now().timestamp_millis() as u64;
778 assert!(
779 validate_effective_from(effective_from).is_ok(),
780 "Postgres bootstrapper effective_from ({effective_from}) should be in millisecond range"
781 );
782 }
783
784 #[test]
786 fn effective_from_rejects_nanoseconds_pattern() {
787 let bad_effective_from = chrono::Utc::now().timestamp_nanos_opt().unwrap() as u64;
789 assert!(
790 validate_effective_from(bad_effective_from).is_err(),
791 "Nanosecond timestamp ({bad_effective_from}) should be rejected"
792 );
793 }
794}