1use anyhow::{Context, Result};
5use tokio_postgres::{Client, Row};
6
7const WRAPAROUND_THRESHOLD: u32 = 2_000_000_000;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum WraparoundCheck {
15 Normal,
17 WraparoundDetected,
19}
20
21pub fn detect_wraparound(old_xmin: u32, current_xmin: u32) -> WraparoundCheck {
36 if old_xmin > current_xmin && (old_xmin - current_xmin) > WRAPAROUND_THRESHOLD {
38 tracing::warn!(
39 "xmin wraparound detected: old_xmin={}, current_xmin={}, delta={}",
40 old_xmin,
41 current_xmin,
42 old_xmin - current_xmin
43 );
44 WraparoundCheck::WraparoundDetected
45 } else {
46 WraparoundCheck::Normal
47 }
48}
49
50pub struct XminReader<'a> {
59 client: &'a Client,
60}
61
62impl<'a> XminReader<'a> {
63 pub fn new(client: &'a Client) -> Self {
65 Self { client }
66 }
67
68 pub fn client(&self) -> &Client {
70 self.client
71 }
72
73 pub async fn get_current_xmin(&self) -> Result<u32> {
77 let row = self
78 .client
79 .query_one("SELECT txid_current()::text::bigint", &[])
80 .await
81 .context("Failed to get current transaction ID")?;
82
83 let txid: i64 = row.get(0);
84 Ok((txid & 0xFFFFFFFF) as u32)
87 }
88
89 pub async fn read_changes(
102 &self,
103 schema: &str,
104 table: &str,
105 columns: &[String],
106 since_xmin: u32,
107 ) -> Result<(Vec<Row>, u32)> {
108 let column_list = if columns.is_empty() {
109 "*".to_string()
110 } else {
111 columns
112 .iter()
113 .map(|c| format!("\"{}\"", c))
114 .collect::<Vec<_>>()
115 .join(", ")
116 };
117
118 let query = format!(
121 "SELECT {}, xmin::text::bigint as _xmin FROM \"{}\".\"{}\" WHERE xmin::text::bigint > $1 ORDER BY xmin::text::bigint",
122 column_list, schema, table
123 );
124
125 let rows = self
126 .client
127 .query(&query, &[&(since_xmin as i64)])
128 .await
129 .with_context(|| format!("Failed to read changes from {}.{}", schema, table))?;
130
131 let max_xmin = rows
133 .iter()
134 .map(|row| {
135 let xmin: i64 = row.get("_xmin");
136 (xmin & 0xFFFFFFFF) as u32
137 })
138 .max()
139 .unwrap_or(since_xmin);
140
141 Ok((rows, max_xmin))
142 }
143
144 pub async fn read_changes_batched(
158 &self,
159 schema: &str,
160 table: &str,
161 columns: &[String],
162 since_xmin: u32,
163 batch_size: usize,
164 ) -> Result<BatchReader> {
165 Ok(BatchReader {
166 schema: schema.to_string(),
167 table: table.to_string(),
168 columns: columns.to_vec(),
169 current_xmin: since_xmin,
170 batch_size,
171 exhausted: false,
172 })
173 }
174
175 pub async fn fetch_batch(
177 &self,
178 batch_reader: &mut BatchReader,
179 ) -> Result<Option<(Vec<Row>, u32)>> {
180 if batch_reader.exhausted {
181 return Ok(None);
182 }
183
184 let column_list = if batch_reader.columns.is_empty() {
185 "*".to_string()
186 } else {
187 batch_reader
188 .columns
189 .iter()
190 .map(|c| format!("\"{}\"", c))
191 .collect::<Vec<_>>()
192 .join(", ")
193 };
194
195 let query = format!(
196 "SELECT {}, xmin::text::bigint as _xmin FROM \"{}\".\"{}\" \
197 WHERE xmin::text::bigint > $1 \
198 ORDER BY xmin::text::bigint \
199 LIMIT $2",
200 column_list, batch_reader.schema, batch_reader.table
201 );
202
203 let rows = self
204 .client
205 .query(
206 &query,
207 &[
208 &(batch_reader.current_xmin as i64),
209 &(batch_reader.batch_size as i64),
210 ],
211 )
212 .await
213 .with_context(|| {
214 format!(
215 "Failed to read batch from {}.{}",
216 batch_reader.schema, batch_reader.table
217 )
218 })?;
219
220 if rows.is_empty() {
221 batch_reader.exhausted = true;
222 return Ok(None);
223 }
224
225 let max_xmin = rows
227 .iter()
228 .map(|row| {
229 let xmin: i64 = row.get("_xmin");
230 (xmin & 0xFFFFFFFF) as u32
231 })
232 .max()
233 .unwrap_or(batch_reader.current_xmin);
234
235 if rows.len() < batch_reader.batch_size {
237 batch_reader.exhausted = true;
238 }
239
240 batch_reader.current_xmin = max_xmin;
241
242 Ok(Some((rows, max_xmin)))
243 }
244
245 pub async fn estimate_changes(
249 &self,
250 schema: &str,
251 table: &str,
252 since_xmin: u32,
253 ) -> Result<i64> {
254 let query = format!(
255 "SELECT COUNT(*) FROM \"{}\".\"{}\" WHERE xmin::text::bigint > $1",
256 schema, table
257 );
258
259 let row = self
260 .client
261 .query_one(&query, &[&(since_xmin as i64)])
262 .await
263 .with_context(|| format!("Failed to count changes in {}.{}", schema, table))?;
264
265 let count: i64 = row.get(0);
266 Ok(count)
267 }
268
269 pub async fn list_tables(&self, schema: &str) -> Result<Vec<String>> {
271 let rows = self
272 .client
273 .query(
274 "SELECT tablename FROM pg_tables WHERE schemaname = $1 ORDER BY tablename",
275 &[&schema],
276 )
277 .await
278 .with_context(|| format!("Failed to list tables in schema {}", schema))?;
279
280 Ok(rows.iter().map(|row| row.get(0)).collect())
281 }
282
283 pub async fn get_columns(&self, schema: &str, table: &str) -> Result<Vec<ColumnInfo>> {
285 let rows = self
286 .client
287 .query(
288 "SELECT column_name, data_type, is_nullable, column_default
289 FROM information_schema.columns
290 WHERE table_schema = $1 AND table_name = $2
291 ORDER BY ordinal_position",
292 &[&schema, &table],
293 )
294 .await
295 .with_context(|| format!("Failed to get columns for {}.{}", schema, table))?;
296
297 Ok(rows
298 .iter()
299 .map(|row| ColumnInfo {
300 name: row.get(0),
301 data_type: row.get(1),
302 is_nullable: row.get::<_, String>(2) == "YES",
303 has_default: row.get::<_, Option<String>>(3).is_some(),
304 })
305 .collect())
306 }
307
308 pub async fn get_primary_key(&self, schema: &str, table: &str) -> Result<Vec<String>> {
310 let rows = self
311 .client
312 .query(
313 "SELECT a.attname
314 FROM pg_index i
315 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
316 JOIN pg_class c ON c.oid = i.indrelid
317 JOIN pg_namespace n ON n.oid = c.relnamespace
318 WHERE i.indisprimary
319 AND n.nspname = $1
320 AND c.relname = $2
321 ORDER BY array_position(i.indkey, a.attnum)",
322 &[&schema, &table],
323 )
324 .await
325 .with_context(|| format!("Failed to get primary key for {}.{}", schema, table))?;
326
327 Ok(rows.iter().map(|row| row.get(0)).collect())
328 }
329
330 pub async fn read_all_rows(
345 &self,
346 schema: &str,
347 table: &str,
348 columns: &[String],
349 ) -> Result<(Vec<Row>, u32)> {
350 tracing::info!(
351 "Performing full table read for {}.{} (wraparound recovery)",
352 schema,
353 table
354 );
355
356 let column_list = if columns.is_empty() {
357 "*".to_string()
358 } else {
359 columns
360 .iter()
361 .map(|c| format!("\"{}\"", c))
362 .collect::<Vec<_>>()
363 .join(", ")
364 };
365
366 let query = format!(
369 "SELECT {}, xmin::text::bigint as _xmin FROM \"{}\".\"{}\" ORDER BY xmin::text::bigint",
370 column_list, schema, table
371 );
372
373 let rows = self
374 .client
375 .query(&query, &[])
376 .await
377 .with_context(|| format!("Failed to read all rows from {}.{}", schema, table))?;
378
379 let max_xmin = rows
381 .iter()
382 .map(|row| {
383 let xmin: i64 = row.get("_xmin");
384 (xmin & 0xFFFFFFFF) as u32
385 })
386 .max()
387 .unwrap_or(0);
388
389 tracing::info!(
390 "Full table read complete: {} rows, max_xmin={}",
391 rows.len(),
392 max_xmin
393 );
394
395 Ok((rows, max_xmin))
396 }
397
398 pub async fn read_changes_with_wraparound_check(
415 &self,
416 schema: &str,
417 table: &str,
418 columns: &[String],
419 since_xmin: u32,
420 ) -> Result<(Vec<Row>, u32, bool)> {
421 let current_xmin = self.get_current_xmin().await?;
423
424 if detect_wraparound(since_xmin, current_xmin) == WraparoundCheck::WraparoundDetected {
426 let (rows, max_xmin) = self.read_all_rows(schema, table, columns).await?;
428 Ok((rows, max_xmin, true))
429 } else {
430 let (rows, max_xmin) = self
432 .read_changes(schema, table, columns, since_xmin)
433 .await?;
434 Ok((rows, max_xmin, false))
435 }
436 }
437}
438
439pub struct BatchReader {
441 pub schema: String,
442 pub table: String,
443 pub columns: Vec<String>,
444 pub current_xmin: u32,
445 pub batch_size: usize,
446 pub exhausted: bool,
447}
448
449#[derive(Debug, Clone)]
451pub struct ColumnInfo {
452 pub name: String,
453 pub data_type: String,
454 pub is_nullable: bool,
455 pub has_default: bool,
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461
462 #[test]
463 fn test_batch_reader_initial_state() {
464 let reader = BatchReader {
465 schema: "public".to_string(),
466 table: "users".to_string(),
467 columns: vec!["id".to_string(), "name".to_string()],
468 current_xmin: 0,
469 batch_size: 1000,
470 exhausted: false,
471 };
472
473 assert_eq!(reader.schema, "public");
474 assert_eq!(reader.table, "users");
475 assert_eq!(reader.current_xmin, 0);
476 assert!(!reader.exhausted);
477 }
478
479 #[test]
480 fn test_column_info() {
481 let col = ColumnInfo {
482 name: "id".to_string(),
483 data_type: "integer".to_string(),
484 is_nullable: false,
485 has_default: true,
486 };
487
488 assert_eq!(col.name, "id");
489 assert!(!col.is_nullable);
490 assert!(col.has_default);
491 }
492
493 #[test]
494 fn test_wraparound_detection_normal() {
495 assert_eq!(detect_wraparound(100, 200), WraparoundCheck::Normal);
497
498 assert_eq!(detect_wraparound(1000, 900), WraparoundCheck::Normal);
500
501 assert_eq!(detect_wraparound(0, 100), WraparoundCheck::Normal);
503 }
504
505 #[test]
506 fn test_wraparound_detection_wraparound() {
507 assert_eq!(
510 detect_wraparound(3_500_000_000, 100),
511 WraparoundCheck::WraparoundDetected
512 );
513
514 assert_eq!(
516 detect_wraparound(4_000_000_000, 1_000_000),
517 WraparoundCheck::WraparoundDetected
518 );
519
520 assert_eq!(
522 detect_wraparound(2_500_000_000, 400_000_000),
523 WraparoundCheck::WraparoundDetected
524 );
525 }
526
527 #[test]
528 fn test_wraparound_detection_edge_cases() {
529 assert_eq!(detect_wraparound(0, 1_000_000), WraparoundCheck::Normal);
531
532 assert_eq!(detect_wraparound(1000, 1000), WraparoundCheck::Normal);
534
535 assert_eq!(detect_wraparound(2_000_000_001, 1), WraparoundCheck::Normal);
537
538 assert_eq!(
540 detect_wraparound(2_000_000_002, 1),
541 WraparoundCheck::WraparoundDetected
542 );
543 }
544}