1use anyhow::{anyhow, Result};
18use async_trait::async_trait;
19use drasi_core::models::{
20 Element, ElementMetadata, ElementPropertyMap, ElementReference, SourceChange,
21};
22use log::{debug, error, info, warn};
23use std::collections::HashMap;
24use std::sync::Arc;
25use tokio_postgres::{Client, NoTls, Row, Transaction};
26
27use drasi_lib::bootstrap::{
28 BootstrapContext, BootstrapProvider, BootstrapRequest, BootstrapResult,
29};
30use drasi_lib::channels::SourceChangeEvent;
31
32pub use crate::config::{PostgresBootstrapConfig, SslMode, TableKeyConfig};
33
34pub struct PostgresBootstrapProvider {
39 config: PostgresConfig,
40}
41
42impl PostgresBootstrapProvider {
43 pub fn new(postgres_config: PostgresBootstrapConfig) -> Self {
45 Self {
46 config: PostgresConfig::from_bootstrap_config(postgres_config),
47 }
48 }
49
50 pub fn builder() -> PostgresBootstrapProviderBuilder {
52 PostgresBootstrapProviderBuilder::new()
53 }
54}
55
56pub struct PostgresBootstrapProviderBuilder {
73 host: String,
74 port: u16,
75 database: String,
76 user: String,
77 password: String,
78 tables: Vec<String>,
79 slot_name: String,
80 publication_name: String,
81 ssl_mode: SslMode,
82 table_keys: Vec<TableKeyConfig>,
83}
84
85impl PostgresBootstrapProviderBuilder {
86 pub fn new() -> Self {
88 Self {
89 host: "localhost".to_string(), port: 5432,
91 database: String::new(),
92 user: String::new(),
93 password: String::new(),
94 tables: Vec::new(),
95 slot_name: "drasi_slot".to_string(),
96 publication_name: "drasi_pub".to_string(),
97 ssl_mode: SslMode::Disable,
98 table_keys: Vec::new(),
99 }
100 }
101
102 pub fn with_host(mut self, host: impl Into<String>) -> Self {
104 self.host = host.into();
105 self
106 }
107
108 pub fn with_port(mut self, port: u16) -> Self {
110 self.port = port;
111 self
112 }
113
114 pub fn with_database(mut self, database: impl Into<String>) -> Self {
116 self.database = database.into();
117 self
118 }
119
120 pub fn with_user(mut self, user: impl Into<String>) -> Self {
122 self.user = user.into();
123 self
124 }
125
126 pub fn with_password(mut self, password: impl Into<String>) -> Self {
128 self.password = password.into();
129 self
130 }
131
132 pub fn with_tables(mut self, tables: Vec<String>) -> Self {
134 self.tables = tables;
135 self
136 }
137
138 pub fn with_table(mut self, table: impl Into<String>) -> Self {
140 self.tables.push(table.into());
141 self
142 }
143
144 pub fn with_slot_name(mut self, slot_name: impl Into<String>) -> Self {
146 self.slot_name = slot_name.into();
147 self
148 }
149
150 pub fn with_publication_name(mut self, publication_name: impl Into<String>) -> Self {
152 self.publication_name = publication_name.into();
153 self
154 }
155
156 pub fn with_ssl_mode(mut self, ssl_mode: SslMode) -> Self {
158 self.ssl_mode = ssl_mode;
159 self
160 }
161
162 pub fn with_table_keys(mut self, table_keys: Vec<TableKeyConfig>) -> Self {
164 self.table_keys = table_keys;
165 self
166 }
167
168 pub fn with_table_key(mut self, table: impl Into<String>, key_columns: Vec<String>) -> Self {
170 self.table_keys.push(TableKeyConfig {
171 table: table.into(),
172 key_columns,
173 });
174 self
175 }
176
177 pub fn build(self) -> PostgresBootstrapProvider {
179 let config = PostgresBootstrapConfig {
180 host: self.host,
181 port: self.port,
182 database: self.database,
183 user: self.user,
184 password: self.password,
185 tables: self.tables,
186 slot_name: self.slot_name,
187 publication_name: self.publication_name,
188 ssl_mode: self.ssl_mode,
189 table_keys: self.table_keys,
190 };
191 PostgresBootstrapProvider::new(config)
192 }
193}
194
195impl Default for PostgresBootstrapProviderBuilder {
196 fn default() -> Self {
197 Self::new()
198 }
199}
200
201#[async_trait]
202impl BootstrapProvider for PostgresBootstrapProvider {
203 async fn bootstrap(
204 &self,
205 request: BootstrapRequest,
206 context: &BootstrapContext,
207 event_tx: drasi_lib::channels::BootstrapEventSender,
208 _settings: Option<&drasi_lib::config::SourceSubscriptionSettings>,
209 ) -> Result<BootstrapResult> {
210 info!(
211 "Starting PostgreSQL bootstrap for query '{}' with {} node labels and {} relation labels",
212 request.query_id,
213 request.node_labels.len(),
214 request.relation_labels.len()
215 );
216
217 let mut handler =
219 PostgresBootstrapHandler::new(self.config.clone(), context.source_id.clone());
220
221 let query_id = request.query_id.clone();
223
224 let count = handler.execute(request, context, event_tx).await?;
226
227 info!("Completed PostgreSQL bootstrap for query {query_id}: sent {count} records");
228
229 Ok(BootstrapResult {
235 event_count: count,
236 last_sequence: None,
237 sequences_aligned: false,
238 source_position: None,
239 })
240 }
241}
242
243#[derive(Debug, Clone)]
245struct PostgresConfig {
246 pub host: String,
247 pub port: u16,
248 pub database: String,
249 pub user: String,
250 pub password: String,
251 #[allow(dead_code)]
252 pub tables: Vec<String>,
253 #[allow(dead_code)]
254 pub slot_name: String,
255 #[allow(dead_code)]
256 pub publication_name: String,
257 #[allow(dead_code)]
258 pub ssl_mode: SslMode,
259 pub table_keys: Vec<TableKeyConfig>,
260}
261
262impl PostgresConfig {
263 fn from_bootstrap_config(postgres_config: PostgresBootstrapConfig) -> Self {
264 PostgresConfig {
265 host: postgres_config.host.clone(),
266 port: postgres_config.port,
267 database: postgres_config.database.clone(),
268 user: postgres_config.user.clone(),
269 password: postgres_config.password.clone(),
270 tables: postgres_config.tables.clone(),
271 slot_name: postgres_config.slot_name.clone(),
272 publication_name: postgres_config.publication_name.clone(),
273 ssl_mode: postgres_config.ssl_mode,
274 table_keys: postgres_config.table_keys.clone(),
275 }
276 }
277}
278
279struct PostgresBootstrapHandler {
281 config: PostgresConfig,
282 source_id: String,
283 table_primary_keys: HashMap<String, Vec<String>>,
285}
286
287impl PostgresBootstrapHandler {
288 fn new(config: PostgresConfig, source_id: String) -> Self {
289 Self {
290 config,
291 source_id,
292 table_primary_keys: HashMap::new(),
293 }
294 }
295
296 async fn execute(
298 &mut self,
299 request: BootstrapRequest,
300 context: &BootstrapContext,
301 event_tx: drasi_lib::channels::BootstrapEventSender,
302 ) -> Result<usize> {
303 info!(
304 "Bootstrap: Connecting to PostgreSQL at {}:{}",
305 self.config.host, self.config.port
306 );
307
308 let mut client = self.connect().await?;
310
311 self.query_primary_keys(&client).await?;
313
314 info!("Bootstrap: Connected, creating snapshot transaction...");
315 let (transaction, lsn) = self.create_snapshot(&mut client).await?;
317
318 info!("Bootstrap snapshot created at LSN: {lsn}");
319
320 let tables = self.resolve_tables(&request, &transaction).await?;
322 info!(
323 "Resolved {} labels to {} tables",
324 request.node_labels.len() + request.relation_labels.len(),
325 tables.len()
326 );
327
328 let mut total_count = 0;
330 for table in &tables {
331 let count = self
332 .bootstrap_table(&transaction, table, context, &event_tx)
333 .await?;
334 info!("Bootstrapped {count} rows from table '{table}'");
335 total_count += count;
336 }
337
338 transaction.commit().await?;
340
341 info!("Bootstrap completed: {total_count} total elements sent");
342 Ok(total_count)
343 }
344
345 async fn connect(&self) -> Result<Client> {
347 let connection_string = format!(
348 "host={} port={} user={} password={} dbname={}",
349 self.config.host,
350 self.config.port,
351 self.config.user,
352 self.config.password,
353 self.config.database
354 );
355
356 let (client, connection) = tokio_postgres::connect(&connection_string, NoTls).await?;
357
358 tokio::spawn(async move {
360 if let Err(e) = connection.await {
361 error!("PostgreSQL connection error: {e}");
362 }
363 });
364
365 Ok(client)
366 }
367
368 async fn create_snapshot<'a>(
370 &self,
371 client: &'a mut Client,
372 ) -> Result<(Transaction<'a>, String)> {
373 let transaction = client
375 .build_transaction()
376 .isolation_level(tokio_postgres::IsolationLevel::RepeatableRead)
377 .start()
378 .await?;
379
380 let row = transaction
382 .query_one("SELECT pg_current_wal_lsn()::text", &[])
383 .await?;
384 let lsn: String = row.get(0);
385
386 Ok((transaction, lsn))
387 }
388
389 async fn resolve_tables(
393 &self,
394 request: &BootstrapRequest,
395 transaction: &Transaction<'_>,
396 ) -> Result<Vec<String>> {
397 let mut tables = Vec::new();
398
399 let all_labels: Vec<String> = request
401 .node_labels
402 .iter()
403 .chain(request.relation_labels.iter())
404 .cloned()
405 .collect();
406
407 for label in all_labels {
408 if self.table_exists(transaction, &label).await? {
409 tables.push(label);
410 } else {
411 warn!("Table '{label}' does not exist, skipping");
412 }
413 }
414
415 Ok(tables)
416 }
417
418 async fn table_exists(&self, transaction: &Transaction<'_>, table_name: &str) -> Result<bool> {
420 let row = transaction
421 .query_one(
422 "SELECT EXISTS (
423 SELECT 1 FROM information_schema.tables
424 WHERE table_schema = 'public'
425 AND table_name = $1
426 )",
427 &[&table_name],
428 )
429 .await?;
430
431 Ok(row.get(0))
432 }
433
434 async fn bootstrap_table(
436 &self,
437 transaction: &Transaction<'_>,
438 table: &str,
439 context: &BootstrapContext,
440 event_tx: &drasi_lib::channels::BootstrapEventSender,
441 ) -> Result<usize> {
442 debug!("Starting bootstrap of table '{table}'");
443
444 let columns = self.get_table_columns(transaction, table).await?;
446
447 let query = format!("SELECT * FROM \"{}\"", table.replace('"', "\"\""));
449 let rows = transaction.query(&query, &[]).await?;
450
451 let mut count = 0;
452 let mut batch = Vec::new();
453 let batch_size = 1000;
454
455 for row in rows {
456 let source_change = self.row_to_source_change(&row, table, &columns).await?;
457
458 batch.push(SourceChangeEvent {
459 source_id: self.source_id.clone(),
460 change: source_change,
461 timestamp: chrono::Utc::now(),
462 });
463
464 if batch.len() >= batch_size {
465 self.send_batch(&mut batch, context, event_tx).await?;
466 count += batch_size;
467 }
468 }
469
470 if !batch.is_empty() {
472 count += batch.len();
473 self.send_batch(&mut batch, context, event_tx).await?;
474 }
475
476 Ok(count)
477 }
478
479 async fn get_table_columns(
481 &self,
482 transaction: &Transaction<'_>,
483 table_name: &str,
484 ) -> Result<Vec<ColumnInfo>> {
485 let rows = transaction
486 .query(
487 "SELECT column_name,
488 CASE
489 WHEN data_type = 'character varying' THEN 1043
490 WHEN data_type = 'integer' THEN 23
491 WHEN data_type = 'bigint' THEN 20
492 WHEN data_type = 'smallint' THEN 21
493 WHEN data_type = 'text' THEN 25
494 WHEN data_type = 'boolean' THEN 16
495 WHEN data_type = 'numeric' THEN 1700
496 WHEN data_type = 'real' THEN 700
497 WHEN data_type = 'double precision' THEN 701
498 WHEN data_type = 'timestamp without time zone' THEN 1114
499 WHEN data_type = 'timestamp with time zone' THEN 1184
500 WHEN data_type = 'date' THEN 1082
501 WHEN data_type = 'uuid' THEN 2950
502 WHEN data_type = 'json' THEN 114
503 WHEN data_type = 'jsonb' THEN 3802
504 ELSE 25 -- Default to text
505 END as type_oid
506 FROM information_schema.columns
507 WHERE table_schema = 'public' AND table_name = $1
508 ORDER BY ordinal_position",
509 &[&table_name],
510 )
511 .await?;
512
513 let mut columns = Vec::new();
514 for row in rows {
515 columns.push(ColumnInfo {
516 name: row.get(0),
517 type_oid: row.get::<_, i32>(1),
518 });
519 }
520
521 Ok(columns)
522 }
523
524 async fn query_primary_keys(&mut self, client: &Client) -> Result<()> {
526 info!("Querying primary key information from PostgreSQL system catalogs");
527
528 let query = r#"
529 SELECT
530 n.nspname as schema_name,
531 c.relname as table_name,
532 a.attname as column_name
533 FROM pg_constraint con
534 JOIN pg_class c ON con.conrelid = c.oid
535 JOIN pg_namespace n ON c.relnamespace = n.oid
536 JOIN pg_attribute a ON a.attrelid = c.oid
537 WHERE con.contype = 'p' -- Primary key constraint
538 AND a.attnum = ANY(con.conkey)
539 AND n.nspname NOT IN ('pg_catalog', 'information_schema')
540 ORDER BY n.nspname, c.relname, array_position(con.conkey, a.attnum)
541 "#;
542
543 let rows = client.query(query, &[]).await?;
544
545 let mut primary_keys: HashMap<String, Vec<String>> = HashMap::new();
546
547 for row in rows {
548 let schema: &str = row.get(0);
549 let table: &str = row.get(1);
550 let column: &str = row.get(2);
551
552 let table_key = if schema == "public" {
554 table.to_string()
555 } else {
556 format!("{schema}.{table}")
557 };
558
559 primary_keys
560 .entry(table_key.clone())
561 .or_default()
562 .push(column.to_string());
563
564 debug!("Found primary key column '{column}' for table '{table_key}'");
565 }
566
567 for table_key_config in &self.config.table_keys {
569 let table_name = &table_key_config.table;
570 let key_columns = &table_key_config.key_columns;
571
572 if !key_columns.is_empty() {
573 info!(
574 "Using user-configured key columns for table '{table_name}': {key_columns:?}"
575 );
576 primary_keys.insert(table_name.clone(), key_columns.clone());
577 }
578 }
579
580 self.table_primary_keys = primary_keys.clone();
582
583 info!("Found primary keys for {} tables", primary_keys.len());
584 for (table, keys) in &primary_keys {
585 info!("Table '{table}' primary key columns: {keys:?}");
586 }
587
588 Ok(())
589 }
590
591 async fn row_to_source_change(
593 &self,
594 row: &Row,
595 table: &str,
596 columns: &[ColumnInfo],
597 ) -> Result<SourceChange> {
598 let mut properties = ElementPropertyMap::new();
599
600 let pk_columns = self.table_primary_keys.get(table);
602
603 let mut pk_values = Vec::new();
605
606 for (idx, column) in columns.iter().enumerate() {
607 let is_pk = pk_columns
609 .map(|pks| pks.contains(&column.name))
610 .unwrap_or(false);
611
612 let element_value = match column.type_oid {
614 16 => {
615 if let Ok(Some(val)) = row.try_get::<_, Option<bool>>(idx) {
617 drasi_core::models::ElementValue::Bool(val)
618 } else {
619 drasi_core::models::ElementValue::Null
620 }
621 }
622 21 | 23 | 20 => {
623 if let Ok(Some(val)) = row.try_get::<_, Option<i64>>(idx) {
625 drasi_core::models::ElementValue::Integer(val)
626 } else if let Ok(Some(val)) = row.try_get::<_, Option<i32>>(idx) {
627 drasi_core::models::ElementValue::Integer(val as i64)
628 } else if let Ok(Some(val)) = row.try_get::<_, Option<i16>>(idx) {
629 drasi_core::models::ElementValue::Integer(val as i64)
630 } else {
631 drasi_core::models::ElementValue::Null
632 }
633 }
634 700 | 701 => {
635 if let Ok(Some(val)) = row.try_get::<_, Option<f64>>(idx) {
637 drasi_core::models::ElementValue::Float(ordered_float::OrderedFloat(val))
638 } else if let Ok(Some(val)) = row.try_get::<_, Option<f32>>(idx) {
639 drasi_core::models::ElementValue::Float(ordered_float::OrderedFloat(
640 val as f64,
641 ))
642 } else {
643 drasi_core::models::ElementValue::Null
644 }
645 }
646 1700 => {
647 if let Ok(Some(val)) = row.try_get::<_, Option<rust_decimal::Decimal>>(idx) {
649 drasi_core::models::ElementValue::Float(ordered_float::OrderedFloat(
650 val.to_string().parse::<f64>().unwrap_or(0.0),
651 ))
652 } else {
653 drasi_core::models::ElementValue::Null
654 }
655 }
656 25 | 1043 | 19 => {
657 if let Ok(Some(val)) = row.try_get::<_, Option<String>>(idx) {
659 drasi_core::models::ElementValue::String(std::sync::Arc::from(val))
660 } else {
661 drasi_core::models::ElementValue::Null
662 }
663 }
664 1114 | 1184 => {
665 if let Ok(Some(val)) = row.try_get::<_, Option<chrono::NaiveDateTime>>(idx) {
667 drasi_core::models::ElementValue::String(std::sync::Arc::from(
668 val.to_string(),
669 ))
670 } else if let Ok(Some(val)) =
671 row.try_get::<_, Option<chrono::DateTime<chrono::Utc>>>(idx)
672 {
673 drasi_core::models::ElementValue::String(std::sync::Arc::from(
674 val.to_string(),
675 ))
676 } else {
677 drasi_core::models::ElementValue::Null
678 }
679 }
680 _ => {
681 if let Ok(Some(val)) = row.try_get::<_, Option<String>>(idx) {
683 drasi_core::models::ElementValue::String(std::sync::Arc::from(val))
684 } else {
685 drasi_core::models::ElementValue::Null
686 }
687 }
688 };
689
690 if is_pk && !matches!(element_value, drasi_core::models::ElementValue::Null) {
692 let value_str = match &element_value {
693 drasi_core::models::ElementValue::Integer(i) => i.to_string(),
694 drasi_core::models::ElementValue::Float(f) => f.to_string(),
695 drasi_core::models::ElementValue::String(s) => s.to_string(),
696 drasi_core::models::ElementValue::Bool(b) => b.to_string(),
697 _ => format!("{element_value:?}"),
698 };
699 pk_values.push(value_str);
700 }
701
702 properties.insert(&column.name, element_value);
703 }
704
705 let elem_id = if !pk_values.is_empty() {
708 format!("{}:{}", table, pk_values.join("_"))
710 } else if pk_columns.is_none() || pk_columns.map(|pks| pks.is_empty()).unwrap_or(true) {
711 warn!(
713 "No primary key found for table '{table}'. Consider adding 'table_keys' configuration."
714 );
715 format!("{}:{}", table, uuid::Uuid::new_v4())
717 } else {
718 format!("{}:{}", table, uuid::Uuid::new_v4())
720 };
721
722 let metadata = ElementMetadata {
723 reference: ElementReference::new(&self.source_id, &elem_id),
724 labels: Arc::from(vec![Arc::from(table)]),
725 effective_from: chrono::Utc::now().timestamp_millis() as u64,
726 };
727
728 let element = Element::Node {
729 metadata,
730 properties,
731 };
732
733 Ok(SourceChange::Insert { element })
734 }
735
736 async fn send_batch(
738 &self,
739 batch: &mut Vec<SourceChangeEvent>,
740 context: &BootstrapContext,
741 event_tx: &drasi_lib::channels::BootstrapEventSender,
742 ) -> Result<()> {
743 for event in batch.drain(..) {
744 let sequence = context.next_sequence();
746
747 let bootstrap_event = drasi_lib::channels::BootstrapEvent {
748 source_id: event.source_id,
749 change: event.change,
750 timestamp: event.timestamp,
751 sequence,
752 };
753 event_tx.send(bootstrap_event).await.map_err(|e| {
754 anyhow!("Failed to send bootstrap event to channel (channel may be closed): {e}")
755 })?;
756 }
757 Ok(())
758 }
759}
760
761#[derive(Debug)]
762struct ColumnInfo {
763 name: String,
764 type_oid: i32,
765}
766
767#[cfg(test)]
768mod tests {
769 use drasi_core::models::validate_effective_from;
770
771 #[test]
777 fn effective_from_uses_milliseconds() {
778 let effective_from = chrono::Utc::now().timestamp_millis() as u64;
779 assert!(
780 validate_effective_from(effective_from).is_ok(),
781 "Postgres bootstrapper effective_from ({effective_from}) should be in millisecond range"
782 );
783 }
784
785 #[test]
787 fn effective_from_rejects_nanoseconds_pattern() {
788 let bad_effective_from = chrono::Utc::now().timestamp_nanos_opt().unwrap() as u64;
790 assert!(
791 validate_effective_from(bad_effective_from).is_err(),
792 "Nanosecond timestamp ({bad_effective_from}) should be rejected"
793 );
794 }
795}