1use anyhow::{Context, Result};
5use rust_decimal::Decimal;
6use tokio_postgres::types::ToSql;
7use tokio_postgres::{Client, Row};
8
9pub struct ChangeWriter<'a> {
15 client: &'a Client,
16}
17
18impl<'a> ChangeWriter<'a> {
19 pub fn new(client: &'a Client) -> Self {
21 Self { client }
22 }
23
24 pub fn client(&self) -> &Client {
28 self.client
29 }
30
31 pub async fn apply_batch(
49 &self,
50 schema: &str,
51 table: &str,
52 primary_key_columns: &[String],
53 all_columns: &[String],
54 rows: Vec<Vec<Box<dyn ToSql + Sync + Send>>>,
55 ) -> Result<u64> {
56 if rows.is_empty() {
57 return Ok(0);
58 }
59
60 let params_per_row = all_columns.len();
64 let max_params = 65000; let param_based_batch_size = std::cmp::max(1, max_params / params_per_row);
66 let batch_size = std::cmp::min(param_based_batch_size, 100); let mut total_affected = 0u64;
69
70 for chunk in rows.chunks(batch_size) {
71 let affected = self
72 .execute_upsert_batch_with_retry(
73 schema,
74 table,
75 primary_key_columns,
76 all_columns,
77 chunk,
78 )
79 .await?;
80 total_affected += affected;
81 }
82
83 Ok(total_affected)
84 }
85
86 async fn execute_upsert_batch_with_retry(
89 &self,
90 schema: &str,
91 table: &str,
92 primary_key_columns: &[String],
93 all_columns: &[String],
94 rows: &[Vec<Box<dyn ToSql + Sync + Send>>],
95 ) -> Result<u64> {
96 let mut current_batch_size = rows.len();
98 let mut total_affected = 0u64;
99 let mut offset = 0;
100
101 while offset < rows.len() {
102 let end = std::cmp::min(offset + current_batch_size, rows.len());
103 let chunk = &rows[offset..end];
104
105 match self
106 .execute_upsert_batch(schema, table, primary_key_columns, all_columns, chunk)
107 .await
108 {
109 Ok(affected) => {
110 total_affected += affected;
111 offset = end;
112 current_batch_size = std::cmp::min(100, rows.len() - offset);
114 }
115 Err(e) => {
116 let error_str = format!("{:?}", e);
117 if error_str.contains("value too large to transmit") {
118 if current_batch_size > 1 {
119 current_batch_size /= 2;
121 tracing::warn!(
122 "Batch too large for {}.{}, reducing to {} rows",
123 schema,
124 table,
125 current_batch_size
126 );
127 } else {
128 anyhow::bail!(
130 "Single row too large to transmit for {}.{}. \
131 Consider reducing column sizes or using COPY protocol.",
132 schema,
133 table
134 );
135 }
136 } else {
137 return Err(e);
138 }
139 }
140 }
141 }
142
143 Ok(total_affected)
144 }
145
146 async fn execute_upsert_batch(
148 &self,
149 schema: &str,
150 table: &str,
151 primary_key_columns: &[String],
152 all_columns: &[String],
153 rows: &[Vec<Box<dyn ToSql + Sync + Send>>],
154 ) -> Result<u64> {
155 if rows.is_empty() {
156 return Ok(0);
157 }
158
159 let query = build_upsert_query(schema, table, primary_key_columns, all_columns, rows.len());
160
161 let params: Vec<&(dyn ToSql + Sync)> = rows
163 .iter()
164 .flat_map(|row| row.iter().map(|v| v.as_ref() as &(dyn ToSql + Sync)))
165 .collect();
166
167 let affected = self
168 .client
169 .execute(&query, ¶ms)
170 .await
171 .with_context(|| format!("Failed to upsert batch into {}.{}", schema, table))?;
172
173 Ok(affected)
174 }
175
176 pub async fn apply_row(
180 &self,
181 schema: &str,
182 table: &str,
183 primary_key_columns: &[String],
184 all_columns: &[String],
185 values: Vec<Box<dyn ToSql + Sync + Send>>,
186 ) -> Result<u64> {
187 let query = build_upsert_query(schema, table, primary_key_columns, all_columns, 1);
188
189 let params: Vec<&(dyn ToSql + Sync)> = values
190 .iter()
191 .map(|v| v.as_ref() as &(dyn ToSql + Sync))
192 .collect();
193
194 let affected = self
195 .client
196 .execute(&query, ¶ms)
197 .await
198 .with_context(|| format!("Failed to upsert row into {}.{}", schema, table))?;
199
200 Ok(affected)
201 }
202
203 pub async fn delete_rows(
208 &self,
209 schema: &str,
210 table: &str,
211 primary_key_columns: &[String],
212 pk_values: Vec<Vec<Box<dyn ToSql + Sync + Send>>>,
213 ) -> Result<u64> {
214 if pk_values.is_empty() {
215 return Ok(0);
216 }
217
218 let mut total_deleted = 0u64;
219
220 let batch_size = 1000;
222 for chunk in pk_values.chunks(batch_size) {
223 let deleted = self
224 .execute_delete_batch(schema, table, primary_key_columns, chunk)
225 .await?;
226 total_deleted += deleted;
227 }
228
229 Ok(total_deleted)
230 }
231
232 async fn execute_delete_batch(
234 &self,
235 schema: &str,
236 table: &str,
237 primary_key_columns: &[String],
238 pk_values: &[Vec<Box<dyn ToSql + Sync + Send>>],
239 ) -> Result<u64> {
240 if pk_values.is_empty() {
241 return Ok(0);
242 }
243
244 let query = build_delete_query(schema, table, primary_key_columns, pk_values.len());
245
246 let params: Vec<&(dyn ToSql + Sync)> = pk_values
247 .iter()
248 .flat_map(|row| row.iter().map(|v| v.as_ref() as &(dyn ToSql + Sync)))
249 .collect();
250
251 let deleted = self
252 .client
253 .execute(&query, ¶ms)
254 .await
255 .with_context(|| format!("Failed to delete rows from {}.{}", schema, table))?;
256
257 Ok(deleted)
258 }
259}
260
261fn build_upsert_query(
272 schema: &str,
273 table: &str,
274 primary_key_columns: &[String],
275 all_columns: &[String],
276 num_rows: usize,
277) -> String {
278 let quoted_columns: Vec<String> = all_columns.iter().map(|c| format!("\"{}\"", c)).collect();
280
281 let quoted_pk_columns: Vec<String> = primary_key_columns
282 .iter()
283 .map(|c| format!("\"{}\"", c))
284 .collect();
285
286 let num_cols = all_columns.len();
288 let value_rows: Vec<String> = (0..num_rows)
289 .map(|row_idx| {
290 let placeholders: Vec<String> = (0..num_cols)
291 .map(|col_idx| format!("${}", row_idx * num_cols + col_idx + 1))
292 .collect();
293 format!("({})", placeholders.join(", "))
294 })
295 .collect();
296
297 let update_columns: Vec<String> = all_columns
299 .iter()
300 .filter(|c| !primary_key_columns.contains(c))
301 .map(|c| format!("\"{}\" = EXCLUDED.\"{}\"", c, c))
302 .collect();
303
304 let update_clause = if update_columns.is_empty() {
305 "DO NOTHING".to_string()
307 } else {
308 format!("DO UPDATE SET {}", update_columns.join(", "))
309 };
310
311 format!(
312 "INSERT INTO \"{}\".\"{}\" ({}) VALUES {} ON CONFLICT ({}) {}",
313 schema,
314 table,
315 quoted_columns.join(", "),
316 value_rows.join(", "),
317 quoted_pk_columns.join(", "),
318 update_clause
319 )
320}
321
322fn build_delete_query(
334 schema: &str,
335 table: &str,
336 primary_key_columns: &[String],
337 num_rows: usize,
338) -> String {
339 let num_pk_cols = primary_key_columns.len();
340
341 if num_pk_cols == 1 {
342 let pk_col = format!("\"{}\"", primary_key_columns[0]);
344 let placeholders: Vec<String> = (1..=num_rows).map(|i| format!("${}", i)).collect();
345
346 format!(
347 "DELETE FROM \"{}\".\"{}\" WHERE {} IN ({})",
348 schema,
349 table,
350 pk_col,
351 placeholders.join(", ")
352 )
353 } else {
354 let pk_cols: Vec<String> = primary_key_columns
356 .iter()
357 .map(|c| format!("\"{}\"", c))
358 .collect();
359
360 let value_tuples: Vec<String> = (0..num_rows)
361 .map(|row_idx| {
362 let placeholders: Vec<String> = (0..num_pk_cols)
363 .map(|col_idx| format!("${}", row_idx * num_pk_cols + col_idx + 1))
364 .collect();
365 format!("({})", placeholders.join(", "))
366 })
367 .collect();
368
369 format!(
370 "DELETE FROM \"{}\".\"{}\" WHERE ({}) IN ({})",
371 schema,
372 table,
373 pk_cols.join(", "),
374 value_tuples.join(", ")
375 )
376 }
377}
378
379pub async fn get_table_columns(
385 client: &Client,
386 schema: &str,
387 table: &str,
388) -> Result<Vec<(String, String)>> {
389 let rows = client
390 .query(
391 "SELECT column_name, udt_name
392 FROM information_schema.columns
393 WHERE table_schema = $1 AND table_name = $2
394 ORDER BY ordinal_position",
395 &[&schema, &table],
396 )
397 .await
398 .with_context(|| format!("Failed to get columns for {}.{}", schema, table))?;
399
400 Ok(rows
401 .iter()
402 .map(|row| {
403 let name: String = row.get(0);
404 let dtype: String = row.get(1);
405 (name, dtype)
406 })
407 .collect())
408}
409
410pub async fn get_primary_key_columns(
414 client: &Client,
415 schema: &str,
416 table: &str,
417) -> Result<Vec<String>> {
418 let rows = client
419 .query(
420 "SELECT a.attname
421 FROM pg_index i
422 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
423 JOIN pg_class c ON c.oid = i.indrelid
424 JOIN pg_namespace n ON n.oid = c.relnamespace
425 WHERE i.indisprimary
426 AND n.nspname = $1
427 AND c.relname = $2
428 ORDER BY array_position(i.indkey, a.attnum)",
429 &[&schema, &table],
430 )
431 .await
432 .with_context(|| format!("Failed to get primary key for {}.{}", schema, table))?;
433
434 Ok(rows.iter().map(|row| row.get(0)).collect())
435}
436
437pub fn row_to_values(
442 row: &Row,
443 column_types: &[(String, String)],
444) -> Vec<Box<dyn ToSql + Sync + Send>> {
445 column_types
446 .iter()
447 .enumerate()
448 .map(|(idx, (_name, dtype))| -> Box<dyn ToSql + Sync + Send> {
449 match dtype.as_str() {
451 "integer" | "int4" => {
452 let val: Option<i32> = row.get(idx);
453 Box::new(val)
454 }
455 "bigint" | "int8" => {
456 let val: Option<i64> = row.get(idx);
457 Box::new(val)
458 }
459 "smallint" | "int2" => {
460 let val: Option<i16> = row.get(idx);
461 Box::new(val)
462 }
463 "text" | "varchar" | "bpchar" | "char" | "character" | "name" | "citext" => {
464 let val: Option<String> = row.get(idx);
465 Box::new(val)
466 }
467 "boolean" | "bool" => {
468 let val: Option<bool> = row.get(idx);
469 Box::new(val)
470 }
471 "real" | "float4" => {
472 let val: Option<f32> = row.get(idx);
473 Box::new(val)
474 }
475 "double precision" | "float8" => {
476 let val: Option<f64> = row.get(idx);
477 Box::new(val)
478 }
479 "uuid" => {
480 let val: Option<uuid::Uuid> = row.get(idx);
481 Box::new(val)
482 }
483 "timestamp without time zone" | "timestamp" => {
484 let val: Option<chrono::NaiveDateTime> = row.get(idx);
485 Box::new(val)
486 }
487 "timestamp with time zone" | "timestamptz" => {
488 let val: Option<chrono::DateTime<chrono::Utc>> = row.get(idx);
489 Box::new(val)
490 }
491 "date" => {
492 let val: Option<chrono::NaiveDate> = row.get(idx);
493 Box::new(val)
494 }
495 "json" | "jsonb" => {
496 let val: Option<serde_json::Value> = row.get(idx);
497 Box::new(val)
498 }
499 "bytea" => {
500 let val: Option<Vec<u8>> = row.get(idx);
501 Box::new(val)
502 }
503 "numeric" | "decimal" => {
504 let val: Option<Decimal> = row.get(idx);
506 Box::new(val)
507 }
508 "_text" | "_varchar" | "_bpchar" | "_citext" => {
510 let val: Option<Vec<String>> = row.get(idx);
511 Box::new(val)
512 }
513 "_int4" => {
514 let val: Option<Vec<i32>> = row.get(idx);
515 Box::new(val)
516 }
517 "_int8" => {
518 let val: Option<Vec<i64>> = row.get(idx);
519 Box::new(val)
520 }
521 "_int2" => {
522 let val: Option<Vec<i16>> = row.get(idx);
523 Box::new(val)
524 }
525 "_float4" => {
526 let val: Option<Vec<f32>> = row.get(idx);
527 Box::new(val)
528 }
529 "_float8" => {
530 let val: Option<Vec<f64>> = row.get(idx);
531 Box::new(val)
532 }
533 "_bool" => {
534 let val: Option<Vec<bool>> = row.get(idx);
535 Box::new(val)
536 }
537 "_uuid" => {
538 let val: Option<Vec<uuid::Uuid>> = row.get(idx);
539 Box::new(val)
540 }
541 "_bytea" => {
542 let val: Option<Vec<Vec<u8>>> = row.get(idx);
543 Box::new(val)
544 }
545 "_numeric" => {
546 let val: Option<Vec<Decimal>> = row.get(idx);
547 Box::new(val)
548 }
549 "_jsonb" | "_json" => {
550 let val: Option<Vec<serde_json::Value>> = row.get(idx);
551 Box::new(val)
552 }
553 "_timestamp" => {
554 let val: Option<Vec<chrono::NaiveDateTime>> = row.get(idx);
555 Box::new(val)
556 }
557 "_timestamptz" => {
558 let val: Option<Vec<chrono::DateTime<chrono::Utc>>> = row.get(idx);
559 Box::new(val)
560 }
561 "_date" => {
562 let val: Option<Vec<chrono::NaiveDate>> = row.get(idx);
563 Box::new(val)
564 }
565 _ => {
566 let val: Option<String> = row.try_get::<_, String>(idx).ok();
568 Box::new(val)
569 }
570 }
571 })
572 .collect()
573}
574
575#[cfg(test)]
576mod tests {
577 use super::*;
578
579 #[test]
580 fn test_build_upsert_query_single_row() {
581 let query = build_upsert_query(
582 "public",
583 "users",
584 &["id".to_string()],
585 &["id".to_string(), "name".to_string(), "email".to_string()],
586 1,
587 );
588
589 assert!(query.contains("INSERT INTO \"public\".\"users\""));
590 assert!(query.contains("(\"id\", \"name\", \"email\")"));
591 assert!(query.contains("VALUES ($1, $2, $3)"));
592 assert!(query.contains("ON CONFLICT (\"id\")"));
593 assert!(query.contains("DO UPDATE SET"));
594 assert!(query.contains("\"name\" = EXCLUDED.\"name\""));
595 assert!(query.contains("\"email\" = EXCLUDED.\"email\""));
596 }
597
598 #[test]
599 fn test_build_upsert_query_multiple_rows() {
600 let query = build_upsert_query(
601 "public",
602 "users",
603 &["id".to_string()],
604 &["id".to_string(), "name".to_string()],
605 3,
606 );
607
608 assert!(query.contains("($1, $2), ($3, $4), ($5, $6)"));
609 }
610
611 #[test]
612 fn test_build_upsert_query_composite_pk() {
613 let query = build_upsert_query(
614 "public",
615 "order_items",
616 &["order_id".to_string(), "item_id".to_string()],
617 &[
618 "order_id".to_string(),
619 "item_id".to_string(),
620 "quantity".to_string(),
621 ],
622 1,
623 );
624
625 assert!(query.contains("ON CONFLICT (\"order_id\", \"item_id\")"));
626 assert!(query.contains("\"quantity\" = EXCLUDED.\"quantity\""));
627 }
628
629 #[test]
630 fn test_build_upsert_query_all_pk_columns() {
631 let query = build_upsert_query(
633 "public",
634 "tags",
635 &["id".to_string()],
636 &["id".to_string()],
637 1,
638 );
639
640 assert!(query.contains("DO NOTHING"));
641 assert!(!query.contains("DO UPDATE SET"));
642 }
643
644 #[test]
645 fn test_build_delete_query_single_pk() {
646 let query = build_delete_query("public", "users", &["id".to_string()], 3);
647
648 assert!(query.contains("DELETE FROM \"public\".\"users\""));
649 assert!(query.contains("WHERE \"id\" IN ($1, $2, $3)"));
650 }
651
652 #[test]
653 fn test_build_delete_query_composite_pk() {
654 let query = build_delete_query(
655 "public",
656 "order_items",
657 &["order_id".to_string(), "item_id".to_string()],
658 2,
659 );
660
661 assert!(query.contains("WHERE (\"order_id\", \"item_id\") IN"));
662 assert!(query.contains("($1, $2), ($3, $4)"));
663 }
664}