1use anyhow::{anyhow, Result};
18use async_trait::async_trait;
19use drasi_core::models::{
20 Element, ElementMetadata, ElementPropertyMap, ElementReference, SourceChange,
21};
22use log::{debug, error, info, warn};
23use std::collections::HashMap;
24use std::sync::Arc;
25use tokio_postgres::{Client, NoTls, Row, Transaction};
26
27use drasi_lib::bootstrap::{BootstrapContext, BootstrapProvider, BootstrapRequest};
28use drasi_lib::channels::SourceChangeEvent;
29
30pub use crate::config::{PostgresBootstrapConfig, SslMode, TableKeyConfig};
31
32pub struct PostgresBootstrapProvider {
37 config: PostgresConfig,
38}
39
40impl PostgresBootstrapProvider {
41 pub fn new(postgres_config: PostgresBootstrapConfig) -> Self {
43 Self {
44 config: PostgresConfig::from_bootstrap_config(postgres_config),
45 }
46 }
47
48 pub fn builder() -> PostgresBootstrapProviderBuilder {
50 PostgresBootstrapProviderBuilder::new()
51 }
52}
53
54pub struct PostgresBootstrapProviderBuilder {
71 host: String,
72 port: u16,
73 database: String,
74 user: String,
75 password: String,
76 tables: Vec<String>,
77 slot_name: String,
78 publication_name: String,
79 ssl_mode: SslMode,
80 table_keys: Vec<TableKeyConfig>,
81}
82
83impl PostgresBootstrapProviderBuilder {
84 pub fn new() -> Self {
86 Self {
87 host: "localhost".to_string(), port: 5432,
89 database: String::new(),
90 user: String::new(),
91 password: String::new(),
92 tables: Vec::new(),
93 slot_name: "drasi_slot".to_string(),
94 publication_name: "drasi_pub".to_string(),
95 ssl_mode: SslMode::Disable,
96 table_keys: Vec::new(),
97 }
98 }
99
100 pub fn with_host(mut self, host: impl Into<String>) -> Self {
102 self.host = host.into();
103 self
104 }
105
106 pub fn with_port(mut self, port: u16) -> Self {
108 self.port = port;
109 self
110 }
111
112 pub fn with_database(mut self, database: impl Into<String>) -> Self {
114 self.database = database.into();
115 self
116 }
117
118 pub fn with_user(mut self, user: impl Into<String>) -> Self {
120 self.user = user.into();
121 self
122 }
123
124 pub fn with_password(mut self, password: impl Into<String>) -> Self {
126 self.password = password.into();
127 self
128 }
129
130 pub fn with_tables(mut self, tables: Vec<String>) -> Self {
132 self.tables = tables;
133 self
134 }
135
136 pub fn with_table(mut self, table: impl Into<String>) -> Self {
138 self.tables.push(table.into());
139 self
140 }
141
142 pub fn with_slot_name(mut self, slot_name: impl Into<String>) -> Self {
144 self.slot_name = slot_name.into();
145 self
146 }
147
148 pub fn with_publication_name(mut self, publication_name: impl Into<String>) -> Self {
150 self.publication_name = publication_name.into();
151 self
152 }
153
154 pub fn with_ssl_mode(mut self, ssl_mode: SslMode) -> Self {
156 self.ssl_mode = ssl_mode;
157 self
158 }
159
160 pub fn with_table_keys(mut self, table_keys: Vec<TableKeyConfig>) -> Self {
162 self.table_keys = table_keys;
163 self
164 }
165
166 pub fn with_table_key(mut self, table: impl Into<String>, key_columns: Vec<String>) -> Self {
168 self.table_keys.push(TableKeyConfig {
169 table: table.into(),
170 key_columns,
171 });
172 self
173 }
174
175 pub fn build(self) -> PostgresBootstrapProvider {
177 let config = PostgresBootstrapConfig {
178 host: self.host,
179 port: self.port,
180 database: self.database,
181 user: self.user,
182 password: self.password,
183 tables: self.tables,
184 slot_name: self.slot_name,
185 publication_name: self.publication_name,
186 ssl_mode: self.ssl_mode,
187 table_keys: self.table_keys,
188 };
189 PostgresBootstrapProvider::new(config)
190 }
191}
192
193impl Default for PostgresBootstrapProviderBuilder {
194 fn default() -> Self {
195 Self::new()
196 }
197}
198
199#[async_trait]
200impl BootstrapProvider for PostgresBootstrapProvider {
201 async fn bootstrap(
202 &self,
203 request: BootstrapRequest,
204 context: &BootstrapContext,
205 event_tx: drasi_lib::channels::BootstrapEventSender,
206 _settings: Option<&drasi_lib::config::SourceSubscriptionSettings>,
207 ) -> Result<usize> {
208 info!(
209 "Starting PostgreSQL bootstrap for query '{}' with {} node labels and {} relation labels",
210 request.query_id,
211 request.node_labels.len(),
212 request.relation_labels.len()
213 );
214
215 let mut handler =
217 PostgresBootstrapHandler::new(self.config.clone(), context.source_id.clone());
218
219 let query_id = request.query_id.clone();
221
222 let count = handler.execute(request, context, event_tx).await?;
224
225 info!("Completed PostgreSQL bootstrap for query {query_id}: sent {count} records");
226
227 Ok(count)
228 }
229}
230
231#[derive(Debug, Clone)]
233struct PostgresConfig {
234 pub host: String,
235 pub port: u16,
236 pub database: String,
237 pub user: String,
238 pub password: String,
239 #[allow(dead_code)]
240 pub tables: Vec<String>,
241 #[allow(dead_code)]
242 pub slot_name: String,
243 #[allow(dead_code)]
244 pub publication_name: String,
245 #[allow(dead_code)]
246 pub ssl_mode: SslMode,
247 pub table_keys: Vec<TableKeyConfig>,
248}
249
250impl PostgresConfig {
251 fn from_bootstrap_config(postgres_config: PostgresBootstrapConfig) -> Self {
252 PostgresConfig {
253 host: postgres_config.host.clone(),
254 port: postgres_config.port,
255 database: postgres_config.database.clone(),
256 user: postgres_config.user.clone(),
257 password: postgres_config.password.clone(),
258 tables: postgres_config.tables.clone(),
259 slot_name: postgres_config.slot_name.clone(),
260 publication_name: postgres_config.publication_name.clone(),
261 ssl_mode: postgres_config.ssl_mode,
262 table_keys: postgres_config.table_keys.clone(),
263 }
264 }
265}
266
267struct PostgresBootstrapHandler {
269 config: PostgresConfig,
270 source_id: String,
271 table_primary_keys: HashMap<String, Vec<String>>,
273}
274
275impl PostgresBootstrapHandler {
276 fn new(config: PostgresConfig, source_id: String) -> Self {
277 Self {
278 config,
279 source_id,
280 table_primary_keys: HashMap::new(),
281 }
282 }
283
284 async fn execute(
286 &mut self,
287 request: BootstrapRequest,
288 context: &BootstrapContext,
289 event_tx: drasi_lib::channels::BootstrapEventSender,
290 ) -> Result<usize> {
291 info!(
292 "Bootstrap: Connecting to PostgreSQL at {}:{}",
293 self.config.host, self.config.port
294 );
295
296 let mut client = self.connect().await?;
298
299 self.query_primary_keys(&client).await?;
301
302 info!("Bootstrap: Connected, creating snapshot transaction...");
303 let (transaction, lsn) = self.create_snapshot(&mut client).await?;
305
306 info!("Bootstrap snapshot created at LSN: {lsn}");
307
308 let tables = self.map_labels_to_tables(&request, &transaction).await?;
310 info!(
311 "Mapped {} labels to {} tables",
312 request.node_labels.len() + request.relation_labels.len(),
313 tables.len()
314 );
315
316 let mut total_count = 0;
318 for (label, table_name) in tables {
319 let count = self
320 .bootstrap_table(&transaction, &label, &table_name, context, &event_tx)
321 .await?;
322 info!("Bootstrapped {count} rows from table '{table_name}' with label '{label}'");
323 total_count += count;
324 }
325
326 transaction.commit().await?;
328
329 info!("Bootstrap completed: {total_count} total elements sent");
330 Ok(total_count)
331 }
332
333 async fn connect(&self) -> Result<Client> {
335 let connection_string = format!(
336 "host={} port={} user={} password={} dbname={}",
337 self.config.host,
338 self.config.port,
339 self.config.user,
340 self.config.password,
341 self.config.database
342 );
343
344 let (client, connection) = tokio_postgres::connect(&connection_string, NoTls).await?;
345
346 tokio::spawn(async move {
348 if let Err(e) = connection.await {
349 error!("PostgreSQL connection error: {e}");
350 }
351 });
352
353 Ok(client)
354 }
355
356 async fn create_snapshot<'a>(
358 &self,
359 client: &'a mut Client,
360 ) -> Result<(Transaction<'a>, String)> {
361 let transaction = client
363 .build_transaction()
364 .isolation_level(tokio_postgres::IsolationLevel::RepeatableRead)
365 .start()
366 .await?;
367
368 let row = transaction
370 .query_one("SELECT pg_current_wal_lsn()::text", &[])
371 .await?;
372 let lsn: String = row.get(0);
373
374 Ok((transaction, lsn))
375 }
376
377 async fn map_labels_to_tables(
379 &self,
380 request: &BootstrapRequest,
381 transaction: &Transaction<'_>,
382 ) -> Result<Vec<(String, String)>> {
383 let mut tables = Vec::new();
384
385 let all_labels: Vec<String> = request
387 .node_labels
388 .iter()
389 .chain(request.relation_labels.iter())
390 .cloned()
391 .collect();
392
393 for label in all_labels {
394 let table_name = label.to_lowercase();
396
397 let exists = self.table_exists(transaction, &table_name).await?;
399
400 if exists {
401 tables.push((label, table_name));
402 } else {
403 warn!("Table '{table_name}' for label '{label}' does not exist, skipping");
404 }
405 }
406
407 Ok(tables)
408 }
409
410 async fn table_exists(&self, transaction: &Transaction<'_>, table_name: &str) -> Result<bool> {
412 let row = transaction
413 .query_one(
414 "SELECT EXISTS (
415 SELECT 1 FROM information_schema.tables
416 WHERE table_schema = 'public'
417 AND table_name = $1
418 )",
419 &[&table_name],
420 )
421 .await?;
422
423 Ok(row.get(0))
424 }
425
426 async fn bootstrap_table(
428 &self,
429 transaction: &Transaction<'_>,
430 label: &str,
431 table_name: &str,
432 context: &BootstrapContext,
433 event_tx: &drasi_lib::channels::BootstrapEventSender,
434 ) -> Result<usize> {
435 debug!("Starting bootstrap of table '{table_name}' with label '{label}'");
436
437 let columns = self.get_table_columns(transaction, table_name).await?;
439
440 let query = format!("SELECT * FROM {table_name}");
442 let rows = transaction.query(&query, &[]).await?;
443
444 let mut count = 0;
445 let mut batch = Vec::new();
446 let batch_size = 1000;
447
448 for row in rows {
449 let source_change = self
450 .row_to_source_change(&row, label, table_name, &columns)
451 .await?;
452
453 batch.push(SourceChangeEvent {
454 source_id: self.source_id.clone(),
455 change: source_change,
456 timestamp: chrono::Utc::now(),
457 });
458
459 if batch.len() >= batch_size {
460 self.send_batch(&mut batch, context, event_tx).await?;
461 count += batch_size;
462 }
463 }
464
465 if !batch.is_empty() {
467 count += batch.len();
468 self.send_batch(&mut batch, context, event_tx).await?;
469 }
470
471 Ok(count)
472 }
473
474 async fn get_table_columns(
476 &self,
477 transaction: &Transaction<'_>,
478 table_name: &str,
479 ) -> Result<Vec<ColumnInfo>> {
480 let rows = transaction
481 .query(
482 "SELECT column_name,
483 CASE
484 WHEN data_type = 'character varying' THEN 1043
485 WHEN data_type = 'integer' THEN 23
486 WHEN data_type = 'bigint' THEN 20
487 WHEN data_type = 'smallint' THEN 21
488 WHEN data_type = 'text' THEN 25
489 WHEN data_type = 'boolean' THEN 16
490 WHEN data_type = 'numeric' THEN 1700
491 WHEN data_type = 'real' THEN 700
492 WHEN data_type = 'double precision' THEN 701
493 WHEN data_type = 'timestamp without time zone' THEN 1114
494 WHEN data_type = 'timestamp with time zone' THEN 1184
495 WHEN data_type = 'date' THEN 1082
496 WHEN data_type = 'uuid' THEN 2950
497 WHEN data_type = 'json' THEN 114
498 WHEN data_type = 'jsonb' THEN 3802
499 ELSE 25 -- Default to text
500 END as type_oid
501 FROM information_schema.columns
502 WHERE table_schema = 'public' AND table_name = $1
503 ORDER BY ordinal_position",
504 &[&table_name],
505 )
506 .await?;
507
508 let mut columns = Vec::new();
509 for row in rows {
510 columns.push(ColumnInfo {
511 name: row.get(0),
512 type_oid: row.get::<_, i32>(1),
513 });
514 }
515
516 Ok(columns)
517 }
518
519 async fn query_primary_keys(&mut self, client: &Client) -> Result<()> {
521 info!("Querying primary key information from PostgreSQL system catalogs");
522
523 let query = r#"
524 SELECT
525 n.nspname as schema_name,
526 c.relname as table_name,
527 a.attname as column_name
528 FROM pg_constraint con
529 JOIN pg_class c ON con.conrelid = c.oid
530 JOIN pg_namespace n ON c.relnamespace = n.oid
531 JOIN pg_attribute a ON a.attrelid = c.oid
532 WHERE con.contype = 'p' -- Primary key constraint
533 AND a.attnum = ANY(con.conkey)
534 AND n.nspname NOT IN ('pg_catalog', 'information_schema')
535 ORDER BY n.nspname, c.relname, array_position(con.conkey, a.attnum)
536 "#;
537
538 let rows = client.query(query, &[]).await?;
539
540 let mut primary_keys: HashMap<String, Vec<String>> = HashMap::new();
541
542 for row in rows {
543 let schema: &str = row.get(0);
544 let table: &str = row.get(1);
545 let column: &str = row.get(2);
546
547 let table_key = if schema == "public" {
549 table.to_string()
550 } else {
551 format!("{schema}.{table}")
552 };
553
554 primary_keys
555 .entry(table_key.clone())
556 .or_default()
557 .push(column.to_string());
558
559 debug!("Found primary key column '{column}' for table '{table_key}'");
560 }
561
562 for table_key_config in &self.config.table_keys {
564 let table_name = &table_key_config.table;
565 let key_columns = &table_key_config.key_columns;
566
567 if !key_columns.is_empty() {
568 info!(
569 "Using user-configured key columns for table '{table_name}': {key_columns:?}"
570 );
571 primary_keys.insert(table_name.clone(), key_columns.clone());
572 }
573 }
574
575 self.table_primary_keys = primary_keys.clone();
577
578 info!("Found primary keys for {} tables", primary_keys.len());
579 for (table, keys) in &primary_keys {
580 info!("Table '{table}' primary key columns: {keys:?}");
581 }
582
583 Ok(())
584 }
585
586 async fn row_to_source_change(
588 &self,
589 row: &Row,
590 label: &str,
591 table_name: &str,
592 columns: &[ColumnInfo],
593 ) -> Result<SourceChange> {
594 let mut properties = ElementPropertyMap::new();
595
596 let pk_columns = self.table_primary_keys.get(table_name);
598
599 let mut pk_values = Vec::new();
601
602 for (idx, column) in columns.iter().enumerate() {
603 let is_pk = pk_columns
605 .map(|pks| pks.contains(&column.name))
606 .unwrap_or(false);
607
608 let element_value = match column.type_oid {
610 16 => {
611 if let Ok(Some(val)) = row.try_get::<_, Option<bool>>(idx) {
613 drasi_core::models::ElementValue::Bool(val)
614 } else {
615 drasi_core::models::ElementValue::Null
616 }
617 }
618 21 | 23 | 20 => {
619 if let Ok(Some(val)) = row.try_get::<_, Option<i64>>(idx) {
621 drasi_core::models::ElementValue::Integer(val)
622 } else if let Ok(Some(val)) = row.try_get::<_, Option<i32>>(idx) {
623 drasi_core::models::ElementValue::Integer(val as i64)
624 } else if let Ok(Some(val)) = row.try_get::<_, Option<i16>>(idx) {
625 drasi_core::models::ElementValue::Integer(val as i64)
626 } else {
627 drasi_core::models::ElementValue::Null
628 }
629 }
630 700 | 701 => {
631 if let Ok(Some(val)) = row.try_get::<_, Option<f64>>(idx) {
633 drasi_core::models::ElementValue::Float(ordered_float::OrderedFloat(val))
634 } else if let Ok(Some(val)) = row.try_get::<_, Option<f32>>(idx) {
635 drasi_core::models::ElementValue::Float(ordered_float::OrderedFloat(
636 val as f64,
637 ))
638 } else {
639 drasi_core::models::ElementValue::Null
640 }
641 }
642 1700 => {
643 if let Ok(Some(val)) = row.try_get::<_, Option<rust_decimal::Decimal>>(idx) {
645 drasi_core::models::ElementValue::Float(ordered_float::OrderedFloat(
646 val.to_string().parse::<f64>().unwrap_or(0.0),
647 ))
648 } else {
649 drasi_core::models::ElementValue::Null
650 }
651 }
652 25 | 1043 | 19 => {
653 if let Ok(Some(val)) = row.try_get::<_, Option<String>>(idx) {
655 drasi_core::models::ElementValue::String(std::sync::Arc::from(val))
656 } else {
657 drasi_core::models::ElementValue::Null
658 }
659 }
660 1114 | 1184 => {
661 if let Ok(Some(val)) = row.try_get::<_, Option<chrono::NaiveDateTime>>(idx) {
663 drasi_core::models::ElementValue::String(std::sync::Arc::from(
664 val.to_string(),
665 ))
666 } else if let Ok(Some(val)) =
667 row.try_get::<_, Option<chrono::DateTime<chrono::Utc>>>(idx)
668 {
669 drasi_core::models::ElementValue::String(std::sync::Arc::from(
670 val.to_string(),
671 ))
672 } else {
673 drasi_core::models::ElementValue::Null
674 }
675 }
676 _ => {
677 if let Ok(Some(val)) = row.try_get::<_, Option<String>>(idx) {
679 drasi_core::models::ElementValue::String(std::sync::Arc::from(val))
680 } else {
681 drasi_core::models::ElementValue::Null
682 }
683 }
684 };
685
686 if is_pk && !matches!(element_value, drasi_core::models::ElementValue::Null) {
688 let value_str = match &element_value {
689 drasi_core::models::ElementValue::Integer(i) => i.to_string(),
690 drasi_core::models::ElementValue::Float(f) => f.to_string(),
691 drasi_core::models::ElementValue::String(s) => s.to_string(),
692 drasi_core::models::ElementValue::Bool(b) => b.to_string(),
693 _ => format!("{element_value:?}"),
694 };
695 pk_values.push(value_str);
696 }
697
698 properties.insert(&column.name, element_value);
699 }
700
701 let elem_id = if !pk_values.is_empty() {
704 format!("{}:{}", table_name, pk_values.join("_"))
706 } else if pk_columns.is_none() || pk_columns.map(|pks| pks.is_empty()).unwrap_or(true) {
707 warn!(
709 "No primary key found for table '{table_name}'. Consider adding 'table_keys' configuration."
710 );
711 format!("{}:{}", table_name, uuid::Uuid::new_v4())
713 } else {
714 format!("{}:{}", table_name, uuid::Uuid::new_v4())
716 };
717
718 let metadata = ElementMetadata {
719 reference: ElementReference::new(&self.source_id, &elem_id),
720 labels: Arc::from(vec![Arc::from(label)]),
721 effective_from: chrono::Utc::now().timestamp_nanos_opt().unwrap_or(0) as u64,
722 };
723
724 let element = Element::Node {
725 metadata,
726 properties,
727 };
728
729 Ok(SourceChange::Insert { element })
730 }
731
732 async fn send_batch(
734 &self,
735 batch: &mut Vec<SourceChangeEvent>,
736 context: &BootstrapContext,
737 event_tx: &drasi_lib::channels::BootstrapEventSender,
738 ) -> Result<()> {
739 for event in batch.drain(..) {
740 let sequence = context.next_sequence();
742
743 let bootstrap_event = drasi_lib::channels::BootstrapEvent {
744 source_id: event.source_id,
745 change: event.change,
746 timestamp: event.timestamp,
747 sequence,
748 };
749 event_tx.send(bootstrap_event).await.map_err(|e| {
750 anyhow!("Failed to send bootstrap event to channel (channel may be closed): {e}")
751 })?;
752 }
753 Ok(())
754 }
755}
756
757#[derive(Debug)]
758struct ColumnInfo {
759 name: String,
760 type_oid: i32,
761}