database_replicator/xmin/
writer.rs

1// ABOUTME: ChangeWriter for xmin-based sync - applies changes to target PostgreSQL
2// ABOUTME: Uses INSERT ... ON CONFLICT DO UPDATE for efficient upserts
3
4use anyhow::{Context, Result};
5use rust_decimal::Decimal;
6use tokio_postgres::types::ToSql;
7use tokio_postgres::{Client, Row};
8
9/// Writes changes to the target PostgreSQL database using upsert operations.
10///
11/// The ChangeWriter handles batched upserts within transactions for efficiency
12/// and atomicity. It dynamically builds INSERT ... ON CONFLICT DO UPDATE queries
13/// based on table schema.
14pub struct ChangeWriter<'a> {
15    client: &'a Client,
16}
17
18impl<'a> ChangeWriter<'a> {
19    /// Create a new ChangeWriter for the given PostgreSQL client connection.
20    pub fn new(client: &'a Client) -> Self {
21        Self { client }
22    }
23
24    /// Get a reference to the underlying client.
25    ///
26    /// Useful for callers that need to perform additional queries.
27    pub fn client(&self) -> &Client {
28        self.client
29    }
30
31    /// Apply a batch of rows to a table using upsert (INSERT ... ON CONFLICT DO UPDATE).
32    ///
33    /// Uses batching internally to stay within PostgreSQL's parameter limits.
34    /// Each batch is executed as a separate query (PostgreSQL auto-commits).
35    /// Automatically retries with smaller batches if "value too large" errors occur.
36    ///
37    /// # Arguments
38    ///
39    /// * `schema` - The schema name (e.g., "public")
40    /// * `table` - The table name
41    /// * `primary_key_columns` - Column names that form the primary key
42    /// * `all_columns` - All column names in the order they appear in `rows`
43    /// * `rows` - The rows to upsert, each row is a vector of values
44    ///
45    /// # Returns
46    ///
47    /// The number of rows affected.
48    pub async fn apply_batch(
49        &self,
50        schema: &str,
51        table: &str,
52        primary_key_columns: &[String],
53        all_columns: &[String],
54        rows: Vec<Vec<Box<dyn ToSql + Sync + Send>>>,
55    ) -> Result<u64> {
56        if rows.is_empty() {
57            return Ok(0);
58        }
59
60        // PostgreSQL has a limit of ~65535 parameters per query
61        // Calculate batch size based on number of columns, but cap at 100 rows
62        // to avoid "value too large to transmit" errors with large JSONB/TEXT columns
63        let params_per_row = all_columns.len();
64        let max_params = 65000; // Leave some margin
65        let param_based_batch_size = std::cmp::max(1, max_params / params_per_row);
66        let batch_size = std::cmp::min(param_based_batch_size, 100); // Cap at 100 rows
67
68        let mut total_affected = 0u64;
69
70        for chunk in rows.chunks(batch_size) {
71            let affected = self
72                .execute_upsert_batch_with_retry(
73                    schema,
74                    table,
75                    primary_key_columns,
76                    all_columns,
77                    chunk,
78                )
79                .await?;
80            total_affected += affected;
81        }
82
83        Ok(total_affected)
84    }
85
86    /// Execute upsert batch with automatic retry using smaller batches on "value too large" errors.
87    /// Uses iterative splitting instead of recursion to handle Rust's async limitations.
88    async fn execute_upsert_batch_with_retry(
89        &self,
90        schema: &str,
91        table: &str,
92        primary_key_columns: &[String],
93        all_columns: &[String],
94        rows: &[Vec<Box<dyn ToSql + Sync + Send>>],
95    ) -> Result<u64> {
96        // Try progressively smaller batch sizes until success
97        let mut current_batch_size = rows.len();
98        let mut total_affected = 0u64;
99        let mut offset = 0;
100
101        while offset < rows.len() {
102            let end = std::cmp::min(offset + current_batch_size, rows.len());
103            let chunk = &rows[offset..end];
104
105            match self
106                .execute_upsert_batch(schema, table, primary_key_columns, all_columns, chunk)
107                .await
108            {
109                Ok(affected) => {
110                    total_affected += affected;
111                    offset = end;
112                    // Reset batch size for next chunk
113                    current_batch_size = std::cmp::min(100, rows.len() - offset);
114                }
115                Err(e) => {
116                    let error_str = format!("{:?}", e);
117                    if error_str.contains("value too large to transmit") {
118                        if current_batch_size > 1 {
119                            // Halve the batch size and retry
120                            current_batch_size /= 2;
121                            tracing::warn!(
122                                "Batch too large for {}.{}, reducing to {} rows",
123                                schema,
124                                table,
125                                current_batch_size
126                            );
127                        } else {
128                            // Single row still too large - this is a data issue
129                            anyhow::bail!(
130                                "Single row too large to transmit for {}.{}. \
131                                 Consider reducing column sizes or using COPY protocol.",
132                                schema,
133                                table
134                            );
135                        }
136                    } else {
137                        return Err(e);
138                    }
139                }
140            }
141        }
142
143        Ok(total_affected)
144    }
145
146    /// Execute a single batch of upserts.
147    async fn execute_upsert_batch(
148        &self,
149        schema: &str,
150        table: &str,
151        primary_key_columns: &[String],
152        all_columns: &[String],
153        rows: &[Vec<Box<dyn ToSql + Sync + Send>>],
154    ) -> Result<u64> {
155        if rows.is_empty() {
156            return Ok(0);
157        }
158
159        let query = build_upsert_query(schema, table, primary_key_columns, all_columns, rows.len());
160
161        // Flatten all row values into a single params vector
162        let params: Vec<&(dyn ToSql + Sync)> = rows
163            .iter()
164            .flat_map(|row| row.iter().map(|v| v.as_ref() as &(dyn ToSql + Sync)))
165            .collect();
166
167        let affected = self
168            .client
169            .execute(&query, &params)
170            .await
171            .with_context(|| format!("Failed to upsert batch into {}.{}", schema, table))?;
172
173        Ok(affected)
174    }
175
176    /// Apply a single row using upsert.
177    ///
178    /// For single rows, this is more efficient than creating a batch.
179    pub async fn apply_row(
180        &self,
181        schema: &str,
182        table: &str,
183        primary_key_columns: &[String],
184        all_columns: &[String],
185        values: Vec<Box<dyn ToSql + Sync + Send>>,
186    ) -> Result<u64> {
187        let query = build_upsert_query(schema, table, primary_key_columns, all_columns, 1);
188
189        let params: Vec<&(dyn ToSql + Sync)> = values
190            .iter()
191            .map(|v| v.as_ref() as &(dyn ToSql + Sync))
192            .collect();
193
194        let affected = self
195            .client
196            .execute(&query, &params)
197            .await
198            .with_context(|| format!("Failed to upsert row into {}.{}", schema, table))?;
199
200        Ok(affected)
201    }
202
203    /// Delete rows by primary key values.
204    ///
205    /// Used by the reconciler to remove rows that no longer exist in source.
206    /// Executes deletes in batches to stay within PostgreSQL parameter limits.
207    pub async fn delete_rows(
208        &self,
209        schema: &str,
210        table: &str,
211        primary_key_columns: &[String],
212        pk_values: Vec<Vec<Box<dyn ToSql + Sync + Send>>>,
213    ) -> Result<u64> {
214        if pk_values.is_empty() {
215            return Ok(0);
216        }
217
218        let mut total_deleted = 0u64;
219
220        // Delete in batches
221        let batch_size = 1000;
222        for chunk in pk_values.chunks(batch_size) {
223            let deleted = self
224                .execute_delete_batch(schema, table, primary_key_columns, chunk)
225                .await?;
226            total_deleted += deleted;
227        }
228
229        Ok(total_deleted)
230    }
231
232    /// Execute a batch delete.
233    async fn execute_delete_batch(
234        &self,
235        schema: &str,
236        table: &str,
237        primary_key_columns: &[String],
238        pk_values: &[Vec<Box<dyn ToSql + Sync + Send>>],
239    ) -> Result<u64> {
240        if pk_values.is_empty() {
241            return Ok(0);
242        }
243
244        let query = build_delete_query(schema, table, primary_key_columns, pk_values.len());
245
246        let params: Vec<&(dyn ToSql + Sync)> = pk_values
247            .iter()
248            .flat_map(|row| row.iter().map(|v| v.as_ref() as &(dyn ToSql + Sync)))
249            .collect();
250
251        let deleted = self
252            .client
253            .execute(&query, &params)
254            .await
255            .with_context(|| format!("Failed to delete rows from {}.{}", schema, table))?;
256
257        Ok(deleted)
258    }
259}
260
261/// Build an upsert query for the given table schema and batch size.
262///
263/// Generates a query like:
264/// ```sql
265/// INSERT INTO "schema"."table" ("col1", "col2", "col3")
266/// VALUES ($1, $2, $3), ($4, $5, $6), ...
267/// ON CONFLICT ("pk_col") DO UPDATE SET
268///   "col2" = EXCLUDED."col2",
269///   "col3" = EXCLUDED."col3"
270/// ```
271fn build_upsert_query(
272    schema: &str,
273    table: &str,
274    primary_key_columns: &[String],
275    all_columns: &[String],
276    num_rows: usize,
277) -> String {
278    // Quote identifiers to handle reserved words and special characters
279    let quoted_columns: Vec<String> = all_columns.iter().map(|c| format!("\"{}\"", c)).collect();
280
281    let quoted_pk_columns: Vec<String> = primary_key_columns
282        .iter()
283        .map(|c| format!("\"{}\"", c))
284        .collect();
285
286    // Build VALUES placeholders: ($1, $2, $3), ($4, $5, $6), ...
287    let num_cols = all_columns.len();
288    let value_rows: Vec<String> = (0..num_rows)
289        .map(|row_idx| {
290            let placeholders: Vec<String> = (0..num_cols)
291                .map(|col_idx| format!("${}", row_idx * num_cols + col_idx + 1))
292                .collect();
293            format!("({})", placeholders.join(", "))
294        })
295        .collect();
296
297    // Build UPDATE SET clause for non-PK columns
298    let update_columns: Vec<String> = all_columns
299        .iter()
300        .filter(|c| !primary_key_columns.contains(c))
301        .map(|c| format!("\"{}\" = EXCLUDED.\"{}\"", c, c))
302        .collect();
303
304    let update_clause = if update_columns.is_empty() {
305        // All columns are PKs - use DO NOTHING
306        "DO NOTHING".to_string()
307    } else {
308        format!("DO UPDATE SET {}", update_columns.join(", "))
309    };
310
311    format!(
312        "INSERT INTO \"{}\".\"{}\" ({}) VALUES {} ON CONFLICT ({}) {}",
313        schema,
314        table,
315        quoted_columns.join(", "),
316        value_rows.join(", "),
317        quoted_pk_columns.join(", "),
318        update_clause
319    )
320}
321
322/// Build a delete query for multiple rows by primary key.
323///
324/// For single-column PK:
325/// ```sql
326/// DELETE FROM "schema"."table" WHERE "id" IN ($1, $2, $3, ...)
327/// ```
328///
329/// For composite PK:
330/// ```sql
331/// DELETE FROM "schema"."table" WHERE ("pk1", "pk2") IN (($1, $2), ($3, $4), ...)
332/// ```
333fn build_delete_query(
334    schema: &str,
335    table: &str,
336    primary_key_columns: &[String],
337    num_rows: usize,
338) -> String {
339    let num_pk_cols = primary_key_columns.len();
340
341    if num_pk_cols == 1 {
342        // Simple case: single-column primary key
343        let pk_col = format!("\"{}\"", primary_key_columns[0]);
344        let placeholders: Vec<String> = (1..=num_rows).map(|i| format!("${}", i)).collect();
345
346        format!(
347            "DELETE FROM \"{}\".\"{}\" WHERE {} IN ({})",
348            schema,
349            table,
350            pk_col,
351            placeholders.join(", ")
352        )
353    } else {
354        // Composite primary key
355        let pk_cols: Vec<String> = primary_key_columns
356            .iter()
357            .map(|c| format!("\"{}\"", c))
358            .collect();
359
360        let value_tuples: Vec<String> = (0..num_rows)
361            .map(|row_idx| {
362                let placeholders: Vec<String> = (0..num_pk_cols)
363                    .map(|col_idx| format!("${}", row_idx * num_pk_cols + col_idx + 1))
364                    .collect();
365                format!("({})", placeholders.join(", "))
366            })
367            .collect();
368
369        format!(
370            "DELETE FROM \"{}\".\"{}\" WHERE ({}) IN ({})",
371            schema,
372            table,
373            pk_cols.join(", "),
374            value_tuples.join(", ")
375        )
376    }
377}
378
379/// Extract column metadata from a PostgreSQL table.
380///
381/// Returns (column_name, data_type) pairs for all columns in the table.
382/// Uses `udt_name` from information_schema which includes array type info
383/// (e.g., `_text` for text[], `_int4` for integer[]).
384pub async fn get_table_columns(
385    client: &Client,
386    schema: &str,
387    table: &str,
388) -> Result<Vec<(String, String)>> {
389    let rows = client
390        .query(
391            "SELECT column_name, udt_name
392             FROM information_schema.columns
393             WHERE table_schema = $1 AND table_name = $2
394             ORDER BY ordinal_position",
395            &[&schema, &table],
396        )
397        .await
398        .with_context(|| format!("Failed to get columns for {}.{}", schema, table))?;
399
400    Ok(rows
401        .iter()
402        .map(|row| {
403            let name: String = row.get(0);
404            let dtype: String = row.get(1);
405            (name, dtype)
406        })
407        .collect())
408}
409
410/// Get primary key columns for a table.
411///
412/// Returns the column names that form the primary key constraint.
413pub async fn get_primary_key_columns(
414    client: &Client,
415    schema: &str,
416    table: &str,
417) -> Result<Vec<String>> {
418    let rows = client
419        .query(
420            "SELECT a.attname
421             FROM pg_index i
422             JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
423             JOIN pg_class c ON c.oid = i.indrelid
424             JOIN pg_namespace n ON n.oid = c.relnamespace
425             WHERE i.indisprimary
426               AND n.nspname = $1
427               AND c.relname = $2
428             ORDER BY array_position(i.indkey, a.attnum)",
429            &[&schema, &table],
430        )
431        .await
432        .with_context(|| format!("Failed to get primary key for {}.{}", schema, table))?;
433
434    Ok(rows.iter().map(|row| row.get(0)).collect())
435}
436
437/// Convert a tokio_postgres Row to a vector of boxed ToSql values.
438///
439/// This is a helper for extracting values from source rows to pass to ChangeWriter.
440/// The caller must know the column types to extract values correctly.
441pub fn row_to_values(
442    row: &Row,
443    column_types: &[(String, String)],
444) -> Vec<Box<dyn ToSql + Sync + Send>> {
445    column_types
446        .iter()
447        .enumerate()
448        .map(|(idx, (_name, dtype))| -> Box<dyn ToSql + Sync + Send> {
449            // Handle common PostgreSQL types
450            match dtype.as_str() {
451                "integer" | "int4" => {
452                    let val: Option<i32> = row.get(idx);
453                    Box::new(val)
454                }
455                "bigint" | "int8" => {
456                    let val: Option<i64> = row.get(idx);
457                    Box::new(val)
458                }
459                "smallint" | "int2" => {
460                    let val: Option<i16> = row.get(idx);
461                    Box::new(val)
462                }
463                "text" | "varchar" | "bpchar" | "char" | "character" | "name" | "citext" => {
464                    let val: Option<String> = row.get(idx);
465                    Box::new(val)
466                }
467                "boolean" | "bool" => {
468                    let val: Option<bool> = row.get(idx);
469                    Box::new(val)
470                }
471                "real" | "float4" => {
472                    let val: Option<f32> = row.get(idx);
473                    Box::new(val)
474                }
475                "double precision" | "float8" => {
476                    let val: Option<f64> = row.get(idx);
477                    Box::new(val)
478                }
479                "uuid" => {
480                    let val: Option<uuid::Uuid> = row.get(idx);
481                    Box::new(val)
482                }
483                "timestamp without time zone" | "timestamp" => {
484                    let val: Option<chrono::NaiveDateTime> = row.get(idx);
485                    Box::new(val)
486                }
487                "timestamp with time zone" | "timestamptz" => {
488                    let val: Option<chrono::DateTime<chrono::Utc>> = row.get(idx);
489                    Box::new(val)
490                }
491                "date" => {
492                    let val: Option<chrono::NaiveDate> = row.get(idx);
493                    Box::new(val)
494                }
495                "json" | "jsonb" => {
496                    let val: Option<serde_json::Value> = row.get(idx);
497                    Box::new(val)
498                }
499                "bytea" => {
500                    let val: Option<Vec<u8>> = row.get(idx);
501                    Box::new(val)
502                }
503                "numeric" | "decimal" => {
504                    // Use rust_decimal for proper numeric handling
505                    let val: Option<Decimal> = row.get(idx);
506                    Box::new(val)
507                }
508                // Array types (PostgreSQL udt_name uses underscore prefix for array types)
509                "_text" | "_varchar" | "_bpchar" | "_citext" => {
510                    let val: Option<Vec<String>> = row.get(idx);
511                    Box::new(val)
512                }
513                "_int4" => {
514                    let val: Option<Vec<i32>> = row.get(idx);
515                    Box::new(val)
516                }
517                "_int8" => {
518                    let val: Option<Vec<i64>> = row.get(idx);
519                    Box::new(val)
520                }
521                "_int2" => {
522                    let val: Option<Vec<i16>> = row.get(idx);
523                    Box::new(val)
524                }
525                "_float4" => {
526                    let val: Option<Vec<f32>> = row.get(idx);
527                    Box::new(val)
528                }
529                "_float8" => {
530                    let val: Option<Vec<f64>> = row.get(idx);
531                    Box::new(val)
532                }
533                "_bool" => {
534                    let val: Option<Vec<bool>> = row.get(idx);
535                    Box::new(val)
536                }
537                "_uuid" => {
538                    let val: Option<Vec<uuid::Uuid>> = row.get(idx);
539                    Box::new(val)
540                }
541                "_bytea" => {
542                    let val: Option<Vec<Vec<u8>>> = row.get(idx);
543                    Box::new(val)
544                }
545                "_numeric" => {
546                    let val: Option<Vec<Decimal>> = row.get(idx);
547                    Box::new(val)
548                }
549                "_jsonb" | "_json" => {
550                    let val: Option<Vec<serde_json::Value>> = row.get(idx);
551                    Box::new(val)
552                }
553                "_timestamp" => {
554                    let val: Option<Vec<chrono::NaiveDateTime>> = row.get(idx);
555                    Box::new(val)
556                }
557                "_timestamptz" => {
558                    let val: Option<Vec<chrono::DateTime<chrono::Utc>>> = row.get(idx);
559                    Box::new(val)
560                }
561                "_date" => {
562                    let val: Option<Vec<chrono::NaiveDate>> = row.get(idx);
563                    Box::new(val)
564                }
565                _ => {
566                    // For unknown types, try to get as string
567                    let val: Option<String> = row.try_get::<_, String>(idx).ok();
568                    Box::new(val)
569                }
570            }
571        })
572        .collect()
573}
574
575#[cfg(test)]
576mod tests {
577    use super::*;
578
579    #[test]
580    fn test_build_upsert_query_single_row() {
581        let query = build_upsert_query(
582            "public",
583            "users",
584            &["id".to_string()],
585            &["id".to_string(), "name".to_string(), "email".to_string()],
586            1,
587        );
588
589        assert!(query.contains("INSERT INTO \"public\".\"users\""));
590        assert!(query.contains("(\"id\", \"name\", \"email\")"));
591        assert!(query.contains("VALUES ($1, $2, $3)"));
592        assert!(query.contains("ON CONFLICT (\"id\")"));
593        assert!(query.contains("DO UPDATE SET"));
594        assert!(query.contains("\"name\" = EXCLUDED.\"name\""));
595        assert!(query.contains("\"email\" = EXCLUDED.\"email\""));
596    }
597
598    #[test]
599    fn test_build_upsert_query_multiple_rows() {
600        let query = build_upsert_query(
601            "public",
602            "users",
603            &["id".to_string()],
604            &["id".to_string(), "name".to_string()],
605            3,
606        );
607
608        assert!(query.contains("($1, $2), ($3, $4), ($5, $6)"));
609    }
610
611    #[test]
612    fn test_build_upsert_query_composite_pk() {
613        let query = build_upsert_query(
614            "public",
615            "order_items",
616            &["order_id".to_string(), "item_id".to_string()],
617            &[
618                "order_id".to_string(),
619                "item_id".to_string(),
620                "quantity".to_string(),
621            ],
622            1,
623        );
624
625        assert!(query.contains("ON CONFLICT (\"order_id\", \"item_id\")"));
626        assert!(query.contains("\"quantity\" = EXCLUDED.\"quantity\""));
627    }
628
629    #[test]
630    fn test_build_upsert_query_all_pk_columns() {
631        // When all columns are PK columns, should use DO NOTHING
632        let query = build_upsert_query(
633            "public",
634            "tags",
635            &["id".to_string()],
636            &["id".to_string()],
637            1,
638        );
639
640        assert!(query.contains("DO NOTHING"));
641        assert!(!query.contains("DO UPDATE SET"));
642    }
643
644    #[test]
645    fn test_build_delete_query_single_pk() {
646        let query = build_delete_query("public", "users", &["id".to_string()], 3);
647
648        assert!(query.contains("DELETE FROM \"public\".\"users\""));
649        assert!(query.contains("WHERE \"id\" IN ($1, $2, $3)"));
650    }
651
652    #[test]
653    fn test_build_delete_query_composite_pk() {
654        let query = build_delete_query(
655            "public",
656            "order_items",
657            &["order_id".to_string(), "item_id".to_string()],
658            2,
659        );
660
661        assert!(query.contains("WHERE (\"order_id\", \"item_id\") IN"));
662        assert!(query.contains("($1, $2), ($3, $4)"));
663    }
664}