1mod cancel;
19mod connection;
20mod copy;
21mod cursor;
22mod io;
23pub mod io_backend;
24mod pipeline;
25mod pool;
26mod prepared;
27mod query;
28pub mod rls;
29mod row;
30mod stream;
31mod transaction;
32
33pub use connection::PgConnection;
34pub use connection::TlsConfig;
35pub(crate) use connection::{CANCEL_REQUEST_CODE, parse_affected_rows};
36pub use cancel::CancelToken;
37pub use io_backend::{IoBackend, backend_name, detect as detect_io_backend};
38pub use pool::{PgPool, PoolConfig, PoolStats, PooledConnection};
39pub use prepared::PreparedStatement;
40pub use rls::RlsContext;
41pub use row::QailRow;
42
43use qail_core::ast::Qail;
44use std::collections::HashMap;
45use std::sync::Arc;
46
47#[derive(Debug, Clone)]
48pub struct ColumnInfo {
49 pub name_to_index: HashMap<String, usize>,
50 pub oids: Vec<u32>,
51 pub formats: Vec<i16>,
52}
53
54impl ColumnInfo {
55 pub fn from_fields(fields: &[crate::protocol::FieldDescription]) -> Self {
56 let mut name_to_index = HashMap::with_capacity(fields.len());
57 let mut oids = Vec::with_capacity(fields.len());
58 let mut formats = Vec::with_capacity(fields.len());
59
60 for (i, field) in fields.iter().enumerate() {
61 name_to_index.insert(field.name.clone(), i);
62 oids.push(field.type_oid);
63 formats.push(field.format);
64 }
65
66 Self {
67 name_to_index,
68 oids,
69 formats,
70 }
71 }
72}
73
74pub struct PgRow {
76 pub columns: Vec<Option<Vec<u8>>>,
77 pub column_info: Option<Arc<ColumnInfo>>,
78}
79
80#[derive(Debug)]
82pub enum PgError {
83 Connection(String),
84 Protocol(String),
85 Auth(String),
86 Query(String),
87 NoRows,
88 Io(std::io::Error),
90 Encode(String),
92}
93
94impl std::fmt::Display for PgError {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 match self {
97 PgError::Connection(e) => write!(f, "Connection error: {}", e),
98 PgError::Protocol(e) => write!(f, "Protocol error: {}", e),
99 PgError::Auth(e) => write!(f, "Auth error: {}", e),
100 PgError::Query(e) => write!(f, "Query error: {}", e),
101 PgError::NoRows => write!(f, "No rows returned"),
102 PgError::Io(e) => write!(f, "I/O error: {}", e),
103 PgError::Encode(e) => write!(f, "Encode error: {}", e),
104 }
105 }
106}
107
108impl std::error::Error for PgError {}
109
110impl From<std::io::Error> for PgError {
111 fn from(e: std::io::Error) -> Self {
112 PgError::Io(e)
113 }
114}
115
116pub type PgResult<T> = Result<T, PgError>;
118
119pub struct PgDriver {
121 #[allow(dead_code)]
122 connection: PgConnection,
123 rls_context: Option<RlsContext>,
125}
126
127impl PgDriver {
128 pub fn new(connection: PgConnection) -> Self {
130 Self { connection, rls_context: None }
131 }
132
133 pub fn builder() -> PgDriverBuilder {
146 PgDriverBuilder::new()
147 }
148
149 pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
151 let connection = PgConnection::connect(host, port, user, database).await?;
152 Ok(Self::new(connection))
153 }
154
155 pub async fn connect_with_password(
157 host: &str,
158 port: u16,
159 user: &str,
160 database: &str,
161 password: &str,
162 ) -> PgResult<Self> {
163 let connection =
164 PgConnection::connect_with_password(host, port, user, database, Some(password)).await?;
165 Ok(Self::new(connection))
166 }
167
168 pub async fn connect_env() -> PgResult<Self> {
179 let url = std::env::var("DATABASE_URL")
180 .map_err(|_| PgError::Connection("DATABASE_URL environment variable not set".to_string()))?;
181 Self::connect_url(&url).await
182 }
183
184 pub async fn connect_url(url: &str) -> PgResult<Self> {
194 let (host, port, user, database, password) = Self::parse_database_url(url)?;
195
196 if let Some(pwd) = password {
197 Self::connect_with_password(&host, port, &user, &database, &pwd).await
198 } else {
199 Self::connect(&host, port, &user, &database).await
200 }
201 }
202
203 fn parse_database_url(url: &str) -> PgResult<(String, u16, String, String, Option<String>)> {
210 let after_scheme = url.split("://").nth(1)
212 .ok_or_else(|| PgError::Connection("Invalid DATABASE_URL: missing scheme".to_string()))?;
213
214 let (auth_part, host_db_part) = if let Some(at_pos) = after_scheme.rfind('@') {
216 (Some(&after_scheme[..at_pos]), &after_scheme[at_pos + 1..])
217 } else {
218 (None, after_scheme)
219 };
220
221 let (user, password) = if let Some(auth) = auth_part {
223 let parts: Vec<&str> = auth.splitn(2, ':').collect();
224 if parts.len() == 2 {
225 (
227 Self::percent_decode(parts[0]),
228 Some(Self::percent_decode(parts[1])),
229 )
230 } else {
231 (Self::percent_decode(parts[0]), None)
232 }
233 } else {
234 return Err(PgError::Connection("Invalid DATABASE_URL: missing user".to_string()));
235 };
236
237 let (host_port, database) = if let Some(slash_pos) = host_db_part.find('/') {
239 (&host_db_part[..slash_pos], host_db_part[slash_pos + 1..].to_string())
240 } else {
241 return Err(PgError::Connection("Invalid DATABASE_URL: missing database name".to_string()));
242 };
243
244 let (host, port) = if let Some(colon_pos) = host_port.rfind(':') {
246 let port_str = &host_port[colon_pos + 1..];
247 let port = port_str.parse::<u16>()
248 .map_err(|_| PgError::Connection(format!("Invalid port: {}", port_str)))?;
249 (host_port[..colon_pos].to_string(), port)
250 } else {
251 (host_port.to_string(), 5432) };
253
254 Ok((host, port, user, database, password))
255 }
256
257 fn percent_decode(s: &str) -> String {
260 let mut result = String::with_capacity(s.len());
261 let mut chars = s.chars().peekable();
262
263 while let Some(c) = chars.next() {
264 if c == '%' {
265 let hex: String = chars.by_ref().take(2).collect();
267 if hex.len() == 2
268 && let Ok(byte) = u8::from_str_radix(&hex, 16)
269 {
270 result.push(byte as char);
271 continue;
272 }
273 result.push('%');
275 result.push_str(&hex);
276 } else if c == '+' {
277 result.push('+');
280 } else {
281 result.push(c);
282 }
283 }
284
285 result
286 }
287
288 pub async fn connect_with_timeout(
299 host: &str,
300 port: u16,
301 user: &str,
302 database: &str,
303 password: &str,
304 timeout: std::time::Duration,
305 ) -> PgResult<Self> {
306 tokio::time::timeout(
307 timeout,
308 Self::connect_with_password(host, port, user, database, password),
309 )
310 .await
311 .map_err(|_| PgError::Connection(format!("Connection timeout after {:?}", timeout)))?
312 }
313 pub fn clear_cache(&mut self) {
317 self.connection.stmt_cache.clear();
318 self.connection.prepared_statements.clear();
319 }
320
321 pub fn cache_stats(&self) -> (usize, usize) {
324 (self.connection.stmt_cache.len(), self.connection.stmt_cache.cap().get())
325 }
326
327 pub async fn fetch_all(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
333 self.fetch_all_cached(cmd).await
335 }
336
337 pub async fn fetch_typed<T: row::QailRow>(&mut self, cmd: &Qail) -> PgResult<Vec<T>> {
345 let rows = self.fetch_all(cmd).await?;
346 Ok(rows.iter().map(T::from_row).collect())
347 }
348
349 pub async fn fetch_one_typed<T: row::QailRow>(&mut self, cmd: &Qail) -> PgResult<Option<T>> {
352 let rows = self.fetch_all(cmd).await?;
353 Ok(rows.first().map(T::from_row))
354 }
355
356 pub async fn fetch_all_uncached(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
360 use crate::protocol::AstEncoder;
361
362 let wire_bytes = AstEncoder::encode_cmd_reuse(
363 cmd,
364 &mut self.connection.sql_buf,
365 &mut self.connection.params_buf,
366 );
367
368 self.connection.send_bytes(&wire_bytes).await?;
369
370 let mut rows: Vec<PgRow> = Vec::new();
371 let mut column_info: Option<Arc<ColumnInfo>> = None;
372
373 let mut error: Option<PgError> = None;
374
375 loop {
376 let msg = self.connection.recv().await?;
377 match msg {
378 crate::protocol::BackendMessage::ParseComplete
379 | crate::protocol::BackendMessage::BindComplete => {}
380 crate::protocol::BackendMessage::RowDescription(fields) => {
381 column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
382 }
383 crate::protocol::BackendMessage::DataRow(data) => {
384 if error.is_none() {
385 rows.push(PgRow {
386 columns: data,
387 column_info: column_info.clone(),
388 });
389 }
390 }
391 crate::protocol::BackendMessage::CommandComplete(_) => {}
392 crate::protocol::BackendMessage::ReadyForQuery(_) => {
393 if let Some(err) = error {
394 return Err(err);
395 }
396 return Ok(rows);
397 }
398 crate::protocol::BackendMessage::ErrorResponse(err) => {
399 if error.is_none() {
400 error = Some(PgError::Query(err.message));
401 }
402 }
403 _ => {}
404 }
405 }
406 }
407
408 pub async fn fetch_all_fast(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
412 use crate::protocol::AstEncoder;
413
414 let wire_bytes = AstEncoder::encode_cmd_reuse(
415 cmd,
416 &mut self.connection.sql_buf,
417 &mut self.connection.params_buf,
418 );
419
420 self.connection.send_bytes(&wire_bytes).await?;
421
422 let mut rows: Vec<PgRow> = Vec::new();
424 let mut error: Option<PgError> = None;
425
426 loop {
427 let res = self.connection.recv_with_data_fast().await;
428 match res {
429 Ok((msg_type, data)) => {
430 match msg_type {
431 b'D' => {
432 if error.is_none() && let Some(columns) = data {
434 rows.push(PgRow {
435 columns,
436 column_info: None, });
438 }
439 }
440 b'Z' => {
441 if let Some(err) = error {
443 return Err(err);
444 }
445 return Ok(rows);
446 }
447 _ => {} }
449 }
450 Err(e) => {
451 if error.is_none() {
460 error = Some(e);
461 }
462 }
467 }
468 }
469 }
470
471 pub async fn fetch_one(&mut self, cmd: &Qail) -> PgResult<PgRow> {
473 let rows = self.fetch_all(cmd).await?;
474 rows.into_iter().next().ok_or(PgError::NoRows)
475 }
476
477 pub async fn fetch_all_cached(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
484 use crate::protocol::AstEncoder;
485 use std::collections::hash_map::DefaultHasher;
486 use std::hash::{Hash, Hasher};
487
488 self.connection.sql_buf.clear();
489 self.connection.params_buf.clear();
490
491 match cmd.action {
493 qail_core::ast::Action::Get | qail_core::ast::Action::With => {
494 crate::protocol::ast_encoder::dml::encode_select(cmd, &mut self.connection.sql_buf, &mut self.connection.params_buf).ok();
495 }
496 qail_core::ast::Action::Add => {
497 crate::protocol::ast_encoder::dml::encode_insert(cmd, &mut self.connection.sql_buf, &mut self.connection.params_buf).ok();
498 }
499 qail_core::ast::Action::Set => {
500 crate::protocol::ast_encoder::dml::encode_update(cmd, &mut self.connection.sql_buf, &mut self.connection.params_buf).ok();
501 }
502 qail_core::ast::Action::Del => {
503 crate::protocol::ast_encoder::dml::encode_delete(cmd, &mut self.connection.sql_buf, &mut self.connection.params_buf).ok();
504 }
505 _ => {
506 let (sql, params) = AstEncoder::encode_cmd_sql(cmd);
508 let raw_rows = self.connection.query_cached(&sql, ¶ms).await?;
509 return Ok(raw_rows.into_iter().map(|data| PgRow { columns: data, column_info: None }).collect());
510 }
511 }
512
513 let mut hasher = DefaultHasher::new();
514 self.connection.sql_buf.hash(&mut hasher);
515 let sql_hash = hasher.finish();
516
517 let is_cache_miss = !self.connection.stmt_cache.contains(&sql_hash);
518
519 let stmt_name = if let Some(name) = self.connection.stmt_cache.get(&sql_hash) {
520 name.clone()
521 } else {
522 let name = format!("qail_{:x}", sql_hash);
523
524 use crate::protocol::PgEncoder;
525 use tokio::io::AsyncWriteExt;
526
527 let sql_str = std::str::from_utf8(&self.connection.sql_buf).unwrap_or("");
528
529 let parse_msg = PgEncoder::encode_parse(&name, sql_str, &[]);
531 let describe_msg = PgEncoder::encode_describe(false, &name);
532 self.connection.stream.write_all(&parse_msg).await?;
533 self.connection.stream.write_all(&describe_msg).await?;
534
535 self.connection.stmt_cache.put(sql_hash, name.clone());
536 self.connection.prepared_statements.insert(name.clone(), sql_str.to_string());
537
538 name
539 };
540
541 use crate::protocol::PgEncoder;
543 use tokio::io::AsyncWriteExt;
544
545 let mut buf = bytes::BytesMut::with_capacity(128);
546 PgEncoder::encode_bind_to(&mut buf, &stmt_name, &self.connection.params_buf)
547 .map_err(|e| PgError::Encode(e.to_string()))?;
548 PgEncoder::encode_execute_to(&mut buf);
549 PgEncoder::encode_sync_to(&mut buf);
550 self.connection.stream.write_all(&buf).await?;
551
552 let cached_column_info = self.connection.column_info_cache.get(&sql_hash).cloned();
554
555 let mut rows: Vec<PgRow> = Vec::new();
556 let mut column_info: Option<Arc<ColumnInfo>> = cached_column_info;
557 let mut error: Option<PgError> = None;
558
559 loop {
560 let msg = self.connection.recv().await?;
561 match msg {
562 crate::protocol::BackendMessage::ParseComplete
563 | crate::protocol::BackendMessage::BindComplete => {}
564 crate::protocol::BackendMessage::ParameterDescription(_) => {
565 }
567 crate::protocol::BackendMessage::RowDescription(fields) => {
568 let info = Arc::new(ColumnInfo::from_fields(&fields));
570 if is_cache_miss {
571 self.connection.column_info_cache.insert(sql_hash, info.clone());
572 }
573 column_info = Some(info);
574 }
575 crate::protocol::BackendMessage::DataRow(data) => {
576 if error.is_none() {
577 rows.push(PgRow {
578 columns: data,
579 column_info: column_info.clone(),
580 });
581 }
582 }
583 crate::protocol::BackendMessage::CommandComplete(_) => {}
584 crate::protocol::BackendMessage::NoData => {
585 }
587 crate::protocol::BackendMessage::ReadyForQuery(_) => {
588 if let Some(err) = error {
589 return Err(err);
590 }
591 return Ok(rows);
592 }
593 crate::protocol::BackendMessage::ErrorResponse(err) => {
594 if error.is_none() {
595 error = Some(PgError::Query(err.message));
596 self.connection.stmt_cache.clear();
599 self.connection.prepared_statements.clear();
600 self.connection.column_info_cache.clear();
601 }
602 }
603 _ => {}
604 }
605 }
606 }
607
608 pub async fn execute(&mut self, cmd: &Qail) -> PgResult<u64> {
610 use crate::protocol::AstEncoder;
611
612 let wire_bytes = AstEncoder::encode_cmd_reuse(
613 cmd,
614 &mut self.connection.sql_buf,
615 &mut self.connection.params_buf,
616 );
617
618 self.connection.send_bytes(&wire_bytes).await?;
619
620 let mut affected = 0u64;
621 let mut error: Option<PgError> = None;
622
623 loop {
624 let msg = self.connection.recv().await?;
625 match msg {
626 crate::protocol::BackendMessage::ParseComplete
627 | crate::protocol::BackendMessage::BindComplete => {}
628 crate::protocol::BackendMessage::RowDescription(_) => {}
629 crate::protocol::BackendMessage::DataRow(_) => {}
630 crate::protocol::BackendMessage::CommandComplete(tag) => {
631 if error.is_none() && let Some(n) = tag.split_whitespace().last() {
632 affected = n.parse().unwrap_or(0);
633 }
634 }
635 crate::protocol::BackendMessage::ReadyForQuery(_) => {
636 if let Some(err) = error {
637 return Err(err);
638 }
639 return Ok(affected);
640 }
641 crate::protocol::BackendMessage::ErrorResponse(err) => {
642 if error.is_none() {
643 error = Some(PgError::Query(err.message));
644 }
645 }
646 _ => {}
647 }
648 }
649 }
650
651 pub async fn begin(&mut self) -> PgResult<()> {
655 self.connection.begin_transaction().await
656 }
657
658 pub async fn commit(&mut self) -> PgResult<()> {
660 self.connection.commit().await
661 }
662
663 pub async fn rollback(&mut self) -> PgResult<()> {
665 self.connection.rollback().await
666 }
667
668 pub async fn savepoint(&mut self, name: &str) -> PgResult<()> {
681 self.connection.savepoint(name).await
682 }
683
684 pub async fn rollback_to(&mut self, name: &str) -> PgResult<()> {
688 self.connection.rollback_to(name).await
689 }
690
691 pub async fn release_savepoint(&mut self, name: &str) -> PgResult<()> {
694 self.connection.release_savepoint(name).await
695 }
696
697 pub async fn execute_batch(&mut self, cmds: &[Qail]) -> PgResult<Vec<u64>> {
711 self.begin().await?;
712 let mut results = Vec::with_capacity(cmds.len());
713 for cmd in cmds {
714 match self.execute(cmd).await {
715 Ok(n) => results.push(n),
716 Err(e) => {
717 self.rollback().await?;
718 return Err(e);
719 }
720 }
721 }
722 self.commit().await?;
723 Ok(results)
724 }
725
726 pub async fn set_statement_timeout(&mut self, ms: u32) -> PgResult<()> {
734 self.execute_raw(&format!("SET statement_timeout = {}", ms))
735 .await
736 }
737
738 pub async fn reset_statement_timeout(&mut self) -> PgResult<()> {
740 self.execute_raw("RESET statement_timeout").await
741 }
742
743 pub async fn set_rls_context(&mut self, ctx: rls::RlsContext) -> PgResult<()> {
761 let sql = rls::context_to_sql(&ctx);
762 self.execute_raw(&sql).await?;
763 self.rls_context = Some(ctx);
764 Ok(())
765 }
766
767 pub async fn clear_rls_context(&mut self) -> PgResult<()> {
772 self.execute_raw(rls::reset_sql()).await?;
773 self.rls_context = None;
774 Ok(())
775 }
776
777 pub fn rls_context(&self) -> Option<&rls::RlsContext> {
779 self.rls_context.as_ref()
780 }
781
782 pub async fn pipeline_batch(&mut self, cmds: &[Qail]) -> PgResult<usize> {
794 self.connection.pipeline_ast_fast(cmds).await
795 }
796
797 pub async fn pipeline_fetch(&mut self, cmds: &[Qail]) -> PgResult<Vec<Vec<PgRow>>> {
799 let raw_results = self.connection.pipeline_ast(cmds).await?;
800
801 let results: Vec<Vec<PgRow>> = raw_results
802 .into_iter()
803 .map(|rows| {
804 rows.into_iter()
805 .map(|columns| PgRow {
806 columns,
807 column_info: None,
808 })
809 .collect()
810 })
811 .collect();
812
813 Ok(results)
814 }
815
816 pub async fn prepare(&mut self, sql: &str) -> PgResult<PreparedStatement> {
818 self.connection.prepare(sql).await
819 }
820
821 pub async fn pipeline_prepared_fast(
823 &mut self,
824 stmt: &PreparedStatement,
825 params_batch: &[Vec<Option<Vec<u8>>>],
826 ) -> PgResult<usize> {
827 self.connection
828 .pipeline_prepared_fast(stmt, params_batch)
829 .await
830 }
831
832 pub async fn execute_raw(&mut self, sql: &str) -> PgResult<()> {
839 if sql.as_bytes().contains(&0) {
841 return Err(crate::PgError::Protocol(
842 "SQL contains NULL byte (0x00) which is invalid in PostgreSQL".to_string(),
843 ));
844 }
845 self.connection.execute_simple(sql).await
846 }
847
848 pub async fn fetch_raw(&mut self, sql: &str) -> PgResult<Vec<PgRow>> {
852 if sql.as_bytes().contains(&0) {
853 return Err(crate::PgError::Protocol(
854 "SQL contains NULL byte (0x00) which is invalid in PostgreSQL".to_string(),
855 ));
856 }
857
858 use tokio::io::AsyncWriteExt;
859 use crate::protocol::PgEncoder;
860
861 let msg = PgEncoder::encode_query_string(sql);
863 self.connection.stream.write_all(&msg).await?;
864
865 let mut rows: Vec<PgRow> = Vec::new();
866 let mut column_info: Option<std::sync::Arc<ColumnInfo>> = None;
867
868
869 let mut error: Option<PgError> = None;
870
871 loop {
872 let msg = self.connection.recv().await?;
873 match msg {
874 crate::protocol::BackendMessage::RowDescription(fields) => {
875 column_info = Some(std::sync::Arc::new(ColumnInfo::from_fields(&fields)));
876 }
877 crate::protocol::BackendMessage::DataRow(data) => {
878 if error.is_none() {
879 rows.push(PgRow {
880 columns: data,
881 column_info: column_info.clone(),
882 });
883 }
884 }
885 crate::protocol::BackendMessage::CommandComplete(_) => {}
886 crate::protocol::BackendMessage::ReadyForQuery(_) => {
887 if let Some(err) = error {
888 return Err(err);
889 }
890 return Ok(rows);
891 }
892 crate::protocol::BackendMessage::ErrorResponse(err) => {
893 if error.is_none() {
894 error = Some(PgError::Query(err.message));
895 }
896 }
897 _ => {}
898 }
899 }
900 }
901
902 pub async fn copy_bulk(
918 &mut self,
919 cmd: &Qail,
920 rows: &[Vec<qail_core::ast::Value>],
921 ) -> PgResult<u64> {
922 use qail_core::ast::Action;
923
924
925 if cmd.action != Action::Add {
926 return Err(PgError::Query(
927 "copy_bulk requires Qail::Add action".to_string(),
928 ));
929 }
930
931 let table = &cmd.table;
932
933 let columns: Vec<String> = cmd
934 .columns
935 .iter()
936 .filter_map(|expr| {
937 use qail_core::ast::Expr;
938 match expr {
939 Expr::Named(name) => Some(name.clone()),
940 Expr::Aliased { name, .. } => Some(name.clone()),
941 Expr::Star => None, _ => None,
943 }
944 })
945 .collect();
946
947 if columns.is_empty() {
948 return Err(PgError::Query(
949 "copy_bulk requires columns in Qail".to_string(),
950 ));
951 }
952
953 self.connection.copy_in_fast(table, &columns, rows).await
955 }
956
957 pub async fn copy_bulk_bytes(&mut self, cmd: &Qail, data: &[u8]) -> PgResult<u64> {
970 use qail_core::ast::Action;
971
972 if cmd.action != Action::Add {
973 return Err(PgError::Query(
974 "copy_bulk_bytes requires Qail::Add action".to_string(),
975 ));
976 }
977
978 let table = &cmd.table;
979 let columns: Vec<String> = cmd
980 .columns
981 .iter()
982 .filter_map(|expr| {
983 use qail_core::ast::Expr;
984 match expr {
985 Expr::Named(name) => Some(name.clone()),
986 Expr::Aliased { name, .. } => Some(name.clone()),
987 _ => None,
988 }
989 })
990 .collect();
991
992 if columns.is_empty() {
993 return Err(PgError::Query(
994 "copy_bulk_bytes requires columns in Qail".to_string(),
995 ));
996 }
997
998 self.connection.copy_in_raw(table, &columns, data).await
1000 }
1001
1002 pub async fn copy_export_table(
1010 &mut self,
1011 table: &str,
1012 columns: &[String],
1013 ) -> PgResult<Vec<u8>> {
1014 let cols = columns.join(", ");
1015 let sql = format!("COPY {} ({}) TO STDOUT", table, cols);
1016
1017 self.connection.copy_out_raw(&sql).await
1018 }
1019
1020 pub async fn stream_cmd(
1034 &mut self,
1035 cmd: &Qail,
1036 batch_size: usize,
1037 ) -> PgResult<Vec<Vec<PgRow>>> {
1038 use std::sync::atomic::{AtomicU64, Ordering};
1039 static CURSOR_ID: AtomicU64 = AtomicU64::new(0);
1040
1041 let cursor_name = format!("qail_cursor_{}", CURSOR_ID.fetch_add(1, Ordering::SeqCst));
1042
1043 use crate::protocol::AstEncoder;
1045 let mut sql_buf = bytes::BytesMut::with_capacity(256);
1046 let mut params: Vec<Option<Vec<u8>>> = Vec::new();
1047 AstEncoder::encode_select_sql(cmd, &mut sql_buf, &mut params);
1048 let sql = String::from_utf8_lossy(&sql_buf).to_string();
1049
1050 self.connection.begin_transaction().await?;
1052
1053 self.connection.declare_cursor(&cursor_name, &sql).await?;
1055
1056 let mut all_batches = Vec::new();
1058 while let Some(rows) = self
1059 .connection
1060 .fetch_cursor(&cursor_name, batch_size)
1061 .await?
1062 {
1063 let pg_rows: Vec<PgRow> = rows
1064 .into_iter()
1065 .map(|cols| PgRow {
1066 columns: cols,
1067 column_info: None,
1068 })
1069 .collect();
1070 all_batches.push(pg_rows);
1071 }
1072
1073 self.connection.close_cursor(&cursor_name).await?;
1074 self.connection.commit().await?;
1075
1076 Ok(all_batches)
1077 }
1078}
1079
1080#[derive(Default)]
1097pub struct PgDriverBuilder {
1098 host: Option<String>,
1099 port: Option<u16>,
1100 user: Option<String>,
1101 database: Option<String>,
1102 password: Option<String>,
1103 timeout: Option<std::time::Duration>,
1104}
1105
1106impl PgDriverBuilder {
1107 pub fn new() -> Self {
1109 Self::default()
1110 }
1111
1112 pub fn host(mut self, host: impl Into<String>) -> Self {
1114 self.host = Some(host.into());
1115 self
1116 }
1117
1118 pub fn port(mut self, port: u16) -> Self {
1120 self.port = Some(port);
1121 self
1122 }
1123
1124 pub fn user(mut self, user: impl Into<String>) -> Self {
1126 self.user = Some(user.into());
1127 self
1128 }
1129
1130 pub fn database(mut self, database: impl Into<String>) -> Self {
1132 self.database = Some(database.into());
1133 self
1134 }
1135
1136 pub fn password(mut self, password: impl Into<String>) -> Self {
1138 self.password = Some(password.into());
1139 self
1140 }
1141
1142 pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
1144 self.timeout = Some(timeout);
1145 self
1146 }
1147
1148 pub async fn connect(self) -> PgResult<PgDriver> {
1150 let host = self.host.as_deref().unwrap_or("127.0.0.1");
1151 let port = self.port.unwrap_or(5432);
1152 let user = self.user.as_deref().ok_or_else(|| {
1153 PgError::Connection("User is required".to_string())
1154 })?;
1155 let database = self.database.as_deref().ok_or_else(|| {
1156 PgError::Connection("Database is required".to_string())
1157 })?;
1158
1159 match (self.password.as_deref(), self.timeout) {
1160 (Some(password), Some(timeout)) => {
1161 PgDriver::connect_with_timeout(host, port, user, database, password, timeout).await
1162 }
1163 (Some(password), None) => {
1164 PgDriver::connect_with_password(host, port, user, database, password).await
1165 }
1166 (None, Some(timeout)) => {
1167 tokio::time::timeout(
1168 timeout,
1169 PgDriver::connect(host, port, user, database),
1170 )
1171 .await
1172 .map_err(|_| PgError::Connection(format!("Connection timeout after {:?}", timeout)))?
1173 }
1174 (None, None) => {
1175 PgDriver::connect(host, port, user, database).await
1176 }
1177 }
1178 }
1179}