database_replicator/xmin/
reader.rs

1// ABOUTME: XminReader for xmin-based sync - reads changed rows from source PostgreSQL
2// ABOUTME: Uses xmin system column to detect rows modified since last sync
3
4use anyhow::{Context, Result};
5use tokio_postgres::{Client, Row};
6
7/// Threshold for detecting xmin wraparound.
8/// If old_xmin - new_xmin > this value, we assume wraparound occurred.
9/// PostgreSQL xmin is 32-bit (~4 billion max), so 2 billion is half.
10const WRAPAROUND_THRESHOLD: u32 = 2_000_000_000;
11
12/// Result of checking for xmin wraparound.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum WraparoundCheck {
15    /// No wraparound detected, safe to proceed with incremental sync
16    Normal,
17    /// Wraparound detected, full table sync required
18    WraparoundDetected,
19}
20
21/// Detect if xmin wraparound has occurred.
22///
23/// PostgreSQL transaction IDs are 32-bit unsigned integers that wrap around
24/// after ~4 billion transactions. When this happens, new xmin values will be
25/// smaller than old ones by a large margin (> 2 billion).
26///
27/// # Arguments
28///
29/// * `old_xmin` - The previously recorded xmin value
30/// * `current_xmin` - The current database transaction ID
31///
32/// # Returns
33///
34/// `WraparoundCheck::WraparoundDetected` if wraparound occurred, `Normal` otherwise.
35pub fn detect_wraparound(old_xmin: u32, current_xmin: u32) -> WraparoundCheck {
36    // If current < old by more than half the 32-bit range, it's likely a wraparound
37    if old_xmin > current_xmin && (old_xmin - current_xmin) > WRAPAROUND_THRESHOLD {
38        tracing::warn!(
39            "xmin wraparound detected: old_xmin={}, current_xmin={}, delta={}",
40            old_xmin,
41            current_xmin,
42            old_xmin - current_xmin
43        );
44        WraparoundCheck::WraparoundDetected
45    } else {
46        WraparoundCheck::Normal
47    }
48}
49
50/// Reads changed rows from a PostgreSQL table using xmin-based change detection.
51///
52/// PostgreSQL's `xmin` system column contains the transaction ID that last modified
53/// each row. By tracking the maximum xmin seen, we can query for only rows that
54/// have been modified since the last sync.
55///
56/// **Warning:** xmin wraps around at 2^32 transactions. Use `detect_wraparound()`
57/// to check for this condition and trigger a full table sync when detected.
58pub struct XminReader<'a> {
59    client: &'a Client,
60}
61
62impl<'a> XminReader<'a> {
63    /// Create a new XminReader for the given PostgreSQL client connection.
64    pub fn new(client: &'a Client) -> Self {
65        Self { client }
66    }
67
68    /// Get the underlying database client.
69    pub fn client(&self) -> &Client {
70        self.client
71    }
72
73    /// Get the current transaction ID (xmin snapshot) from the database.
74    ///
75    /// This should be called at the start of a sync to establish the high-water mark.
76    pub async fn get_current_xmin(&self) -> Result<u32> {
77        let row = self
78            .client
79            .query_one("SELECT txid_current()::text::bigint", &[])
80            .await
81            .context("Failed to get current transaction ID")?;
82
83        let txid: i64 = row.get(0);
84        // xmin is stored as u32, txid_current() returns i64
85        // We mask to get the 32-bit xmin value
86        Ok((txid & 0xFFFFFFFF) as u32)
87    }
88
89    /// Read all rows from a table that have xmin greater than the given value.
90    ///
91    /// # Arguments
92    ///
93    /// * `schema` - The schema name (e.g., "public")
94    /// * `table` - The table name
95    /// * `columns` - Column names to select (pass empty slice to select all)
96    /// * `since_xmin` - Only return rows with xmin > this value (0 = all rows)
97    ///
98    /// # Returns
99    ///
100    /// A tuple of (rows, max_xmin) where max_xmin is the highest xmin seen in the result set.
101    pub async fn read_changes(
102        &self,
103        schema: &str,
104        table: &str,
105        columns: &[String],
106        since_xmin: u32,
107    ) -> Result<(Vec<Row>, u32)> {
108        let column_list = if columns.is_empty() {
109            "*".to_string()
110        } else {
111            columns
112                .iter()
113                .map(|c| format!("\"{}\"", c))
114                .collect::<Vec<_>>()
115                .join(", ")
116        };
117
118        // Query rows where xmin > since_xmin, including the xmin value
119        // Note: ORDER BY uses the casted value because xid type doesn't have ordering operators
120        let query = format!(
121            "SELECT {}, xmin::text::bigint as _xmin FROM \"{}\".\"{}\" WHERE xmin::text::bigint > $1 ORDER BY xmin::text::bigint",
122            column_list, schema, table
123        );
124
125        let rows = self
126            .client
127            .query(&query, &[&(since_xmin as i64)])
128            .await
129            .with_context(|| format!("Failed to read changes from {}.{}", schema, table))?;
130
131        // Find the max xmin in the result set
132        let max_xmin = rows
133            .iter()
134            .map(|row| {
135                let xmin: i64 = row.get("_xmin");
136                (xmin & 0xFFFFFFFF) as u32
137            })
138            .max()
139            .unwrap_or(since_xmin);
140
141        Ok((rows, max_xmin))
142    }
143
144    /// Read changes in batches to handle large tables efficiently.
145    ///
146    /// # Arguments
147    ///
148    /// * `schema` - The schema name
149    /// * `table` - The table name
150    /// * `columns` - Column names to select
151    /// * `since_xmin` - Only return rows with xmin > this value
152    /// * `batch_size` - Maximum rows per batch
153    ///
154    /// # Returns
155    ///
156    /// An iterator-like struct that yields batches of rows.
157    pub async fn read_changes_batched(
158        &self,
159        schema: &str,
160        table: &str,
161        columns: &[String],
162        since_xmin: u32,
163        batch_size: usize,
164    ) -> Result<BatchReader> {
165        Ok(BatchReader {
166            schema: schema.to_string(),
167            table: table.to_string(),
168            columns: columns.to_vec(),
169            current_xmin: since_xmin,
170            last_ctid: None,
171            batch_size,
172            exhausted: false,
173        })
174    }
175
176    /// Execute a batched read query and return the next batch.
177    ///
178    /// Uses (xmin, ctid) as the pagination key to correctly handle cases where
179    /// many rows share the same xmin (e.g., bulk inserts in a single transaction).
180    /// Without ctid tie-breaking, rows with duplicate xmin values would be skipped.
181    pub async fn fetch_batch(
182        &self,
183        batch_reader: &mut BatchReader,
184    ) -> Result<Option<(Vec<Row>, u32)>> {
185        if batch_reader.exhausted {
186            return Ok(None);
187        }
188
189        let column_list = if batch_reader.columns.is_empty() {
190            "*".to_string()
191        } else {
192            batch_reader
193                .columns
194                .iter()
195                .map(|c| format!("\"{}\"", c))
196                .collect::<Vec<_>>()
197                .join(", ")
198        };
199
200        // Use (xmin, ctid) as compound pagination key to handle duplicate xmin values.
201        // ctid is the physical tuple location and provides a stable tie-breaker.
202        let (query, rows) = if let Some(ref last_ctid) = batch_reader.last_ctid {
203            // Subsequent batches: use compound (xmin, ctid) > ($1, $2) filter
204            let query = format!(
205                "SELECT {}, xmin::text::bigint as _xmin, ctid::text as _ctid \
206                 FROM \"{}\".\"{}\" \
207                 WHERE (xmin::text::bigint, ctid) > ($1, $2::tid) \
208                 ORDER BY xmin::text::bigint, ctid \
209                 LIMIT $3",
210                column_list, batch_reader.schema, batch_reader.table
211            );
212
213            let rows = self
214                .client
215                .query(
216                    &query,
217                    &[
218                        &(batch_reader.current_xmin as i64),
219                        &last_ctid,
220                        &(batch_reader.batch_size as i64),
221                    ],
222                )
223                .await
224                .with_context(|| {
225                    format!(
226                        "Failed to read batch from {}.{}",
227                        batch_reader.schema, batch_reader.table
228                    )
229                })?;
230            (query, rows)
231        } else {
232            // First batch: simple xmin > $1 filter
233            let query = format!(
234                "SELECT {}, xmin::text::bigint as _xmin, ctid::text as _ctid \
235                 FROM \"{}\".\"{}\" \
236                 WHERE xmin::text::bigint > $1 \
237                 ORDER BY xmin::text::bigint, ctid \
238                 LIMIT $2",
239                column_list, batch_reader.schema, batch_reader.table
240            );
241
242            let rows = self
243                .client
244                .query(
245                    &query,
246                    &[
247                        &(batch_reader.current_xmin as i64),
248                        &(batch_reader.batch_size as i64),
249                    ],
250                )
251                .await
252                .with_context(|| {
253                    format!(
254                        "Failed to read batch from {}.{}",
255                        batch_reader.schema, batch_reader.table
256                    )
257                })?;
258            (query, rows)
259        };
260
261        // Suppress unused variable warning - query is useful for debugging
262        let _ = query;
263
264        if rows.is_empty() {
265            batch_reader.exhausted = true;
266            return Ok(None);
267        }
268
269        // Get xmin and ctid from the last row for next iteration's pagination
270        let last_row = rows.last().unwrap();
271        let last_xmin: i64 = last_row.get("_xmin");
272        let last_ctid: String = last_row.get("_ctid");
273
274        let max_xmin = (last_xmin & 0xFFFFFFFF) as u32;
275
276        // Mark as exhausted if we got fewer rows than batch_size
277        if rows.len() < batch_reader.batch_size {
278            batch_reader.exhausted = true;
279        }
280
281        batch_reader.current_xmin = max_xmin;
282        batch_reader.last_ctid = Some(last_ctid);
283
284        Ok(Some((rows, max_xmin)))
285    }
286
287    /// Get the estimated row count for changes since a given xmin.
288    ///
289    /// This uses EXPLAIN to estimate without actually scanning the table.
290    pub async fn estimate_changes(
291        &self,
292        schema: &str,
293        table: &str,
294        since_xmin: u32,
295    ) -> Result<i64> {
296        let query = format!(
297            "SELECT COUNT(*) FROM \"{}\".\"{}\" WHERE xmin::text::bigint > $1",
298            schema, table
299        );
300
301        let row = self
302            .client
303            .query_one(&query, &[&(since_xmin as i64)])
304            .await
305            .with_context(|| format!("Failed to count changes in {}.{}", schema, table))?;
306
307        let count: i64 = row.get(0);
308        Ok(count)
309    }
310
311    /// Get list of all tables in a schema.
312    pub async fn list_tables(&self, schema: &str) -> Result<Vec<String>> {
313        let rows = self
314            .client
315            .query(
316                "SELECT tablename FROM pg_tables WHERE schemaname = $1 ORDER BY tablename",
317                &[&schema],
318            )
319            .await
320            .with_context(|| format!("Failed to list tables in schema {}", schema))?;
321
322        Ok(rows.iter().map(|row| row.get(0)).collect())
323    }
324
325    /// Get column information for a table.
326    pub async fn get_columns(&self, schema: &str, table: &str) -> Result<Vec<ColumnInfo>> {
327        let rows = self
328            .client
329            .query(
330                "SELECT column_name, data_type, is_nullable, column_default
331                 FROM information_schema.columns
332                 WHERE table_schema = $1 AND table_name = $2
333                 ORDER BY ordinal_position",
334                &[&schema, &table],
335            )
336            .await
337            .with_context(|| format!("Failed to get columns for {}.{}", schema, table))?;
338
339        Ok(rows
340            .iter()
341            .map(|row| ColumnInfo {
342                name: row.get(0),
343                data_type: row.get(1),
344                is_nullable: row.get::<_, String>(2) == "YES",
345                has_default: row.get::<_, Option<String>>(3).is_some(),
346            })
347            .collect())
348    }
349
350    /// Get primary key columns for a table.
351    pub async fn get_primary_key(&self, schema: &str, table: &str) -> Result<Vec<String>> {
352        let rows = self
353            .client
354            .query(
355                "SELECT a.attname
356                 FROM pg_index i
357                 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
358                 JOIN pg_class c ON c.oid = i.indrelid
359                 JOIN pg_namespace n ON n.oid = c.relnamespace
360                 WHERE i.indisprimary
361                   AND n.nspname = $1
362                   AND c.relname = $2
363                 ORDER BY array_position(i.indkey, a.attnum)",
364                &[&schema, &table],
365            )
366            .await
367            .with_context(|| format!("Failed to get primary key for {}.{}", schema, table))?;
368
369        Ok(rows.iter().map(|row| row.get(0)).collect())
370    }
371
372    /// Read ALL rows from a table (full sync).
373    ///
374    /// This is used when xmin wraparound is detected and we need to resync
375    /// the entire table to ensure data consistency.
376    ///
377    /// # Arguments
378    ///
379    /// * `schema` - The schema name (e.g., "public")
380    /// * `table` - The table name
381    /// * `columns` - Column names to select (pass empty slice to select all)
382    ///
383    /// # Returns
384    ///
385    /// A tuple of (rows, max_xmin) where max_xmin is the highest xmin seen.
386    pub async fn read_all_rows(
387        &self,
388        schema: &str,
389        table: &str,
390        columns: &[String],
391    ) -> Result<(Vec<Row>, u32)> {
392        tracing::info!(
393            "Performing full table read for {}.{} (wraparound recovery)",
394            schema,
395            table
396        );
397
398        let column_list = if columns.is_empty() {
399            "*".to_string()
400        } else {
401            columns
402                .iter()
403                .map(|c| format!("\"{}\"", c))
404                .collect::<Vec<_>>()
405                .join(", ")
406        };
407
408        // Query ALL rows, including their xmin values
409        // Note: ORDER BY uses the casted value because xid type doesn't have ordering operators
410        let query = format!(
411            "SELECT {}, xmin::text::bigint as _xmin FROM \"{}\".\"{}\" ORDER BY xmin::text::bigint",
412            column_list, schema, table
413        );
414
415        let rows = self
416            .client
417            .query(&query, &[])
418            .await
419            .with_context(|| format!("Failed to read all rows from {}.{}", schema, table))?;
420
421        // Find the max xmin in the result set
422        let max_xmin = rows
423            .iter()
424            .map(|row| {
425                let xmin: i64 = row.get("_xmin");
426                (xmin & 0xFFFFFFFF) as u32
427            })
428            .max()
429            .unwrap_or(0);
430
431        tracing::info!(
432            "Full table read complete: {} rows, max_xmin={}",
433            rows.len(),
434            max_xmin
435        );
436
437        Ok((rows, max_xmin))
438    }
439
440    /// Check for wraparound and read changes accordingly.
441    ///
442    /// This is the recommended method for reading changes as it automatically
443    /// handles wraparound detection and triggers full table sync when needed.
444    ///
445    /// # Arguments
446    ///
447    /// * `schema` - The schema name
448    /// * `table` - The table name
449    /// * `columns` - Column names to select
450    /// * `since_xmin` - The last synced xmin value
451    ///
452    /// # Returns
453    ///
454    /// A tuple of (rows, max_xmin, was_full_sync) where was_full_sync indicates
455    /// if a full table sync was performed due to wraparound.
456    pub async fn read_changes_with_wraparound_check(
457        &self,
458        schema: &str,
459        table: &str,
460        columns: &[String],
461        since_xmin: u32,
462    ) -> Result<(Vec<Row>, u32, bool)> {
463        // Get current database xmin to check for wraparound
464        let current_xmin = self.get_current_xmin().await?;
465
466        // Check for wraparound
467        if detect_wraparound(since_xmin, current_xmin) == WraparoundCheck::WraparoundDetected {
468            // Wraparound detected - perform full table sync
469            let (rows, max_xmin) = self.read_all_rows(schema, table, columns).await?;
470            Ok((rows, max_xmin, true))
471        } else {
472            // Normal incremental sync
473            let (rows, max_xmin) = self
474                .read_changes(schema, table, columns, since_xmin)
475                .await?;
476            Ok((rows, max_xmin, false))
477        }
478    }
479}
480
481/// Batch reader state for iterating over large result sets.
482///
483/// Uses (xmin, ctid) as the pagination key to handle cases where many rows
484/// share the same xmin (e.g., bulk inserts in a single transaction).
485pub struct BatchReader {
486    pub schema: String,
487    pub table: String,
488    pub columns: Vec<String>,
489    pub current_xmin: u32,
490    /// Last seen ctid for tie-breaking when multiple rows have same xmin.
491    /// Format: "(page,tuple)" e.g., "(0,1)"
492    pub last_ctid: Option<String>,
493    pub batch_size: usize,
494    pub exhausted: bool,
495}
496
497/// Information about a table column.
498#[derive(Debug, Clone)]
499pub struct ColumnInfo {
500    pub name: String,
501    pub data_type: String,
502    pub is_nullable: bool,
503    pub has_default: bool,
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509
510    #[test]
511    fn test_batch_reader_initial_state() {
512        let reader = BatchReader {
513            schema: "public".to_string(),
514            table: "users".to_string(),
515            columns: vec!["id".to_string(), "name".to_string()],
516            current_xmin: 0,
517            last_ctid: None,
518            batch_size: 1000,
519            exhausted: false,
520        };
521
522        assert_eq!(reader.schema, "public");
523        assert_eq!(reader.table, "users");
524        assert_eq!(reader.current_xmin, 0);
525        assert!(reader.last_ctid.is_none());
526        assert!(!reader.exhausted);
527    }
528
529    #[test]
530    fn test_column_info() {
531        let col = ColumnInfo {
532            name: "id".to_string(),
533            data_type: "integer".to_string(),
534            is_nullable: false,
535            has_default: true,
536        };
537
538        assert_eq!(col.name, "id");
539        assert!(!col.is_nullable);
540        assert!(col.has_default);
541    }
542
543    #[test]
544    fn test_wraparound_detection_normal() {
545        // Normal case: current > old (no wraparound)
546        assert_eq!(detect_wraparound(100, 200), WraparoundCheck::Normal);
547
548        // Normal case: current slightly less than old (normal variation)
549        assert_eq!(detect_wraparound(1000, 900), WraparoundCheck::Normal);
550
551        // Normal case: both at low values
552        assert_eq!(detect_wraparound(0, 100), WraparoundCheck::Normal);
553    }
554
555    #[test]
556    fn test_wraparound_detection_wraparound() {
557        // Wraparound case: old is near max (3.5B), current is near 0
558        // Delta = 3.5B - 100 = 3.5B > 2B threshold
559        assert_eq!(
560            detect_wraparound(3_500_000_000, 100),
561            WraparoundCheck::WraparoundDetected
562        );
563
564        // Wraparound case: old at 4B, current at 1M
565        assert_eq!(
566            detect_wraparound(4_000_000_000, 1_000_000),
567            WraparoundCheck::WraparoundDetected
568        );
569
570        // Edge case: exactly at threshold
571        assert_eq!(
572            detect_wraparound(2_500_000_000, 400_000_000),
573            WraparoundCheck::WraparoundDetected
574        );
575    }
576
577    #[test]
578    fn test_wraparound_detection_edge_cases() {
579        // Edge case: old = 0, current = anything (should be normal)
580        assert_eq!(detect_wraparound(0, 1_000_000), WraparoundCheck::Normal);
581
582        // Edge case: same values
583        assert_eq!(detect_wraparound(1000, 1000), WraparoundCheck::Normal);
584
585        // Edge case: just under threshold
586        assert_eq!(detect_wraparound(2_000_000_001, 1), WraparoundCheck::Normal);
587
588        // Edge case: just at threshold
589        assert_eq!(
590            detect_wraparound(2_000_000_002, 1),
591            WraparoundCheck::WraparoundDetected
592        );
593    }
594}