1use anyhow::{anyhow, Result};
18use async_trait::async_trait;
19use drasi_core::models::{
20 Element, ElementMetadata, ElementPropertyMap, ElementReference, SourceChange,
21};
22use log::{debug, error, info, warn};
23use std::collections::HashMap;
24use std::sync::Arc;
25use tokio_postgres::{Client, NoTls, Row, Transaction};
26
27use drasi_lib::bootstrap::{
28 BootstrapContext, BootstrapProvider, BootstrapRequest, BootstrapResult,
29};
30use drasi_lib::channels::SourceChangeEvent;
31
32pub use crate::config::{PostgresBootstrapConfig, SslMode, TableKeyConfig};
33
34pub struct PostgresBootstrapProvider {
39 config: PostgresConfig,
40}
41
42impl PostgresBootstrapProvider {
43 pub fn new(postgres_config: PostgresBootstrapConfig) -> Self {
45 Self {
46 config: PostgresConfig::from_bootstrap_config(postgres_config),
47 }
48 }
49
50 pub fn builder() -> PostgresBootstrapProviderBuilder {
52 PostgresBootstrapProviderBuilder::new()
53 }
54}
55
56pub struct PostgresBootstrapProviderBuilder {
73 host: String,
74 port: u16,
75 database: String,
76 user: String,
77 password: String,
78 tables: Vec<String>,
79 slot_name: String,
80 publication_name: String,
81 ssl_mode: SslMode,
82 table_keys: Vec<TableKeyConfig>,
83}
84
85impl PostgresBootstrapProviderBuilder {
86 pub fn new() -> Self {
88 Self {
89 host: "localhost".to_string(), port: 5432,
91 database: String::new(),
92 user: String::new(),
93 password: String::new(),
94 tables: Vec::new(),
95 slot_name: "drasi_slot".to_string(),
96 publication_name: "drasi_pub".to_string(),
97 ssl_mode: SslMode::Disable,
98 table_keys: Vec::new(),
99 }
100 }
101
102 pub fn with_host(mut self, host: impl Into<String>) -> Self {
104 self.host = host.into();
105 self
106 }
107
108 pub fn with_port(mut self, port: u16) -> Self {
110 self.port = port;
111 self
112 }
113
114 pub fn with_database(mut self, database: impl Into<String>) -> Self {
116 self.database = database.into();
117 self
118 }
119
120 pub fn with_user(mut self, user: impl Into<String>) -> Self {
122 self.user = user.into();
123 self
124 }
125
126 pub fn with_password(mut self, password: impl Into<String>) -> Self {
128 self.password = password.into();
129 self
130 }
131
132 pub fn with_tables(mut self, tables: Vec<String>) -> Self {
134 self.tables = tables;
135 self
136 }
137
138 pub fn with_table(mut self, table: impl Into<String>) -> Self {
140 self.tables.push(table.into());
141 self
142 }
143
144 pub fn with_slot_name(mut self, slot_name: impl Into<String>) -> Self {
146 self.slot_name = slot_name.into();
147 self
148 }
149
150 pub fn with_publication_name(mut self, publication_name: impl Into<String>) -> Self {
152 self.publication_name = publication_name.into();
153 self
154 }
155
156 pub fn with_ssl_mode(mut self, ssl_mode: SslMode) -> Self {
158 self.ssl_mode = ssl_mode;
159 self
160 }
161
162 pub fn with_table_keys(mut self, table_keys: Vec<TableKeyConfig>) -> Self {
164 self.table_keys = table_keys;
165 self
166 }
167
168 pub fn with_table_key(mut self, table: impl Into<String>, key_columns: Vec<String>) -> Self {
170 self.table_keys.push(TableKeyConfig {
171 table: table.into(),
172 key_columns,
173 });
174 self
175 }
176
177 pub fn build(self) -> PostgresBootstrapProvider {
179 let config = PostgresBootstrapConfig {
180 host: self.host,
181 port: self.port,
182 database: self.database,
183 user: self.user,
184 password: self.password,
185 tables: self.tables,
186 slot_name: self.slot_name,
187 publication_name: self.publication_name,
188 ssl_mode: self.ssl_mode,
189 table_keys: self.table_keys,
190 };
191 PostgresBootstrapProvider::new(config)
192 }
193}
194
195impl Default for PostgresBootstrapProviderBuilder {
196 fn default() -> Self {
197 Self::new()
198 }
199}
200
201#[async_trait]
202impl BootstrapProvider for PostgresBootstrapProvider {
203 async fn bootstrap(
204 &self,
205 request: BootstrapRequest,
206 context: &BootstrapContext,
207 event_tx: drasi_lib::channels::BootstrapEventSender,
208 _settings: Option<&drasi_lib::config::SourceSubscriptionSettings>,
209 ) -> Result<BootstrapResult> {
210 info!(
211 "Starting PostgreSQL bootstrap for query '{}' with {} node labels and {} relation labels",
212 request.query_id,
213 request.node_labels.len(),
214 request.relation_labels.len()
215 );
216
217 let mut handler =
219 PostgresBootstrapHandler::new(self.config.clone(), context.source_id.clone());
220
221 let query_id = request.query_id.clone();
223
224 let count = handler.execute(request, context, event_tx).await?;
226
227 info!("Completed PostgreSQL bootstrap for query {query_id}: sent {count} records");
228
229 Ok(BootstrapResult {
235 event_count: count,
236 last_sequence: None,
237 sequences_aligned: false,
238 source_position: None,
239 })
240 }
241}
242
243#[derive(Debug, Clone)]
245struct PostgresConfig {
246 pub host: String,
247 pub port: u16,
248 pub database: String,
249 pub user: String,
250 pub password: String,
251 #[allow(dead_code)]
252 pub tables: Vec<String>,
253 #[allow(dead_code)]
254 pub slot_name: String,
255 #[allow(dead_code)]
256 pub publication_name: String,
257 #[allow(dead_code)]
258 pub ssl_mode: SslMode,
259 pub table_keys: Vec<TableKeyConfig>,
260}
261
262impl PostgresConfig {
263 fn from_bootstrap_config(postgres_config: PostgresBootstrapConfig) -> Self {
264 PostgresConfig {
265 host: postgres_config.host.clone(),
266 port: postgres_config.port,
267 database: postgres_config.database.clone(),
268 user: postgres_config.user.clone(),
269 password: postgres_config.password.clone(),
270 tables: postgres_config.tables.clone(),
271 slot_name: postgres_config.slot_name.clone(),
272 publication_name: postgres_config.publication_name.clone(),
273 ssl_mode: postgres_config.ssl_mode,
274 table_keys: postgres_config.table_keys.clone(),
275 }
276 }
277}
278
279struct PostgresBootstrapHandler {
281 config: PostgresConfig,
282 source_id: String,
283 table_primary_keys: HashMap<String, Vec<String>>,
285}
286
287impl PostgresBootstrapHandler {
288 fn new(config: PostgresConfig, source_id: String) -> Self {
289 Self {
290 config,
291 source_id,
292 table_primary_keys: HashMap::new(),
293 }
294 }
295
296 async fn execute(
298 &mut self,
299 request: BootstrapRequest,
300 context: &BootstrapContext,
301 event_tx: drasi_lib::channels::BootstrapEventSender,
302 ) -> Result<usize> {
303 info!(
304 "Bootstrap: Connecting to PostgreSQL at {}:{}",
305 self.config.host, self.config.port
306 );
307
308 let mut client = self.connect().await?;
310
311 self.query_primary_keys(&client).await?;
313
314 info!("Bootstrap: Connected, creating snapshot transaction...");
315 let (transaction, lsn) = self.create_snapshot(&mut client).await?;
317
318 info!("Bootstrap snapshot created at LSN: {lsn}");
319
320 let tables = self.resolve_tables(&request, &transaction).await?;
322 info!(
323 "Resolved {} labels to {} tables",
324 request.node_labels.len() + request.relation_labels.len(),
325 tables.len()
326 );
327
328 let mut total_count = 0;
330 for table in &tables {
331 let count = self
332 .bootstrap_table(&transaction, table, context, &event_tx)
333 .await?;
334 info!("Bootstrapped {count} rows from table '{table}'");
335 total_count += count;
336 }
337
338 transaction.commit().await?;
340
341 info!("Bootstrap completed: {total_count} total elements sent");
342 Ok(total_count)
343 }
344
345 async fn connect(&self) -> Result<Client> {
347 let connection_string = format!(
348 "host={} port={} user={} password={} dbname={}",
349 self.config.host,
350 self.config.port,
351 self.config.user,
352 self.config.password,
353 self.config.database
354 );
355
356 let (client, connection) = tokio_postgres::connect(&connection_string, NoTls).await?;
357
358 tokio::spawn(async move {
360 if let Err(e) = connection.await {
361 error!("PostgreSQL connection error: {e}");
362 }
363 });
364
365 Ok(client)
366 }
367
368 async fn create_snapshot<'a>(
370 &self,
371 client: &'a mut Client,
372 ) -> Result<(Transaction<'a>, String)> {
373 let transaction = client
375 .build_transaction()
376 .isolation_level(tokio_postgres::IsolationLevel::RepeatableRead)
377 .start()
378 .await?;
379
380 let row = transaction
382 .query_one("SELECT pg_current_wal_lsn()::text", &[])
383 .await?;
384 let lsn: String = row.get(0);
385
386 Ok((transaction, lsn))
387 }
388
389 async fn resolve_tables(
393 &self,
394 request: &BootstrapRequest,
395 transaction: &Transaction<'_>,
396 ) -> Result<Vec<String>> {
397 let mut tables = Vec::new();
398
399 let all_labels: Vec<String> = request
401 .node_labels
402 .iter()
403 .chain(request.relation_labels.iter())
404 .cloned()
405 .collect();
406
407 for label in all_labels {
408 if self.table_exists(transaction, &label).await? {
409 tables.push(label);
410 } else {
411 warn!("Table '{label}' does not exist, skipping");
412 }
413 }
414
415 Ok(tables)
416 }
417
418 async fn table_exists(&self, transaction: &Transaction<'_>, table_name: &str) -> Result<bool> {
420 let row = transaction
421 .query_one(
422 "SELECT EXISTS (
423 SELECT 1 FROM information_schema.tables
424 WHERE table_schema = 'public'
425 AND table_name = $1
426 )",
427 &[&table_name],
428 )
429 .await?;
430
431 Ok(row.get(0))
432 }
433
434 async fn bootstrap_table(
436 &self,
437 transaction: &Transaction<'_>,
438 table: &str,
439 context: &BootstrapContext,
440 event_tx: &drasi_lib::channels::BootstrapEventSender,
441 ) -> Result<usize> {
442 debug!("Starting bootstrap of table '{table}'");
443
444 let columns = self.get_table_columns(transaction, table).await?;
446
447 let query = format!("SELECT * FROM \"{}\"", table.replace('"', "\"\""));
449 let rows = transaction.query(&query, &[]).await?;
450
451 let mut count = 0;
452 let mut batch = Vec::new();
453 let batch_size = 1000;
454
455 for row in rows {
456 let source_change = self.row_to_source_change(&row, table, &columns).await?;
457
458 batch.push(SourceChangeEvent {
459 source_id: self.source_id.clone(),
460 change: source_change,
461 timestamp: chrono::Utc::now(),
462 sequence: None,
463 });
464
465 if batch.len() >= batch_size {
466 self.send_batch(&mut batch, context, event_tx).await?;
467 count += batch_size;
468 }
469 }
470
471 if !batch.is_empty() {
473 count += batch.len();
474 self.send_batch(&mut batch, context, event_tx).await?;
475 }
476
477 Ok(count)
478 }
479
480 async fn get_table_columns(
482 &self,
483 transaction: &Transaction<'_>,
484 table_name: &str,
485 ) -> Result<Vec<ColumnInfo>> {
486 let rows = transaction
487 .query(
488 "SELECT column_name,
489 CASE
490 WHEN data_type = 'character varying' THEN 1043
491 WHEN data_type = 'integer' THEN 23
492 WHEN data_type = 'bigint' THEN 20
493 WHEN data_type = 'smallint' THEN 21
494 WHEN data_type = 'text' THEN 25
495 WHEN data_type = 'boolean' THEN 16
496 WHEN data_type = 'numeric' THEN 1700
497 WHEN data_type = 'real' THEN 700
498 WHEN data_type = 'double precision' THEN 701
499 WHEN data_type = 'timestamp without time zone' THEN 1114
500 WHEN data_type = 'timestamp with time zone' THEN 1184
501 WHEN data_type = 'date' THEN 1082
502 WHEN data_type = 'uuid' THEN 2950
503 WHEN data_type = 'json' THEN 114
504 WHEN data_type = 'jsonb' THEN 3802
505 ELSE 25 -- Default to text
506 END as type_oid
507 FROM information_schema.columns
508 WHERE table_schema = 'public' AND table_name = $1
509 ORDER BY ordinal_position",
510 &[&table_name],
511 )
512 .await?;
513
514 let mut columns = Vec::new();
515 for row in rows {
516 columns.push(ColumnInfo {
517 name: row.get(0),
518 type_oid: row.get::<_, i32>(1),
519 });
520 }
521
522 Ok(columns)
523 }
524
525 async fn query_primary_keys(&mut self, client: &Client) -> Result<()> {
527 info!("Querying primary key information from PostgreSQL system catalogs");
528
529 let query = r#"
530 SELECT
531 n.nspname as schema_name,
532 c.relname as table_name,
533 a.attname as column_name
534 FROM pg_constraint con
535 JOIN pg_class c ON con.conrelid = c.oid
536 JOIN pg_namespace n ON c.relnamespace = n.oid
537 JOIN pg_attribute a ON a.attrelid = c.oid
538 WHERE con.contype = 'p' -- Primary key constraint
539 AND a.attnum = ANY(con.conkey)
540 AND n.nspname NOT IN ('pg_catalog', 'information_schema')
541 ORDER BY n.nspname, c.relname, array_position(con.conkey, a.attnum)
542 "#;
543
544 let rows = client.query(query, &[]).await?;
545
546 let mut primary_keys: HashMap<String, Vec<String>> = HashMap::new();
547
548 for row in rows {
549 let schema: &str = row.get(0);
550 let table: &str = row.get(1);
551 let column: &str = row.get(2);
552
553 let table_key = if schema == "public" {
555 table.to_string()
556 } else {
557 format!("{schema}.{table}")
558 };
559
560 primary_keys
561 .entry(table_key.clone())
562 .or_default()
563 .push(column.to_string());
564
565 debug!("Found primary key column '{column}' for table '{table_key}'");
566 }
567
568 for table_key_config in &self.config.table_keys {
570 let table_name = &table_key_config.table;
571 let key_columns = &table_key_config.key_columns;
572
573 if !key_columns.is_empty() {
574 info!(
575 "Using user-configured key columns for table '{table_name}': {key_columns:?}"
576 );
577 primary_keys.insert(table_name.clone(), key_columns.clone());
578 }
579 }
580
581 self.table_primary_keys = primary_keys.clone();
583
584 info!("Found primary keys for {} tables", primary_keys.len());
585 for (table, keys) in &primary_keys {
586 info!("Table '{table}' primary key columns: {keys:?}");
587 }
588
589 Ok(())
590 }
591
592 async fn row_to_source_change(
594 &self,
595 row: &Row,
596 table: &str,
597 columns: &[ColumnInfo],
598 ) -> Result<SourceChange> {
599 let mut properties = ElementPropertyMap::new();
600
601 let pk_columns = self.table_primary_keys.get(table);
603
604 let mut pk_values = Vec::new();
606
607 for (idx, column) in columns.iter().enumerate() {
608 let is_pk = pk_columns
610 .map(|pks| pks.contains(&column.name))
611 .unwrap_or(false);
612
613 let element_value = match column.type_oid {
615 16 => {
616 if let Ok(Some(val)) = row.try_get::<_, Option<bool>>(idx) {
618 drasi_core::models::ElementValue::Bool(val)
619 } else {
620 drasi_core::models::ElementValue::Null
621 }
622 }
623 21 | 23 | 20 => {
624 if let Ok(Some(val)) = row.try_get::<_, Option<i64>>(idx) {
626 drasi_core::models::ElementValue::Integer(val)
627 } else if let Ok(Some(val)) = row.try_get::<_, Option<i32>>(idx) {
628 drasi_core::models::ElementValue::Integer(val as i64)
629 } else if let Ok(Some(val)) = row.try_get::<_, Option<i16>>(idx) {
630 drasi_core::models::ElementValue::Integer(val as i64)
631 } else {
632 drasi_core::models::ElementValue::Null
633 }
634 }
635 700 | 701 => {
636 if let Ok(Some(val)) = row.try_get::<_, Option<f64>>(idx) {
638 drasi_core::models::ElementValue::Float(ordered_float::OrderedFloat(val))
639 } else if let Ok(Some(val)) = row.try_get::<_, Option<f32>>(idx) {
640 drasi_core::models::ElementValue::Float(ordered_float::OrderedFloat(
641 val as f64,
642 ))
643 } else {
644 drasi_core::models::ElementValue::Null
645 }
646 }
647 1700 => {
648 if let Ok(Some(val)) = row.try_get::<_, Option<rust_decimal::Decimal>>(idx) {
650 drasi_core::models::ElementValue::Float(ordered_float::OrderedFloat(
651 val.to_string().parse::<f64>().unwrap_or(0.0),
652 ))
653 } else {
654 drasi_core::models::ElementValue::Null
655 }
656 }
657 25 | 1043 | 19 => {
658 if let Ok(Some(val)) = row.try_get::<_, Option<String>>(idx) {
660 drasi_core::models::ElementValue::String(std::sync::Arc::from(val))
661 } else {
662 drasi_core::models::ElementValue::Null
663 }
664 }
665 1114 | 1184 => {
666 if let Ok(Some(val)) = row.try_get::<_, Option<chrono::NaiveDateTime>>(idx) {
668 drasi_core::models::ElementValue::String(std::sync::Arc::from(
669 val.to_string(),
670 ))
671 } else if let Ok(Some(val)) =
672 row.try_get::<_, Option<chrono::DateTime<chrono::Utc>>>(idx)
673 {
674 drasi_core::models::ElementValue::String(std::sync::Arc::from(
675 val.to_string(),
676 ))
677 } else {
678 drasi_core::models::ElementValue::Null
679 }
680 }
681 _ => {
682 if let Ok(Some(val)) = row.try_get::<_, Option<String>>(idx) {
684 drasi_core::models::ElementValue::String(std::sync::Arc::from(val))
685 } else {
686 drasi_core::models::ElementValue::Null
687 }
688 }
689 };
690
691 if is_pk && !matches!(element_value, drasi_core::models::ElementValue::Null) {
693 let value_str = match &element_value {
694 drasi_core::models::ElementValue::Integer(i) => i.to_string(),
695 drasi_core::models::ElementValue::Float(f) => f.to_string(),
696 drasi_core::models::ElementValue::String(s) => s.to_string(),
697 drasi_core::models::ElementValue::Bool(b) => b.to_string(),
698 _ => format!("{element_value:?}"),
699 };
700 pk_values.push(value_str);
701 }
702
703 properties.insert(&column.name, element_value);
704 }
705
706 let elem_id = if !pk_values.is_empty() {
709 format!("{}:{}", table, pk_values.join("_"))
711 } else if pk_columns.is_none() || pk_columns.map(|pks| pks.is_empty()).unwrap_or(true) {
712 warn!(
714 "No primary key found for table '{table}'. Consider adding 'table_keys' configuration."
715 );
716 format!("{}:{}", table, uuid::Uuid::new_v4())
718 } else {
719 format!("{}:{}", table, uuid::Uuid::new_v4())
721 };
722
723 let metadata = ElementMetadata {
724 reference: ElementReference::new(&self.source_id, &elem_id),
725 labels: Arc::from(vec![Arc::from(table)]),
726 effective_from: chrono::Utc::now().timestamp_millis() as u64,
727 };
728
729 let element = Element::Node {
730 metadata,
731 properties,
732 };
733
734 Ok(SourceChange::Insert { element })
735 }
736
737 async fn send_batch(
739 &self,
740 batch: &mut Vec<SourceChangeEvent>,
741 context: &BootstrapContext,
742 event_tx: &drasi_lib::channels::BootstrapEventSender,
743 ) -> Result<()> {
744 for event in batch.drain(..) {
745 let sequence = context.next_sequence();
747
748 let bootstrap_event = drasi_lib::channels::BootstrapEvent {
749 source_id: event.source_id,
750 change: event.change,
751 timestamp: event.timestamp,
752 sequence,
753 };
754 event_tx.send(bootstrap_event).await.map_err(|e| {
755 anyhow!("Failed to send bootstrap event to channel (channel may be closed): {e}")
756 })?;
757 }
758 Ok(())
759 }
760}
761
762#[derive(Debug)]
763struct ColumnInfo {
764 name: String,
765 type_oid: i32,
766}
767
768#[cfg(test)]
769mod tests {
770 use drasi_core::models::validate_effective_from;
771
772 #[test]
778 fn effective_from_uses_milliseconds() {
779 let effective_from = chrono::Utc::now().timestamp_millis() as u64;
780 assert!(
781 validate_effective_from(effective_from).is_ok(),
782 "Postgres bootstrapper effective_from ({effective_from}) should be in millisecond range"
783 );
784 }
785
786 #[test]
788 fn effective_from_rejects_nanoseconds_pattern() {
789 let bad_effective_from = chrono::Utc::now().timestamp_nanos_opt().unwrap() as u64;
791 assert!(
792 validate_effective_from(bad_effective_from).is_err(),
793 "Nanosecond timestamp ({bad_effective_from}) should be rejected"
794 );
795 }
796}