1use anyhow::{Context, Result};
5use std::cmp::Ordering;
6use std::collections::HashSet;
7use tokio_postgres::types::ToSql;
8use tokio_postgres::Client;
9
10use super::writer::ChangeWriter;
11
12pub struct Reconciler<'a> {
18 source_client: &'a Client,
19 target_client: &'a Client,
20}
21
22impl<'a> Reconciler<'a> {
23 pub fn new(source_client: &'a Client, target_client: &'a Client) -> Self {
25 Self {
26 source_client,
27 target_client,
28 }
29 }
30
31 pub async fn find_orphaned_rows(
46 &self,
47 schema: &str,
48 table: &str,
49 primary_key_columns: &[String],
50 ) -> Result<Vec<Vec<String>>> {
51 let source_pks = self
53 .get_all_primary_keys(self.source_client, schema, table, primary_key_columns)
54 .await
55 .context("Failed to get source primary keys")?;
56
57 let target_pks = self
59 .get_all_primary_keys(self.target_client, schema, table, primary_key_columns)
60 .await
61 .context("Failed to get target primary keys")?;
62
63 let source_set: HashSet<Vec<String>> = source_pks.into_iter().collect();
65 let orphaned: Vec<Vec<String>> = target_pks
66 .into_iter()
67 .filter(|pk| !source_set.contains(pk))
68 .collect();
69
70 tracing::info!(
71 "Found {} orphaned rows in {}.{} that need deletion",
72 orphaned.len(),
73 schema,
74 table
75 );
76
77 Ok(orphaned)
78 }
79
80 pub async fn reconcile_table(
88 &self,
89 schema: &str,
90 table: &str,
91 primary_key_columns: &[String],
92 ) -> Result<u64> {
93 let orphaned = self
94 .find_orphaned_rows(schema, table, primary_key_columns)
95 .await?;
96
97 if orphaned.is_empty() {
98 tracing::info!("No orphaned rows found in {}.{}", schema, table);
99 return Ok(0);
100 }
101
102 let pk_values: Vec<Vec<Box<dyn ToSql + Sync + Send>>> = orphaned
104 .into_iter()
105 .map(|pk| {
106 pk.into_iter()
107 .map(|v| Box::new(v) as Box<dyn ToSql + Sync + Send>)
108 .collect()
109 })
110 .collect();
111
112 let writer = ChangeWriter::new(self.target_client);
114 let deleted = writer
115 .delete_rows(schema, table, primary_key_columns, pk_values)
116 .await?;
117
118 tracing::info!(
119 "Deleted {} orphaned rows from {}.{}",
120 deleted,
121 schema,
122 table
123 );
124
125 Ok(deleted)
126 }
127
128 async fn get_all_primary_keys(
133 &self,
134 client: &Client,
135 schema: &str,
136 table: &str,
137 primary_key_columns: &[String],
138 ) -> Result<Vec<Vec<String>>> {
139 let pk_cols_text: Vec<String> = primary_key_columns
141 .iter()
142 .map(|c| format!("\"{}\"::text", c))
143 .collect();
144
145 let query = format!(
146 "SELECT {} FROM \"{}\".\"{}\" ORDER BY {}",
147 pk_cols_text.join(", "),
148 schema,
149 table,
150 pk_cols_text.join(", ")
151 );
152
153 let rows = client
154 .query(&query, &[])
155 .await
156 .with_context(|| format!("Failed to get primary keys from {}.{}", schema, table))?;
157
158 let pks: Vec<Vec<String>> = rows
159 .iter()
160 .map(|row| {
161 (0..primary_key_columns.len())
162 .map(|i| row.get::<_, String>(i))
163 .collect()
164 })
165 .collect();
166
167 Ok(pks)
168 }
169
170 pub async fn get_row_counts(&self, schema: &str, table: &str) -> Result<(i64, i64)> {
172 let query = format!("SELECT COUNT(*) FROM \"{}\".\"{}\"", schema, table);
173
174 let source_row = self
175 .source_client
176 .query_one(&query, &[])
177 .await
178 .context("Failed to get source row count")?;
179 let source_count: i64 = source_row.get(0);
180
181 let target_row = self
182 .target_client
183 .query_one(&query, &[])
184 .await
185 .context("Failed to get target row count")?;
186 let target_count: i64 = target_row.get(0);
187
188 Ok((source_count, target_count))
189 }
190
191 pub async fn table_exists_in_target(&self, schema: &str, table: &str) -> Result<bool> {
193 let query = "SELECT EXISTS (
194 SELECT 1 FROM information_schema.tables
195 WHERE table_schema = $1 AND table_name = $2
196 )";
197
198 let row = self
199 .target_client
200 .query_one(query, &[&schema, &table])
201 .await
202 .context("Failed to check if table exists")?;
203
204 Ok(row.get(0))
205 }
206
207 pub async fn reconcile_table_batched(
224 &self,
225 schema: &str,
226 table: &str,
227 primary_key_columns: &[String],
228 batch_size: usize,
229 ) -> Result<u64> {
230 tracing::info!(
231 "Starting batched reconciliation for {}.{} (batch size: {})",
232 schema,
233 table,
234 batch_size
235 );
236
237 let writer = ChangeWriter::new(self.target_client);
238 let mut total_deleted = 0u64;
239 let mut orphans_batch: Vec<Vec<String>> = Vec::new();
240
241 let mut source_reader = PkBatchReader::new(
243 self.source_client,
244 schema,
245 table,
246 primary_key_columns,
247 batch_size,
248 );
249 let mut target_reader = PkBatchReader::new(
250 self.target_client,
251 schema,
252 table,
253 primary_key_columns,
254 batch_size,
255 );
256
257 let mut source_batch = source_reader.fetch_next().await?;
259 let mut target_batch = target_reader.fetch_next().await?;
260 let mut source_idx = 0;
261 let mut target_idx = 0;
262 let mut comparisons = 0u64;
263
264 loop {
266 if source_idx >= source_batch.len() && !source_reader.exhausted {
268 source_batch = source_reader.fetch_next().await?;
269 source_idx = 0;
270 }
271
272 if target_idx >= target_batch.len() && !target_reader.exhausted {
274 target_batch = target_reader.fetch_next().await?;
275 target_idx = 0;
276 }
277
278 let source_exhausted = source_idx >= source_batch.len();
280 let target_exhausted = target_idx >= target_batch.len();
281
282 if source_exhausted && target_exhausted {
283 break;
285 }
286
287 if source_exhausted {
288 while target_idx < target_batch.len() {
290 orphans_batch.push(target_batch[target_idx].clone());
291 target_idx += 1;
292
293 if orphans_batch.len() >= batch_size {
295 total_deleted += self
296 .delete_orphan_batch(
297 &writer,
298 schema,
299 table,
300 primary_key_columns,
301 &orphans_batch,
302 )
303 .await?;
304 orphans_batch.clear();
305 }
306 }
307
308 if !target_reader.exhausted {
310 target_batch = target_reader.fetch_next().await?;
311 target_idx = 0;
312 }
313 continue;
314 }
315
316 if target_exhausted {
317 break;
319 }
320
321 let source_pk = &source_batch[source_idx];
323 let target_pk = &target_batch[target_idx];
324 comparisons += 1;
325
326 match compare_pks(source_pk, target_pk) {
327 Ordering::Equal => {
328 source_idx += 1;
330 target_idx += 1;
331 }
332 Ordering::Less => {
333 source_idx += 1;
336 }
337 Ordering::Greater => {
338 orphans_batch.push(target_pk.clone());
340 target_idx += 1;
341
342 if orphans_batch.len() >= batch_size {
344 total_deleted += self
345 .delete_orphan_batch(
346 &writer,
347 schema,
348 table,
349 primary_key_columns,
350 &orphans_batch,
351 )
352 .await?;
353 orphans_batch.clear();
354 }
355 }
356 }
357
358 if comparisons.is_multiple_of(100_000) {
360 tracing::info!(
361 "Reconciliation progress for {}.{}: {} comparisons, {} orphans found",
362 schema,
363 table,
364 comparisons,
365 total_deleted + orphans_batch.len() as u64
366 );
367 }
368 }
369
370 if !orphans_batch.is_empty() {
372 total_deleted += self
373 .delete_orphan_batch(&writer, schema, table, primary_key_columns, &orphans_batch)
374 .await?;
375 }
376
377 tracing::info!(
378 "Completed reconciliation for {}.{}: {} comparisons, {} orphans deleted",
379 schema,
380 table,
381 comparisons,
382 total_deleted
383 );
384
385 Ok(total_deleted)
386 }
387
388 async fn delete_orphan_batch(
390 &self,
391 writer: &ChangeWriter<'_>,
392 schema: &str,
393 table: &str,
394 primary_key_columns: &[String],
395 orphans: &[Vec<String>],
396 ) -> Result<u64> {
397 if orphans.is_empty() {
398 return Ok(0);
399 }
400
401 tracing::debug!(
402 "Deleting batch of {} orphan rows from {}.{}",
403 orphans.len(),
404 schema,
405 table
406 );
407
408 let pk_values: Vec<Vec<Box<dyn ToSql + Sync + Send>>> = orphans
410 .iter()
411 .map(|pk| {
412 pk.iter()
413 .map(|v| Box::new(v.clone()) as Box<dyn ToSql + Sync + Send>)
414 .collect()
415 })
416 .collect();
417
418 writer
419 .delete_rows(schema, table, primary_key_columns, pk_values)
420 .await
421 }
422}
423
424fn compare_pks(a: &[String], b: &[String]) -> Ordering {
426 for (av, bv) in a.iter().zip(b.iter()) {
427 match av.cmp(bv) {
428 Ordering::Equal => continue,
429 other => return other,
430 }
431 }
432 a.len().cmp(&b.len())
433}
434
435struct PkBatchReader<'a> {
440 client: &'a Client,
441 schema: String,
442 table: String,
443 pk_columns: Vec<String>,
444 batch_size: usize,
445 last_pk: Option<Vec<String>>,
446 pub exhausted: bool,
447}
448
449impl<'a> PkBatchReader<'a> {
450 fn new(
451 client: &'a Client,
452 schema: &str,
453 table: &str,
454 pk_columns: &[String],
455 batch_size: usize,
456 ) -> Self {
457 Self {
458 client,
459 schema: schema.to_string(),
460 table: table.to_string(),
461 pk_columns: pk_columns.to_vec(),
462 batch_size,
463 last_pk: None,
464 exhausted: false,
465 }
466 }
467
468 async fn fetch_next(&mut self) -> Result<Vec<Vec<String>>> {
476 if self.exhausted {
477 return Ok(Vec::new());
478 }
479
480 let pk_cols_text: Vec<String> = self
483 .pk_columns
484 .iter()
485 .map(|c| format!("\"{}\"::text", c))
486 .collect();
487
488 let query = if self.last_pk.is_some() {
489 let params: Vec<String> = (1..=self.pk_columns.len())
492 .map(|i| format!("${}", i))
493 .collect();
494
495 format!(
496 "SELECT {} FROM \"{}\".\"{}\" WHERE ({}) > ({}) ORDER BY {} LIMIT {}",
497 pk_cols_text.join(", "),
498 self.schema,
499 self.table,
500 pk_cols_text.join(", "),
501 params.join(", "),
502 pk_cols_text.join(", "),
503 self.batch_size
504 )
505 } else {
506 format!(
508 "SELECT {} FROM \"{}\".\"{}\" ORDER BY {} LIMIT {}",
509 pk_cols_text.join(", "),
510 self.schema,
511 self.table,
512 pk_cols_text.join(", "),
513 self.batch_size
514 )
515 };
516
517 let params: Vec<&(dyn ToSql + Sync)> = if let Some(ref last) = self.last_pk {
519 last.iter().map(|s| s as &(dyn ToSql + Sync)).collect()
520 } else {
521 Vec::new()
522 };
523
524 let rows = self.client.query(&query, ¶ms).await.with_context(|| {
525 format!(
526 "Failed to fetch PK batch from {}.{}",
527 self.schema, self.table
528 )
529 })?;
530
531 if rows.len() < self.batch_size {
532 self.exhausted = true;
533 }
534
535 let pks: Vec<Vec<String>> = rows
536 .iter()
537 .map(|row| {
538 (0..self.pk_columns.len())
539 .map(|i| row.get::<_, String>(i))
540 .collect()
541 })
542 .collect();
543
544 if let Some(last_row) = pks.last() {
546 self.last_pk = Some(last_row.clone());
547 }
548
549 Ok(pks)
550 }
551}
552
553#[derive(Debug, Clone)]
555pub struct ReconcileConfig {
556 pub delete_orphans: bool,
558 pub max_deletes: Option<usize>,
560 pub skip_tables: Vec<String>,
562}
563
564impl Default for ReconcileConfig {
565 fn default() -> Self {
566 Self {
567 delete_orphans: true,
568 max_deletes: None,
569 skip_tables: Vec::new(),
570 }
571 }
572}
573
574#[derive(Debug, Clone)]
576pub struct ReconcileResult {
577 pub schema: String,
578 pub table: String,
579 pub source_count: i64,
580 pub target_count: i64,
581 pub orphaned_count: usize,
582 pub deleted_count: u64,
583}
584
585impl ReconcileResult {
586 pub fn is_in_sync(&self) -> bool {
588 self.source_count == self.target_count && self.orphaned_count == 0
589 }
590}
591
592#[cfg(test)]
593mod tests {
594 use super::*;
595
596 #[test]
597 fn test_reconcile_config_default() {
598 let config = ReconcileConfig::default();
599 assert!(config.delete_orphans);
600 assert!(config.max_deletes.is_none());
601 assert!(config.skip_tables.is_empty());
602 }
603
604 #[test]
605 fn test_reconcile_result_in_sync() {
606 let result = ReconcileResult {
607 schema: "public".to_string(),
608 table: "users".to_string(),
609 source_count: 100,
610 target_count: 100,
611 orphaned_count: 0,
612 deleted_count: 0,
613 };
614 assert!(result.is_in_sync());
615 }
616
617 #[test]
618 fn test_reconcile_result_not_in_sync() {
619 let result = ReconcileResult {
620 schema: "public".to_string(),
621 table: "users".to_string(),
622 source_count: 100,
623 target_count: 105,
624 orphaned_count: 5,
625 deleted_count: 0,
626 };
627 assert!(!result.is_in_sync());
628 }
629}