1use crate::connection::{
2 AsyncConnection, BulkInsert, ConnectOptions, ExecutionSummary, ForeignKey, QueryResult,
3 SchemaInfo, StatementResult,
4};
5use crate::error::SqlError;
6use crate::stream::BoxRowStream;
7use crate::url::DatabaseUrl;
8use crate::value::{ColumnInfo, Row, TypeHint, Value};
9use async_trait::async_trait;
10use bytes::Bytes;
11use futures_util::sink::SinkExt;
12use secrecy::ExposeSecret;
13use std::sync::Arc;
14use tokio_postgres::types::Type;
15
16pub struct PostgresConnection {
17 client: tokio_postgres::Client,
18}
19
20#[async_trait]
21impl AsyncConnection for PostgresConnection {
22 async fn execute(&mut self, sql: &str) -> Result<ExecutionSummary, SqlError> {
23 let rows_affected = self
24 .client
25 .execute(sql, &[])
26 .await
27 .map_err(|e| SqlError::QueryFailed(e.to_string()))?;
28 Ok(ExecutionSummary {
29 rows_affected: Some(rows_affected),
30 command_tag: None,
31 })
32 }
33
34 async fn query(&mut self, sql: &str) -> Result<QueryResult, SqlError> {
35 let rows = self
36 .client
37 .query(sql, &[])
38 .await
39 .map_err(|e| SqlError::QueryFailed(e.to_string()))?;
40 if rows.is_empty() {
41 return Ok(QueryResult {
42 columns: Vec::new(),
43 rows: Vec::new(),
44 });
45 }
46 let first = &rows[0];
47 let columns: Vec<ColumnInfo> = first
48 .columns()
49 .iter()
50 .map(|c| ColumnInfo {
51 name: c.name().to_string(),
52 type_hint: pg_type_to_hint(c.type_()),
53 nullable: true,
54 })
55 .collect();
56
57 let data_rows: Vec<Row> = rows
58 .iter()
59 .map(|row| {
60 (0..columns.len())
61 .map(|i| pg_to_value(row, i, row.columns()[i].type_()))
62 .collect()
63 })
64 .collect();
65
66 Ok(QueryResult {
67 columns,
68 rows: data_rows,
69 })
70 }
71
72 async fn query_stream(
82 &mut self,
83 sql: &str,
84 ) -> Result<(Vec<ColumnInfo>, BoxRowStream<'_>), SqlError> {
85 use futures_util::stream::TryStreamExt;
86 let statement = self
90 .client
91 .prepare(sql)
92 .await
93 .map_err(|e| SqlError::QueryFailed(e.to_string()))?;
94 let columns: Vec<ColumnInfo> = statement
95 .columns()
96 .iter()
97 .map(|c| ColumnInfo {
98 name: c.name().to_string(),
99 type_hint: pg_type_to_hint(c.type_()),
100 nullable: true,
101 })
102 .collect();
103 let ncols = columns.len();
104
105 let params: [&(dyn tokio_postgres::types::ToSql + Sync); 0] = [];
108 let row_stream = self
109 .client
110 .query_raw(&statement, params)
111 .await
112 .map_err(|e| SqlError::QueryFailed(e.to_string()))?;
113
114 let mapped = row_stream
115 .map_ok(move |row| {
116 (0..ncols)
117 .map(|i| pg_to_value(&row, i, row.columns()[i].type_()))
118 .collect::<Row>()
119 })
120 .map_err(|e| SqlError::QueryFailed(e.to_string()));
121 Ok((columns, Box::pin(mapped)))
122 }
123
124 async fn execute_multi(&mut self, sql: &str) -> Result<Vec<StatementResult>, SqlError> {
125 let msgs = self
126 .client
127 .simple_query(sql)
128 .await
129 .map_err(|e| SqlError::QueryFailed(e.to_string()))?;
130
131 let mut results = Vec::new();
132 let mut current_columns: Vec<ColumnInfo> = Vec::new();
133 let mut current_rows: Vec<Row> = Vec::new();
134
135 for msg in msgs {
136 use tokio_postgres::SimpleQueryMessage;
137 match msg {
138 SimpleQueryMessage::Row(row) => {
139 if current_columns.is_empty() {
140 current_columns = row
141 .columns()
142 .iter()
143 .map(|c| ColumnInfo {
144 name: c.name().to_string(),
145 type_hint: TypeHint::Other,
146 nullable: true,
147 })
148 .collect();
149 }
150 let values: Vec<Value> = (0..row.len())
151 .map(|i| match row.get(i) {
152 Some(s) => Value::String(s.to_string()),
153 None => Value::Null,
154 })
155 .collect();
156 current_rows.push(values);
157 }
158 SimpleQueryMessage::CommandComplete(n) => {
159 if !current_columns.is_empty() {
160 results.push(StatementResult::Query(QueryResult {
161 columns: std::mem::take(&mut current_columns),
162 rows: std::mem::take(&mut current_rows),
163 }));
164 } else {
165 results.push(StatementResult::Summary(ExecutionSummary {
166 rows_affected: Some(n),
167 command_tag: None,
168 }));
169 }
170 }
171 _ => {}
172 }
173 }
174
175 if !current_columns.is_empty() {
176 results.push(StatementResult::Query(QueryResult {
177 columns: std::mem::take(&mut current_columns),
178 rows: std::mem::take(&mut current_rows),
179 }));
180 }
181
182 Ok(results)
183 }
184
185 async fn ping(&mut self) -> Result<(), SqlError> {
186 self.client
187 .execute("SELECT 1", &[])
188 .await
189 .map_err(|e| SqlError::ConnectionFailed(e.to_string()))?;
190 Ok(())
191 }
192
193 async fn list_tables(&mut self, schema: Option<&str>) -> Result<Vec<String>, SqlError> {
194 let schema = schema.unwrap_or("public");
195 let rows = self
196 .client
197 .query(
198 "SELECT table_name FROM information_schema.tables WHERE table_schema = $1 AND table_type = 'BASE TABLE' ORDER BY table_name",
199 &[&schema,
200 ],
201 )
202 .await
203 .map_err(|e| SqlError::QueryFailed(e.to_string()))?;
204 let names = rows
205 .into_iter()
206 .map(|row| row.get::<_, String>(0))
207 .collect();
208 Ok(names)
209 }
210
211 async fn list_schemas(&mut self) -> Result<Vec<SchemaInfo>, SqlError> {
212 let rows = self
217 .client
218 .query(
219 "SELECT schema_name, schema_name = current_schema() AS is_default FROM information_schema.schemata ORDER BY schema_name",
220 &[],
221 )
222 .await
223 .map_err(|e| SqlError::QueryFailed(e.to_string()))?;
224 let schemas = rows
225 .into_iter()
226 .map(|row| SchemaInfo {
227 name: row.get::<_, String>(0),
228 is_default: row.try_get::<_, bool>(1).unwrap_or(false),
229 })
230 .collect();
231 Ok(schemas)
232 }
233
234 async fn describe_table(
235 &mut self,
236 schema: Option<&str>,
237 table: &str,
238 ) -> Result<QueryResult, SqlError> {
239 let schema = schema.unwrap_or("public");
240 let rows = self
241 .client
242 .query(
243 "SELECT column_name, data_type, is_nullable, column_default, numeric_precision, numeric_scale FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 ORDER BY ordinal_position",
244 &[&schema,
245 &table,
246 ],
247 )
248 .await
249 .map_err(|e| SqlError::QueryFailed(e.to_string()))?;
250
251 let columns = vec![
252 ColumnInfo {
253 name: "column_name".to_string(),
254 type_hint: TypeHint::String,
255 nullable: true,
256 },
257 ColumnInfo {
258 name: "data_type".to_string(),
259 type_hint: TypeHint::String,
260 nullable: true,
261 },
262 ColumnInfo {
263 name: "is_nullable".to_string(),
264 type_hint: TypeHint::String,
265 nullable: true,
266 },
267 ColumnInfo {
268 name: "column_default".to_string(),
269 type_hint: TypeHint::String,
270 nullable: true,
271 },
272 ColumnInfo {
273 name: "numeric_precision".to_string(),
274 type_hint: TypeHint::Int64,
275 nullable: true,
276 },
277 ColumnInfo {
278 name: "numeric_scale".to_string(),
279 type_hint: TypeHint::Int64,
280 nullable: true,
281 },
282 ];
283
284 let data_rows: Vec<Row> = rows
285 .iter()
286 .map(|row| {
287 vec![
288 row.try_get::<_, Option<String>>("column_name")
289 .unwrap_or(None)
290 .map(Value::String)
291 .unwrap_or(Value::Null),
292 row.try_get::<_, Option<String>>("data_type")
293 .unwrap_or(None)
294 .map(Value::String)
295 .unwrap_or(Value::Null),
296 row.try_get::<_, Option<String>>("is_nullable")
297 .unwrap_or(None)
298 .map(Value::String)
299 .unwrap_or(Value::Null),
300 row.try_get::<_, Option<String>>("column_default")
301 .unwrap_or(None)
302 .map(Value::String)
303 .unwrap_or(Value::Null),
304 row.try_get::<_, Option<i32>>("numeric_precision")
305 .unwrap_or(None)
306 .map(|v| Value::Int64(i64::from(v)))
307 .unwrap_or(Value::Null),
308 row.try_get::<_, Option<i32>>("numeric_scale")
309 .unwrap_or(None)
310 .map(|v| Value::Int64(i64::from(v)))
311 .unwrap_or(Value::Null),
312 ]
313 })
314 .collect();
315
316 Ok(QueryResult {
317 columns,
318 rows: data_rows,
319 })
320 }
321
322 async fn primary_key(
323 &mut self,
324 schema: Option<&str>,
325 table: &str,
326 ) -> Result<Vec<String>, SqlError> {
327 let schema = schema.unwrap_or("public");
328 let sql = "SELECT a.attname \
331 FROM pg_index i \
332 JOIN pg_class c ON c.oid = i.indrelid \
333 JOIN pg_namespace n ON n.oid = c.relnamespace \
334 JOIN unnest(i.indkey) WITH ORDINALITY AS k(attnum, ord) ON true \
335 JOIN pg_attribute a ON a.attrelid = c.oid AND a.attnum = k.attnum \
336 WHERE i.indisprimary AND n.nspname = $1 AND c.relname = $2 \
337 ORDER BY k.ord";
338 let rows = self
339 .client
340 .query(sql, &[&schema, &table])
341 .await
342 .map_err(|e| SqlError::QueryFailed(e.to_string()))?;
343 Ok(rows.into_iter().map(|r| r.get::<_, String>(0)).collect())
344 }
345
346 async fn list_foreign_keys(
347 &mut self,
348 schema: Option<&str>,
349 ) -> Result<Vec<ForeignKey>, SqlError> {
350 let schema = schema.unwrap_or("public");
351 let sql = "SELECT c.conname, \
354 cl_child.relname AS child_table, \
355 a_child.attname AS child_col, \
356 cl_parent.relname AS parent_table, \
357 a_parent.attname AS parent_col, \
358 c.confdeltype, \
359 k.ord \
360 FROM pg_constraint c \
361 JOIN pg_class cl_child ON cl_child.oid = c.conrelid \
362 JOIN pg_namespace n_child ON n_child.oid = cl_child.relnamespace \
363 JOIN pg_class cl_parent ON cl_parent.oid = c.confrelid \
364 JOIN pg_namespace n_parent ON n_parent.oid = cl_parent.relnamespace \
365 JOIN unnest(c.conkey) WITH ORDINALITY AS k(attnum, ord) ON true \
366 JOIN pg_attribute a_child ON a_child.attrelid = cl_child.oid AND a_child.attnum = k.attnum \
367 JOIN unnest(c.confkey) WITH ORDINALITY AS kp(attnum, ord) ON kp.ord = k.ord \
368 JOIN pg_attribute a_parent ON a_parent.attrelid = cl_parent.oid AND a_parent.attnum = kp.attnum \
369 WHERE c.contype = 'f' AND n_child.nspname = $1 \
370 ORDER BY c.conname, k.ord";
371 let rows = self
372 .client
373 .query(sql, &[&schema])
374 .await
375 .map_err(|e| SqlError::QueryFailed(e.to_string()))?;
376 let mut map: indexmap::IndexMap<String, ForeignKey> = indexmap::IndexMap::new();
377 for row in rows {
378 let conname: String = row.get(0);
379 let child_table: String = row.get(1);
380 let child_col: String = row.get(2);
381 let parent_table: String = row.get(3);
382 let parent_col: String = row.get(4);
383 let confdeltype: i8 = row.get(5);
384 let on_delete = pg_confdeltype(confdeltype);
385 let entry = map.entry(conname).or_insert_with(|| ForeignKey {
386 child_table: child_table.clone(),
387 child_columns: Vec::new(),
388 parent_table: parent_table.clone(),
389 parent_columns: Vec::new(),
390 on_delete,
391 });
392 entry.child_columns.push(child_col);
393 entry.parent_columns.push(parent_col);
394 }
395 Ok(map.into_values().collect())
396 }
397
398 async fn bulk_insert_rows(&mut self, target: BulkInsert<'_>) -> Result<usize, SqlError> {
399 if target.rows.is_empty() {
400 return Ok(0);
401 }
402 let table = crate::copy::quote_identifier(target.table, crate::backend::Backend::Postgres);
407 let cols = target
408 .columns
409 .iter()
410 .map(|c| crate::copy::quote_identifier(&c.name, crate::backend::Backend::Postgres))
411 .collect::<Vec<_>>()
412 .join(", ");
413 match target.copy_format {
414 crate::copy::CopyFormat::Text => {
415 let stmt = format!("COPY {table} ({cols}) FROM STDIN WITH (FORMAT TEXT)");
416 let sink = self
417 .client
418 .copy_in::<_, Bytes>(stmt.as_str())
419 .await
420 .map_err(|e| pg_text_copy::classify_copy_error(&e))?;
421 tokio::pin!(sink);
422
423 let hints: Vec<TypeHint> = target.columns.iter().map(|c| c.type_hint).collect();
428 for row in target.rows {
429 let buf = pg_text_copy::encode_row(row, &hints)?;
430 sink.send(buf)
431 .await
432 .map_err(|e| SqlError::QueryFailed(format!("COPY send: {e}")))?;
433 }
434
435 let rows = sink
436 .as_mut()
437 .finish()
438 .await
439 .map_err(|e| SqlError::QueryFailed(format!("COPY finish: {e}")))?;
440 Ok(rows as usize)
441 }
442 crate::copy::CopyFormat::Binary => {
443 pg_binary_copy::run(&mut self.client, &table, &cols, &target).await
444 }
445 }
446 }
447}
448
449mod pg_text_copy {
466 use crate::error::SqlError;
467 use crate::value::{TypeHint, Value};
468 use bytes::Bytes;
469
470 pub fn encode_row(row: &[Value], hints: &[TypeHint]) -> Result<Bytes, SqlError> {
476 let mut buf = String::with_capacity(row.len() * 12 + 1);
478 for (i, value) in row.iter().enumerate() {
479 if i > 0 {
480 buf.push('\t');
481 }
482 let hint = hints.get(i).copied().unwrap_or(TypeHint::Other);
483 encode_value(&mut buf, value, hint)?;
484 }
485 buf.push('\n');
486 Ok(Bytes::from(buf.into_bytes()))
487 }
488
489 fn encode_value(out: &mut String, v: &Value, hint: TypeHint) -> Result<(), SqlError> {
490 match v {
491 Value::Null => out.push_str("\\N"),
492 Value::Bool(b) => out.push(if *b { 't' } else { 'f' }),
493 Value::Int64(n) => {
494 use std::fmt::Write;
495 let _ = write!(out, "{n}");
496 }
497 Value::Float64(f) => {
498 if f.is_nan() {
499 out.push_str("NaN");
500 } else if f.is_infinite() {
501 out.push_str(if *f > 0.0 { "Infinity" } else { "-Infinity" });
502 } else {
503 use std::fmt::Write;
504 let _ = write!(out, "{f}");
505 }
506 }
507 Value::Decimal(s) => push_escaped(out, s),
508 Value::String(s) => push_escaped(out, s),
509 Value::Bytes(b) => {
510 out.push_str("\\\\x");
511 use std::fmt::Write;
512 for byte in b {
513 let _ = write!(out, "{byte:02x}");
514 }
515 }
516 Value::Date(d) => {
517 use std::fmt::Write;
518 let _ = write!(out, "{d}");
519 }
520 Value::Time(t) => {
521 use std::fmt::Write;
522 let _ = write!(out, "{t}");
523 }
524 Value::DateTime(dt) => {
525 use std::fmt::Write;
529 let _ = write!(out, "{dt}");
530 }
531 Value::DateTimeTz(dt) => {
532 out.push_str(&dt.to_rfc3339());
534 }
535 Value::Json(j) => {
536 let rendered = serde_json::to_string(j)
537 .map_err(|e| SqlError::QueryFailed(format!("PG bulk: JSON serialize: {e}")))?;
538 push_escaped(out, &rendered);
539 }
540 Value::Uuid(s) => push_escaped(out, s),
541 Value::Array(a) => {
542 let _ = hint; let rendered = serde_json::to_string(a)
548 .map_err(|e| SqlError::QueryFailed(format!("PG bulk: array serialize: {e}")))?;
549 push_escaped(out, &rendered);
550 }
551 }
552 Ok(())
553 }
554
555 fn push_escaped(out: &mut String, s: &str) {
558 for ch in s.chars() {
559 match ch {
560 '\\' => out.push_str("\\\\"),
561 '\t' => out.push_str("\\t"),
562 '\n' => out.push_str("\\n"),
563 '\r' => out.push_str("\\r"),
564 '\0' => {
565 out.push_str("\\x00");
569 }
570 other => out.push(other),
571 }
572 }
573 }
574
575 pub fn classify_copy_error(e: &tokio_postgres::Error) -> SqlError {
591 use tokio_postgres::error::SqlState;
592 if let Some(code) = e.code()
593 && *code == SqlState::WRONG_OBJECT_TYPE
594 {
595 return SqlError::BulkUnavailable(format!("PG rejected COPY: {e}"));
596 }
597 SqlError::QueryFailed(format!("COPY setup: {e}"))
598 }
599
600 #[cfg(test)]
601 mod tests {
602 use super::*;
603 use chrono::{NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc};
604
605 fn enc1(v: Value, hint: TypeHint) -> String {
606 let bytes = encode_row(&[v], &[hint]).expect("encode_row");
607 let s = std::str::from_utf8(&bytes).unwrap().to_string();
609 assert!(s.ends_with('\n'));
610 s.trim_end_matches('\n').to_string()
611 }
612
613 #[test]
614 fn encode_null_is_backslash_n() {
615 assert_eq!(enc1(Value::Null, TypeHint::Null), "\\N");
616 }
617
618 #[test]
619 fn encode_bool_is_t_or_f() {
620 assert_eq!(enc1(Value::Bool(true), TypeHint::Bool), "t");
621 assert_eq!(enc1(Value::Bool(false), TypeHint::Bool), "f");
622 }
623
624 #[test]
625 fn encode_int_and_float() {
626 assert_eq!(enc1(Value::Int64(42), TypeHint::Int64), "42");
627 assert_eq!(enc1(Value::Int64(-7), TypeHint::Int64), "-7");
628 assert_eq!(enc1(Value::Float64(1.5), TypeHint::Float64), "1.5");
629 }
630
631 #[test]
632 fn encode_float_nan_and_inf() {
633 assert_eq!(enc1(Value::Float64(f64::NAN), TypeHint::Float64), "NaN");
634 assert_eq!(
635 enc1(Value::Float64(f64::INFINITY), TypeHint::Float64),
636 "Infinity"
637 );
638 assert_eq!(
639 enc1(Value::Float64(f64::NEG_INFINITY), TypeHint::Float64),
640 "-Infinity"
641 );
642 }
643
644 #[test]
645 fn encode_string_escapes_backslash_first() {
646 assert_eq!(
651 enc1(Value::String("\\.\n".into()), TypeHint::String),
652 "\\\\.\\n"
653 );
654 }
655
656 #[test]
657 fn encode_string_escapes_tab_cr_lf() {
658 assert_eq!(
659 enc1(Value::String("a\tb\nc\rd".into()), TypeHint::String),
660 "a\\tb\\nc\\rd"
661 );
662 }
663
664 #[test]
665 fn encode_string_passes_through_normal_chars() {
666 assert_eq!(
667 enc1(Value::String("héllo, world 🐈".into()), TypeHint::String),
668 "héllo, world 🐈"
669 );
670 }
671
672 #[test]
673 fn encode_string_replaces_nul() {
674 assert_eq!(
678 enc1(Value::String("a\0b".into()), TypeHint::String),
679 "a\\x00b"
680 );
681 }
682
683 #[test]
684 fn encode_bytes_is_hex_with_double_backslash_x() {
685 assert_eq!(
690 enc1(Value::Bytes(vec![0xDE, 0xAD, 0xBE, 0xEF]), TypeHint::Bytes),
691 "\\\\xdeadbeef"
692 );
693 }
694
695 #[test]
696 fn encode_date_time_datetime() {
697 let d = NaiveDate::from_ymd_opt(2026, 5, 14).unwrap();
698 let t = NaiveTime::from_hms_opt(12, 34, 56).unwrap();
699 let dt = NaiveDateTime::new(d, t);
700 assert_eq!(enc1(Value::Date(d), TypeHint::Date), "2026-05-14");
701 assert_eq!(enc1(Value::Time(t), TypeHint::Time), "12:34:56");
702 assert_eq!(
703 enc1(Value::DateTime(dt), TypeHint::DateTime),
704 "2026-05-14 12:34:56"
705 );
706 }
707
708 #[test]
709 fn encode_datetimetz_is_rfc3339() {
710 let dt = Utc.with_ymd_and_hms(2026, 5, 14, 12, 34, 56).unwrap();
711 assert_eq!(
712 enc1(Value::DateTimeTz(dt), TypeHint::DateTimeTz),
713 "2026-05-14T12:34:56+00:00"
714 );
715 }
716
717 #[test]
718 fn encode_json_is_compact_with_escapes() {
719 let j = serde_json::json!({"role": "admin", "active": true});
720 let encoded = enc1(Value::Json(j), TypeHint::Json);
722 assert!(encoded.contains("\"role\":\"admin\""));
726 assert!(encoded.contains("\"active\":true"));
727 }
728
729 #[test]
730 fn encode_uuid_passes_through() {
731 assert_eq!(
732 enc1(
733 Value::Uuid("550e8400-e29b-41d4-a716-446655440000".into()),
734 TypeHint::Uuid
735 ),
736 "550e8400-e29b-41d4-a716-446655440000"
737 );
738 }
739
740 #[test]
741 fn encode_array_is_compact_json() {
742 let a = Value::Array(vec![Value::Int64(1), Value::Int64(2), Value::Int64(3)]);
743 assert_eq!(enc1(a, TypeHint::Array), "[1,2,3]");
744 }
745
746 #[test]
747 fn encode_decimal_passes_through_with_escapes() {
748 assert_eq!(
749 enc1(Value::Decimal("99.5".into()), TypeHint::Decimal),
750 "99.5"
751 );
752 }
753
754 #[test]
755 fn encode_row_with_multiple_cells_uses_tab_separator() {
756 let row = vec![
757 Value::Int64(1),
758 Value::String("Alice".into()),
759 Value::Null,
760 Value::Bool(true),
761 ];
762 let hints = vec![
763 TypeHint::Int64,
764 TypeHint::String,
765 TypeHint::Null,
766 TypeHint::Bool,
767 ];
768 let bytes = encode_row(&row, &hints).unwrap();
769 assert_eq!(std::str::from_utf8(&bytes).unwrap(), "1\tAlice\t\\N\tt\n");
770 }
771
772 #[test]
773 fn encode_row_empty_row_is_just_newline() {
774 let bytes = encode_row(&[], &[]).unwrap();
779 assert_eq!(std::str::from_utf8(&bytes).unwrap(), "\n");
780 }
781 }
782}
783
784mod pg_binary_copy {
800 use super::pg_text_copy;
801 use crate::connection::BulkInsert;
802 use crate::error::SqlError;
803 use crate::value::{TypeHint, Value};
804 use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
805 use rust_decimal::Decimal;
806 use std::str::FromStr;
807 use tokio_postgres::Client;
808 use tokio_postgres::binary_copy::BinaryCopyInWriter;
809 use tokio_postgres::types::{ToSql, Type};
810 use uuid::Uuid;
811
812 pub async fn run(
816 client: &mut Client,
817 table: &str,
818 cols: &str,
819 target: &BulkInsert<'_>,
820 ) -> Result<usize, SqlError> {
821 let types: Vec<Type> = target
822 .columns
823 .iter()
824 .map(|c| pg_type_for_hint(c.type_hint))
825 .collect::<Result<_, _>>()?;
826
827 let stmt = format!("COPY {table} ({cols}) FROM STDIN WITH (FORMAT BINARY)");
828 let sink = client
829 .copy_in::<_, bytes::Bytes>(stmt.as_str())
830 .await
831 .map_err(|e| pg_text_copy::classify_copy_error(&e))?;
832 let writer = BinaryCopyInWriter::new(sink, &types);
833 tokio::pin!(writer);
834
835 let hints: Vec<TypeHint> = target.columns.iter().map(|c| c.type_hint).collect();
836 for row in target.rows {
837 let cells: Vec<PgBinaryBind> = row
841 .iter()
842 .zip(hints.iter())
843 .map(|(v, h)| value_to_pg_binary_bind(v, *h))
844 .collect::<Result<_, _>>()?;
845 let refs: Vec<&(dyn ToSql + Sync)> =
846 cells.iter().map(PgBinaryBind::as_to_sql).collect();
847 writer
848 .as_mut()
849 .write(&refs)
850 .await
851 .map_err(|e| SqlError::QueryFailed(format!("BINARY COPY write: {e}")))?;
852 }
853
854 let rows = writer
855 .as_mut()
856 .finish()
857 .await
858 .map_err(|e| SqlError::QueryFailed(format!("BINARY COPY finish: {e}")))?;
859 Ok(rows as usize)
860 }
861
862 pub(super) fn pg_type_for_hint(hint: TypeHint) -> Result<Type, SqlError> {
866 Ok(match hint {
867 TypeHint::Bool => Type::BOOL,
868 TypeHint::Int64 => Type::INT8,
869 TypeHint::Float64 => Type::FLOAT8,
870 TypeHint::Decimal => Type::NUMERIC,
871 TypeHint::String => Type::TEXT,
872 TypeHint::Bytes => Type::BYTEA,
873 TypeHint::Date => Type::DATE,
874 TypeHint::Time => Type::TIME,
875 TypeHint::DateTime => Type::TIMESTAMP,
876 TypeHint::DateTimeTz => Type::TIMESTAMPTZ,
877 TypeHint::Json => Type::JSONB,
878 TypeHint::Uuid => Type::UUID,
879 TypeHint::Array => Type::JSONB,
880 TypeHint::Null | TypeHint::Other => {
881 return Err(SqlError::BulkUnavailable(format!(
882 "PG binary COPY: cannot bind a column with TypeHint::{hint:?} \
883 (no concrete PG type to declare). Re-run with \
884 --copy-format text or --bulk-native off."
885 )));
886 }
887 })
888 }
889
890 #[derive(Debug)]
894 pub(super) enum PgBinaryBind {
895 Bool(Option<bool>),
896 Int8(Option<i64>),
897 Float8(Option<f64>),
898 Numeric(Option<Decimal>),
899 Text(Option<String>),
900 Bytea(Option<Vec<u8>>),
901 Date(Option<NaiveDate>),
902 Time(Option<NaiveTime>),
903 Timestamp(Option<NaiveDateTime>),
904 TimestampTz(Option<DateTime<Utc>>),
905 Json(Option<serde_json::Value>),
906 Uuid(Option<Uuid>),
907 }
908
909 impl PgBinaryBind {
910 pub(super) fn as_to_sql(&self) -> &(dyn ToSql + Sync) {
911 match self {
912 Self::Bool(v) => v,
913 Self::Int8(v) => v,
914 Self::Float8(v) => v,
915 Self::Numeric(v) => v,
916 Self::Text(v) => v,
917 Self::Bytea(v) => v,
918 Self::Date(v) => v,
919 Self::Time(v) => v,
920 Self::Timestamp(v) => v,
921 Self::TimestampTz(v) => v,
922 Self::Json(v) => v,
923 Self::Uuid(v) => v,
924 }
925 }
926 }
927
928 pub(super) fn value_to_pg_binary_bind(
933 v: &Value,
934 hint: TypeHint,
935 ) -> Result<PgBinaryBind, SqlError> {
936 Ok(match (v, hint) {
937 (Value::Null, _) => null_bind_for_hint(hint)?,
938 (Value::Bool(b), _) => PgBinaryBind::Bool(Some(*b)),
939 (Value::Int64(n), _) => PgBinaryBind::Int8(Some(*n)),
940 (Value::Float64(f), _) => PgBinaryBind::Float8(Some(*f)),
941 (Value::Decimal(s), _) => PgBinaryBind::Numeric(Some(parse_decimal(s)?)),
942 (Value::String(s), TypeHint::Uuid) => {
943 PgBinaryBind::Uuid(Some(Uuid::parse_str(s).map_err(|e| {
944 SqlError::QueryFailed(format!("PG binary COPY: bad UUID '{s}': {e}"))
945 })?))
946 }
947 (Value::String(s), _) => PgBinaryBind::Text(Some(s.clone())),
948 (Value::Bytes(b), _) => PgBinaryBind::Bytea(Some(b.clone())),
949 (Value::Date(d), _) => PgBinaryBind::Date(Some(*d)),
950 (Value::Time(t), _) => PgBinaryBind::Time(Some(*t)),
951 (Value::DateTime(dt), _) => PgBinaryBind::Timestamp(Some(*dt)),
952 (Value::DateTimeTz(dt), _) => PgBinaryBind::TimestampTz(Some(*dt)),
953 (Value::Json(j), _) => PgBinaryBind::Json(Some(j.clone())),
954 (Value::Array(arr), _) => {
955 let json = serde_json::to_value(arr).map_err(|e| {
960 SqlError::QueryFailed(format!("PG binary COPY: array serialize: {e}"))
961 })?;
962 PgBinaryBind::Json(Some(json))
963 }
964 (Value::Uuid(s), _) => PgBinaryBind::Uuid(Some(Uuid::parse_str(s).map_err(|e| {
965 SqlError::QueryFailed(format!("PG binary COPY: bad UUID '{s}': {e}"))
966 })?)),
967 })
968 }
969
970 fn null_bind_for_hint(hint: TypeHint) -> Result<PgBinaryBind, SqlError> {
971 Ok(match hint {
972 TypeHint::Bool => PgBinaryBind::Bool(None),
973 TypeHint::Int64 => PgBinaryBind::Int8(None),
974 TypeHint::Float64 => PgBinaryBind::Float8(None),
975 TypeHint::Decimal => PgBinaryBind::Numeric(None),
976 TypeHint::String => PgBinaryBind::Text(None),
977 TypeHint::Bytes => PgBinaryBind::Bytea(None),
978 TypeHint::Date => PgBinaryBind::Date(None),
979 TypeHint::Time => PgBinaryBind::Time(None),
980 TypeHint::DateTime => PgBinaryBind::Timestamp(None),
981 TypeHint::DateTimeTz => PgBinaryBind::TimestampTz(None),
982 TypeHint::Json | TypeHint::Array => PgBinaryBind::Json(None),
983 TypeHint::Uuid => PgBinaryBind::Uuid(None),
984 TypeHint::Null | TypeHint::Other => {
985 return Err(SqlError::BulkUnavailable(format!(
986 "PG binary COPY: cannot type-encode NULL for TypeHint::{hint:?}"
987 )));
988 }
989 })
990 }
991
992 fn parse_decimal(s: &str) -> Result<Decimal, SqlError> {
993 Decimal::from_str(s).map_err(|e| {
994 SqlError::QueryFailed(format!("PG binary COPY: invalid NUMERIC '{s}': {e}"))
995 })
996 }
997
998 #[cfg(test)]
999 mod tests {
1000 use super::*;
1001
1002 #[test]
1003 fn pg_type_for_hint_maps_canonical_dest_types() {
1004 assert_eq!(pg_type_for_hint(TypeHint::Bool).unwrap(), Type::BOOL);
1005 assert_eq!(pg_type_for_hint(TypeHint::Int64).unwrap(), Type::INT8);
1006 assert_eq!(pg_type_for_hint(TypeHint::Float64).unwrap(), Type::FLOAT8);
1007 assert_eq!(pg_type_for_hint(TypeHint::Decimal).unwrap(), Type::NUMERIC);
1008 assert_eq!(pg_type_for_hint(TypeHint::String).unwrap(), Type::TEXT);
1009 assert_eq!(pg_type_for_hint(TypeHint::Bytes).unwrap(), Type::BYTEA);
1010 assert_eq!(pg_type_for_hint(TypeHint::Date).unwrap(), Type::DATE);
1011 assert_eq!(pg_type_for_hint(TypeHint::Time).unwrap(), Type::TIME);
1012 assert_eq!(
1013 pg_type_for_hint(TypeHint::DateTime).unwrap(),
1014 Type::TIMESTAMP
1015 );
1016 assert_eq!(
1017 pg_type_for_hint(TypeHint::DateTimeTz).unwrap(),
1018 Type::TIMESTAMPTZ
1019 );
1020 assert_eq!(pg_type_for_hint(TypeHint::Json).unwrap(), Type::JSONB);
1021 assert_eq!(pg_type_for_hint(TypeHint::Uuid).unwrap(), Type::UUID);
1022 assert_eq!(pg_type_for_hint(TypeHint::Array).unwrap(), Type::JSONB);
1023 }
1024
1025 #[test]
1026 fn pg_type_for_hint_other_falls_back_via_bulk_unavailable() {
1027 let err = pg_type_for_hint(TypeHint::Other).unwrap_err();
1028 assert!(matches!(err, SqlError::BulkUnavailable(_)));
1029 let err = pg_type_for_hint(TypeHint::Null).unwrap_err();
1030 assert!(matches!(err, SqlError::BulkUnavailable(_)));
1031 }
1032
1033 #[test]
1034 fn null_bind_picks_typed_none_per_hint() {
1035 assert!(matches!(
1036 null_bind_for_hint(TypeHint::Bool).unwrap(),
1037 PgBinaryBind::Bool(None)
1038 ));
1039 assert!(matches!(
1040 null_bind_for_hint(TypeHint::Int64).unwrap(),
1041 PgBinaryBind::Int8(None)
1042 ));
1043 assert!(matches!(
1044 null_bind_for_hint(TypeHint::Json).unwrap(),
1045 PgBinaryBind::Json(None)
1046 ));
1047 assert!(matches!(
1048 null_bind_for_hint(TypeHint::Uuid).unwrap(),
1049 PgBinaryBind::Uuid(None)
1050 ));
1051 }
1052
1053 #[test]
1054 fn null_bind_array_collapses_to_json_none() {
1055 assert!(matches!(
1057 null_bind_for_hint(TypeHint::Array).unwrap(),
1058 PgBinaryBind::Json(None)
1059 ));
1060 }
1061
1062 #[test]
1063 fn value_to_bind_routes_int_bool_string_null() {
1064 assert!(matches!(
1065 value_to_pg_binary_bind(&Value::Int64(42), TypeHint::Int64).unwrap(),
1066 PgBinaryBind::Int8(Some(42))
1067 ));
1068 assert!(matches!(
1069 value_to_pg_binary_bind(&Value::Bool(true), TypeHint::Bool).unwrap(),
1070 PgBinaryBind::Bool(Some(true))
1071 ));
1072 match value_to_pg_binary_bind(&Value::String("hi".into()), TypeHint::String).unwrap() {
1073 PgBinaryBind::Text(Some(s)) => assert_eq!(s, "hi"),
1074 _ => panic!("expected Text"),
1075 }
1076 assert!(matches!(
1077 value_to_pg_binary_bind(&Value::Null, TypeHint::Int64).unwrap(),
1078 PgBinaryBind::Int8(None)
1079 ));
1080 }
1081
1082 #[test]
1083 fn value_to_bind_decimal_roundtrips_through_rust_decimal() {
1084 match value_to_pg_binary_bind(&Value::Decimal("99.5".into()), TypeHint::Decimal)
1085 .unwrap()
1086 {
1087 PgBinaryBind::Numeric(Some(d)) => assert_eq!(d.to_string(), "99.5"),
1088 _ => panic!("expected Numeric"),
1089 }
1090 let err =
1094 value_to_pg_binary_bind(&Value::Decimal("not-a-number".into()), TypeHint::Decimal)
1095 .unwrap_err();
1096 assert!(matches!(err, SqlError::QueryFailed(_)));
1097 }
1098
1099 #[test]
1100 fn value_to_bind_string_to_uuid_when_dest_is_uuid() {
1101 let bind = value_to_pg_binary_bind(
1102 &Value::String("00112233-4455-6677-8899-aabbccddeeff".into()),
1103 TypeHint::Uuid,
1104 )
1105 .unwrap();
1106 match bind {
1107 PgBinaryBind::Uuid(Some(u)) => {
1108 assert_eq!(u.to_string(), "00112233-4455-6677-8899-aabbccddeeff")
1109 }
1110 _ => panic!("expected Uuid"),
1111 }
1112 }
1113
1114 #[test]
1115 fn value_to_bind_array_collapses_to_json() {
1116 let arr = vec![Value::String("a".into()), Value::String("b".into())];
1117 let bind = value_to_pg_binary_bind(&Value::Array(arr), TypeHint::Array).unwrap();
1118 match bind {
1119 PgBinaryBind::Json(Some(v)) => {
1120 assert_eq!(v, serde_json::json!(["a", "b"]));
1121 }
1122 _ => panic!("expected Json"),
1123 }
1124 }
1125 }
1126}
1127
1128pub(crate) async fn connect(
1129 url: &DatabaseUrl,
1130 opts: &ConnectOptions,
1131) -> Result<PostgresConnection, SqlError> {
1132 let mut config = match url.raw().parse::<tokio_postgres::Config>() {
1133 Ok(cfg) => cfg,
1134 Err(_) => build_config_from_url(url)?,
1135 };
1136 if let Some(pwd) = opts.effective_password(url) {
1138 config.password(pwd.expose_secret());
1139 }
1140
1141 let tls_connector = build_tls_connector(opts)
1142 .await
1143 .map_err(SqlError::TlsError)?;
1144
1145 let (client, connection) = config
1146 .connect(tls_connector)
1147 .await
1148 .map_err(|e| SqlError::ConnectionFailed(e.to_string()))?;
1149
1150 tokio::spawn(async move {
1151 if let Err(e) = connection.await {
1152 eprintln!("[ferrule] Postgres background connection error: {}", e);
1153 }
1154 });
1155
1156 Ok(PostgresConnection { client })
1157}
1158
1159pub(crate) async fn connect_with_stream<S>(
1169 url: &DatabaseUrl,
1170 opts: &ConnectOptions,
1171 stream: S,
1172) -> Result<PostgresConnection, SqlError>
1173where
1174 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
1175{
1176 use tokio_postgres::tls::MakeTlsConnect;
1177
1178 let mut config = match url.raw().parse::<tokio_postgres::Config>() {
1179 Ok(cfg) => cfg,
1180 Err(_) => build_config_from_url(url)?,
1181 };
1182 if let Some(pwd) = opts.effective_password(url) {
1184 config.password(pwd.expose_secret());
1185 }
1186
1187 let mut make_tls = build_tls_connector(opts)
1188 .await
1189 .map_err(SqlError::TlsError)?;
1190 let hostname = url.host().unwrap_or("localhost");
1191 let tls = <tokio_postgres_rustls::MakeRustlsConnect as MakeTlsConnect<S>>::make_tls_connect(
1192 &mut make_tls,
1193 hostname,
1194 )
1195 .map_err(|e| SqlError::TlsError(format!("make_tls_connect failed: {e:?}")))?;
1196
1197 let (client, connection) = config
1198 .connect_raw(stream, tls)
1199 .await
1200 .map_err(|e| SqlError::ConnectionFailed(e.to_string()))?;
1201
1202 tokio::spawn(async move {
1203 if let Err(e) = connection.await {
1204 eprintln!("[ferrule] Postgres background connection error: {}", e);
1205 }
1206 });
1207
1208 Ok(PostgresConnection { client })
1209}
1210
1211fn build_config_from_url(url: &DatabaseUrl) -> Result<tokio_postgres::Config, SqlError> {
1212 let mut config = tokio_postgres::Config::new();
1213 if let Some(host) = url.host() {
1214 config.host(host);
1215 } else {
1216 config.host("localhost");
1217 }
1218 config.port(url.port().unwrap_or(5432));
1219 if !url.username().is_empty() {
1220 config.user(url.username());
1221 }
1222 if let Some(pwd) = url.password() {
1223 config.password(pwd.expose_secret());
1224 }
1225 if !url.database().is_empty() {
1226 config.dbname(url.database());
1227 }
1228 Ok(config)
1229}
1230
1231async fn build_tls_connector(
1232 opts: &ConnectOptions,
1233) -> Result<tokio_postgres_rustls::MakeRustlsConnect, String> {
1234 use rustls::{ClientConfig, RootCertStore};
1235
1236 let mut root_store = RootCertStore::empty();
1237 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
1238
1239 let config = if opts.insecure {
1240 let verifier = Arc::new(InsecureVerifier);
1241 ClientConfig::builder()
1242 .dangerous()
1243 .with_custom_certificate_verifier(verifier)
1244 .with_no_client_auth()
1245 } else {
1246 ClientConfig::builder()
1247 .with_root_certificates(root_store)
1248 .with_no_client_auth()
1249 };
1250
1251 Ok(tokio_postgres_rustls::MakeRustlsConnect::new(config))
1252}
1253
1254#[derive(Debug)]
1257struct InsecureVerifier;
1258
1259impl rustls::client::danger::ServerCertVerifier for InsecureVerifier {
1260 fn verify_server_cert(
1261 &self,
1262 _end_entity: &rustls::pki_types::CertificateDer<'_>,
1263 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
1264 _server_name: &rustls::pki_types::ServerName<'_>,
1265 _ocsp_response: &[u8],
1266 _now: rustls::pki_types::UnixTime,
1267 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
1268 Ok(rustls::client::danger::ServerCertVerified::assertion())
1269 }
1270
1271 fn verify_tls12_signature(
1272 &self,
1273 _message: &[u8],
1274 _cert: &rustls::pki_types::CertificateDer<'_>,
1275 _dss: &rustls::DigitallySignedStruct,
1276 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
1277 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
1278 }
1279
1280 fn verify_tls13_signature(
1281 &self,
1282 _message: &[u8],
1283 _cert: &rustls::pki_types::CertificateDer<'_>,
1284 _dss: &rustls::DigitallySignedStruct,
1285 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
1286 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
1287 }
1288
1289 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
1290 vec![
1291 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
1292 rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
1293 rustls::SignatureScheme::RSA_PSS_SHA256,
1294 rustls::SignatureScheme::RSA_PSS_SHA384,
1295 rustls::SignatureScheme::RSA_PSS_SHA512,
1296 rustls::SignatureScheme::RSA_PKCS1_SHA256,
1297 rustls::SignatureScheme::RSA_PKCS1_SHA384,
1298 rustls::SignatureScheme::RSA_PKCS1_SHA512,
1299 rustls::SignatureScheme::ED25519,
1300 ]
1301 }
1302}
1303
1304fn pg_confdeltype(c: i8) -> Option<String> {
1305 match c as u8 {
1307 b'a' => Some("NO ACTION".into()),
1308 b'r' => Some("RESTRICT".into()),
1309 b'c' => Some("CASCADE".into()),
1310 b'n' => Some("SET NULL".into()),
1311 b'd' => Some("SET DEFAULT".into()),
1312 _ => None,
1313 }
1314}
1315
1316fn pg_type_to_hint(ty: &Type) -> TypeHint {
1317 match ty {
1318 &Type::BOOL => TypeHint::Bool,
1319 &Type::INT2 | &Type::INT4 | &Type::INT8 => TypeHint::Int64,
1320 &Type::FLOAT4 | &Type::FLOAT8 => TypeHint::Float64,
1321 &Type::NUMERIC => TypeHint::Decimal,
1322 &Type::TEXT | &Type::VARCHAR | &Type::BPCHAR | &Type::NAME => TypeHint::String,
1323 &Type::BYTEA => TypeHint::Bytes,
1324 &Type::DATE => TypeHint::Date,
1325 &Type::TIME => TypeHint::Time,
1326 &Type::TIMESTAMP => TypeHint::DateTime,
1327 &Type::TIMESTAMPTZ => TypeHint::DateTimeTz,
1328 &Type::JSON | &Type::JSONB => TypeHint::Json,
1329 &Type::UUID => TypeHint::Uuid,
1330 _ if ty.name().starts_with('_') => TypeHint::Array,
1331 _ => TypeHint::Other,
1332 }
1333}
1334
1335fn pg_to_value(row: &tokio_postgres::Row, col: usize, pg_type: &Type) -> Value {
1336 use tokio_postgres::types::Type;
1337
1338 match pg_type {
1340 &Type::BOOL => row
1341 .try_get::<_, Option<bool>>(col)
1342 .unwrap_or(None)
1343 .map(Value::Bool)
1344 .unwrap_or(Value::Null),
1345 &Type::INT2 => row
1346 .try_get::<_, Option<i16>>(col)
1347 .unwrap_or(None)
1348 .map(|v| Value::Int64(i64::from(v)))
1349 .unwrap_or(Value::Null),
1350 &Type::INT4 => row
1351 .try_get::<_, Option<i32>>(col)
1352 .unwrap_or(None)
1353 .map(|v| Value::Int64(i64::from(v)))
1354 .unwrap_or(Value::Null),
1355 &Type::INT8 => row
1356 .try_get::<_, Option<i64>>(col)
1357 .unwrap_or(None)
1358 .map(Value::Int64)
1359 .unwrap_or(Value::Null),
1360 &Type::FLOAT4 => row
1361 .try_get::<_, Option<f32>>(col)
1362 .unwrap_or(None)
1363 .map(|v| Value::Float64(f64::from(v)))
1364 .unwrap_or(Value::Null),
1365 &Type::FLOAT8 => row
1366 .try_get::<_, Option<f64>>(col)
1367 .unwrap_or(None)
1368 .map(Value::Float64)
1369 .unwrap_or(Value::Null),
1370 &Type::NUMERIC => row
1371 .try_get::<_, Option<rust_decimal::Decimal>>(col)
1372 .unwrap_or(None)
1373 .map(|d| Value::Decimal(d.to_string()))
1374 .unwrap_or(Value::Null),
1375 &Type::TEXT | &Type::VARCHAR | &Type::BPCHAR | &Type::NAME => row
1376 .try_get::<_, Option<String>>(col)
1377 .unwrap_or(None)
1378 .map(Value::String)
1379 .unwrap_or(Value::Null),
1380 &Type::BYTEA => row
1381 .try_get::<_, Option<Vec<u8>>>(col)
1382 .unwrap_or(None)
1383 .map(Value::Bytes)
1384 .unwrap_or(Value::Null),
1385 &Type::DATE => row
1386 .try_get::<_, Option<chrono::NaiveDate>>(col)
1387 .unwrap_or(None)
1388 .map(Value::Date)
1389 .unwrap_or(Value::Null),
1390 &Type::TIME => row
1391 .try_get::<_, Option<chrono::NaiveTime>>(col)
1392 .unwrap_or(None)
1393 .map(Value::Time)
1394 .unwrap_or(Value::Null),
1395 &Type::TIMESTAMP => row
1396 .try_get::<_, Option<chrono::NaiveDateTime>>(col)
1397 .unwrap_or(None)
1398 .map(Value::DateTime)
1399 .unwrap_or(Value::Null),
1400 &Type::TIMESTAMPTZ => row
1401 .try_get::<_, Option<chrono::DateTime<chrono::Utc>>>(col)
1402 .unwrap_or(None)
1403 .map(Value::DateTimeTz)
1404 .unwrap_or(Value::Null),
1405 &Type::JSON | &Type::JSONB => row
1406 .try_get::<_, Option<serde_json::Value>>(col)
1407 .unwrap_or(None)
1408 .map(Value::Json)
1409 .unwrap_or(Value::Null),
1410 &Type::UUID => row
1411 .try_get::<_, Option<uuid::Uuid>>(col)
1412 .unwrap_or(None)
1413 .map(|u| Value::Uuid(u.to_string()))
1414 .unwrap_or(Value::Null),
1415 _ => row
1417 .try_get::<_, Option<String>>(col)
1418 .unwrap_or(None)
1419 .map(Value::String)
1420 .unwrap_or(Value::Null),
1421 }
1422}
1423
1424#[cfg(test)]
1425mod tests {
1426 use super::*;
1427
1428 const TEST_POSTGRES_URL: &str =
1432 "postgres://ferrule:ferrule@127.0.0.1:15432/ferrule?sslmode=disable";
1433
1434 fn try_connect() -> Option<Box<dyn crate::Connection>> {
1435 let url = DatabaseUrl::parse(TEST_POSTGRES_URL).ok()?;
1436 let conn = crate::connect(&url, &ConnectOptions::default(), None).ok()?;
1437 Some(conn)
1438 }
1439
1440 #[test]
1441 fn test_postgres_ping() {
1442 let Some(mut conn) = try_connect() else {
1443 eprintln!("Postgres test container not available, skipping test_postgres_ping");
1444 return;
1445 };
1446 conn.ping().expect("ping should succeed");
1447 }
1448
1449 #[test]
1450 fn test_postgres_query() {
1451 let Some(mut conn) = try_connect() else {
1452 eprintln!("Postgres test container not available, skipping test_postgres_query");
1453 return;
1454 };
1455 let result = conn
1456 .query("SELECT * FROM test_users")
1457 .expect("query should succeed");
1458 assert!(!result.columns.is_empty(), "should have columns");
1459 assert!(!result.rows.is_empty(), "should have rows");
1460 }
1461
1462 #[test]
1463 fn test_postgres_execute() {
1464 let Some(mut conn) = try_connect() else {
1465 eprintln!("Postgres test container not available, skipping test_postgres_execute");
1466 return;
1467 };
1468 let summary = conn
1469 .execute("INSERT INTO test_users (name, age) VALUES ('TestUser', 99)")
1470 .expect("execute should succeed");
1471 assert!(
1472 summary.rows_affected.is_some_and(|n| n > 0),
1473 "should have affected rows"
1474 );
1475 }
1476
1477 #[test]
1478 fn test_postgres_list_tables() {
1479 let Some(mut conn) = try_connect() else {
1480 eprintln!("Postgres test container not available, skipping test_postgres_list_tables");
1481 return;
1482 };
1483 let tables = conn.list_tables(None).expect("list_tables should succeed");
1484 assert!(
1485 tables.contains(&"test_users".to_string()),
1486 "should contain test_users, got: {tables:?}"
1487 );
1488 }
1489
1490 #[test]
1491 fn test_postgres_list_schemas() {
1492 let Some(mut conn) = try_connect() else {
1493 eprintln!("Postgres test container not available, skipping test_postgres_list_schemas");
1494 return;
1495 };
1496 let schemas = conn.list_schemas().expect("list_schemas should succeed");
1497 assert!(
1498 schemas.iter().any(|s| s.name == "public"),
1499 "should contain public, got: {schemas:?}"
1500 );
1501 let defaults = schemas.iter().filter(|s| s.is_default).count();
1502 assert_eq!(
1503 defaults, 1,
1504 "exactly one schema should be flagged is_default, got: {schemas:?}"
1505 );
1506 }
1507
1508 #[test]
1509 fn test_postgres_describe_table() {
1510 let Some(mut conn) = try_connect() else {
1511 eprintln!(
1512 "Postgres test container not available, skipping test_postgres_describe_table"
1513 );
1514 return;
1515 };
1516 let result = conn
1517 .describe_table(None, "test_users")
1518 .expect("describe_table should succeed");
1519 assert_eq!(result.columns.len(), 6, "should return 6 metadata columns");
1520 let col_names: Vec<String> = result.columns.iter().map(|c| c.name.clone()).collect();
1521 assert_eq!(
1522 col_names,
1523 vec![
1524 "column_name",
1525 "data_type",
1526 "is_nullable",
1527 "column_default",
1528 "numeric_precision",
1529 "numeric_scale",
1530 ]
1531 );
1532 assert!(
1534 result.rows.len() >= 6,
1535 "expected at least 6 rows, got {}",
1536 result.rows.len()
1537 );
1538 }
1539
1540 #[test]
1541 fn test_postgres_type_mapping() {
1542 let Some(mut conn) = try_connect() else {
1543 eprintln!("Postgres test container not available, skipping test_postgres_type_mapping");
1544 return;
1545 };
1546 let result = conn
1547 .query(
1548 "SELECT name, age, score, active, meta, uid FROM test_users \
1549 WHERE name = 'Alice'",
1550 )
1551 .expect("query should succeed");
1552 assert_eq!(result.rows.len(), 1, "expected exactly Alice");
1553 let row = &result.rows[0];
1554 assert!(matches!(row[0], Value::String(_)), "name should be String");
1555 assert!(matches!(row[1], Value::Int64(_)), "age should be Int64");
1556 assert!(
1557 matches!(row[2], Value::Decimal(_) | Value::Float64(_)),
1558 "score (NUMERIC) should be Decimal or Float64"
1559 );
1560 assert!(matches!(row[3], Value::Bool(_)), "active should be Bool");
1561 assert!(
1562 matches!(row[4], Value::Json(_)),
1563 "meta (JSONB) should be Json"
1564 );
1565 assert!(matches!(row[5], Value::Uuid(_)), "uid should be Uuid");
1566 }
1567
1568 #[test]
1569 fn test_postgres_timestamptz_mapping() {
1570 let Some(mut conn) = try_connect() else {
1571 eprintln!(
1572 "Postgres test container not available, skipping test_postgres_timestamptz_mapping"
1573 );
1574 return;
1575 };
1576 let result = conn
1577 .query("SELECT created_at FROM test_users WHERE name = 'Alice'")
1578 .expect("query should succeed");
1579 assert_eq!(result.rows.len(), 1);
1580 assert!(
1581 matches!(result.rows[0][0], Value::DateTimeTz(_)),
1582 "created_at (TIMESTAMPTZ) should be DateTimeTz, got {:?}",
1583 result.rows[0][0]
1584 );
1585 }
1586
1587 #[test]
1591 fn test_postgres_bulk_insert_rows_round_trip() {
1592 let Some(mut conn) = try_connect() else {
1593 eprintln!(
1594 "Postgres test container not available, skipping test_postgres_bulk_insert_rows_round_trip"
1595 );
1596 return;
1597 };
1598
1599 let pid = std::process::id();
1600 let table = format!("ferrule_bulk_test_{pid}");
1601 let _ = conn.execute(&format!("DROP TABLE IF EXISTS {table}"));
1602 conn.execute(&format!(
1603 "CREATE TABLE {table} (\
1604 id BIGINT, \
1605 name TEXT, \
1606 active BOOLEAN, \
1607 score DOUBLE PRECISION, \
1608 meta JSONB, \
1609 tricky TEXT\
1610 )"
1611 ))
1612 .expect("CREATE TABLE");
1613
1614 let columns = vec![
1615 ColumnInfo {
1616 name: "id".into(),
1617 type_hint: TypeHint::Int64,
1618 nullable: false,
1619 },
1620 ColumnInfo {
1621 name: "name".into(),
1622 type_hint: TypeHint::String,
1623 nullable: true,
1624 },
1625 ColumnInfo {
1626 name: "active".into(),
1627 type_hint: TypeHint::Bool,
1628 nullable: true,
1629 },
1630 ColumnInfo {
1631 name: "score".into(),
1632 type_hint: TypeHint::Float64,
1633 nullable: true,
1634 },
1635 ColumnInfo {
1636 name: "meta".into(),
1637 type_hint: TypeHint::Json,
1638 nullable: true,
1639 },
1640 ColumnInfo {
1641 name: "tricky".into(),
1642 type_hint: TypeHint::String,
1643 nullable: true,
1644 },
1645 ];
1646
1647 let rows: Vec<Row> = vec![
1651 vec![
1652 Value::Int64(1),
1653 Value::String("Alice".into()),
1654 Value::Bool(true),
1655 Value::Float64(99.5),
1656 Value::Json(serde_json::json!({"role": "admin"})),
1657 Value::String("plain".into()),
1658 ],
1659 vec![
1660 Value::Int64(2),
1661 Value::String("Bob".into()),
1662 Value::Bool(false),
1663 Value::Float64(88.25),
1664 Value::Json(serde_json::json!({"role": "user"})),
1665 Value::String("comma,sep".into()),
1666 ],
1667 vec![
1668 Value::Int64(3),
1669 Value::String("Esc\\\t\nape".into()),
1670 Value::Bool(true),
1671 Value::Float64(0.0),
1672 Value::Json(serde_json::Value::Null),
1673 Value::String("\\.".into()),
1674 ],
1675 vec![
1676 Value::Int64(4),
1677 Value::Null,
1678 Value::Null,
1679 Value::Null,
1680 Value::Null,
1681 Value::Null,
1682 ],
1683 vec![
1684 Value::Int64(5),
1685 Value::String("nan-and-inf".into()),
1686 Value::Bool(true),
1687 Value::Float64(f64::INFINITY),
1688 Value::Json(serde_json::json!([1, 2, 3])),
1689 Value::String("héllo 🐈".into()),
1690 ],
1691 ];
1692
1693 let n = conn
1694 .bulk_insert_rows(BulkInsert {
1695 table: &table,
1696 columns: &columns,
1697 rows: &rows,
1698 copy_format: crate::copy::CopyFormat::Text,
1699 })
1700 .expect("bulk_insert_rows");
1701 assert_eq!(n, 5, "bulk should return rows-accepted = 5");
1702
1703 let count = conn
1705 .query(&format!("SELECT count(*)::bigint FROM {table}"))
1706 .unwrap();
1707 assert!(matches!(count.rows[0][0], Value::Int64(5)));
1708
1709 let r3 = conn
1710 .query(&format!("SELECT name, tricky FROM {table} WHERE id = 3"))
1711 .unwrap();
1712 assert_eq!(r3.rows.len(), 1);
1713 if let Value::String(name) = &r3.rows[0][0] {
1714 assert_eq!(
1715 name, "Esc\\\t\nape",
1716 "row 3 name should round-trip with raw bytes"
1717 );
1718 } else {
1719 panic!("row 3 name should be String, got {:?}", r3.rows[0][0]);
1720 }
1721 if let Value::String(tricky) = &r3.rows[0][1] {
1722 assert_eq!(
1723 tricky, "\\.",
1724 "row 3 tricky should be literal backslash-dot"
1725 );
1726 } else {
1727 panic!("row 3 tricky should be String, got {:?}", r3.rows[0][1]);
1728 }
1729
1730 let r4 = conn
1732 .query(&format!("SELECT name, active FROM {table} WHERE id = 4"))
1733 .unwrap();
1734 assert!(matches!(r4.rows[0][0], Value::Null));
1735 assert!(matches!(r4.rows[0][1], Value::Null));
1736
1737 conn.execute(&format!("DROP TABLE {table}"))
1739 .expect("DROP TABLE");
1740 }
1741
1742 #[test]
1743 fn test_postgres_primary_key() {
1744 let Some(mut conn) = try_connect() else {
1745 eprintln!("Postgres test container not available, skipping test_postgres_primary_key");
1746 return;
1747 };
1748 let pk = conn.primary_key(None, "test_users").expect("primary_key");
1750 assert_eq!(pk, vec!["id".to_string()]);
1751 }
1752
1753 #[test]
1754 fn test_postgres_list_foreign_keys() {
1755 let Some(mut conn) = try_connect() else {
1756 eprintln!(
1757 "Postgres test container not available, skipping test_postgres_list_foreign_keys"
1758 );
1759 return;
1760 };
1761 let pid = std::process::id();
1762 let child = format!("ferrule_fk_test_orders_{pid}");
1763 let _ = conn.execute(&format!("DROP TABLE IF EXISTS {child}"));
1764 conn.execute(&format!(
1765 "CREATE TABLE {child} (\
1766 id SERIAL PRIMARY KEY, \
1767 user_id INT REFERENCES test_users(id) ON DELETE CASCADE\
1768 )"
1769 ))
1770 .expect("CREATE TABLE");
1771
1772 let fks = conn.list_foreign_keys(None).expect("list_foreign_keys");
1773 let matching: Vec<_> = fks.iter().filter(|fk| fk.child_table == child).collect();
1774 assert_eq!(matching.len(), 1, "expected 1 FK from {child}, got {fks:?}");
1775 let fk = matching[0];
1776 assert_eq!(fk.child_columns, vec!["user_id".to_string()]);
1777 assert_eq!(fk.parent_table, "test_users");
1778 assert_eq!(fk.parent_columns, vec!["id".to_string()]);
1779 assert_eq!(fk.on_delete.as_deref(), Some("CASCADE"));
1780
1781 conn.execute(&format!("DROP TABLE {child}"))
1782 .expect("DROP TABLE");
1783 }
1784
1785 #[test]
1789 fn test_postgres_copy_skip_then_upsert() {
1790 use crate::backend::Backend;
1791 use crate::copy::{CopyOptions, CopySource, IfExists, copy_rows};
1792
1793 let (Some(mut src), Some(mut dst)) = (try_connect(), try_connect()) else {
1794 eprintln!(
1795 "Postgres test container not available, skipping test_postgres_copy_skip_then_upsert"
1796 );
1797 return;
1798 };
1799
1800 let pid = std::process::id();
1801 let src_table = format!("ferrule_pg_skip_src_{pid}");
1802 let dst_table = format!("ferrule_pg_skip_dst_{pid}");
1803 let _ = src.execute(&format!("DROP TABLE IF EXISTS {src_table}"));
1804 let _ = dst.execute(&format!("DROP TABLE IF EXISTS {dst_table}"));
1805 src.execute(&format!(
1806 "CREATE TABLE {src_table} (id INT PRIMARY KEY, name TEXT, val INT)"
1807 ))
1808 .expect("CREATE src");
1809 dst.execute(&format!(
1810 "CREATE TABLE {dst_table} (id INT PRIMARY KEY, name TEXT, val INT)"
1811 ))
1812 .expect("CREATE dst");
1813 src.execute(&format!(
1814 "INSERT INTO {src_table} VALUES (1, 'new-1', 10), (2, 'new-2', 20)"
1815 ))
1816 .expect("seed src");
1817 dst.execute(&format!("INSERT INTO {dst_table} VALUES (1, 'old-1', 99)"))
1818 .expect("seed dst");
1819
1820 let opts = CopyOptions {
1822 source: CopySource::Query {
1823 sql: format!("SELECT * FROM {src_table} ORDER BY id"),
1824 into: dst_table.clone(),
1825 },
1826 if_exists: IfExists::Skip,
1827 ..Default::default()
1828 };
1829 copy_rows(
1830 &mut src,
1831 Backend::Postgres,
1832 &mut dst,
1833 Backend::Postgres,
1834 &opts,
1835 )
1836 .expect("copy_rows skip");
1837
1838 let out = dst
1839 .query(&format!(
1840 "SELECT id, name, val FROM {dst_table} ORDER BY id"
1841 ))
1842 .expect("verify skip");
1843 assert_eq!(out.rows.len(), 2);
1844 assert!(matches!(&out.rows[0][1], Value::String(s) if s == "old-1"));
1845 assert!(matches!(&out.rows[1][1], Value::String(s) if s == "new-2"));
1846
1847 let opts = CopyOptions {
1849 source: CopySource::Query {
1850 sql: format!("SELECT * FROM {src_table} ORDER BY id"),
1851 into: dst_table.clone(),
1852 },
1853 if_exists: IfExists::Upsert,
1854 ..Default::default()
1855 };
1856 copy_rows(
1857 &mut src,
1858 Backend::Postgres,
1859 &mut dst,
1860 Backend::Postgres,
1861 &opts,
1862 )
1863 .expect("copy_rows upsert");
1864
1865 let out = dst
1866 .query(&format!(
1867 "SELECT id, name, val FROM {dst_table} ORDER BY id"
1868 ))
1869 .expect("verify upsert");
1870 assert_eq!(out.rows.len(), 2);
1871 assert!(matches!(&out.rows[0][1], Value::String(s) if s == "new-1"));
1872 assert!(matches!(&out.rows[0][2], Value::Int64(10)));
1873 assert!(matches!(&out.rows[1][1], Value::String(s) if s == "new-2"));
1874
1875 let _ = src.execute(&format!("DROP TABLE {src_table}"));
1877 let _ = dst.execute(&format!("DROP TABLE {dst_table}"));
1878 }
1879
1880 #[cfg(feature = "sqlite")]
1885 #[test]
1886 fn test_postgres_to_sqlite_all_tables_round_trip() {
1887 use crate::backend::Backend;
1888 use crate::copy::{AllTablesOptions, copy_all_tables};
1889
1890 let Some(mut src) = try_connect() else {
1891 eprintln!(
1892 "Postgres test container not available, skipping test_postgres_to_sqlite_all_tables_round_trip"
1893 );
1894 return;
1895 };
1896
1897 let pid = std::process::id();
1898 let parent = format!("ferrule_all_parent_{pid}");
1899 let child = format!("ferrule_all_child_{pid}");
1900 let _ = src.execute(&format!("DROP TABLE IF EXISTS {child}"));
1901 let _ = src.execute(&format!("DROP TABLE IF EXISTS {parent}"));
1902 src.execute(&format!(
1903 "CREATE TABLE {parent} (id INT PRIMARY KEY, name TEXT)"
1904 ))
1905 .expect("CREATE parent");
1906 src.execute(&format!(
1907 "CREATE TABLE {child} (id INT PRIMARY KEY, \
1908 parent_id INT REFERENCES {parent}(id), \
1909 note TEXT)"
1910 ))
1911 .expect("CREATE child");
1912 src.execute(&format!(
1913 "INSERT INTO {parent} VALUES (1, 'one'), (2, 'two')"
1914 ))
1915 .expect("seed parent");
1916 src.execute(&format!(
1917 "INSERT INTO {child} VALUES (10, 1, 'first'), (11, 2, 'second')"
1918 ))
1919 .expect("seed child");
1920
1921 let dst_path = std::env::temp_dir().join(format!("ferrule-pg-all-tables-{pid}.db"));
1923 let _ = std::fs::remove_file(&dst_path);
1924 let dst_url = DatabaseUrl::parse(&format!("sqlite://{}", dst_path.display())).unwrap();
1925 let mut dst =
1926 crate::connect(&dst_url, &ConnectOptions::default(), None).expect("connect sqlite dst");
1927 dst.execute("PRAGMA foreign_keys = ON").unwrap();
1928
1929 let opts = AllTablesOptions {
1930 include: vec![format!("ferrule_all_*_{pid}")],
1931 create_table: true,
1932 ..Default::default()
1933 };
1934 let copied = copy_all_tables(
1935 &mut src,
1936 Backend::Postgres,
1937 &mut dst,
1938 Backend::Sqlite,
1939 &opts,
1940 )
1941 .expect("copy_all_tables PG -> SQLite");
1942 assert_eq!(copied, 4, "2 parent rows + 2 child rows expected");
1943
1944 let p = dst
1945 .query(&format!("SELECT count(*) FROM {parent}"))
1946 .expect("verify parent");
1947 let c = dst
1948 .query(&format!("SELECT count(*) FROM {child}"))
1949 .expect("verify child");
1950 assert!(matches!(&p.rows[0][0], Value::Int64(2)));
1951 assert!(matches!(&c.rows[0][0], Value::Int64(2)));
1952
1953 let _ = src.execute(&format!("DROP TABLE {child}"));
1955 let _ = src.execute(&format!("DROP TABLE {parent}"));
1956 let _ = std::fs::remove_file(&dst_path);
1957 }
1958
1959 #[test]
1966 fn test_postgres_binary_copy_round_trip_all_value_variants() {
1967 use crate::backend::Backend;
1968 use crate::copy::{BulkMode, CopyFormat, CopyOptions, CopySource, copy_rows};
1969
1970 let (Some(mut src), Some(mut dst)) = (try_connect(), try_connect()) else {
1971 eprintln!(
1972 "Postgres test container not available, skipping test_postgres_binary_copy_round_trip_all_value_variants"
1973 );
1974 return;
1975 };
1976
1977 let pid = std::process::id();
1978 let src_table = format!("ferrule_pg_bin_src_{pid}");
1979 let dst_table = format!("ferrule_pg_bin_dst_{pid}");
1980 let _ = src.execute(&format!("DROP TABLE IF EXISTS {src_table}"));
1981 let _ = dst.execute(&format!("DROP TABLE IF EXISTS {dst_table}"));
1982 let create = format!(
1986 "CREATE TABLE {src_table} (\
1987 b BOOLEAN, \
1988 i BIGINT, \
1989 f DOUBLE PRECISION, \
1990 n NUMERIC, \
1991 t TEXT, \
1992 by BYTEA, \
1993 d DATE, \
1994 tm TIME, \
1995 dt TIMESTAMP, \
1996 dttz TIMESTAMPTZ, \
1997 j JSONB, \
1998 u UUID\
1999 )"
2000 );
2001 src.execute(&create).expect("CREATE src");
2002 dst.execute(&create.replace(&src_table, &dst_table))
2003 .expect("CREATE dst");
2004 src.execute(&format!(
2007 "INSERT INTO {src_table} VALUES (\
2008 true, 42, 2.5, 99.5, 'hello', '\\xdeadbeef', \
2009 DATE '2024-05-14', TIME '12:34:56', \
2010 TIMESTAMP '2024-05-14 12:34:56', \
2011 TIMESTAMPTZ '2024-05-14 12:34:56+00', \
2012 '{{\"k\":\"v\"}}'::jsonb, \
2013 '00112233-4455-6677-8899-aabbccddeeff'::uuid\
2014 ), (\
2015 false, NULL, NULL, NULL, NULL, NULL, \
2016 NULL, NULL, NULL, NULL, NULL, NULL\
2017 )"
2018 ))
2019 .expect("seed src");
2020
2021 let opts = CopyOptions {
2024 source: CopySource::Query {
2025 sql: format!("SELECT * FROM {src_table} ORDER BY i NULLS LAST"),
2026 into: dst_table.clone(),
2027 },
2028 bulk_mode: BulkMode::On,
2029 copy_format: CopyFormat::Binary,
2030 ..Default::default()
2031 };
2032 let copied = copy_rows(
2033 &mut src,
2034 Backend::Postgres,
2035 &mut dst,
2036 Backend::Postgres,
2037 &opts,
2038 )
2039 .expect("copy_rows binary COPY");
2040 assert_eq!(copied, 2);
2041
2042 let out = dst
2044 .query(&format!(
2045 "SELECT b, i, f, n::text, t, by, d::text, tm::text, dt::text, \
2046 dttz::text, j::text, u::text \
2047 FROM {dst_table} ORDER BY i NULLS LAST"
2048 ))
2049 .expect("read back");
2050 assert_eq!(out.rows.len(), 2);
2051 let r0 = &out.rows[0];
2053 assert!(matches!(&r0[0], Value::Bool(true)));
2054 assert!(matches!(&r0[1], Value::Int64(42)));
2055 match &r0[2] {
2056 Value::Float64(f) => assert!((f - 2.5).abs() < 1e-9),
2057 other => panic!("expected Float64(2.5), got {other:?}"),
2058 }
2059 match &r0[3] {
2060 Value::String(s) => assert_eq!(s, "99.5"),
2061 other => panic!("expected NUMERIC text 99.5, got {other:?}"),
2062 }
2063 assert!(matches!(&r0[4], Value::String(s) if s == "hello"));
2064 assert!(matches!(&r0[5], Value::Bytes(b) if b == &vec![0xde, 0xad, 0xbe, 0xef]));
2065 assert!(matches!(&r0[11], Value::String(s) if s == "00112233-4455-6677-8899-aabbccddeeff"));
2066
2067 let r1 = &out.rows[1];
2070 assert!(matches!(&r1[0], Value::Bool(false)));
2071 for col in &r1[1..] {
2072 assert!(matches!(col, Value::Null), "expected NULL, got {col:?}");
2073 }
2074
2075 let _ = src.execute(&format!("DROP TABLE {src_table}"));
2076 let _ = dst.execute(&format!("DROP TABLE {dst_table}"));
2077 }
2078
2079 #[test]
2085 fn test_postgres_cursor_streams_in_bounded_batches() {
2086 let Some(mut conn) = try_connect() else {
2087 eprintln!(
2088 "Postgres test container not available, skipping test_postgres_cursor_streams_in_bounded_batches"
2089 );
2090 return;
2091 };
2092 const TOTAL: i64 = 50_000;
2093 const BATCH: usize = 256;
2094 let sql = format!("SELECT i, i * 2 AS doubled FROM generate_series(1, {TOTAL}) AS g(i)");
2095 let mut cursor = conn.query_cursor(&sql).expect("open pg cursor");
2096 assert_eq!(cursor.columns().len(), 2);
2097 let mut total = 0u64;
2098 let mut batches = 0u64;
2099 loop {
2100 let batch = cursor.next_batch(BATCH).expect("pull pg batch");
2101 if batch.is_empty() {
2102 break;
2103 }
2104 assert!(batch.len() <= BATCH);
2105 total += batch.len() as u64;
2106 batches += 1;
2107 }
2108 assert_eq!(total, TOTAL as u64);
2109 assert_eq!(batches, (TOTAL as u64).div_ceil(BATCH as u64));
2110 }
2111
2112 #[test]
2115 fn test_postgres_write_rows_round_trip() {
2116 let Some(mut conn) = try_connect() else {
2117 eprintln!(
2118 "Postgres test container not available, skipping test_postgres_write_rows_round_trip"
2119 );
2120 return;
2121 };
2122 let _ = conn.execute("DROP TABLE IF EXISTS ferrule_write_test");
2123 conn.execute("CREATE TABLE ferrule_write_test (id INT PRIMARY KEY, name TEXT)")
2124 .expect("create write table");
2125 let columns = vec![
2126 crate::value::ColumnInfo {
2127 name: "id".into(),
2128 type_hint: TypeHint::Int64,
2129 nullable: false,
2130 },
2131 crate::value::ColumnInfo {
2132 name: "name".into(),
2133 type_hint: TypeHint::String,
2134 nullable: true,
2135 },
2136 ];
2137 let rows: Vec<crate::value::Row> = (1..=3000)
2138 .map(|i| vec![Value::Int64(i), Value::String(format!("n{i}"))])
2139 .collect();
2140 let opts = crate::write::WriteOptions {
2141 batch_size: 500,
2142 ..Default::default()
2143 };
2144 let report = crate::write::write_rows(
2145 &mut *conn,
2146 crate::Backend::Postgres,
2147 "ferrule_write_test",
2148 &columns,
2149 rows,
2150 &opts,
2151 )
2152 .expect("write_rows");
2153 assert_eq!(report.rows_written, 3000);
2154 assert!(report.is_complete());
2155 let back = conn
2156 .query("SELECT COUNT(*) FROM ferrule_write_test")
2157 .expect("count");
2158 assert!(matches!(back.rows[0][0], Value::Int64(3000)));
2159 let _ = conn.execute("DROP TABLE ferrule_write_test");
2160 }
2161
2162 #[test]
2165 fn test_postgres_write_rows_partial_failure() {
2166 let Some(mut conn) = try_connect() else {
2167 eprintln!(
2168 "Postgres test container not available, skipping test_postgres_write_rows_partial_failure"
2169 );
2170 return;
2171 };
2172 let _ = conn.execute("DROP TABLE IF EXISTS ferrule_write_pf");
2173 conn.execute("CREATE TABLE ferrule_write_pf (id INT PRIMARY KEY)")
2174 .expect("create");
2175 conn.execute("INSERT INTO ferrule_write_pf VALUES (5)")
2176 .expect("seed");
2177 let columns = vec![crate::value::ColumnInfo {
2178 name: "id".into(),
2179 type_hint: TypeHint::Int64,
2180 nullable: false,
2181 }];
2182 let rows: Vec<crate::value::Row> = (1..=8).map(|i| vec![Value::Int64(i)]).collect();
2184 let opts = crate::write::WriteOptions {
2185 batch_size: 4,
2186 ..Default::default()
2187 };
2188 let report = crate::write::write_rows(
2189 &mut *conn,
2190 crate::Backend::Postgres,
2191 "ferrule_write_pf",
2192 &columns,
2193 rows,
2194 &opts,
2195 )
2196 .expect("write_rows");
2197 assert_eq!(report.rows_written, 4);
2198 assert_eq!(report.rejected_batches.len(), 1);
2199 assert_eq!(report.rejected_batches[0].batch_index, 1);
2200 let _ = conn.execute("DROP TABLE ferrule_write_pf");
2201 }
2202}