database_replicator/xmin/
reconciler.rs

1// ABOUTME: Reconciler for xmin-based sync - detects deleted rows in source
2// ABOUTME: Compares primary keys between source and target to find orphaned rows
3
4use anyhow::{Context, Result};
5use std::cmp::Ordering;
6use std::collections::HashSet;
7use tokio_postgres::types::ToSql;
8use tokio_postgres::Client;
9
10use super::writer::ChangeWriter;
11
12/// Reconciler detects rows that exist in target but not in source (deletions).
13///
14/// Since xmin-based sync only sees modified rows, it cannot detect deletions.
15/// The Reconciler performs periodic full-table primary key comparisons to find
16/// rows that need to be deleted from the target.
17pub struct Reconciler<'a> {
18    source_client: &'a Client,
19    target_client: &'a Client,
20}
21
22impl<'a> Reconciler<'a> {
23    /// Create a new Reconciler with source and target database connections.
24    pub fn new(source_client: &'a Client, target_client: &'a Client) -> Self {
25        Self {
26            source_client,
27            target_client,
28        }
29    }
30
31    /// Find rows that exist in target but not in source (orphaned rows).
32    ///
33    /// This performs a primary key comparison between source and target tables.
34    /// Returns the primary key values of rows that should be deleted from target.
35    ///
36    /// # Arguments
37    ///
38    /// * `schema` - Schema name
39    /// * `table` - Table name
40    /// * `primary_key_columns` - Primary key column names
41    ///
42    /// # Returns
43    ///
44    /// A vector of primary key value tuples for orphaned rows.
45    pub async fn find_orphaned_rows(
46        &self,
47        schema: &str,
48        table: &str,
49        primary_key_columns: &[String],
50    ) -> Result<Vec<Vec<String>>> {
51        // Get all PKs from source
52        let source_pks = self
53            .get_all_primary_keys(self.source_client, schema, table, primary_key_columns)
54            .await
55            .context("Failed to get source primary keys")?;
56
57        // Get all PKs from target
58        let target_pks = self
59            .get_all_primary_keys(self.target_client, schema, table, primary_key_columns)
60            .await
61            .context("Failed to get target primary keys")?;
62
63        // Find PKs in target that don't exist in source
64        let source_set: HashSet<Vec<String>> = source_pks.into_iter().collect();
65        let orphaned: Vec<Vec<String>> = target_pks
66            .into_iter()
67            .filter(|pk| !source_set.contains(pk))
68            .collect();
69
70        tracing::info!(
71            "Found {} orphaned rows in {}.{} that need deletion",
72            orphaned.len(),
73            schema,
74            table
75        );
76
77        Ok(orphaned)
78    }
79
80    /// Reconcile a table by deleting orphaned rows from target.
81    ///
82    /// This is a convenience method that finds orphaned rows and deletes them.
83    ///
84    /// # Returns
85    ///
86    /// The number of rows deleted from target.
87    pub async fn reconcile_table(
88        &self,
89        schema: &str,
90        table: &str,
91        primary_key_columns: &[String],
92    ) -> Result<u64> {
93        let orphaned = self
94            .find_orphaned_rows(schema, table, primary_key_columns)
95            .await?;
96
97        if orphaned.is_empty() {
98            tracing::info!("No orphaned rows found in {}.{}", schema, table);
99            return Ok(0);
100        }
101
102        // Convert string PKs to ToSql values
103        let pk_values: Vec<Vec<Box<dyn ToSql + Sync + Send>>> = orphaned
104            .into_iter()
105            .map(|pk| {
106                pk.into_iter()
107                    .map(|v| Box::new(v) as Box<dyn ToSql + Sync + Send>)
108                    .collect()
109            })
110            .collect();
111
112        // Delete orphaned rows
113        let writer = ChangeWriter::new(self.target_client);
114        let deleted = writer
115            .delete_rows(schema, table, primary_key_columns, pk_values)
116            .await?;
117
118        tracing::info!(
119            "Deleted {} orphaned rows from {}.{}",
120            deleted,
121            schema,
122            table
123        );
124
125        Ok(deleted)
126    }
127
128    /// Get all primary key values from a table.
129    ///
130    /// Note: Uses `::text` cast for both SELECT and ORDER BY to ensure consistent
131    /// lexicographic ordering that matches Rust string comparison.
132    async fn get_all_primary_keys(
133        &self,
134        client: &Client,
135        schema: &str,
136        table: &str,
137        primary_key_columns: &[String],
138    ) -> Result<Vec<Vec<String>>> {
139        // Use ::text cast for both SELECT and ORDER BY to match Rust comparison
140        let pk_cols_text: Vec<String> = primary_key_columns
141            .iter()
142            .map(|c| format!("\"{}\"::text", c))
143            .collect();
144
145        let query = format!(
146            "SELECT {} FROM \"{}\".\"{}\" ORDER BY {}",
147            pk_cols_text.join(", "),
148            schema,
149            table,
150            pk_cols_text.join(", ")
151        );
152
153        let rows = client
154            .query(&query, &[])
155            .await
156            .with_context(|| format!("Failed to get primary keys from {}.{}", schema, table))?;
157
158        let pks: Vec<Vec<String>> = rows
159            .iter()
160            .map(|row| {
161                (0..primary_key_columns.len())
162                    .map(|i| row.get::<_, String>(i))
163                    .collect()
164            })
165            .collect();
166
167        Ok(pks)
168    }
169
170    /// Get count of rows in source and target for comparison.
171    pub async fn get_row_counts(&self, schema: &str, table: &str) -> Result<(i64, i64)> {
172        let query = format!("SELECT COUNT(*) FROM \"{}\".\"{}\"", schema, table);
173
174        let source_row = self
175            .source_client
176            .query_one(&query, &[])
177            .await
178            .context("Failed to get source row count")?;
179        let source_count: i64 = source_row.get(0);
180
181        let target_row = self
182            .target_client
183            .query_one(&query, &[])
184            .await
185            .context("Failed to get target row count")?;
186        let target_count: i64 = target_row.get(0);
187
188        Ok((source_count, target_count))
189    }
190
191    /// Check if a table exists in the target database.
192    pub async fn table_exists_in_target(&self, schema: &str, table: &str) -> Result<bool> {
193        let query = "SELECT EXISTS (
194            SELECT 1 FROM information_schema.tables
195            WHERE table_schema = $1 AND table_name = $2
196        )";
197
198        let row = self
199            .target_client
200            .query_one(query, &[&schema, &table])
201            .await
202            .context("Failed to check if table exists")?;
203
204        Ok(row.get(0))
205    }
206
207    /// Reconcile a table using batched streaming comparison (memory-efficient).
208    ///
209    /// Uses merge-join comparison on sorted primary keys fetched in batches.
210    /// This avoids loading all PKs into memory, making it suitable for tables
211    /// with millions of rows.
212    ///
213    /// # Arguments
214    ///
215    /// * `schema` - Schema name
216    /// * `table` - Table name
217    /// * `primary_key_columns` - Primary key column names
218    /// * `batch_size` - Number of PKs to fetch per batch
219    ///
220    /// # Returns
221    ///
222    /// The number of orphaned rows deleted from target.
223    pub async fn reconcile_table_batched(
224        &self,
225        schema: &str,
226        table: &str,
227        primary_key_columns: &[String],
228        batch_size: usize,
229    ) -> Result<u64> {
230        tracing::info!(
231            "Starting batched reconciliation for {}.{} (batch size: {})",
232            schema,
233            table,
234            batch_size
235        );
236
237        let writer = ChangeWriter::new(self.target_client);
238        let mut total_deleted = 0u64;
239        let mut orphans_batch: Vec<Vec<String>> = Vec::new();
240
241        // Initialize batch readers for both source and target
242        let mut source_reader = PkBatchReader::new(
243            self.source_client,
244            schema,
245            table,
246            primary_key_columns,
247            batch_size,
248        );
249        let mut target_reader = PkBatchReader::new(
250            self.target_client,
251            schema,
252            table,
253            primary_key_columns,
254            batch_size,
255        );
256
257        // Fetch initial batches
258        let mut source_batch = source_reader.fetch_next().await?;
259        let mut target_batch = target_reader.fetch_next().await?;
260        let mut source_idx = 0;
261        let mut target_idx = 0;
262        let mut comparisons = 0u64;
263
264        // Merge-join comparison loop
265        loop {
266            // Refill source batch if exhausted
267            if source_idx >= source_batch.len() && !source_reader.exhausted {
268                source_batch = source_reader.fetch_next().await?;
269                source_idx = 0;
270            }
271
272            // Refill target batch if exhausted
273            if target_idx >= target_batch.len() && !target_reader.exhausted {
274                target_batch = target_reader.fetch_next().await?;
275                target_idx = 0;
276            }
277
278            // Check termination conditions
279            let source_exhausted = source_idx >= source_batch.len();
280            let target_exhausted = target_idx >= target_batch.len();
281
282            if source_exhausted && target_exhausted {
283                // Both exhausted - done
284                break;
285            }
286
287            if source_exhausted {
288                // Source exhausted but target has more - all remaining are orphans
289                while target_idx < target_batch.len() {
290                    orphans_batch.push(target_batch[target_idx].clone());
291                    target_idx += 1;
292
293                    // Delete batch when full
294                    if orphans_batch.len() >= batch_size {
295                        total_deleted += self
296                            .delete_orphan_batch(
297                                &writer,
298                                schema,
299                                table,
300                                primary_key_columns,
301                                &orphans_batch,
302                            )
303                            .await?;
304                        orphans_batch.clear();
305                    }
306                }
307
308                // Fetch more from target
309                if !target_reader.exhausted {
310                    target_batch = target_reader.fetch_next().await?;
311                    target_idx = 0;
312                }
313                continue;
314            }
315
316            if target_exhausted {
317                // Target exhausted but source has more - no more orphans possible
318                break;
319            }
320
321            // Compare current PKs
322            let source_pk = &source_batch[source_idx];
323            let target_pk = &target_batch[target_idx];
324            comparisons += 1;
325
326            match compare_pks(source_pk, target_pk) {
327                Ordering::Equal => {
328                    // PKs match - both exist, advance both
329                    source_idx += 1;
330                    target_idx += 1;
331                }
332                Ordering::Less => {
333                    // Source PK < Target PK - source has row target doesn't
334                    // This is fine, just advance source
335                    source_idx += 1;
336                }
337                Ordering::Greater => {
338                    // Source PK > Target PK - target has orphan
339                    orphans_batch.push(target_pk.clone());
340                    target_idx += 1;
341
342                    // Delete batch when full
343                    if orphans_batch.len() >= batch_size {
344                        total_deleted += self
345                            .delete_orphan_batch(
346                                &writer,
347                                schema,
348                                table,
349                                primary_key_columns,
350                                &orphans_batch,
351                            )
352                            .await?;
353                        orphans_batch.clear();
354                    }
355                }
356            }
357
358            // Log progress periodically
359            if comparisons.is_multiple_of(100_000) {
360                tracing::info!(
361                    "Reconciliation progress for {}.{}: {} comparisons, {} orphans found",
362                    schema,
363                    table,
364                    comparisons,
365                    total_deleted + orphans_batch.len() as u64
366                );
367            }
368        }
369
370        // Delete remaining orphans
371        if !orphans_batch.is_empty() {
372            total_deleted += self
373                .delete_orphan_batch(&writer, schema, table, primary_key_columns, &orphans_batch)
374                .await?;
375        }
376
377        tracing::info!(
378            "Completed reconciliation for {}.{}: {} comparisons, {} orphans deleted",
379            schema,
380            table,
381            comparisons,
382            total_deleted
383        );
384
385        Ok(total_deleted)
386    }
387
388    /// Delete a batch of orphan rows.
389    async fn delete_orphan_batch(
390        &self,
391        writer: &ChangeWriter<'_>,
392        schema: &str,
393        table: &str,
394        primary_key_columns: &[String],
395        orphans: &[Vec<String>],
396    ) -> Result<u64> {
397        if orphans.is_empty() {
398            return Ok(0);
399        }
400
401        tracing::debug!(
402            "Deleting batch of {} orphan rows from {}.{}",
403            orphans.len(),
404            schema,
405            table
406        );
407
408        // Convert string PKs to ToSql values
409        let pk_values: Vec<Vec<Box<dyn ToSql + Sync + Send>>> = orphans
410            .iter()
411            .map(|pk| {
412                pk.iter()
413                    .map(|v| Box::new(v.clone()) as Box<dyn ToSql + Sync + Send>)
414                    .collect()
415            })
416            .collect();
417
418        writer
419            .delete_rows(schema, table, primary_key_columns, pk_values)
420            .await
421    }
422}
423
424/// Compare two primary key tuples lexicographically.
425fn compare_pks(a: &[String], b: &[String]) -> Ordering {
426    for (av, bv) in a.iter().zip(b.iter()) {
427        match av.cmp(bv) {
428            Ordering::Equal => continue,
429            other => return other,
430        }
431    }
432    a.len().cmp(&b.len())
433}
434
435/// Batch reader for primary keys using keyset pagination.
436///
437/// Fetches PKs in sorted order using WHERE pk > last_pk LIMIT batch_size,
438/// which is more efficient than OFFSET for large tables.
439struct PkBatchReader<'a> {
440    client: &'a Client,
441    schema: String,
442    table: String,
443    pk_columns: Vec<String>,
444    batch_size: usize,
445    last_pk: Option<Vec<String>>,
446    pub exhausted: bool,
447}
448
449impl<'a> PkBatchReader<'a> {
450    fn new(
451        client: &'a Client,
452        schema: &str,
453        table: &str,
454        pk_columns: &[String],
455        batch_size: usize,
456    ) -> Self {
457        Self {
458            client,
459            schema: schema.to_string(),
460            table: table.to_string(),
461            pk_columns: pk_columns.to_vec(),
462            batch_size,
463            last_pk: None,
464            exhausted: false,
465        }
466    }
467
468    /// Fetch the next batch of primary keys.
469    ///
470    /// IMPORTANT: Both SELECT and ORDER BY use `::text` cast to ensure the SQL
471    /// stream order matches the lexicographic comparison used in Rust. Without
472    /// this, numeric PKs would be ordered numerically in SQL (1, 2, 10) but
473    /// compared lexicographically in Rust ("1" < "10" < "2"), causing false
474    /// orphan detection and data loss.
475    async fn fetch_next(&mut self) -> Result<Vec<Vec<String>>> {
476        if self.exhausted {
477            return Ok(Vec::new());
478        }
479
480        // Cast PKs to text for both SELECT and ORDER BY to ensure SQL stream
481        // order matches Rust's lexicographic string comparison
482        let pk_cols_text: Vec<String> = self
483            .pk_columns
484            .iter()
485            .map(|c| format!("\"{}\"::text", c))
486            .collect();
487
488        let query = if self.last_pk.is_some() {
489            // Keyset pagination: WHERE (pk1::text, pk2::text, ...) > ($1, $2, ...)
490            // Must use text-cast columns in WHERE to match ORDER BY ordering
491            let params: Vec<String> = (1..=self.pk_columns.len())
492                .map(|i| format!("${}", i))
493                .collect();
494
495            format!(
496                "SELECT {} FROM \"{}\".\"{}\" WHERE ({}) > ({}) ORDER BY {} LIMIT {}",
497                pk_cols_text.join(", "),
498                self.schema,
499                self.table,
500                pk_cols_text.join(", "),
501                params.join(", "),
502                pk_cols_text.join(", "),
503                self.batch_size
504            )
505        } else {
506            // First batch: no WHERE clause
507            format!(
508                "SELECT {} FROM \"{}\".\"{}\" ORDER BY {} LIMIT {}",
509                pk_cols_text.join(", "),
510                self.schema,
511                self.table,
512                pk_cols_text.join(", "),
513                self.batch_size
514            )
515        };
516
517        // Build parameters for keyset pagination
518        let params: Vec<&(dyn ToSql + Sync)> = if let Some(ref last) = self.last_pk {
519            last.iter().map(|s| s as &(dyn ToSql + Sync)).collect()
520        } else {
521            Vec::new()
522        };
523
524        let rows = self.client.query(&query, &params).await.with_context(|| {
525            format!(
526                "Failed to fetch PK batch from {}.{}",
527                self.schema, self.table
528            )
529        })?;
530
531        if rows.len() < self.batch_size {
532            self.exhausted = true;
533        }
534
535        let pks: Vec<Vec<String>> = rows
536            .iter()
537            .map(|row| {
538                (0..self.pk_columns.len())
539                    .map(|i| row.get::<_, String>(i))
540                    .collect()
541            })
542            .collect();
543
544        // Update last_pk for next iteration
545        if let Some(last_row) = pks.last() {
546            self.last_pk = Some(last_row.clone());
547        }
548
549        Ok(pks)
550    }
551}
552
553/// Configuration for reconciliation behavior.
554#[derive(Debug, Clone)]
555pub struct ReconcileConfig {
556    /// Whether to actually delete orphaned rows (false = dry run)
557    pub delete_orphans: bool,
558    /// Maximum number of orphans to delete in one batch
559    pub max_deletes: Option<usize>,
560    /// Tables to skip during reconciliation
561    pub skip_tables: Vec<String>,
562}
563
564impl Default for ReconcileConfig {
565    fn default() -> Self {
566        Self {
567            delete_orphans: true,
568            max_deletes: None,
569            skip_tables: Vec::new(),
570        }
571    }
572}
573
574/// Result of a reconciliation operation.
575#[derive(Debug, Clone)]
576pub struct ReconcileResult {
577    pub schema: String,
578    pub table: String,
579    pub source_count: i64,
580    pub target_count: i64,
581    pub orphaned_count: usize,
582    pub deleted_count: u64,
583}
584
585impl ReconcileResult {
586    /// Check if the table is in sync (same row count, no orphans).
587    pub fn is_in_sync(&self) -> bool {
588        self.source_count == self.target_count && self.orphaned_count == 0
589    }
590}
591
592#[cfg(test)]
593mod tests {
594    use super::*;
595
596    #[test]
597    fn test_reconcile_config_default() {
598        let config = ReconcileConfig::default();
599        assert!(config.delete_orphans);
600        assert!(config.max_deletes.is_none());
601        assert!(config.skip_tables.is_empty());
602    }
603
604    #[test]
605    fn test_reconcile_result_in_sync() {
606        let result = ReconcileResult {
607            schema: "public".to_string(),
608            table: "users".to_string(),
609            source_count: 100,
610            target_count: 100,
611            orphaned_count: 0,
612            deleted_count: 0,
613        };
614        assert!(result.is_in_sync());
615    }
616
617    #[test]
618    fn test_reconcile_result_not_in_sync() {
619        let result = ReconcileResult {
620            schema: "public".to_string(),
621            table: "users".to_string(),
622            source_count: 100,
623            target_count: 105,
624            orphaned_count: 5,
625            deleted_count: 0,
626        };
627        assert!(!result.is_in_sync());
628    }
629}