1mod cancel;
19mod connection;
20mod copy;
21mod cursor;
22mod io;
23pub mod io_backend;
24mod pipeline;
25mod pool;
26mod prepared;
27mod query;
28mod row;
29mod stream;
30mod transaction;
31
32pub use connection::PgConnection;
33pub use connection::TlsConfig;
34pub(crate) use connection::{CANCEL_REQUEST_CODE, parse_affected_rows};
35pub use cancel::CancelToken;
36pub use io_backend::{IoBackend, backend_name, detect as detect_io_backend};
37pub use pool::{PgPool, PoolConfig, PoolStats, PooledConnection};
38pub use prepared::PreparedStatement;
39pub use row::QailRow;
40
41use qail_core::ast::Qail;
42use std::collections::HashMap;
43use std::sync::Arc;
44
45#[derive(Debug, Clone)]
46pub struct ColumnInfo {
47 pub name_to_index: HashMap<String, usize>,
48 pub oids: Vec<u32>,
49 pub formats: Vec<i16>,
50}
51
52impl ColumnInfo {
53 pub fn from_fields(fields: &[crate::protocol::FieldDescription]) -> Self {
54 let mut name_to_index = HashMap::with_capacity(fields.len());
55 let mut oids = Vec::with_capacity(fields.len());
56 let mut formats = Vec::with_capacity(fields.len());
57
58 for (i, field) in fields.iter().enumerate() {
59 name_to_index.insert(field.name.clone(), i);
60 oids.push(field.type_oid);
61 formats.push(field.format);
62 }
63
64 Self {
65 name_to_index,
66 oids,
67 formats,
68 }
69 }
70}
71
72pub struct PgRow {
74 pub columns: Vec<Option<Vec<u8>>>,
75 pub column_info: Option<Arc<ColumnInfo>>,
76}
77
78#[derive(Debug)]
80pub enum PgError {
81 Connection(String),
82 Protocol(String),
83 Auth(String),
84 Query(String),
85 NoRows,
86 Io(std::io::Error),
88 Encode(String),
90}
91
92impl std::fmt::Display for PgError {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 match self {
95 PgError::Connection(e) => write!(f, "Connection error: {}", e),
96 PgError::Protocol(e) => write!(f, "Protocol error: {}", e),
97 PgError::Auth(e) => write!(f, "Auth error: {}", e),
98 PgError::Query(e) => write!(f, "Query error: {}", e),
99 PgError::NoRows => write!(f, "No rows returned"),
100 PgError::Io(e) => write!(f, "I/O error: {}", e),
101 PgError::Encode(e) => write!(f, "Encode error: {}", e),
102 }
103 }
104}
105
106impl std::error::Error for PgError {}
107
108impl From<std::io::Error> for PgError {
109 fn from(e: std::io::Error) -> Self {
110 PgError::Io(e)
111 }
112}
113
114pub type PgResult<T> = Result<T, PgError>;
116
117pub struct PgDriver {
119 #[allow(dead_code)]
120 connection: PgConnection,
121}
122
123impl PgDriver {
124 pub fn new(connection: PgConnection) -> Self {
126 Self { connection }
127 }
128
129 pub fn builder() -> PgDriverBuilder {
142 PgDriverBuilder::new()
143 }
144
145 pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
147 let connection = PgConnection::connect(host, port, user, database).await?;
148 Ok(Self::new(connection))
149 }
150
151 pub async fn connect_with_password(
153 host: &str,
154 port: u16,
155 user: &str,
156 database: &str,
157 password: &str,
158 ) -> PgResult<Self> {
159 let connection =
160 PgConnection::connect_with_password(host, port, user, database, Some(password)).await?;
161 Ok(Self::new(connection))
162 }
163
164 pub async fn connect_env() -> PgResult<Self> {
175 let url = std::env::var("DATABASE_URL")
176 .map_err(|_| PgError::Connection("DATABASE_URL environment variable not set".to_string()))?;
177 Self::connect_url(&url).await
178 }
179
180 pub async fn connect_url(url: &str) -> PgResult<Self> {
190 let (host, port, user, database, password) = Self::parse_database_url(url)?;
191
192 if let Some(pwd) = password {
193 Self::connect_with_password(&host, port, &user, &database, &pwd).await
194 } else {
195 Self::connect(&host, port, &user, &database).await
196 }
197 }
198
199 fn parse_database_url(url: &str) -> PgResult<(String, u16, String, String, Option<String>)> {
206 let after_scheme = url.split("://").nth(1)
208 .ok_or_else(|| PgError::Connection("Invalid DATABASE_URL: missing scheme".to_string()))?;
209
210 let (auth_part, host_db_part) = if let Some(at_pos) = after_scheme.rfind('@') {
212 (Some(&after_scheme[..at_pos]), &after_scheme[at_pos + 1..])
213 } else {
214 (None, after_scheme)
215 };
216
217 let (user, password) = if let Some(auth) = auth_part {
219 let parts: Vec<&str> = auth.splitn(2, ':').collect();
220 if parts.len() == 2 {
221 (
223 Self::percent_decode(parts[0]),
224 Some(Self::percent_decode(parts[1])),
225 )
226 } else {
227 (Self::percent_decode(parts[0]), None)
228 }
229 } else {
230 return Err(PgError::Connection("Invalid DATABASE_URL: missing user".to_string()));
231 };
232
233 let (host_port, database) = if let Some(slash_pos) = host_db_part.find('/') {
235 (&host_db_part[..slash_pos], host_db_part[slash_pos + 1..].to_string())
236 } else {
237 return Err(PgError::Connection("Invalid DATABASE_URL: missing database name".to_string()));
238 };
239
240 let (host, port) = if let Some(colon_pos) = host_port.rfind(':') {
242 let port_str = &host_port[colon_pos + 1..];
243 let port = port_str.parse::<u16>()
244 .map_err(|_| PgError::Connection(format!("Invalid port: {}", port_str)))?;
245 (host_port[..colon_pos].to_string(), port)
246 } else {
247 (host_port.to_string(), 5432) };
249
250 Ok((host, port, user, database, password))
251 }
252
253 fn percent_decode(s: &str) -> String {
256 let mut result = String::with_capacity(s.len());
257 let mut chars = s.chars().peekable();
258
259 while let Some(c) = chars.next() {
260 if c == '%' {
261 let hex: String = chars.by_ref().take(2).collect();
263 if hex.len() == 2 {
264 if let Ok(byte) = u8::from_str_radix(&hex, 16) {
265 result.push(byte as char);
266 continue;
267 }
268 }
269 result.push('%');
271 result.push_str(&hex);
272 } else if c == '+' {
273 result.push('+');
276 } else {
277 result.push(c);
278 }
279 }
280
281 result
282 }
283
284 pub async fn connect_with_timeout(
295 host: &str,
296 port: u16,
297 user: &str,
298 database: &str,
299 password: &str,
300 timeout: std::time::Duration,
301 ) -> PgResult<Self> {
302 tokio::time::timeout(
303 timeout,
304 Self::connect_with_password(host, port, user, database, password),
305 )
306 .await
307 .map_err(|_| PgError::Connection(format!("Connection timeout after {:?}", timeout)))?
308 }
309 pub fn clear_cache(&mut self) {
313 self.connection.stmt_cache.clear();
314 self.connection.prepared_statements.clear();
315 }
316
317 pub fn cache_stats(&self) -> (usize, usize) {
320 (self.connection.stmt_cache.len(), self.connection.stmt_cache.cap().get())
321 }
322
323 pub async fn fetch_all(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
329 self.fetch_all_cached(cmd).await
331 }
332
333 pub async fn fetch_typed<T: row::QailRow>(&mut self, cmd: &Qail) -> PgResult<Vec<T>> {
341 let rows = self.fetch_all(cmd).await?;
342 Ok(rows.iter().map(T::from_row).collect())
343 }
344
345 pub async fn fetch_one_typed<T: row::QailRow>(&mut self, cmd: &Qail) -> PgResult<Option<T>> {
348 let rows = self.fetch_all(cmd).await?;
349 Ok(rows.first().map(T::from_row))
350 }
351
352 pub async fn fetch_all_uncached(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
356 use crate::protocol::AstEncoder;
357
358 let wire_bytes = AstEncoder::encode_cmd_reuse(
359 cmd,
360 &mut self.connection.sql_buf,
361 &mut self.connection.params_buf,
362 );
363
364 self.connection.send_bytes(&wire_bytes).await?;
365
366 let mut rows: Vec<PgRow> = Vec::new();
367 let mut column_info: Option<Arc<ColumnInfo>> = None;
368
369 let mut error: Option<PgError> = None;
370
371 loop {
372 let msg = self.connection.recv().await?;
373 match msg {
374 crate::protocol::BackendMessage::ParseComplete
375 | crate::protocol::BackendMessage::BindComplete => {}
376 crate::protocol::BackendMessage::RowDescription(fields) => {
377 column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
378 }
379 crate::protocol::BackendMessage::DataRow(data) => {
380 if error.is_none() {
381 rows.push(PgRow {
382 columns: data,
383 column_info: column_info.clone(),
384 });
385 }
386 }
387 crate::protocol::BackendMessage::CommandComplete(_) => {}
388 crate::protocol::BackendMessage::ReadyForQuery(_) => {
389 if let Some(err) = error {
390 return Err(err);
391 }
392 return Ok(rows);
393 }
394 crate::protocol::BackendMessage::ErrorResponse(err) => {
395 if error.is_none() {
396 error = Some(PgError::Query(err.message));
397 }
398 }
399 _ => {}
400 }
401 }
402 }
403
404 pub async fn fetch_all_fast(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
408 use crate::protocol::AstEncoder;
409
410 let wire_bytes = AstEncoder::encode_cmd_reuse(
411 cmd,
412 &mut self.connection.sql_buf,
413 &mut self.connection.params_buf,
414 );
415
416 self.connection.send_bytes(&wire_bytes).await?;
417
418 let mut rows: Vec<PgRow> = Vec::new();
420 let mut error: Option<PgError> = None;
421
422 loop {
423 let res = self.connection.recv_with_data_fast().await;
424 match res {
425 Ok((msg_type, data)) => {
426 match msg_type {
427 b'D' => {
428 if error.is_none() && let Some(columns) = data {
430 rows.push(PgRow {
431 columns,
432 column_info: None, });
434 }
435 }
436 b'Z' => {
437 if let Some(err) = error {
439 return Err(err);
440 }
441 return Ok(rows);
442 }
443 _ => {} }
445 }
446 Err(e) => {
447 if error.is_none() {
456 error = Some(e);
457 }
458 }
463 }
464 }
465 }
466
467 pub async fn fetch_one(&mut self, cmd: &Qail) -> PgResult<PgRow> {
469 let rows = self.fetch_all(cmd).await?;
470 rows.into_iter().next().ok_or(PgError::NoRows)
471 }
472
473 pub async fn fetch_all_cached(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
478 use crate::protocol::AstEncoder;
479 use std::collections::hash_map::DefaultHasher;
480 use std::hash::{Hash, Hasher};
481
482 self.connection.sql_buf.clear();
483 self.connection.params_buf.clear();
484
485 match cmd.action {
487 qail_core::ast::Action::Get | qail_core::ast::Action::With => {
488 crate::protocol::ast_encoder::dml::encode_select(cmd, &mut self.connection.sql_buf, &mut self.connection.params_buf).ok();
489 }
490 qail_core::ast::Action::Add => {
491 crate::protocol::ast_encoder::dml::encode_insert(cmd, &mut self.connection.sql_buf, &mut self.connection.params_buf).ok();
492 }
493 qail_core::ast::Action::Set => {
494 crate::protocol::ast_encoder::dml::encode_update(cmd, &mut self.connection.sql_buf, &mut self.connection.params_buf).ok();
495 }
496 qail_core::ast::Action::Del => {
497 crate::protocol::ast_encoder::dml::encode_delete(cmd, &mut self.connection.sql_buf, &mut self.connection.params_buf).ok();
498 }
499 _ => {
500 let (sql, params) = AstEncoder::encode_cmd_sql(cmd);
502 let raw_rows = self.connection.query_cached(&sql, ¶ms).await?;
503 return Ok(raw_rows.into_iter().map(|data| PgRow { columns: data, column_info: None }).collect());
504 }
505 }
506
507 let mut hasher = DefaultHasher::new();
508 self.connection.sql_buf.hash(&mut hasher);
509 let sql_hash = hasher.finish();
510
511 let stmt_name = if let Some(name) = self.connection.stmt_cache.get(&sql_hash) {
512 name.clone()
513 } else {
514 let name = format!("qail_{:x}", sql_hash);
515
516 use crate::protocol::PgEncoder;
517 use tokio::io::AsyncWriteExt;
518
519 let sql_str = std::str::from_utf8(&self.connection.sql_buf).unwrap_or("");
520 let parse_msg = PgEncoder::encode_parse(&name, sql_str, &[]);
521 self.connection.stream.write_all(&parse_msg).await?;
522
523 self.connection.stmt_cache.put(sql_hash, name.clone());
524 self.connection.prepared_statements.insert(name.clone(), sql_str.to_string());
525
526 name
527 };
528
529 use crate::protocol::PgEncoder;
531 use tokio::io::AsyncWriteExt;
532
533 let mut buf = bytes::BytesMut::with_capacity(128);
534 PgEncoder::encode_bind_to(&mut buf, &stmt_name, &self.connection.params_buf)
535 .map_err(|e| PgError::Encode(e.to_string()))?;
536 PgEncoder::encode_execute_to(&mut buf);
537 PgEncoder::encode_sync_to(&mut buf);
538 self.connection.stream.write_all(&buf).await?;
539
540 let mut rows: Vec<PgRow> = Vec::new();
541 let mut error: Option<PgError> = None;
542
543 loop {
544 let msg = self.connection.recv().await?;
545 match msg {
546 crate::protocol::BackendMessage::ParseComplete
547 | crate::protocol::BackendMessage::BindComplete => {}
548 crate::protocol::BackendMessage::RowDescription(_) => {}
549 crate::protocol::BackendMessage::DataRow(data) => {
550 if error.is_none() {
551 rows.push(PgRow {
552 columns: data,
553 column_info: None,
554 });
555 }
556 }
557 crate::protocol::BackendMessage::CommandComplete(_) => {}
558 crate::protocol::BackendMessage::ReadyForQuery(_) => {
559 if let Some(err) = error {
560 return Err(err);
561 }
562 return Ok(rows);
563 }
564 crate::protocol::BackendMessage::ErrorResponse(err) => {
565 if error.is_none() {
566 error = Some(PgError::Query(err.message));
567 self.connection.stmt_cache.clear();
570 self.connection.prepared_statements.clear();
571 }
572 }
573 _ => {}
574 }
575 }
576 }
577
578 pub async fn execute(&mut self, cmd: &Qail) -> PgResult<u64> {
580 use crate::protocol::AstEncoder;
581
582 let wire_bytes = AstEncoder::encode_cmd_reuse(
583 cmd,
584 &mut self.connection.sql_buf,
585 &mut self.connection.params_buf,
586 );
587
588 self.connection.send_bytes(&wire_bytes).await?;
589
590 let mut affected = 0u64;
591 let mut error: Option<PgError> = None;
592
593 loop {
594 let msg = self.connection.recv().await?;
595 match msg {
596 crate::protocol::BackendMessage::ParseComplete
597 | crate::protocol::BackendMessage::BindComplete => {}
598 crate::protocol::BackendMessage::RowDescription(_) => {}
599 crate::protocol::BackendMessage::DataRow(_) => {}
600 crate::protocol::BackendMessage::CommandComplete(tag) => {
601 if error.is_none() && let Some(n) = tag.split_whitespace().last() {
602 affected = n.parse().unwrap_or(0);
603 }
604 }
605 crate::protocol::BackendMessage::ReadyForQuery(_) => {
606 if let Some(err) = error {
607 return Err(err);
608 }
609 return Ok(affected);
610 }
611 crate::protocol::BackendMessage::ErrorResponse(err) => {
612 if error.is_none() {
613 error = Some(PgError::Query(err.message));
614 }
615 }
616 _ => {}
617 }
618 }
619 }
620
621 pub async fn begin(&mut self) -> PgResult<()> {
625 self.connection.begin_transaction().await
626 }
627
628 pub async fn commit(&mut self) -> PgResult<()> {
630 self.connection.commit().await
631 }
632
633 pub async fn rollback(&mut self) -> PgResult<()> {
635 self.connection.rollback().await
636 }
637
638 pub async fn savepoint(&mut self, name: &str) -> PgResult<()> {
651 self.connection.savepoint(name).await
652 }
653
654 pub async fn rollback_to(&mut self, name: &str) -> PgResult<()> {
658 self.connection.rollback_to(name).await
659 }
660
661 pub async fn release_savepoint(&mut self, name: &str) -> PgResult<()> {
664 self.connection.release_savepoint(name).await
665 }
666
667 pub async fn execute_batch(&mut self, cmds: &[Qail]) -> PgResult<Vec<u64>> {
681 self.begin().await?;
682 let mut results = Vec::with_capacity(cmds.len());
683 for cmd in cmds {
684 match self.execute(cmd).await {
685 Ok(n) => results.push(n),
686 Err(e) => {
687 self.rollback().await?;
688 return Err(e);
689 }
690 }
691 }
692 self.commit().await?;
693 Ok(results)
694 }
695
696 pub async fn set_statement_timeout(&mut self, ms: u32) -> PgResult<()> {
704 self.execute_raw(&format!("SET statement_timeout = {}", ms))
705 .await
706 }
707
708 pub async fn reset_statement_timeout(&mut self) -> PgResult<()> {
710 self.execute_raw("RESET statement_timeout").await
711 }
712
713 pub async fn pipeline_batch(&mut self, cmds: &[Qail]) -> PgResult<usize> {
725 self.connection.pipeline_ast_fast(cmds).await
726 }
727
728 pub async fn pipeline_fetch(&mut self, cmds: &[Qail]) -> PgResult<Vec<Vec<PgRow>>> {
730 let raw_results = self.connection.pipeline_ast(cmds).await?;
731
732 let results: Vec<Vec<PgRow>> = raw_results
733 .into_iter()
734 .map(|rows| {
735 rows.into_iter()
736 .map(|columns| PgRow {
737 columns,
738 column_info: None,
739 })
740 .collect()
741 })
742 .collect();
743
744 Ok(results)
745 }
746
747 pub async fn prepare(&mut self, sql: &str) -> PgResult<PreparedStatement> {
749 self.connection.prepare(sql).await
750 }
751
752 pub async fn pipeline_prepared_fast(
754 &mut self,
755 stmt: &PreparedStatement,
756 params_batch: &[Vec<Option<Vec<u8>>>],
757 ) -> PgResult<usize> {
758 self.connection
759 .pipeline_prepared_fast(stmt, params_batch)
760 .await
761 }
762
763 pub async fn execute_raw(&mut self, sql: &str) -> PgResult<()> {
770 if sql.as_bytes().contains(&0) {
772 return Err(crate::PgError::Protocol(
773 "SQL contains NULL byte (0x00) which is invalid in PostgreSQL".to_string(),
774 ));
775 }
776 self.connection.execute_simple(sql).await
777 }
778
779 pub async fn fetch_raw(&mut self, sql: &str) -> PgResult<Vec<PgRow>> {
783 if sql.as_bytes().contains(&0) {
784 return Err(crate::PgError::Protocol(
785 "SQL contains NULL byte (0x00) which is invalid in PostgreSQL".to_string(),
786 ));
787 }
788
789 use tokio::io::AsyncWriteExt;
790 use crate::protocol::PgEncoder;
791
792 let msg = PgEncoder::encode_query_string(sql);
794 self.connection.stream.write_all(&msg).await?;
795
796 let mut rows: Vec<PgRow> = Vec::new();
797 let mut column_info: Option<std::sync::Arc<ColumnInfo>> = None;
798
799
800 let mut error: Option<PgError> = None;
801
802 loop {
803 let msg = self.connection.recv().await?;
804 match msg {
805 crate::protocol::BackendMessage::RowDescription(fields) => {
806 column_info = Some(std::sync::Arc::new(ColumnInfo::from_fields(&fields)));
807 }
808 crate::protocol::BackendMessage::DataRow(data) => {
809 if error.is_none() {
810 rows.push(PgRow {
811 columns: data,
812 column_info: column_info.clone(),
813 });
814 }
815 }
816 crate::protocol::BackendMessage::CommandComplete(_) => {}
817 crate::protocol::BackendMessage::ReadyForQuery(_) => {
818 if let Some(err) = error {
819 return Err(err);
820 }
821 return Ok(rows);
822 }
823 crate::protocol::BackendMessage::ErrorResponse(err) => {
824 if error.is_none() {
825 error = Some(PgError::Query(err.message));
826 }
827 }
828 _ => {}
829 }
830 }
831 }
832
833 pub async fn copy_bulk(
849 &mut self,
850 cmd: &Qail,
851 rows: &[Vec<qail_core::ast::Value>],
852 ) -> PgResult<u64> {
853 use qail_core::ast::Action;
854
855
856 if cmd.action != Action::Add {
857 return Err(PgError::Query(
858 "copy_bulk requires Qail::Add action".to_string(),
859 ));
860 }
861
862 let table = &cmd.table;
863
864 let columns: Vec<String> = cmd
865 .columns
866 .iter()
867 .filter_map(|expr| {
868 use qail_core::ast::Expr;
869 match expr {
870 Expr::Named(name) => Some(name.clone()),
871 Expr::Aliased { name, .. } => Some(name.clone()),
872 Expr::Star => None, _ => None,
874 }
875 })
876 .collect();
877
878 if columns.is_empty() {
879 return Err(PgError::Query(
880 "copy_bulk requires columns in Qail".to_string(),
881 ));
882 }
883
884 self.connection.copy_in_fast(table, &columns, rows).await
886 }
887
888 pub async fn copy_bulk_bytes(&mut self, cmd: &Qail, data: &[u8]) -> PgResult<u64> {
901 use qail_core::ast::Action;
902
903 if cmd.action != Action::Add {
904 return Err(PgError::Query(
905 "copy_bulk_bytes requires Qail::Add action".to_string(),
906 ));
907 }
908
909 let table = &cmd.table;
910 let columns: Vec<String> = cmd
911 .columns
912 .iter()
913 .filter_map(|expr| {
914 use qail_core::ast::Expr;
915 match expr {
916 Expr::Named(name) => Some(name.clone()),
917 Expr::Aliased { name, .. } => Some(name.clone()),
918 _ => None,
919 }
920 })
921 .collect();
922
923 if columns.is_empty() {
924 return Err(PgError::Query(
925 "copy_bulk_bytes requires columns in Qail".to_string(),
926 ));
927 }
928
929 self.connection.copy_in_raw(table, &columns, data).await
931 }
932
933 pub async fn copy_export_table(
941 &mut self,
942 table: &str,
943 columns: &[String],
944 ) -> PgResult<Vec<u8>> {
945 let cols = columns.join(", ");
946 let sql = format!("COPY {} ({}) TO STDOUT", table, cols);
947
948 self.connection.copy_out_raw(&sql).await
949 }
950
951 pub async fn stream_cmd(
965 &mut self,
966 cmd: &Qail,
967 batch_size: usize,
968 ) -> PgResult<Vec<Vec<PgRow>>> {
969 use std::sync::atomic::{AtomicU64, Ordering};
970 static CURSOR_ID: AtomicU64 = AtomicU64::new(0);
971
972 let cursor_name = format!("qail_cursor_{}", CURSOR_ID.fetch_add(1, Ordering::SeqCst));
973
974 use crate::protocol::AstEncoder;
976 let mut sql_buf = bytes::BytesMut::with_capacity(256);
977 let mut params: Vec<Option<Vec<u8>>> = Vec::new();
978 AstEncoder::encode_select_sql(cmd, &mut sql_buf, &mut params);
979 let sql = String::from_utf8_lossy(&sql_buf).to_string();
980
981 self.connection.begin_transaction().await?;
983
984 self.connection.declare_cursor(&cursor_name, &sql).await?;
986
987 let mut all_batches = Vec::new();
989 while let Some(rows) = self
990 .connection
991 .fetch_cursor(&cursor_name, batch_size)
992 .await?
993 {
994 let pg_rows: Vec<PgRow> = rows
995 .into_iter()
996 .map(|cols| PgRow {
997 columns: cols,
998 column_info: None,
999 })
1000 .collect();
1001 all_batches.push(pg_rows);
1002 }
1003
1004 self.connection.close_cursor(&cursor_name).await?;
1005 self.connection.commit().await?;
1006
1007 Ok(all_batches)
1008 }
1009}
1010
1011#[derive(Default)]
1028pub struct PgDriverBuilder {
1029 host: Option<String>,
1030 port: Option<u16>,
1031 user: Option<String>,
1032 database: Option<String>,
1033 password: Option<String>,
1034 timeout: Option<std::time::Duration>,
1035}
1036
1037impl PgDriverBuilder {
1038 pub fn new() -> Self {
1040 Self::default()
1041 }
1042
1043 pub fn host(mut self, host: impl Into<String>) -> Self {
1045 self.host = Some(host.into());
1046 self
1047 }
1048
1049 pub fn port(mut self, port: u16) -> Self {
1051 self.port = Some(port);
1052 self
1053 }
1054
1055 pub fn user(mut self, user: impl Into<String>) -> Self {
1057 self.user = Some(user.into());
1058 self
1059 }
1060
1061 pub fn database(mut self, database: impl Into<String>) -> Self {
1063 self.database = Some(database.into());
1064 self
1065 }
1066
1067 pub fn password(mut self, password: impl Into<String>) -> Self {
1069 self.password = Some(password.into());
1070 self
1071 }
1072
1073 pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
1075 self.timeout = Some(timeout);
1076 self
1077 }
1078
1079 pub async fn connect(self) -> PgResult<PgDriver> {
1081 let host = self.host.as_deref().unwrap_or("127.0.0.1");
1082 let port = self.port.unwrap_or(5432);
1083 let user = self.user.as_deref().ok_or_else(|| {
1084 PgError::Connection("User is required".to_string())
1085 })?;
1086 let database = self.database.as_deref().ok_or_else(|| {
1087 PgError::Connection("Database is required".to_string())
1088 })?;
1089
1090 match (self.password.as_deref(), self.timeout) {
1091 (Some(password), Some(timeout)) => {
1092 PgDriver::connect_with_timeout(host, port, user, database, password, timeout).await
1093 }
1094 (Some(password), None) => {
1095 PgDriver::connect_with_password(host, port, user, database, password).await
1096 }
1097 (None, Some(timeout)) => {
1098 tokio::time::timeout(
1099 timeout,
1100 PgDriver::connect(host, port, user, database),
1101 )
1102 .await
1103 .map_err(|_| PgError::Connection(format!("Connection timeout after {:?}", timeout)))?
1104 }
1105 (None, None) => {
1106 PgDriver::connect(host, port, user, database).await
1107 }
1108 }
1109 }
1110}