1use anyhow::{Context, Result};
5use tokio_postgres::{Client, Row};
6
7const WRAPAROUND_THRESHOLD: u32 = 2_000_000_000;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum WraparoundCheck {
15 Normal,
17 WraparoundDetected,
19}
20
21pub fn detect_wraparound(old_xmin: u32, current_xmin: u32) -> WraparoundCheck {
36 if old_xmin > current_xmin && (old_xmin - current_xmin) > WRAPAROUND_THRESHOLD {
38 tracing::warn!(
39 "xmin wraparound detected: old_xmin={}, current_xmin={}, delta={}",
40 old_xmin,
41 current_xmin,
42 old_xmin - current_xmin
43 );
44 WraparoundCheck::WraparoundDetected
45 } else {
46 WraparoundCheck::Normal
47 }
48}
49
50pub struct XminReader<'a> {
59 client: &'a Client,
60}
61
62impl<'a> XminReader<'a> {
63 pub fn new(client: &'a Client) -> Self {
65 Self { client }
66 }
67
68 pub fn client(&self) -> &Client {
70 self.client
71 }
72
73 pub async fn get_current_xmin(&self) -> Result<u32> {
77 let row = self
78 .client
79 .query_one("SELECT txid_current()::text::bigint", &[])
80 .await
81 .context("Failed to get current transaction ID")?;
82
83 let txid: i64 = row.get(0);
84 Ok((txid & 0xFFFFFFFF) as u32)
87 }
88
89 pub async fn read_changes(
102 &self,
103 schema: &str,
104 table: &str,
105 columns: &[String],
106 since_xmin: u32,
107 ) -> Result<(Vec<Row>, u32)> {
108 let column_list = if columns.is_empty() {
109 "*".to_string()
110 } else {
111 columns
112 .iter()
113 .map(|c| format!("\"{}\"", c))
114 .collect::<Vec<_>>()
115 .join(", ")
116 };
117
118 let query = format!(
121 "SELECT {}, xmin::text::bigint as _xmin FROM \"{}\".\"{}\" WHERE xmin::text::bigint > $1 ORDER BY xmin::text::bigint",
122 column_list, schema, table
123 );
124
125 let rows = self
126 .client
127 .query(&query, &[&(since_xmin as i64)])
128 .await
129 .with_context(|| format!("Failed to read changes from {}.{}", schema, table))?;
130
131 let max_xmin = rows
133 .iter()
134 .map(|row| {
135 let xmin: i64 = row.get("_xmin");
136 (xmin & 0xFFFFFFFF) as u32
137 })
138 .max()
139 .unwrap_or(since_xmin);
140
141 Ok((rows, max_xmin))
142 }
143
144 pub async fn read_changes_batched(
158 &self,
159 schema: &str,
160 table: &str,
161 columns: &[String],
162 since_xmin: u32,
163 batch_size: usize,
164 ) -> Result<BatchReader> {
165 Ok(BatchReader {
166 schema: schema.to_string(),
167 table: table.to_string(),
168 columns: columns.to_vec(),
169 current_xmin: since_xmin,
170 last_ctid: None,
171 batch_size,
172 exhausted: false,
173 })
174 }
175
176 pub async fn fetch_batch(
182 &self,
183 batch_reader: &mut BatchReader,
184 ) -> Result<Option<(Vec<Row>, u32)>> {
185 if batch_reader.exhausted {
186 return Ok(None);
187 }
188
189 let column_list = if batch_reader.columns.is_empty() {
190 "*".to_string()
191 } else {
192 batch_reader
193 .columns
194 .iter()
195 .map(|c| format!("\"{}\"", c))
196 .collect::<Vec<_>>()
197 .join(", ")
198 };
199
200 let (query, rows) = if let Some(ref last_ctid) = batch_reader.last_ctid {
203 let query = format!(
205 "SELECT {}, xmin::text::bigint as _xmin, ctid::text as _ctid \
206 FROM \"{}\".\"{}\" \
207 WHERE (xmin::text::bigint, ctid) > ($1, $2::tid) \
208 ORDER BY xmin::text::bigint, ctid \
209 LIMIT $3",
210 column_list, batch_reader.schema, batch_reader.table
211 );
212
213 let rows = self
214 .client
215 .query(
216 &query,
217 &[
218 &(batch_reader.current_xmin as i64),
219 &last_ctid,
220 &(batch_reader.batch_size as i64),
221 ],
222 )
223 .await
224 .with_context(|| {
225 format!(
226 "Failed to read batch from {}.{}",
227 batch_reader.schema, batch_reader.table
228 )
229 })?;
230 (query, rows)
231 } else {
232 let query = format!(
234 "SELECT {}, xmin::text::bigint as _xmin, ctid::text as _ctid \
235 FROM \"{}\".\"{}\" \
236 WHERE xmin::text::bigint > $1 \
237 ORDER BY xmin::text::bigint, ctid \
238 LIMIT $2",
239 column_list, batch_reader.schema, batch_reader.table
240 );
241
242 let rows = self
243 .client
244 .query(
245 &query,
246 &[
247 &(batch_reader.current_xmin as i64),
248 &(batch_reader.batch_size as i64),
249 ],
250 )
251 .await
252 .with_context(|| {
253 format!(
254 "Failed to read batch from {}.{}",
255 batch_reader.schema, batch_reader.table
256 )
257 })?;
258 (query, rows)
259 };
260
261 let _ = query;
263
264 if rows.is_empty() {
265 batch_reader.exhausted = true;
266 return Ok(None);
267 }
268
269 let last_row = rows.last().unwrap();
271 let last_xmin: i64 = last_row.get("_xmin");
272 let last_ctid: String = last_row.get("_ctid");
273
274 let max_xmin = (last_xmin & 0xFFFFFFFF) as u32;
275
276 if rows.len() < batch_reader.batch_size {
278 batch_reader.exhausted = true;
279 }
280
281 batch_reader.current_xmin = max_xmin;
282 batch_reader.last_ctid = Some(last_ctid);
283
284 Ok(Some((rows, max_xmin)))
285 }
286
287 pub async fn estimate_changes(
291 &self,
292 schema: &str,
293 table: &str,
294 since_xmin: u32,
295 ) -> Result<i64> {
296 let query = format!(
297 "SELECT COUNT(*) FROM \"{}\".\"{}\" WHERE xmin::text::bigint > $1",
298 schema, table
299 );
300
301 let row = self
302 .client
303 .query_one(&query, &[&(since_xmin as i64)])
304 .await
305 .with_context(|| format!("Failed to count changes in {}.{}", schema, table))?;
306
307 let count: i64 = row.get(0);
308 Ok(count)
309 }
310
311 pub async fn list_tables(&self, schema: &str) -> Result<Vec<String>> {
313 let rows = self
314 .client
315 .query(
316 "SELECT tablename FROM pg_tables WHERE schemaname = $1 ORDER BY tablename",
317 &[&schema],
318 )
319 .await
320 .with_context(|| format!("Failed to list tables in schema {}", schema))?;
321
322 Ok(rows.iter().map(|row| row.get(0)).collect())
323 }
324
325 pub async fn get_columns(&self, schema: &str, table: &str) -> Result<Vec<ColumnInfo>> {
327 let rows = self
328 .client
329 .query(
330 "SELECT column_name, data_type, is_nullable, column_default
331 FROM information_schema.columns
332 WHERE table_schema = $1 AND table_name = $2
333 ORDER BY ordinal_position",
334 &[&schema, &table],
335 )
336 .await
337 .with_context(|| format!("Failed to get columns for {}.{}", schema, table))?;
338
339 Ok(rows
340 .iter()
341 .map(|row| ColumnInfo {
342 name: row.get(0),
343 data_type: row.get(1),
344 is_nullable: row.get::<_, String>(2) == "YES",
345 has_default: row.get::<_, Option<String>>(3).is_some(),
346 })
347 .collect())
348 }
349
350 pub async fn get_primary_key(&self, schema: &str, table: &str) -> Result<Vec<String>> {
352 let rows = self
353 .client
354 .query(
355 "SELECT a.attname
356 FROM pg_index i
357 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
358 JOIN pg_class c ON c.oid = i.indrelid
359 JOIN pg_namespace n ON n.oid = c.relnamespace
360 WHERE i.indisprimary
361 AND n.nspname = $1
362 AND c.relname = $2
363 ORDER BY array_position(i.indkey, a.attnum)",
364 &[&schema, &table],
365 )
366 .await
367 .with_context(|| format!("Failed to get primary key for {}.{}", schema, table))?;
368
369 Ok(rows.iter().map(|row| row.get(0)).collect())
370 }
371
372 pub async fn read_all_rows(
387 &self,
388 schema: &str,
389 table: &str,
390 columns: &[String],
391 ) -> Result<(Vec<Row>, u32)> {
392 tracing::info!(
393 "Performing full table read for {}.{} (wraparound recovery)",
394 schema,
395 table
396 );
397
398 let column_list = if columns.is_empty() {
399 "*".to_string()
400 } else {
401 columns
402 .iter()
403 .map(|c| format!("\"{}\"", c))
404 .collect::<Vec<_>>()
405 .join(", ")
406 };
407
408 let query = format!(
411 "SELECT {}, xmin::text::bigint as _xmin FROM \"{}\".\"{}\" ORDER BY xmin::text::bigint",
412 column_list, schema, table
413 );
414
415 let rows = self
416 .client
417 .query(&query, &[])
418 .await
419 .with_context(|| format!("Failed to read all rows from {}.{}", schema, table))?;
420
421 let max_xmin = rows
423 .iter()
424 .map(|row| {
425 let xmin: i64 = row.get("_xmin");
426 (xmin & 0xFFFFFFFF) as u32
427 })
428 .max()
429 .unwrap_or(0);
430
431 tracing::info!(
432 "Full table read complete: {} rows, max_xmin={}",
433 rows.len(),
434 max_xmin
435 );
436
437 Ok((rows, max_xmin))
438 }
439
440 pub async fn read_changes_with_wraparound_check(
457 &self,
458 schema: &str,
459 table: &str,
460 columns: &[String],
461 since_xmin: u32,
462 ) -> Result<(Vec<Row>, u32, bool)> {
463 let current_xmin = self.get_current_xmin().await?;
465
466 if detect_wraparound(since_xmin, current_xmin) == WraparoundCheck::WraparoundDetected {
468 let (rows, max_xmin) = self.read_all_rows(schema, table, columns).await?;
470 Ok((rows, max_xmin, true))
471 } else {
472 let (rows, max_xmin) = self
474 .read_changes(schema, table, columns, since_xmin)
475 .await?;
476 Ok((rows, max_xmin, false))
477 }
478 }
479}
480
481pub struct BatchReader {
486 pub schema: String,
487 pub table: String,
488 pub columns: Vec<String>,
489 pub current_xmin: u32,
490 pub last_ctid: Option<String>,
493 pub batch_size: usize,
494 pub exhausted: bool,
495}
496
497#[derive(Debug, Clone)]
499pub struct ColumnInfo {
500 pub name: String,
501 pub data_type: String,
502 pub is_nullable: bool,
503 pub has_default: bool,
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509
510 #[test]
511 fn test_batch_reader_initial_state() {
512 let reader = BatchReader {
513 schema: "public".to_string(),
514 table: "users".to_string(),
515 columns: vec!["id".to_string(), "name".to_string()],
516 current_xmin: 0,
517 last_ctid: None,
518 batch_size: 1000,
519 exhausted: false,
520 };
521
522 assert_eq!(reader.schema, "public");
523 assert_eq!(reader.table, "users");
524 assert_eq!(reader.current_xmin, 0);
525 assert!(reader.last_ctid.is_none());
526 assert!(!reader.exhausted);
527 }
528
529 #[test]
530 fn test_column_info() {
531 let col = ColumnInfo {
532 name: "id".to_string(),
533 data_type: "integer".to_string(),
534 is_nullable: false,
535 has_default: true,
536 };
537
538 assert_eq!(col.name, "id");
539 assert!(!col.is_nullable);
540 assert!(col.has_default);
541 }
542
543 #[test]
544 fn test_wraparound_detection_normal() {
545 assert_eq!(detect_wraparound(100, 200), WraparoundCheck::Normal);
547
548 assert_eq!(detect_wraparound(1000, 900), WraparoundCheck::Normal);
550
551 assert_eq!(detect_wraparound(0, 100), WraparoundCheck::Normal);
553 }
554
555 #[test]
556 fn test_wraparound_detection_wraparound() {
557 assert_eq!(
560 detect_wraparound(3_500_000_000, 100),
561 WraparoundCheck::WraparoundDetected
562 );
563
564 assert_eq!(
566 detect_wraparound(4_000_000_000, 1_000_000),
567 WraparoundCheck::WraparoundDetected
568 );
569
570 assert_eq!(
572 detect_wraparound(2_500_000_000, 400_000_000),
573 WraparoundCheck::WraparoundDetected
574 );
575 }
576
577 #[test]
578 fn test_wraparound_detection_edge_cases() {
579 assert_eq!(detect_wraparound(0, 1_000_000), WraparoundCheck::Normal);
581
582 assert_eq!(detect_wraparound(1000, 1000), WraparoundCheck::Normal);
584
585 assert_eq!(detect_wraparound(2_000_000_001, 1), WraparoundCheck::Normal);
587
588 assert_eq!(
590 detect_wraparound(2_000_000_002, 1),
591 WraparoundCheck::WraparoundDetected
592 );
593 }
594}