database_replicator/xmin/
reader.rs

1// ABOUTME: XminReader for xmin-based sync - reads changed rows from source PostgreSQL
2// ABOUTME: Uses xmin system column to detect rows modified since last sync
3
4use anyhow::{Context, Result};
5use tokio_postgres::{Client, Row};
6
7/// Threshold for detecting xmin wraparound.
8/// If old_xmin - new_xmin > this value, we assume wraparound occurred.
9/// PostgreSQL xmin is 32-bit (~4 billion max), so 2 billion is half.
10const WRAPAROUND_THRESHOLD: u32 = 2_000_000_000;
11
12/// Result of checking for xmin wraparound.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum WraparoundCheck {
15    /// No wraparound detected, safe to proceed with incremental sync
16    Normal,
17    /// Wraparound detected, full table sync required
18    WraparoundDetected,
19}
20
21/// Detect if xmin wraparound has occurred.
22///
23/// PostgreSQL transaction IDs are 32-bit unsigned integers that wrap around
24/// after ~4 billion transactions. When this happens, new xmin values will be
25/// smaller than old ones by a large margin (> 2 billion).
26///
27/// # Arguments
28///
29/// * `old_xmin` - The previously recorded xmin value
30/// * `current_xmin` - The current database transaction ID
31///
32/// # Returns
33///
34/// `WraparoundCheck::WraparoundDetected` if wraparound occurred, `Normal` otherwise.
35pub fn detect_wraparound(old_xmin: u32, current_xmin: u32) -> WraparoundCheck {
36    // If current < old by more than half the 32-bit range, it's likely a wraparound
37    if old_xmin > current_xmin && (old_xmin - current_xmin) > WRAPAROUND_THRESHOLD {
38        tracing::warn!(
39            "xmin wraparound detected: old_xmin={}, current_xmin={}, delta={}",
40            old_xmin,
41            current_xmin,
42            old_xmin - current_xmin
43        );
44        WraparoundCheck::WraparoundDetected
45    } else {
46        WraparoundCheck::Normal
47    }
48}
49
50/// Validate that a ctid string has the correct format "(page,tuple)".
51///
52/// ctid is a PostgreSQL system column representing the physical location of a row.
53/// Format is "(page_number,tuple_index)" where both are non-negative integers.
54/// Examples: "(0,1)", "(123,45)", "(0,100)"
55///
56/// We validate before inlining in SQL to prevent injection attacks.
57fn is_valid_ctid(s: &str) -> bool {
58    let s = s.trim();
59    if !s.starts_with('(') || !s.ends_with(')') {
60        return false;
61    }
62    let inner = &s[1..s.len() - 1];
63    let parts: Vec<&str> = inner.split(',').collect();
64    if parts.len() != 2 {
65        return false;
66    }
67    // Both parts must be valid unsigned integers
68    parts[0].trim().parse::<u64>().is_ok() && parts[1].trim().parse::<u32>().is_ok()
69}
70
71/// Reads changed rows from a PostgreSQL table using xmin-based change detection.
72///
73/// PostgreSQL's `xmin` system column contains the transaction ID that last modified
74/// each row. By tracking the maximum xmin seen, we can query for only rows that
75/// have been modified since the last sync.
76///
77/// **Warning:** xmin wraps around at 2^32 transactions. Use `detect_wraparound()`
78/// to check for this condition and trigger a full table sync when detected.
79pub struct XminReader<'a> {
80    client: &'a Client,
81}
82
83impl<'a> XminReader<'a> {
84    /// Create a new XminReader for the given PostgreSQL client connection.
85    pub fn new(client: &'a Client) -> Self {
86        Self { client }
87    }
88
89    /// Get the underlying database client.
90    pub fn client(&self) -> &Client {
91        self.client
92    }
93
94    /// Get the current transaction ID (xmin snapshot) from the database.
95    ///
96    /// This should be called at the start of a sync to establish the high-water mark.
97    pub async fn get_current_xmin(&self) -> Result<u32> {
98        let row = self
99            .client
100            .query_one("SELECT txid_current()::text::bigint", &[])
101            .await
102            .context("Failed to get current transaction ID")?;
103
104        let txid: i64 = row.get(0);
105        // xmin is stored as u32, txid_current() returns i64
106        // We mask to get the 32-bit xmin value
107        Ok((txid & 0xFFFFFFFF) as u32)
108    }
109
110    /// Read all rows from a table that have xmin greater than the given value.
111    ///
112    /// # Arguments
113    ///
114    /// * `schema` - The schema name (e.g., "public")
115    /// * `table` - The table name
116    /// * `columns` - Column names to select (pass empty slice to select all)
117    /// * `since_xmin` - Only return rows with xmin > this value (0 = all rows)
118    ///
119    /// # Returns
120    ///
121    /// A tuple of (rows, max_xmin) where max_xmin is the highest xmin seen in the result set.
122    pub async fn read_changes(
123        &self,
124        schema: &str,
125        table: &str,
126        columns: &[String],
127        since_xmin: u32,
128    ) -> Result<(Vec<Row>, u32)> {
129        let column_list = if columns.is_empty() {
130            "*".to_string()
131        } else {
132            columns
133                .iter()
134                .map(|c| format!("\"{}\"", c))
135                .collect::<Vec<_>>()
136                .join(", ")
137        };
138
139        // Query rows where xmin > since_xmin, including the xmin value
140        // Note: ORDER BY uses the casted value because xid type doesn't have ordering operators
141        let query = format!(
142            "SELECT {}, xmin::text::bigint as _xmin FROM \"{}\".\"{}\" WHERE xmin::text::bigint > $1 ORDER BY xmin::text::bigint",
143            column_list, schema, table
144        );
145
146        let rows = self
147            .client
148            .query(&query, &[&(since_xmin as i64)])
149            .await
150            .with_context(|| format!("Failed to read changes from {}.{}", schema, table))?;
151
152        // Find the max xmin in the result set
153        let max_xmin = rows
154            .iter()
155            .map(|row| {
156                let xmin: i64 = row.get("_xmin");
157                (xmin & 0xFFFFFFFF) as u32
158            })
159            .max()
160            .unwrap_or(since_xmin);
161
162        Ok((rows, max_xmin))
163    }
164
165    /// Read changes in batches to handle large tables efficiently.
166    ///
167    /// # Arguments
168    ///
169    /// * `schema` - The schema name
170    /// * `table` - The table name
171    /// * `columns` - Column names to select
172    /// * `since_xmin` - Only return rows with xmin > this value
173    /// * `batch_size` - Maximum rows per batch
174    ///
175    /// # Returns
176    ///
177    /// An iterator-like struct that yields batches of rows.
178    pub async fn read_changes_batched(
179        &self,
180        schema: &str,
181        table: &str,
182        columns: &[String],
183        since_xmin: u32,
184        batch_size: usize,
185    ) -> Result<BatchReader> {
186        Ok(BatchReader {
187            schema: schema.to_string(),
188            table: table.to_string(),
189            columns: columns.to_vec(),
190            current_xmin: since_xmin,
191            last_ctid: None,
192            batch_size,
193            exhausted: false,
194        })
195    }
196
197    /// Execute a batched read query and return the next batch.
198    ///
199    /// Uses (xmin, ctid) as the pagination key to correctly handle cases where
200    /// many rows share the same xmin (e.g., bulk inserts in a single transaction).
201    /// Without ctid tie-breaking, rows with duplicate xmin values would be skipped.
202    pub async fn fetch_batch(
203        &self,
204        batch_reader: &mut BatchReader,
205    ) -> Result<Option<(Vec<Row>, u32)>> {
206        if batch_reader.exhausted {
207            return Ok(None);
208        }
209
210        let column_list = if batch_reader.columns.is_empty() {
211            "*".to_string()
212        } else {
213            batch_reader
214                .columns
215                .iter()
216                .map(|c| format!("\"{}\"", c))
217                .collect::<Vec<_>>()
218                .join(", ")
219        };
220
221        // Use (xmin, ctid) as compound pagination key to handle duplicate xmin values.
222        // ctid is the physical tuple location and provides a stable tie-breaker.
223        let (query, rows) = if let Some(ref last_ctid) = batch_reader.last_ctid {
224            // Validate ctid format for safety before inlining in query.
225            // ctid format is "(page,tuple)" e.g., "(0,1)" or "(123,45)"
226            if !is_valid_ctid(last_ctid) {
227                anyhow::bail!("Invalid ctid format: {}", last_ctid);
228            }
229
230            // Subsequent batches: use compound (xmin, ctid) > ($1, 'ctid'::tid) filter
231            // Note: ctid must be inlined because tokio-postgres can't serialize String to tid type
232            let query = format!(
233                "SELECT {}, xmin::text::bigint as _xmin, ctid::text as _ctid \
234                 FROM \"{}\".\"{}\" \
235                 WHERE (xmin::text::bigint, ctid) > ($1, '{}'::tid) \
236                 ORDER BY xmin::text::bigint, ctid \
237                 LIMIT $2",
238                column_list, batch_reader.schema, batch_reader.table, last_ctid
239            );
240
241            let rows = self
242                .client
243                .query(
244                    &query,
245                    &[
246                        &(batch_reader.current_xmin as i64),
247                        &(batch_reader.batch_size as i64),
248                    ],
249                )
250                .await
251                .with_context(|| {
252                    format!(
253                        "Failed to read batch from {}.{}",
254                        batch_reader.schema, batch_reader.table
255                    )
256                })?;
257            (query, rows)
258        } else {
259            // First batch: simple xmin > $1 filter
260            let query = format!(
261                "SELECT {}, xmin::text::bigint as _xmin, ctid::text as _ctid \
262                 FROM \"{}\".\"{}\" \
263                 WHERE xmin::text::bigint > $1 \
264                 ORDER BY xmin::text::bigint, ctid \
265                 LIMIT $2",
266                column_list, batch_reader.schema, batch_reader.table
267            );
268
269            let rows = self
270                .client
271                .query(
272                    &query,
273                    &[
274                        &(batch_reader.current_xmin as i64),
275                        &(batch_reader.batch_size as i64),
276                    ],
277                )
278                .await
279                .with_context(|| {
280                    format!(
281                        "Failed to read batch from {}.{}",
282                        batch_reader.schema, batch_reader.table
283                    )
284                })?;
285            (query, rows)
286        };
287
288        // Suppress unused variable warning - query is useful for debugging
289        let _ = query;
290
291        if rows.is_empty() {
292            batch_reader.exhausted = true;
293            return Ok(None);
294        }
295
296        // Get xmin and ctid from the last row for next iteration's pagination
297        let last_row = rows.last().unwrap();
298        let last_xmin: i64 = last_row.get("_xmin");
299        let last_ctid: String = last_row.get("_ctid");
300
301        let max_xmin = (last_xmin & 0xFFFFFFFF) as u32;
302
303        // Mark as exhausted if we got fewer rows than batch_size
304        if rows.len() < batch_reader.batch_size {
305            batch_reader.exhausted = true;
306        }
307
308        batch_reader.current_xmin = max_xmin;
309        batch_reader.last_ctid = Some(last_ctid);
310
311        Ok(Some((rows, max_xmin)))
312    }
313
314    /// Get the estimated row count for changes since a given xmin.
315    ///
316    /// This uses EXPLAIN to estimate without actually scanning the table.
317    pub async fn estimate_changes(
318        &self,
319        schema: &str,
320        table: &str,
321        since_xmin: u32,
322    ) -> Result<i64> {
323        let query = format!(
324            "SELECT COUNT(*) FROM \"{}\".\"{}\" WHERE xmin::text::bigint > $1",
325            schema, table
326        );
327
328        let row = self
329            .client
330            .query_one(&query, &[&(since_xmin as i64)])
331            .await
332            .with_context(|| format!("Failed to count changes in {}.{}", schema, table))?;
333
334        let count: i64 = row.get(0);
335        Ok(count)
336    }
337
338    /// Get list of all tables in a schema.
339    pub async fn list_tables(&self, schema: &str) -> Result<Vec<String>> {
340        let rows = self
341            .client
342            .query(
343                "SELECT tablename FROM pg_tables WHERE schemaname = $1 ORDER BY tablename",
344                &[&schema],
345            )
346            .await
347            .with_context(|| format!("Failed to list tables in schema {}", schema))?;
348
349        Ok(rows.iter().map(|row| row.get(0)).collect())
350    }
351
352    /// Get column information for a table.
353    pub async fn get_columns(&self, schema: &str, table: &str) -> Result<Vec<ColumnInfo>> {
354        let rows = self
355            .client
356            .query(
357                "SELECT column_name, data_type, is_nullable, column_default
358                 FROM information_schema.columns
359                 WHERE table_schema = $1 AND table_name = $2
360                 ORDER BY ordinal_position",
361                &[&schema, &table],
362            )
363            .await
364            .with_context(|| format!("Failed to get columns for {}.{}", schema, table))?;
365
366        Ok(rows
367            .iter()
368            .map(|row| ColumnInfo {
369                name: row.get(0),
370                data_type: row.get(1),
371                is_nullable: row.get::<_, String>(2) == "YES",
372                has_default: row.get::<_, Option<String>>(3).is_some(),
373            })
374            .collect())
375    }
376
377    /// Get primary key columns for a table.
378    pub async fn get_primary_key(&self, schema: &str, table: &str) -> Result<Vec<String>> {
379        let rows = self
380            .client
381            .query(
382                "SELECT a.attname
383                 FROM pg_index i
384                 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
385                 JOIN pg_class c ON c.oid = i.indrelid
386                 JOIN pg_namespace n ON n.oid = c.relnamespace
387                 WHERE i.indisprimary
388                   AND n.nspname = $1
389                   AND c.relname = $2
390                 ORDER BY array_position(i.indkey, a.attnum)",
391                &[&schema, &table],
392            )
393            .await
394            .with_context(|| format!("Failed to get primary key for {}.{}", schema, table))?;
395
396        Ok(rows.iter().map(|row| row.get(0)).collect())
397    }
398
399    /// Read ALL rows from a table (full sync).
400    ///
401    /// This is used when xmin wraparound is detected and we need to resync
402    /// the entire table to ensure data consistency.
403    ///
404    /// # Arguments
405    ///
406    /// * `schema` - The schema name (e.g., "public")
407    /// * `table` - The table name
408    /// * `columns` - Column names to select (pass empty slice to select all)
409    ///
410    /// # Returns
411    ///
412    /// A tuple of (rows, max_xmin) where max_xmin is the highest xmin seen.
413    pub async fn read_all_rows(
414        &self,
415        schema: &str,
416        table: &str,
417        columns: &[String],
418    ) -> Result<(Vec<Row>, u32)> {
419        tracing::info!(
420            "Performing full table read for {}.{} (wraparound recovery)",
421            schema,
422            table
423        );
424
425        let column_list = if columns.is_empty() {
426            "*".to_string()
427        } else {
428            columns
429                .iter()
430                .map(|c| format!("\"{}\"", c))
431                .collect::<Vec<_>>()
432                .join(", ")
433        };
434
435        // Query ALL rows, including their xmin values
436        // Note: ORDER BY uses the casted value because xid type doesn't have ordering operators
437        let query = format!(
438            "SELECT {}, xmin::text::bigint as _xmin FROM \"{}\".\"{}\" ORDER BY xmin::text::bigint",
439            column_list, schema, table
440        );
441
442        let rows = self
443            .client
444            .query(&query, &[])
445            .await
446            .with_context(|| format!("Failed to read all rows from {}.{}", schema, table))?;
447
448        // Find the max xmin in the result set
449        let max_xmin = rows
450            .iter()
451            .map(|row| {
452                let xmin: i64 = row.get("_xmin");
453                (xmin & 0xFFFFFFFF) as u32
454            })
455            .max()
456            .unwrap_or(0);
457
458        tracing::info!(
459            "Full table read complete: {} rows, max_xmin={}",
460            rows.len(),
461            max_xmin
462        );
463
464        Ok((rows, max_xmin))
465    }
466
467    /// Check for wraparound and read changes accordingly.
468    ///
469    /// This is the recommended method for reading changes as it automatically
470    /// handles wraparound detection and triggers full table sync when needed.
471    ///
472    /// # Arguments
473    ///
474    /// * `schema` - The schema name
475    /// * `table` - The table name
476    /// * `columns` - Column names to select
477    /// * `since_xmin` - The last synced xmin value
478    ///
479    /// # Returns
480    ///
481    /// A tuple of (rows, max_xmin, was_full_sync) where was_full_sync indicates
482    /// if a full table sync was performed due to wraparound.
483    pub async fn read_changes_with_wraparound_check(
484        &self,
485        schema: &str,
486        table: &str,
487        columns: &[String],
488        since_xmin: u32,
489    ) -> Result<(Vec<Row>, u32, bool)> {
490        // Get current database xmin to check for wraparound
491        let current_xmin = self.get_current_xmin().await?;
492
493        // Check for wraparound
494        if detect_wraparound(since_xmin, current_xmin) == WraparoundCheck::WraparoundDetected {
495            // Wraparound detected - perform full table sync
496            let (rows, max_xmin) = self.read_all_rows(schema, table, columns).await?;
497            Ok((rows, max_xmin, true))
498        } else {
499            // Normal incremental sync
500            let (rows, max_xmin) = self
501                .read_changes(schema, table, columns, since_xmin)
502                .await?;
503            Ok((rows, max_xmin, false))
504        }
505    }
506}
507
508/// Batch reader state for iterating over large result sets.
509///
510/// Uses (xmin, ctid) as the pagination key to handle cases where many rows
511/// share the same xmin (e.g., bulk inserts in a single transaction).
512pub struct BatchReader {
513    pub schema: String,
514    pub table: String,
515    pub columns: Vec<String>,
516    pub current_xmin: u32,
517    /// Last seen ctid for tie-breaking when multiple rows have same xmin.
518    /// Format: "(page,tuple)" e.g., "(0,1)"
519    pub last_ctid: Option<String>,
520    pub batch_size: usize,
521    pub exhausted: bool,
522}
523
524/// Information about a table column.
525#[derive(Debug, Clone)]
526pub struct ColumnInfo {
527    pub name: String,
528    pub data_type: String,
529    pub is_nullable: bool,
530    pub has_default: bool,
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536
537    #[test]
538    fn test_batch_reader_initial_state() {
539        let reader = BatchReader {
540            schema: "public".to_string(),
541            table: "users".to_string(),
542            columns: vec!["id".to_string(), "name".to_string()],
543            current_xmin: 0,
544            last_ctid: None,
545            batch_size: 1000,
546            exhausted: false,
547        };
548
549        assert_eq!(reader.schema, "public");
550        assert_eq!(reader.table, "users");
551        assert_eq!(reader.current_xmin, 0);
552        assert!(reader.last_ctid.is_none());
553        assert!(!reader.exhausted);
554    }
555
556    #[test]
557    fn test_column_info() {
558        let col = ColumnInfo {
559            name: "id".to_string(),
560            data_type: "integer".to_string(),
561            is_nullable: false,
562            has_default: true,
563        };
564
565        assert_eq!(col.name, "id");
566        assert!(!col.is_nullable);
567        assert!(col.has_default);
568    }
569
570    #[test]
571    fn test_wraparound_detection_normal() {
572        // Normal case: current > old (no wraparound)
573        assert_eq!(detect_wraparound(100, 200), WraparoundCheck::Normal);
574
575        // Normal case: current slightly less than old (normal variation)
576        assert_eq!(detect_wraparound(1000, 900), WraparoundCheck::Normal);
577
578        // Normal case: both at low values
579        assert_eq!(detect_wraparound(0, 100), WraparoundCheck::Normal);
580    }
581
582    #[test]
583    fn test_wraparound_detection_wraparound() {
584        // Wraparound case: old is near max (3.5B), current is near 0
585        // Delta = 3.5B - 100 = 3.5B > 2B threshold
586        assert_eq!(
587            detect_wraparound(3_500_000_000, 100),
588            WraparoundCheck::WraparoundDetected
589        );
590
591        // Wraparound case: old at 4B, current at 1M
592        assert_eq!(
593            detect_wraparound(4_000_000_000, 1_000_000),
594            WraparoundCheck::WraparoundDetected
595        );
596
597        // Edge case: exactly at threshold
598        assert_eq!(
599            detect_wraparound(2_500_000_000, 400_000_000),
600            WraparoundCheck::WraparoundDetected
601        );
602    }
603
604    #[test]
605    fn test_wraparound_detection_edge_cases() {
606        // Edge case: old = 0, current = anything (should be normal)
607        assert_eq!(detect_wraparound(0, 1_000_000), WraparoundCheck::Normal);
608
609        // Edge case: same values
610        assert_eq!(detect_wraparound(1000, 1000), WraparoundCheck::Normal);
611
612        // Edge case: just under threshold
613        assert_eq!(detect_wraparound(2_000_000_001, 1), WraparoundCheck::Normal);
614
615        // Edge case: just at threshold
616        assert_eq!(
617            detect_wraparound(2_000_000_002, 1),
618            WraparoundCheck::WraparoundDetected
619        );
620    }
621
622    #[test]
623    fn test_is_valid_ctid() {
624        // Valid ctid formats
625        assert!(is_valid_ctid("(0,1)"));
626        assert!(is_valid_ctid("(123,45)"));
627        assert!(is_valid_ctid("(0,100)"));
628        assert!(is_valid_ctid("(999999,65535)"));
629        assert!(is_valid_ctid(" (0,1) ")); // Whitespace trimmed
630
631        // Invalid formats
632        assert!(!is_valid_ctid("0,1")); // Missing parentheses
633        assert!(!is_valid_ctid("(0,1")); // Missing closing paren
634        assert!(!is_valid_ctid("0,1)")); // Missing opening paren
635        assert!(!is_valid_ctid("(0)")); // Missing tuple index
636        assert!(!is_valid_ctid("(0,1,2)")); // Too many parts
637        assert!(!is_valid_ctid("(a,1)")); // Non-numeric page
638        assert!(!is_valid_ctid("(0,b)")); // Non-numeric tuple
639        assert!(!is_valid_ctid("")); // Empty string
640        assert!(!is_valid_ctid("()")); // Empty parens
641        assert!(!is_valid_ctid("(-1,1)")); // Negative page (parses as invalid)
642    }
643}