1use anyhow::{anyhow, Result};
18use async_trait::async_trait;
19use drasi_core::models::{
20 Element, ElementMetadata, ElementPropertyMap, ElementReference, SourceChange,
21};
22use log::{debug, error, info, warn};
23use std::collections::HashMap;
24use std::sync::Arc;
25use tokio_postgres::{Client, NoTls, Row, Transaction};
26
27use drasi_lib::bootstrap::{BootstrapContext, BootstrapProvider, BootstrapRequest};
28use drasi_lib::channels::SourceChangeEvent;
29
30pub use crate::config::{PostgresBootstrapConfig, SslMode, TableKeyConfig};
31
32pub struct PostgresBootstrapProvider {
37 config: PostgresConfig,
38}
39
40impl PostgresBootstrapProvider {
41 pub fn new(postgres_config: PostgresBootstrapConfig) -> Self {
43 Self {
44 config: PostgresConfig::from_bootstrap_config(postgres_config),
45 }
46 }
47
48 pub fn builder() -> PostgresBootstrapProviderBuilder {
50 PostgresBootstrapProviderBuilder::new()
51 }
52}
53
54pub struct PostgresBootstrapProviderBuilder {
71 host: String,
72 port: u16,
73 database: String,
74 user: String,
75 password: String,
76 tables: Vec<String>,
77 slot_name: String,
78 publication_name: String,
79 ssl_mode: SslMode,
80 table_keys: Vec<TableKeyConfig>,
81}
82
83impl PostgresBootstrapProviderBuilder {
84 pub fn new() -> Self {
86 Self {
87 host: "localhost".to_string(), port: 5432,
89 database: String::new(),
90 user: String::new(),
91 password: String::new(),
92 tables: Vec::new(),
93 slot_name: "drasi_slot".to_string(),
94 publication_name: "drasi_pub".to_string(),
95 ssl_mode: SslMode::Disable,
96 table_keys: Vec::new(),
97 }
98 }
99
100 pub fn with_host(mut self, host: impl Into<String>) -> Self {
102 self.host = host.into();
103 self
104 }
105
106 pub fn with_port(mut self, port: u16) -> Self {
108 self.port = port;
109 self
110 }
111
112 pub fn with_database(mut self, database: impl Into<String>) -> Self {
114 self.database = database.into();
115 self
116 }
117
118 pub fn with_user(mut self, user: impl Into<String>) -> Self {
120 self.user = user.into();
121 self
122 }
123
124 pub fn with_password(mut self, password: impl Into<String>) -> Self {
126 self.password = password.into();
127 self
128 }
129
130 pub fn with_tables(mut self, tables: Vec<String>) -> Self {
132 self.tables = tables;
133 self
134 }
135
136 pub fn with_table(mut self, table: impl Into<String>) -> Self {
138 self.tables.push(table.into());
139 self
140 }
141
142 pub fn with_slot_name(mut self, slot_name: impl Into<String>) -> Self {
144 self.slot_name = slot_name.into();
145 self
146 }
147
148 pub fn with_publication_name(mut self, publication_name: impl Into<String>) -> Self {
150 self.publication_name = publication_name.into();
151 self
152 }
153
154 pub fn with_ssl_mode(mut self, ssl_mode: SslMode) -> Self {
156 self.ssl_mode = ssl_mode;
157 self
158 }
159
160 pub fn with_table_keys(mut self, table_keys: Vec<TableKeyConfig>) -> Self {
162 self.table_keys = table_keys;
163 self
164 }
165
166 pub fn with_table_key(mut self, table: impl Into<String>, key_columns: Vec<String>) -> Self {
168 self.table_keys.push(TableKeyConfig {
169 table: table.into(),
170 key_columns,
171 });
172 self
173 }
174
175 pub fn build(self) -> PostgresBootstrapProvider {
177 let config = PostgresBootstrapConfig {
178 host: self.host,
179 port: self.port,
180 database: self.database,
181 user: self.user,
182 password: self.password,
183 tables: self.tables,
184 slot_name: self.slot_name,
185 publication_name: self.publication_name,
186 ssl_mode: self.ssl_mode,
187 table_keys: self.table_keys,
188 };
189 PostgresBootstrapProvider::new(config)
190 }
191}
192
193impl Default for PostgresBootstrapProviderBuilder {
194 fn default() -> Self {
195 Self::new()
196 }
197}
198
199#[async_trait]
200impl BootstrapProvider for PostgresBootstrapProvider {
201 async fn bootstrap(
202 &self,
203 request: BootstrapRequest,
204 context: &BootstrapContext,
205 event_tx: drasi_lib::channels::BootstrapEventSender,
206 _settings: Option<&drasi_lib::config::SourceSubscriptionSettings>,
207 ) -> Result<usize> {
208 info!(
209 "Starting PostgreSQL bootstrap for query '{}' with {} node labels and {} relation labels",
210 request.query_id,
211 request.node_labels.len(),
212 request.relation_labels.len()
213 );
214
215 let mut handler =
217 PostgresBootstrapHandler::new(self.config.clone(), context.source_id.clone());
218
219 let query_id = request.query_id.clone();
221
222 let count = handler.execute(request, context, event_tx).await?;
224
225 info!("Completed PostgreSQL bootstrap for query {query_id}: sent {count} records");
226
227 Ok(count)
228 }
229}
230
231#[derive(Debug, Clone)]
233struct PostgresConfig {
234 pub host: String,
235 pub port: u16,
236 pub database: String,
237 pub user: String,
238 pub password: String,
239 #[allow(dead_code)]
240 pub tables: Vec<String>,
241 #[allow(dead_code)]
242 pub slot_name: String,
243 #[allow(dead_code)]
244 pub publication_name: String,
245 #[allow(dead_code)]
246 pub ssl_mode: SslMode,
247 pub table_keys: Vec<TableKeyConfig>,
248}
249
250impl PostgresConfig {
251 fn from_bootstrap_config(postgres_config: PostgresBootstrapConfig) -> Self {
252 PostgresConfig {
253 host: postgres_config.host.clone(),
254 port: postgres_config.port,
255 database: postgres_config.database.clone(),
256 user: postgres_config.user.clone(),
257 password: postgres_config.password.clone(),
258 tables: postgres_config.tables.clone(),
259 slot_name: postgres_config.slot_name.clone(),
260 publication_name: postgres_config.publication_name.clone(),
261 ssl_mode: postgres_config.ssl_mode,
262 table_keys: postgres_config.table_keys.clone(),
263 }
264 }
265}
266
267struct PostgresBootstrapHandler {
269 config: PostgresConfig,
270 source_id: String,
271 table_primary_keys: HashMap<String, Vec<String>>,
273}
274
275impl PostgresBootstrapHandler {
276 fn new(config: PostgresConfig, source_id: String) -> Self {
277 Self {
278 config,
279 source_id,
280 table_primary_keys: HashMap::new(),
281 }
282 }
283
284 async fn execute(
286 &mut self,
287 request: BootstrapRequest,
288 context: &BootstrapContext,
289 event_tx: drasi_lib::channels::BootstrapEventSender,
290 ) -> Result<usize> {
291 info!(
292 "Bootstrap: Connecting to PostgreSQL at {}:{}",
293 self.config.host, self.config.port
294 );
295
296 let mut client = self.connect().await?;
298
299 self.query_primary_keys(&client).await?;
301
302 info!("Bootstrap: Connected, creating snapshot transaction...");
303 let (transaction, lsn) = self.create_snapshot(&mut client).await?;
305
306 info!("Bootstrap snapshot created at LSN: {lsn}");
307
308 let tables = self.resolve_tables(&request, &transaction).await?;
310 info!(
311 "Resolved {} labels to {} tables",
312 request.node_labels.len() + request.relation_labels.len(),
313 tables.len()
314 );
315
316 let mut total_count = 0;
318 for table in &tables {
319 let count = self
320 .bootstrap_table(&transaction, table, context, &event_tx)
321 .await?;
322 info!("Bootstrapped {count} rows from table '{table}'");
323 total_count += count;
324 }
325
326 transaction.commit().await?;
328
329 info!("Bootstrap completed: {total_count} total elements sent");
330 Ok(total_count)
331 }
332
333 async fn connect(&self) -> Result<Client> {
335 let connection_string = format!(
336 "host={} port={} user={} password={} dbname={}",
337 self.config.host,
338 self.config.port,
339 self.config.user,
340 self.config.password,
341 self.config.database
342 );
343
344 let (client, connection) = tokio_postgres::connect(&connection_string, NoTls).await?;
345
346 tokio::spawn(async move {
348 if let Err(e) = connection.await {
349 error!("PostgreSQL connection error: {e}");
350 }
351 });
352
353 Ok(client)
354 }
355
356 async fn create_snapshot<'a>(
358 &self,
359 client: &'a mut Client,
360 ) -> Result<(Transaction<'a>, String)> {
361 let transaction = client
363 .build_transaction()
364 .isolation_level(tokio_postgres::IsolationLevel::RepeatableRead)
365 .start()
366 .await?;
367
368 let row = transaction
370 .query_one("SELECT pg_current_wal_lsn()::text", &[])
371 .await?;
372 let lsn: String = row.get(0);
373
374 Ok((transaction, lsn))
375 }
376
377 async fn resolve_tables(
381 &self,
382 request: &BootstrapRequest,
383 transaction: &Transaction<'_>,
384 ) -> Result<Vec<String>> {
385 let mut tables = Vec::new();
386
387 let all_labels: Vec<String> = request
389 .node_labels
390 .iter()
391 .chain(request.relation_labels.iter())
392 .cloned()
393 .collect();
394
395 for label in all_labels {
396 if self.table_exists(transaction, &label).await? {
397 tables.push(label);
398 } else {
399 warn!("Table '{label}' does not exist, skipping");
400 }
401 }
402
403 Ok(tables)
404 }
405
406 async fn table_exists(&self, transaction: &Transaction<'_>, table_name: &str) -> Result<bool> {
408 let row = transaction
409 .query_one(
410 "SELECT EXISTS (
411 SELECT 1 FROM information_schema.tables
412 WHERE table_schema = 'public'
413 AND table_name = $1
414 )",
415 &[&table_name],
416 )
417 .await?;
418
419 Ok(row.get(0))
420 }
421
422 async fn bootstrap_table(
424 &self,
425 transaction: &Transaction<'_>,
426 table: &str,
427 context: &BootstrapContext,
428 event_tx: &drasi_lib::channels::BootstrapEventSender,
429 ) -> Result<usize> {
430 debug!("Starting bootstrap of table '{table}'");
431
432 let columns = self.get_table_columns(transaction, table).await?;
434
435 let query = format!("SELECT * FROM \"{}\"", table.replace('"', "\"\""));
437 let rows = transaction.query(&query, &[]).await?;
438
439 let mut count = 0;
440 let mut batch = Vec::new();
441 let batch_size = 1000;
442
443 for row in rows {
444 let source_change = self.row_to_source_change(&row, table, &columns).await?;
445
446 batch.push(SourceChangeEvent {
447 source_id: self.source_id.clone(),
448 change: source_change,
449 timestamp: chrono::Utc::now(),
450 });
451
452 if batch.len() >= batch_size {
453 self.send_batch(&mut batch, context, event_tx).await?;
454 count += batch_size;
455 }
456 }
457
458 if !batch.is_empty() {
460 count += batch.len();
461 self.send_batch(&mut batch, context, event_tx).await?;
462 }
463
464 Ok(count)
465 }
466
467 async fn get_table_columns(
469 &self,
470 transaction: &Transaction<'_>,
471 table_name: &str,
472 ) -> Result<Vec<ColumnInfo>> {
473 let rows = transaction
474 .query(
475 "SELECT column_name,
476 CASE
477 WHEN data_type = 'character varying' THEN 1043
478 WHEN data_type = 'integer' THEN 23
479 WHEN data_type = 'bigint' THEN 20
480 WHEN data_type = 'smallint' THEN 21
481 WHEN data_type = 'text' THEN 25
482 WHEN data_type = 'boolean' THEN 16
483 WHEN data_type = 'numeric' THEN 1700
484 WHEN data_type = 'real' THEN 700
485 WHEN data_type = 'double precision' THEN 701
486 WHEN data_type = 'timestamp without time zone' THEN 1114
487 WHEN data_type = 'timestamp with time zone' THEN 1184
488 WHEN data_type = 'date' THEN 1082
489 WHEN data_type = 'uuid' THEN 2950
490 WHEN data_type = 'json' THEN 114
491 WHEN data_type = 'jsonb' THEN 3802
492 ELSE 25 -- Default to text
493 END as type_oid
494 FROM information_schema.columns
495 WHERE table_schema = 'public' AND table_name = $1
496 ORDER BY ordinal_position",
497 &[&table_name],
498 )
499 .await?;
500
501 let mut columns = Vec::new();
502 for row in rows {
503 columns.push(ColumnInfo {
504 name: row.get(0),
505 type_oid: row.get::<_, i32>(1),
506 });
507 }
508
509 Ok(columns)
510 }
511
512 async fn query_primary_keys(&mut self, client: &Client) -> Result<()> {
514 info!("Querying primary key information from PostgreSQL system catalogs");
515
516 let query = r#"
517 SELECT
518 n.nspname as schema_name,
519 c.relname as table_name,
520 a.attname as column_name
521 FROM pg_constraint con
522 JOIN pg_class c ON con.conrelid = c.oid
523 JOIN pg_namespace n ON c.relnamespace = n.oid
524 JOIN pg_attribute a ON a.attrelid = c.oid
525 WHERE con.contype = 'p' -- Primary key constraint
526 AND a.attnum = ANY(con.conkey)
527 AND n.nspname NOT IN ('pg_catalog', 'information_schema')
528 ORDER BY n.nspname, c.relname, array_position(con.conkey, a.attnum)
529 "#;
530
531 let rows = client.query(query, &[]).await?;
532
533 let mut primary_keys: HashMap<String, Vec<String>> = HashMap::new();
534
535 for row in rows {
536 let schema: &str = row.get(0);
537 let table: &str = row.get(1);
538 let column: &str = row.get(2);
539
540 let table_key = if schema == "public" {
542 table.to_string()
543 } else {
544 format!("{schema}.{table}")
545 };
546
547 primary_keys
548 .entry(table_key.clone())
549 .or_default()
550 .push(column.to_string());
551
552 debug!("Found primary key column '{column}' for table '{table_key}'");
553 }
554
555 for table_key_config in &self.config.table_keys {
557 let table_name = &table_key_config.table;
558 let key_columns = &table_key_config.key_columns;
559
560 if !key_columns.is_empty() {
561 info!(
562 "Using user-configured key columns for table '{table_name}': {key_columns:?}"
563 );
564 primary_keys.insert(table_name.clone(), key_columns.clone());
565 }
566 }
567
568 self.table_primary_keys = primary_keys.clone();
570
571 info!("Found primary keys for {} tables", primary_keys.len());
572 for (table, keys) in &primary_keys {
573 info!("Table '{table}' primary key columns: {keys:?}");
574 }
575
576 Ok(())
577 }
578
579 async fn row_to_source_change(
581 &self,
582 row: &Row,
583 table: &str,
584 columns: &[ColumnInfo],
585 ) -> Result<SourceChange> {
586 let mut properties = ElementPropertyMap::new();
587
588 let pk_columns = self.table_primary_keys.get(table);
590
591 let mut pk_values = Vec::new();
593
594 for (idx, column) in columns.iter().enumerate() {
595 let is_pk = pk_columns
597 .map(|pks| pks.contains(&column.name))
598 .unwrap_or(false);
599
600 let element_value = match column.type_oid {
602 16 => {
603 if let Ok(Some(val)) = row.try_get::<_, Option<bool>>(idx) {
605 drasi_core::models::ElementValue::Bool(val)
606 } else {
607 drasi_core::models::ElementValue::Null
608 }
609 }
610 21 | 23 | 20 => {
611 if let Ok(Some(val)) = row.try_get::<_, Option<i64>>(idx) {
613 drasi_core::models::ElementValue::Integer(val)
614 } else if let Ok(Some(val)) = row.try_get::<_, Option<i32>>(idx) {
615 drasi_core::models::ElementValue::Integer(val as i64)
616 } else if let Ok(Some(val)) = row.try_get::<_, Option<i16>>(idx) {
617 drasi_core::models::ElementValue::Integer(val as i64)
618 } else {
619 drasi_core::models::ElementValue::Null
620 }
621 }
622 700 | 701 => {
623 if let Ok(Some(val)) = row.try_get::<_, Option<f64>>(idx) {
625 drasi_core::models::ElementValue::Float(ordered_float::OrderedFloat(val))
626 } else if let Ok(Some(val)) = row.try_get::<_, Option<f32>>(idx) {
627 drasi_core::models::ElementValue::Float(ordered_float::OrderedFloat(
628 val as f64,
629 ))
630 } else {
631 drasi_core::models::ElementValue::Null
632 }
633 }
634 1700 => {
635 if let Ok(Some(val)) = row.try_get::<_, Option<rust_decimal::Decimal>>(idx) {
637 drasi_core::models::ElementValue::Float(ordered_float::OrderedFloat(
638 val.to_string().parse::<f64>().unwrap_or(0.0),
639 ))
640 } else {
641 drasi_core::models::ElementValue::Null
642 }
643 }
644 25 | 1043 | 19 => {
645 if let Ok(Some(val)) = row.try_get::<_, Option<String>>(idx) {
647 drasi_core::models::ElementValue::String(std::sync::Arc::from(val))
648 } else {
649 drasi_core::models::ElementValue::Null
650 }
651 }
652 1114 | 1184 => {
653 if let Ok(Some(val)) = row.try_get::<_, Option<chrono::NaiveDateTime>>(idx) {
655 drasi_core::models::ElementValue::String(std::sync::Arc::from(
656 val.to_string(),
657 ))
658 } else if let Ok(Some(val)) =
659 row.try_get::<_, Option<chrono::DateTime<chrono::Utc>>>(idx)
660 {
661 drasi_core::models::ElementValue::String(std::sync::Arc::from(
662 val.to_string(),
663 ))
664 } else {
665 drasi_core::models::ElementValue::Null
666 }
667 }
668 _ => {
669 if let Ok(Some(val)) = row.try_get::<_, Option<String>>(idx) {
671 drasi_core::models::ElementValue::String(std::sync::Arc::from(val))
672 } else {
673 drasi_core::models::ElementValue::Null
674 }
675 }
676 };
677
678 if is_pk && !matches!(element_value, drasi_core::models::ElementValue::Null) {
680 let value_str = match &element_value {
681 drasi_core::models::ElementValue::Integer(i) => i.to_string(),
682 drasi_core::models::ElementValue::Float(f) => f.to_string(),
683 drasi_core::models::ElementValue::String(s) => s.to_string(),
684 drasi_core::models::ElementValue::Bool(b) => b.to_string(),
685 _ => format!("{element_value:?}"),
686 };
687 pk_values.push(value_str);
688 }
689
690 properties.insert(&column.name, element_value);
691 }
692
693 let elem_id = if !pk_values.is_empty() {
696 format!("{}:{}", table, pk_values.join("_"))
698 } else if pk_columns.is_none() || pk_columns.map(|pks| pks.is_empty()).unwrap_or(true) {
699 warn!(
701 "No primary key found for table '{table}'. Consider adding 'table_keys' configuration."
702 );
703 format!("{}:{}", table, uuid::Uuid::new_v4())
705 } else {
706 format!("{}:{}", table, uuid::Uuid::new_v4())
708 };
709
710 let metadata = ElementMetadata {
711 reference: ElementReference::new(&self.source_id, &elem_id),
712 labels: Arc::from(vec![Arc::from(table)]),
713 effective_from: chrono::Utc::now().timestamp_nanos_opt().unwrap_or(0) as u64,
714 };
715
716 let element = Element::Node {
717 metadata,
718 properties,
719 };
720
721 Ok(SourceChange::Insert { element })
722 }
723
724 async fn send_batch(
726 &self,
727 batch: &mut Vec<SourceChangeEvent>,
728 context: &BootstrapContext,
729 event_tx: &drasi_lib::channels::BootstrapEventSender,
730 ) -> Result<()> {
731 for event in batch.drain(..) {
732 let sequence = context.next_sequence();
734
735 let bootstrap_event = drasi_lib::channels::BootstrapEvent {
736 source_id: event.source_id,
737 change: event.change,
738 timestamp: event.timestamp,
739 sequence,
740 };
741 event_tx.send(bootstrap_event).await.map_err(|e| {
742 anyhow!("Failed to send bootstrap event to channel (channel may be closed): {e}")
743 })?;
744 }
745 Ok(())
746 }
747}
748
749#[derive(Debug)]
750struct ColumnInfo {
751 name: String,
752 type_oid: i32,
753}