database_replicator/xmin/
reader.rs

1// ABOUTME: XminReader for xmin-based sync - reads changed rows from source PostgreSQL
2// ABOUTME: Uses xmin system column to detect rows modified since last sync
3
4use anyhow::{Context, Result};
5use tokio_postgres::{Client, Row};
6
7/// Threshold for detecting xmin wraparound.
8/// If old_xmin - new_xmin > this value, we assume wraparound occurred.
9/// PostgreSQL xmin is 32-bit (~4 billion max), so 2 billion is half.
10const WRAPAROUND_THRESHOLD: u32 = 2_000_000_000;
11
12/// Result of checking for xmin wraparound.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum WraparoundCheck {
15    /// No wraparound detected, safe to proceed with incremental sync
16    Normal,
17    /// Wraparound detected, full table sync required
18    WraparoundDetected,
19}
20
21/// Detect if xmin wraparound has occurred.
22///
23/// PostgreSQL transaction IDs are 32-bit unsigned integers that wrap around
24/// after ~4 billion transactions. When this happens, new xmin values will be
25/// smaller than old ones by a large margin (> 2 billion).
26///
27/// # Arguments
28///
29/// * `old_xmin` - The previously recorded xmin value
30/// * `current_xmin` - The current database transaction ID
31///
32/// # Returns
33///
34/// `WraparoundCheck::WraparoundDetected` if wraparound occurred, `Normal` otherwise.
35pub fn detect_wraparound(old_xmin: u32, current_xmin: u32) -> WraparoundCheck {
36    // If current < old by more than half the 32-bit range, it's likely a wraparound
37    if old_xmin > current_xmin && (old_xmin - current_xmin) > WRAPAROUND_THRESHOLD {
38        tracing::warn!(
39            "xmin wraparound detected: old_xmin={}, current_xmin={}, delta={}",
40            old_xmin,
41            current_xmin,
42            old_xmin - current_xmin
43        );
44        WraparoundCheck::WraparoundDetected
45    } else {
46        WraparoundCheck::Normal
47    }
48}
49
50/// Reads changed rows from a PostgreSQL table using xmin-based change detection.
51///
52/// PostgreSQL's `xmin` system column contains the transaction ID that last modified
53/// each row. By tracking the maximum xmin seen, we can query for only rows that
54/// have been modified since the last sync.
55///
56/// **Warning:** xmin wraps around at 2^32 transactions. Use `detect_wraparound()`
57/// to check for this condition and trigger a full table sync when detected.
58pub struct XminReader<'a> {
59    client: &'a Client,
60}
61
62impl<'a> XminReader<'a> {
63    /// Create a new XminReader for the given PostgreSQL client connection.
64    pub fn new(client: &'a Client) -> Self {
65        Self { client }
66    }
67
68    /// Get the underlying database client.
69    pub fn client(&self) -> &Client {
70        self.client
71    }
72
73    /// Get the current transaction ID (xmin snapshot) from the database.
74    ///
75    /// This should be called at the start of a sync to establish the high-water mark.
76    pub async fn get_current_xmin(&self) -> Result<u32> {
77        let row = self
78            .client
79            .query_one("SELECT txid_current()::text::bigint", &[])
80            .await
81            .context("Failed to get current transaction ID")?;
82
83        let txid: i64 = row.get(0);
84        // xmin is stored as u32, txid_current() returns i64
85        // We mask to get the 32-bit xmin value
86        Ok((txid & 0xFFFFFFFF) as u32)
87    }
88
89    /// Read all rows from a table that have xmin greater than the given value.
90    ///
91    /// # Arguments
92    ///
93    /// * `schema` - The schema name (e.g., "public")
94    /// * `table` - The table name
95    /// * `columns` - Column names to select (pass empty slice to select all)
96    /// * `since_xmin` - Only return rows with xmin > this value (0 = all rows)
97    ///
98    /// # Returns
99    ///
100    /// A tuple of (rows, max_xmin) where max_xmin is the highest xmin seen in the result set.
101    pub async fn read_changes(
102        &self,
103        schema: &str,
104        table: &str,
105        columns: &[String],
106        since_xmin: u32,
107    ) -> Result<(Vec<Row>, u32)> {
108        let column_list = if columns.is_empty() {
109            "*".to_string()
110        } else {
111            columns
112                .iter()
113                .map(|c| format!("\"{}\"", c))
114                .collect::<Vec<_>>()
115                .join(", ")
116        };
117
118        // Query rows where xmin > since_xmin, including the xmin value
119        // Note: ORDER BY uses the casted value because xid type doesn't have ordering operators
120        let query = format!(
121            "SELECT {}, xmin::text::bigint as _xmin FROM \"{}\".\"{}\" WHERE xmin::text::bigint > $1 ORDER BY xmin::text::bigint",
122            column_list, schema, table
123        );
124
125        let rows = self
126            .client
127            .query(&query, &[&(since_xmin as i64)])
128            .await
129            .with_context(|| format!("Failed to read changes from {}.{}", schema, table))?;
130
131        // Find the max xmin in the result set
132        let max_xmin = rows
133            .iter()
134            .map(|row| {
135                let xmin: i64 = row.get("_xmin");
136                (xmin & 0xFFFFFFFF) as u32
137            })
138            .max()
139            .unwrap_or(since_xmin);
140
141        Ok((rows, max_xmin))
142    }
143
144    /// Read changes in batches to handle large tables efficiently.
145    ///
146    /// # Arguments
147    ///
148    /// * `schema` - The schema name
149    /// * `table` - The table name
150    /// * `columns` - Column names to select
151    /// * `since_xmin` - Only return rows with xmin > this value
152    /// * `batch_size` - Maximum rows per batch
153    ///
154    /// # Returns
155    ///
156    /// An iterator-like struct that yields batches of rows.
157    pub async fn read_changes_batched(
158        &self,
159        schema: &str,
160        table: &str,
161        columns: &[String],
162        since_xmin: u32,
163        batch_size: usize,
164    ) -> Result<BatchReader> {
165        Ok(BatchReader {
166            schema: schema.to_string(),
167            table: table.to_string(),
168            columns: columns.to_vec(),
169            current_xmin: since_xmin,
170            batch_size,
171            exhausted: false,
172        })
173    }
174
175    /// Execute a batched read query and return the next batch.
176    pub async fn fetch_batch(
177        &self,
178        batch_reader: &mut BatchReader,
179    ) -> Result<Option<(Vec<Row>, u32)>> {
180        if batch_reader.exhausted {
181            return Ok(None);
182        }
183
184        let column_list = if batch_reader.columns.is_empty() {
185            "*".to_string()
186        } else {
187            batch_reader
188                .columns
189                .iter()
190                .map(|c| format!("\"{}\"", c))
191                .collect::<Vec<_>>()
192                .join(", ")
193        };
194
195        let query = format!(
196            "SELECT {}, xmin::text::bigint as _xmin FROM \"{}\".\"{}\" \
197             WHERE xmin::text::bigint > $1 \
198             ORDER BY xmin::text::bigint \
199             LIMIT $2",
200            column_list, batch_reader.schema, batch_reader.table
201        );
202
203        let rows = self
204            .client
205            .query(
206                &query,
207                &[
208                    &(batch_reader.current_xmin as i64),
209                    &(batch_reader.batch_size as i64),
210                ],
211            )
212            .await
213            .with_context(|| {
214                format!(
215                    "Failed to read batch from {}.{}",
216                    batch_reader.schema, batch_reader.table
217                )
218            })?;
219
220        if rows.is_empty() {
221            batch_reader.exhausted = true;
222            return Ok(None);
223        }
224
225        // Update current_xmin to the max in this batch
226        let max_xmin = rows
227            .iter()
228            .map(|row| {
229                let xmin: i64 = row.get("_xmin");
230                (xmin & 0xFFFFFFFF) as u32
231            })
232            .max()
233            .unwrap_or(batch_reader.current_xmin);
234
235        // Mark as exhausted if we got fewer rows than batch_size
236        if rows.len() < batch_reader.batch_size {
237            batch_reader.exhausted = true;
238        }
239
240        batch_reader.current_xmin = max_xmin;
241
242        Ok(Some((rows, max_xmin)))
243    }
244
245    /// Get the estimated row count for changes since a given xmin.
246    ///
247    /// This uses EXPLAIN to estimate without actually scanning the table.
248    pub async fn estimate_changes(
249        &self,
250        schema: &str,
251        table: &str,
252        since_xmin: u32,
253    ) -> Result<i64> {
254        let query = format!(
255            "SELECT COUNT(*) FROM \"{}\".\"{}\" WHERE xmin::text::bigint > $1",
256            schema, table
257        );
258
259        let row = self
260            .client
261            .query_one(&query, &[&(since_xmin as i64)])
262            .await
263            .with_context(|| format!("Failed to count changes in {}.{}", schema, table))?;
264
265        let count: i64 = row.get(0);
266        Ok(count)
267    }
268
269    /// Get list of all tables in a schema.
270    pub async fn list_tables(&self, schema: &str) -> Result<Vec<String>> {
271        let rows = self
272            .client
273            .query(
274                "SELECT tablename FROM pg_tables WHERE schemaname = $1 ORDER BY tablename",
275                &[&schema],
276            )
277            .await
278            .with_context(|| format!("Failed to list tables in schema {}", schema))?;
279
280        Ok(rows.iter().map(|row| row.get(0)).collect())
281    }
282
283    /// Get column information for a table.
284    pub async fn get_columns(&self, schema: &str, table: &str) -> Result<Vec<ColumnInfo>> {
285        let rows = self
286            .client
287            .query(
288                "SELECT column_name, data_type, is_nullable, column_default
289                 FROM information_schema.columns
290                 WHERE table_schema = $1 AND table_name = $2
291                 ORDER BY ordinal_position",
292                &[&schema, &table],
293            )
294            .await
295            .with_context(|| format!("Failed to get columns for {}.{}", schema, table))?;
296
297        Ok(rows
298            .iter()
299            .map(|row| ColumnInfo {
300                name: row.get(0),
301                data_type: row.get(1),
302                is_nullable: row.get::<_, String>(2) == "YES",
303                has_default: row.get::<_, Option<String>>(3).is_some(),
304            })
305            .collect())
306    }
307
308    /// Get primary key columns for a table.
309    pub async fn get_primary_key(&self, schema: &str, table: &str) -> Result<Vec<String>> {
310        let rows = self
311            .client
312            .query(
313                "SELECT a.attname
314                 FROM pg_index i
315                 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
316                 JOIN pg_class c ON c.oid = i.indrelid
317                 JOIN pg_namespace n ON n.oid = c.relnamespace
318                 WHERE i.indisprimary
319                   AND n.nspname = $1
320                   AND c.relname = $2
321                 ORDER BY array_position(i.indkey, a.attnum)",
322                &[&schema, &table],
323            )
324            .await
325            .with_context(|| format!("Failed to get primary key for {}.{}", schema, table))?;
326
327        Ok(rows.iter().map(|row| row.get(0)).collect())
328    }
329
330    /// Read ALL rows from a table (full sync).
331    ///
332    /// This is used when xmin wraparound is detected and we need to resync
333    /// the entire table to ensure data consistency.
334    ///
335    /// # Arguments
336    ///
337    /// * `schema` - The schema name (e.g., "public")
338    /// * `table` - The table name
339    /// * `columns` - Column names to select (pass empty slice to select all)
340    ///
341    /// # Returns
342    ///
343    /// A tuple of (rows, max_xmin) where max_xmin is the highest xmin seen.
344    pub async fn read_all_rows(
345        &self,
346        schema: &str,
347        table: &str,
348        columns: &[String],
349    ) -> Result<(Vec<Row>, u32)> {
350        tracing::info!(
351            "Performing full table read for {}.{} (wraparound recovery)",
352            schema,
353            table
354        );
355
356        let column_list = if columns.is_empty() {
357            "*".to_string()
358        } else {
359            columns
360                .iter()
361                .map(|c| format!("\"{}\"", c))
362                .collect::<Vec<_>>()
363                .join(", ")
364        };
365
366        // Query ALL rows, including their xmin values
367        // Note: ORDER BY uses the casted value because xid type doesn't have ordering operators
368        let query = format!(
369            "SELECT {}, xmin::text::bigint as _xmin FROM \"{}\".\"{}\" ORDER BY xmin::text::bigint",
370            column_list, schema, table
371        );
372
373        let rows = self
374            .client
375            .query(&query, &[])
376            .await
377            .with_context(|| format!("Failed to read all rows from {}.{}", schema, table))?;
378
379        // Find the max xmin in the result set
380        let max_xmin = rows
381            .iter()
382            .map(|row| {
383                let xmin: i64 = row.get("_xmin");
384                (xmin & 0xFFFFFFFF) as u32
385            })
386            .max()
387            .unwrap_or(0);
388
389        tracing::info!(
390            "Full table read complete: {} rows, max_xmin={}",
391            rows.len(),
392            max_xmin
393        );
394
395        Ok((rows, max_xmin))
396    }
397
398    /// Check for wraparound and read changes accordingly.
399    ///
400    /// This is the recommended method for reading changes as it automatically
401    /// handles wraparound detection and triggers full table sync when needed.
402    ///
403    /// # Arguments
404    ///
405    /// * `schema` - The schema name
406    /// * `table` - The table name
407    /// * `columns` - Column names to select
408    /// * `since_xmin` - The last synced xmin value
409    ///
410    /// # Returns
411    ///
412    /// A tuple of (rows, max_xmin, was_full_sync) where was_full_sync indicates
413    /// if a full table sync was performed due to wraparound.
414    pub async fn read_changes_with_wraparound_check(
415        &self,
416        schema: &str,
417        table: &str,
418        columns: &[String],
419        since_xmin: u32,
420    ) -> Result<(Vec<Row>, u32, bool)> {
421        // Get current database xmin to check for wraparound
422        let current_xmin = self.get_current_xmin().await?;
423
424        // Check for wraparound
425        if detect_wraparound(since_xmin, current_xmin) == WraparoundCheck::WraparoundDetected {
426            // Wraparound detected - perform full table sync
427            let (rows, max_xmin) = self.read_all_rows(schema, table, columns).await?;
428            Ok((rows, max_xmin, true))
429        } else {
430            // Normal incremental sync
431            let (rows, max_xmin) = self
432                .read_changes(schema, table, columns, since_xmin)
433                .await?;
434            Ok((rows, max_xmin, false))
435        }
436    }
437}
438
439/// Batch reader state for iterating over large result sets.
440pub struct BatchReader {
441    pub schema: String,
442    pub table: String,
443    pub columns: Vec<String>,
444    pub current_xmin: u32,
445    pub batch_size: usize,
446    pub exhausted: bool,
447}
448
449/// Information about a table column.
450#[derive(Debug, Clone)]
451pub struct ColumnInfo {
452    pub name: String,
453    pub data_type: String,
454    pub is_nullable: bool,
455    pub has_default: bool,
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461
462    #[test]
463    fn test_batch_reader_initial_state() {
464        let reader = BatchReader {
465            schema: "public".to_string(),
466            table: "users".to_string(),
467            columns: vec!["id".to_string(), "name".to_string()],
468            current_xmin: 0,
469            batch_size: 1000,
470            exhausted: false,
471        };
472
473        assert_eq!(reader.schema, "public");
474        assert_eq!(reader.table, "users");
475        assert_eq!(reader.current_xmin, 0);
476        assert!(!reader.exhausted);
477    }
478
479    #[test]
480    fn test_column_info() {
481        let col = ColumnInfo {
482            name: "id".to_string(),
483            data_type: "integer".to_string(),
484            is_nullable: false,
485            has_default: true,
486        };
487
488        assert_eq!(col.name, "id");
489        assert!(!col.is_nullable);
490        assert!(col.has_default);
491    }
492
493    #[test]
494    fn test_wraparound_detection_normal() {
495        // Normal case: current > old (no wraparound)
496        assert_eq!(detect_wraparound(100, 200), WraparoundCheck::Normal);
497
498        // Normal case: current slightly less than old (normal variation)
499        assert_eq!(detect_wraparound(1000, 900), WraparoundCheck::Normal);
500
501        // Normal case: both at low values
502        assert_eq!(detect_wraparound(0, 100), WraparoundCheck::Normal);
503    }
504
505    #[test]
506    fn test_wraparound_detection_wraparound() {
507        // Wraparound case: old is near max (3.5B), current is near 0
508        // Delta = 3.5B - 100 = 3.5B > 2B threshold
509        assert_eq!(
510            detect_wraparound(3_500_000_000, 100),
511            WraparoundCheck::WraparoundDetected
512        );
513
514        // Wraparound case: old at 4B, current at 1M
515        assert_eq!(
516            detect_wraparound(4_000_000_000, 1_000_000),
517            WraparoundCheck::WraparoundDetected
518        );
519
520        // Edge case: exactly at threshold
521        assert_eq!(
522            detect_wraparound(2_500_000_000, 400_000_000),
523            WraparoundCheck::WraparoundDetected
524        );
525    }
526
527    #[test]
528    fn test_wraparound_detection_edge_cases() {
529        // Edge case: old = 0, current = anything (should be normal)
530        assert_eq!(detect_wraparound(0, 1_000_000), WraparoundCheck::Normal);
531
532        // Edge case: same values
533        assert_eq!(detect_wraparound(1000, 1000), WraparoundCheck::Normal);
534
535        // Edge case: just under threshold
536        assert_eq!(detect_wraparound(2_000_000_001, 1), WraparoundCheck::Normal);
537
538        // Edge case: just at threshold
539        assert_eq!(
540            detect_wraparound(2_000_000_002, 1),
541            WraparoundCheck::WraparoundDetected
542        );
543    }
544}