1mod cancel;
20mod connection;
21mod copy;
22mod cursor;
23mod io;
24pub mod io_backend;
25mod pipeline;
26mod pool;
27mod prepared;
28mod query;
29pub mod rls;
30pub mod explain;
31pub mod branch_sql;
32mod row;
33mod stream;
34mod transaction;
35pub mod notification;
36
37pub use connection::PgConnection;
38pub use connection::TlsConfig;
39pub(crate) use connection::{CANCEL_REQUEST_CODE, parse_affected_rows};
40pub use cancel::CancelToken;
41pub use io_backend::{IoBackend, backend_name, detect as detect_io_backend};
42pub use pool::{PgPool, PoolConfig, PoolStats, PooledConnection};
43pub use prepared::PreparedStatement;
44pub use rls::RlsContext;
45pub use row::QailRow;
46pub use notification::Notification;
47
48use qail_core::ast::Qail;
49use std::collections::HashMap;
50use std::sync::Arc;
51
52#[derive(Debug, Clone)]
53pub struct ColumnInfo {
54 pub name_to_index: HashMap<String, usize>,
55 pub oids: Vec<u32>,
56 pub formats: Vec<i16>,
57}
58
59impl ColumnInfo {
60 pub fn from_fields(fields: &[crate::protocol::FieldDescription]) -> Self {
61 let mut name_to_index = HashMap::with_capacity(fields.len());
62 let mut oids = Vec::with_capacity(fields.len());
63 let mut formats = Vec::with_capacity(fields.len());
64
65 for (i, field) in fields.iter().enumerate() {
66 name_to_index.insert(field.name.clone(), i);
67 oids.push(field.type_oid);
68 formats.push(field.format);
69 }
70
71 Self {
72 name_to_index,
73 oids,
74 formats,
75 }
76 }
77}
78
79pub struct PgRow {
81 pub columns: Vec<Option<Vec<u8>>>,
82 pub column_info: Option<Arc<ColumnInfo>>,
83}
84
85#[derive(Debug)]
87pub enum PgError {
88 Connection(String),
89 Protocol(String),
90 Auth(String),
91 Query(String),
92 NoRows,
93 Io(std::io::Error),
95 Encode(String),
97}
98
99impl std::fmt::Display for PgError {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 match self {
102 PgError::Connection(e) => write!(f, "Connection error: {}", e),
103 PgError::Protocol(e) => write!(f, "Protocol error: {}", e),
104 PgError::Auth(e) => write!(f, "Auth error: {}", e),
105 PgError::Query(e) => write!(f, "Query error: {}", e),
106 PgError::NoRows => write!(f, "No rows returned"),
107 PgError::Io(e) => write!(f, "I/O error: {}", e),
108 PgError::Encode(e) => write!(f, "Encode error: {}", e),
109 }
110 }
111}
112
113impl std::error::Error for PgError {}
114
115impl From<std::io::Error> for PgError {
116 fn from(e: std::io::Error) -> Self {
117 PgError::Io(e)
118 }
119}
120
121pub type PgResult<T> = Result<T, PgError>;
123
124#[derive(Debug, Clone)]
126pub struct QueryResult {
127 pub columns: Vec<String>,
129 pub rows: Vec<Vec<Option<String>>>,
131}
132
133pub struct PgDriver {
135 #[allow(dead_code)]
136 connection: PgConnection,
137 rls_context: Option<RlsContext>,
139}
140
141impl PgDriver {
142 pub fn new(connection: PgConnection) -> Self {
144 Self { connection, rls_context: None }
145 }
146
147 pub fn builder() -> PgDriverBuilder {
160 PgDriverBuilder::new()
161 }
162
163 pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
165 let connection = PgConnection::connect(host, port, user, database).await?;
166 Ok(Self::new(connection))
167 }
168
169 pub async fn connect_with_password(
171 host: &str,
172 port: u16,
173 user: &str,
174 database: &str,
175 password: &str,
176 ) -> PgResult<Self> {
177 let connection =
178 PgConnection::connect_with_password(host, port, user, database, Some(password)).await?;
179 Ok(Self::new(connection))
180 }
181
182 pub async fn connect_env() -> PgResult<Self> {
193 let url = std::env::var("DATABASE_URL")
194 .map_err(|_| PgError::Connection("DATABASE_URL environment variable not set".to_string()))?;
195 Self::connect_url(&url).await
196 }
197
198 pub async fn connect_url(url: &str) -> PgResult<Self> {
208 let (host, port, user, database, password) = Self::parse_database_url(url)?;
209
210 if let Some(pwd) = password {
211 Self::connect_with_password(&host, port, &user, &database, &pwd).await
212 } else {
213 Self::connect(&host, port, &user, &database).await
214 }
215 }
216
217 fn parse_database_url(url: &str) -> PgResult<(String, u16, String, String, Option<String>)> {
224 let after_scheme = url.split("://").nth(1)
226 .ok_or_else(|| PgError::Connection("Invalid DATABASE_URL: missing scheme".to_string()))?;
227
228 let (auth_part, host_db_part) = if let Some(at_pos) = after_scheme.rfind('@') {
230 (Some(&after_scheme[..at_pos]), &after_scheme[at_pos + 1..])
231 } else {
232 (None, after_scheme)
233 };
234
235 let (user, password) = if let Some(auth) = auth_part {
237 let parts: Vec<&str> = auth.splitn(2, ':').collect();
238 if parts.len() == 2 {
239 (
241 Self::percent_decode(parts[0]),
242 Some(Self::percent_decode(parts[1])),
243 )
244 } else {
245 (Self::percent_decode(parts[0]), None)
246 }
247 } else {
248 return Err(PgError::Connection("Invalid DATABASE_URL: missing user".to_string()));
249 };
250
251 let (host_port, database) = if let Some(slash_pos) = host_db_part.find('/') {
253 (&host_db_part[..slash_pos], host_db_part[slash_pos + 1..].to_string())
254 } else {
255 return Err(PgError::Connection("Invalid DATABASE_URL: missing database name".to_string()));
256 };
257
258 let (host, port) = if let Some(colon_pos) = host_port.rfind(':') {
260 let port_str = &host_port[colon_pos + 1..];
261 let port = port_str.parse::<u16>()
262 .map_err(|_| PgError::Connection(format!("Invalid port: {}", port_str)))?;
263 (host_port[..colon_pos].to_string(), port)
264 } else {
265 (host_port.to_string(), 5432) };
267
268 Ok((host, port, user, database, password))
269 }
270
271 fn percent_decode(s: &str) -> String {
274 let mut result = String::with_capacity(s.len());
275 let mut chars = s.chars().peekable();
276
277 while let Some(c) = chars.next() {
278 if c == '%' {
279 let hex: String = chars.by_ref().take(2).collect();
281 if hex.len() == 2
282 && let Ok(byte) = u8::from_str_radix(&hex, 16)
283 {
284 result.push(byte as char);
285 continue;
286 }
287 result.push('%');
289 result.push_str(&hex);
290 } else if c == '+' {
291 result.push('+');
294 } else {
295 result.push(c);
296 }
297 }
298
299 result
300 }
301
302 pub async fn connect_with_timeout(
313 host: &str,
314 port: u16,
315 user: &str,
316 database: &str,
317 password: &str,
318 timeout: std::time::Duration,
319 ) -> PgResult<Self> {
320 tokio::time::timeout(
321 timeout,
322 Self::connect_with_password(host, port, user, database, password),
323 )
324 .await
325 .map_err(|_| PgError::Connection(format!("Connection timeout after {:?}", timeout)))?
326 }
327 pub fn clear_cache(&mut self) {
331 self.connection.stmt_cache.clear();
332 self.connection.prepared_statements.clear();
333 }
334
335 pub fn cache_stats(&self) -> (usize, usize) {
338 (self.connection.stmt_cache.len(), self.connection.stmt_cache.cap().get())
339 }
340
341 pub async fn fetch_all(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
347 self.fetch_all_cached(cmd).await
349 }
350
351 pub async fn fetch_typed<T: row::QailRow>(&mut self, cmd: &Qail) -> PgResult<Vec<T>> {
359 let rows = self.fetch_all(cmd).await?;
360 Ok(rows.iter().map(T::from_row).collect())
361 }
362
363 pub async fn fetch_one_typed<T: row::QailRow>(&mut self, cmd: &Qail) -> PgResult<Option<T>> {
366 let rows = self.fetch_all(cmd).await?;
367 Ok(rows.first().map(T::from_row))
368 }
369
370 pub async fn fetch_all_uncached(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
376 use crate::protocol::AstEncoder;
377
378 AstEncoder::encode_cmd_reuse_into(
379 cmd,
380 &mut self.connection.sql_buf,
381 &mut self.connection.params_buf,
382 &mut self.connection.write_buf,
383 )
384 .map_err(|e| PgError::Encode(e.to_string()))?;
385
386 self.connection.flush_write_buf().await?;
387
388 let mut rows: Vec<PgRow> = Vec::with_capacity(32);
389 let mut column_info: Option<Arc<ColumnInfo>> = None;
390
391 let mut error: Option<PgError> = None;
392
393 loop {
394 let msg = self.connection.recv().await?;
395 match msg {
396 crate::protocol::BackendMessage::ParseComplete
397 | crate::protocol::BackendMessage::BindComplete => {}
398 crate::protocol::BackendMessage::RowDescription(fields) => {
399 column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
400 }
401 crate::protocol::BackendMessage::DataRow(data) => {
402 if error.is_none() {
403 rows.push(PgRow {
404 columns: data,
405 column_info: column_info.clone(),
406 });
407 }
408 }
409 crate::protocol::BackendMessage::CommandComplete(_) => {}
410 crate::protocol::BackendMessage::ReadyForQuery(_) => {
411 if let Some(err) = error {
412 return Err(err);
413 }
414 return Ok(rows);
415 }
416 crate::protocol::BackendMessage::ErrorResponse(err) => {
417 if error.is_none() {
418 error = Some(PgError::Query(err.message));
419 }
420 }
421 _ => {}
422 }
423 }
424 }
425
426 pub async fn fetch_all_fast(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
430 use crate::protocol::AstEncoder;
431
432 AstEncoder::encode_cmd_reuse_into(
433 cmd,
434 &mut self.connection.sql_buf,
435 &mut self.connection.params_buf,
436 &mut self.connection.write_buf,
437 )
438 .map_err(|e| PgError::Encode(e.to_string()))?;
439
440 self.connection.flush_write_buf().await?;
441
442 let mut rows: Vec<PgRow> = Vec::with_capacity(32);
444 let mut error: Option<PgError> = None;
445
446 loop {
447 let res = self.connection.recv_with_data_fast().await;
448 match res {
449 Ok((msg_type, data)) => {
450 match msg_type {
451 b'D' => {
452 if error.is_none() && let Some(columns) = data {
454 rows.push(PgRow {
455 columns,
456 column_info: None, });
458 }
459 }
460 b'Z' => {
461 if let Some(err) = error {
463 return Err(err);
464 }
465 return Ok(rows);
466 }
467 _ => {} }
469 }
470 Err(e) => {
471 if error.is_none() {
480 error = Some(e);
481 }
482 }
487 }
488 }
489 }
490
491 pub async fn fetch_one(&mut self, cmd: &Qail) -> PgResult<PgRow> {
493 let rows = self.fetch_all(cmd).await?;
494 rows.into_iter().next().ok_or(PgError::NoRows)
495 }
496
497 pub async fn fetch_all_cached(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
506 use crate::protocol::AstEncoder;
507 use std::collections::hash_map::DefaultHasher;
508 use std::hash::{Hash, Hasher};
509
510 self.connection.sql_buf.clear();
511 self.connection.params_buf.clear();
512
513 match cmd.action {
515 qail_core::ast::Action::Get | qail_core::ast::Action::With => {
516 crate::protocol::ast_encoder::dml::encode_select(cmd, &mut self.connection.sql_buf, &mut self.connection.params_buf).ok();
517 }
518 qail_core::ast::Action::Add => {
519 crate::protocol::ast_encoder::dml::encode_insert(cmd, &mut self.connection.sql_buf, &mut self.connection.params_buf).ok();
520 }
521 qail_core::ast::Action::Set => {
522 crate::protocol::ast_encoder::dml::encode_update(cmd, &mut self.connection.sql_buf, &mut self.connection.params_buf).ok();
523 }
524 qail_core::ast::Action::Del => {
525 crate::protocol::ast_encoder::dml::encode_delete(cmd, &mut self.connection.sql_buf, &mut self.connection.params_buf).ok();
526 }
527 _ => {
528 let (sql, params) = AstEncoder::encode_cmd_sql(cmd).map_err(|e| PgError::Encode(e.to_string()))?;
530 let raw_rows = self.connection.query_cached(&sql, ¶ms).await?;
531 return Ok(raw_rows.into_iter().map(|data| PgRow { columns: data, column_info: None }).collect());
532 }
533 }
534
535 let mut hasher = DefaultHasher::new();
536 self.connection.sql_buf.hash(&mut hasher);
537 let sql_hash = hasher.finish();
538
539 let is_cache_miss = !self.connection.stmt_cache.contains(&sql_hash);
540
541 self.connection.write_buf.clear();
543
544 let stmt_name = if let Some(name) = self.connection.stmt_cache.get(&sql_hash) {
545 name.clone()
546 } else {
547 let name = format!("qail_{:x}", sql_hash);
548
549 self.connection.evict_prepared_if_full();
551
552 let sql_str = std::str::from_utf8(&self.connection.sql_buf).unwrap_or("");
553
554 use crate::protocol::PgEncoder;
556 let parse_msg = PgEncoder::encode_parse(&name, sql_str, &[]);
557 let describe_msg = PgEncoder::encode_describe(false, &name);
558 self.connection.write_buf.extend_from_slice(&parse_msg);
559 self.connection.write_buf.extend_from_slice(&describe_msg);
560
561 self.connection.stmt_cache.put(sql_hash, name.clone());
562 self.connection.prepared_statements.insert(name.clone(), sql_str.to_string());
563
564 name
565 };
566
567 use crate::protocol::PgEncoder;
569 PgEncoder::encode_bind_to(&mut self.connection.write_buf, &stmt_name, &self.connection.params_buf)
570 .map_err(|e| PgError::Encode(e.to_string()))?;
571 PgEncoder::encode_execute_to(&mut self.connection.write_buf);
572 PgEncoder::encode_sync_to(&mut self.connection.write_buf);
573
574 self.connection.flush_write_buf().await?;
576
577 let cached_column_info = self.connection.column_info_cache.get(&sql_hash).cloned();
579
580 let mut rows: Vec<PgRow> = Vec::with_capacity(32);
581 let mut column_info: Option<Arc<ColumnInfo>> = cached_column_info;
582 let mut error: Option<PgError> = None;
583
584 loop {
585 let msg = self.connection.recv().await?;
586 match msg {
587 crate::protocol::BackendMessage::ParseComplete
588 | crate::protocol::BackendMessage::BindComplete => {}
589 crate::protocol::BackendMessage::ParameterDescription(_) => {
590 }
592 crate::protocol::BackendMessage::RowDescription(fields) => {
593 let info = Arc::new(ColumnInfo::from_fields(&fields));
595 if is_cache_miss {
596 self.connection.column_info_cache.insert(sql_hash, info.clone());
597 }
598 column_info = Some(info);
599 }
600 crate::protocol::BackendMessage::DataRow(data) => {
601 if error.is_none() {
602 rows.push(PgRow {
603 columns: data,
604 column_info: column_info.clone(),
605 });
606 }
607 }
608 crate::protocol::BackendMessage::CommandComplete(_) => {}
609 crate::protocol::BackendMessage::NoData => {
610 }
612 crate::protocol::BackendMessage::ReadyForQuery(_) => {
613 if let Some(err) = error {
614 return Err(err);
615 }
616 return Ok(rows);
617 }
618 crate::protocol::BackendMessage::ErrorResponse(err) => {
619 if error.is_none() {
620 error = Some(PgError::Query(err.message));
621 self.connection.stmt_cache.clear();
624 self.connection.prepared_statements.clear();
625 self.connection.column_info_cache.clear();
626 }
627 }
628 _ => {}
629 }
630 }
631 }
632
633 pub async fn execute(&mut self, cmd: &Qail) -> PgResult<u64> {
635 use crate::protocol::AstEncoder;
636
637 let wire_bytes = AstEncoder::encode_cmd_reuse(
638 cmd,
639 &mut self.connection.sql_buf,
640 &mut self.connection.params_buf,
641 )
642 .map_err(|e| PgError::Encode(e.to_string()))?;
643
644 self.connection.send_bytes(&wire_bytes).await?;
645
646 let mut affected = 0u64;
647 let mut error: Option<PgError> = None;
648
649 loop {
650 let msg = self.connection.recv().await?;
651 match msg {
652 crate::protocol::BackendMessage::ParseComplete
653 | crate::protocol::BackendMessage::BindComplete => {}
654 crate::protocol::BackendMessage::RowDescription(_) => {}
655 crate::protocol::BackendMessage::DataRow(_) => {}
656 crate::protocol::BackendMessage::CommandComplete(tag) => {
657 if error.is_none() && let Some(n) = tag.split_whitespace().last() {
658 affected = n.parse().unwrap_or(0);
659 }
660 }
661 crate::protocol::BackendMessage::ReadyForQuery(_) => {
662 if let Some(err) = error {
663 return Err(err);
664 }
665 return Ok(affected);
666 }
667 crate::protocol::BackendMessage::ErrorResponse(err) => {
668 if error.is_none() {
669 error = Some(PgError::Query(err.message));
670 }
671 }
672 _ => {}
673 }
674 }
675 }
676
677 pub async fn query_ast(&mut self, cmd: &Qail) -> PgResult<QueryResult> {
681 use crate::protocol::AstEncoder;
682
683 let wire_bytes = AstEncoder::encode_cmd_reuse(
684 cmd,
685 &mut self.connection.sql_buf,
686 &mut self.connection.params_buf,
687 )
688 .map_err(|e| PgError::Encode(e.to_string()))?;
689
690 self.connection.send_bytes(&wire_bytes).await?;
691
692 let mut columns: Vec<String> = Vec::new();
693 let mut rows: Vec<Vec<Option<String>>> = Vec::new();
694 let mut error: Option<PgError> = None;
695
696 loop {
697 let msg = self.connection.recv().await?;
698 match msg {
699 crate::protocol::BackendMessage::ParseComplete
700 | crate::protocol::BackendMessage::BindComplete => {}
701 crate::protocol::BackendMessage::RowDescription(fields) => {
702 columns = fields.into_iter().map(|f| f.name).collect();
703 }
704 crate::protocol::BackendMessage::DataRow(data) => {
705 if error.is_none() {
706 let row: Vec<Option<String>> = data
707 .into_iter()
708 .map(|col| col.map(|bytes| String::from_utf8_lossy(&bytes).to_string()))
709 .collect();
710 rows.push(row);
711 }
712 }
713 crate::protocol::BackendMessage::CommandComplete(_) => {}
714 crate::protocol::BackendMessage::NoData => {}
715 crate::protocol::BackendMessage::ReadyForQuery(_) => {
716 if let Some(err) = error {
717 return Err(err);
718 }
719 return Ok(QueryResult { columns, rows });
720 }
721 crate::protocol::BackendMessage::ErrorResponse(err) => {
722 if error.is_none() {
723 error = Some(PgError::Query(err.message));
724 }
725 }
726 _ => {}
727 }
728 }
729 }
730
731 pub async fn begin(&mut self) -> PgResult<()> {
735 self.connection.begin_transaction().await
736 }
737
738 pub async fn commit(&mut self) -> PgResult<()> {
740 self.connection.commit().await
741 }
742
743 pub async fn rollback(&mut self) -> PgResult<()> {
745 self.connection.rollback().await
746 }
747
748 pub async fn savepoint(&mut self, name: &str) -> PgResult<()> {
761 self.connection.savepoint(name).await
762 }
763
764 pub async fn rollback_to(&mut self, name: &str) -> PgResult<()> {
768 self.connection.rollback_to(name).await
769 }
770
771 pub async fn release_savepoint(&mut self, name: &str) -> PgResult<()> {
774 self.connection.release_savepoint(name).await
775 }
776
777 pub async fn execute_batch(&mut self, cmds: &[Qail]) -> PgResult<Vec<u64>> {
791 self.begin().await?;
792 let mut results = Vec::with_capacity(cmds.len());
793 for cmd in cmds {
794 match self.execute(cmd).await {
795 Ok(n) => results.push(n),
796 Err(e) => {
797 self.rollback().await?;
798 return Err(e);
799 }
800 }
801 }
802 self.commit().await?;
803 Ok(results)
804 }
805
806 pub async fn set_statement_timeout(&mut self, ms: u32) -> PgResult<()> {
814 self.execute_raw(&format!("SET statement_timeout = {}", ms))
815 .await
816 }
817
818 pub async fn reset_statement_timeout(&mut self) -> PgResult<()> {
820 self.execute_raw("RESET statement_timeout").await
821 }
822
823 pub async fn set_rls_context(&mut self, ctx: rls::RlsContext) -> PgResult<()> {
841 let sql = rls::context_to_sql(&ctx);
842 self.execute_raw(&sql).await?;
843 self.rls_context = Some(ctx);
844 Ok(())
845 }
846
847 pub async fn clear_rls_context(&mut self) -> PgResult<()> {
852 self.execute_raw(rls::reset_sql()).await?;
853 self.rls_context = None;
854 Ok(())
855 }
856
857 pub fn rls_context(&self) -> Option<&rls::RlsContext> {
859 self.rls_context.as_ref()
860 }
861
862 pub async fn pipeline_batch(&mut self, cmds: &[Qail]) -> PgResult<usize> {
874 self.connection.pipeline_ast_fast(cmds).await
875 }
876
877 pub async fn pipeline_fetch(&mut self, cmds: &[Qail]) -> PgResult<Vec<Vec<PgRow>>> {
879 let raw_results = self.connection.pipeline_ast(cmds).await?;
880
881 let results: Vec<Vec<PgRow>> = raw_results
882 .into_iter()
883 .map(|rows| {
884 rows.into_iter()
885 .map(|columns| PgRow {
886 columns,
887 column_info: None,
888 })
889 .collect()
890 })
891 .collect();
892
893 Ok(results)
894 }
895
896 pub async fn prepare(&mut self, sql: &str) -> PgResult<PreparedStatement> {
898 self.connection.prepare(sql).await
899 }
900
901 pub async fn pipeline_prepared_fast(
903 &mut self,
904 stmt: &PreparedStatement,
905 params_batch: &[Vec<Option<Vec<u8>>>],
906 ) -> PgResult<usize> {
907 self.connection
908 .pipeline_prepared_fast(stmt, params_batch)
909 .await
910 }
911
912 pub async fn execute_raw(&mut self, sql: &str) -> PgResult<()> {
919 if sql.as_bytes().contains(&0) {
921 return Err(crate::PgError::Protocol(
922 "SQL contains NULL byte (0x00) which is invalid in PostgreSQL".to_string(),
923 ));
924 }
925 self.connection.execute_simple(sql).await
926 }
927
928 pub async fn fetch_raw(&mut self, sql: &str) -> PgResult<Vec<PgRow>> {
932 if sql.as_bytes().contains(&0) {
933 return Err(crate::PgError::Protocol(
934 "SQL contains NULL byte (0x00) which is invalid in PostgreSQL".to_string(),
935 ));
936 }
937
938 use tokio::io::AsyncWriteExt;
939 use crate::protocol::PgEncoder;
940
941 let msg = PgEncoder::encode_query_string(sql);
943 self.connection.stream.write_all(&msg).await?;
944
945 let mut rows: Vec<PgRow> = Vec::new();
946 let mut column_info: Option<std::sync::Arc<ColumnInfo>> = None;
947
948
949 let mut error: Option<PgError> = None;
950
951 loop {
952 let msg = self.connection.recv().await?;
953 match msg {
954 crate::protocol::BackendMessage::RowDescription(fields) => {
955 column_info = Some(std::sync::Arc::new(ColumnInfo::from_fields(&fields)));
956 }
957 crate::protocol::BackendMessage::DataRow(data) => {
958 if error.is_none() {
959 rows.push(PgRow {
960 columns: data,
961 column_info: column_info.clone(),
962 });
963 }
964 }
965 crate::protocol::BackendMessage::CommandComplete(_) => {}
966 crate::protocol::BackendMessage::ReadyForQuery(_) => {
967 if let Some(err) = error {
968 return Err(err);
969 }
970 return Ok(rows);
971 }
972 crate::protocol::BackendMessage::ErrorResponse(err) => {
973 if error.is_none() {
974 error = Some(PgError::Query(err.message));
975 }
976 }
977 _ => {}
978 }
979 }
980 }
981
982 pub async fn copy_bulk(
998 &mut self,
999 cmd: &Qail,
1000 rows: &[Vec<qail_core::ast::Value>],
1001 ) -> PgResult<u64> {
1002 use qail_core::ast::Action;
1003
1004
1005 if cmd.action != Action::Add {
1006 return Err(PgError::Query(
1007 "copy_bulk requires Qail::Add action".to_string(),
1008 ));
1009 }
1010
1011 let table = &cmd.table;
1012
1013 let columns: Vec<String> = cmd
1014 .columns
1015 .iter()
1016 .filter_map(|expr| {
1017 use qail_core::ast::Expr;
1018 match expr {
1019 Expr::Named(name) => Some(name.clone()),
1020 Expr::Aliased { name, .. } => Some(name.clone()),
1021 Expr::Star => None, _ => None,
1023 }
1024 })
1025 .collect();
1026
1027 if columns.is_empty() {
1028 return Err(PgError::Query(
1029 "copy_bulk requires columns in Qail".to_string(),
1030 ));
1031 }
1032
1033 self.connection.copy_in_fast(table, &columns, rows).await
1035 }
1036
1037 pub async fn copy_bulk_bytes(&mut self, cmd: &Qail, data: &[u8]) -> PgResult<u64> {
1050 use qail_core::ast::Action;
1051
1052 if cmd.action != Action::Add {
1053 return Err(PgError::Query(
1054 "copy_bulk_bytes requires Qail::Add action".to_string(),
1055 ));
1056 }
1057
1058 let table = &cmd.table;
1059 let columns: Vec<String> = cmd
1060 .columns
1061 .iter()
1062 .filter_map(|expr| {
1063 use qail_core::ast::Expr;
1064 match expr {
1065 Expr::Named(name) => Some(name.clone()),
1066 Expr::Aliased { name, .. } => Some(name.clone()),
1067 _ => None,
1068 }
1069 })
1070 .collect();
1071
1072 if columns.is_empty() {
1073 return Err(PgError::Query(
1074 "copy_bulk_bytes requires columns in Qail".to_string(),
1075 ));
1076 }
1077
1078 self.connection.copy_in_raw(table, &columns, data).await
1080 }
1081
1082 pub async fn copy_export_table(
1090 &mut self,
1091 table: &str,
1092 columns: &[String],
1093 ) -> PgResult<Vec<u8>> {
1094 let cols = columns.join(", ");
1095 let sql = format!("COPY {} ({}) TO STDOUT", table, cols);
1096
1097 self.connection.copy_out_raw(&sql).await
1098 }
1099
1100 pub async fn stream_cmd(
1114 &mut self,
1115 cmd: &Qail,
1116 batch_size: usize,
1117 ) -> PgResult<Vec<Vec<PgRow>>> {
1118 use std::sync::atomic::{AtomicU64, Ordering};
1119 static CURSOR_ID: AtomicU64 = AtomicU64::new(0);
1120
1121 let cursor_name = format!("qail_cursor_{}", CURSOR_ID.fetch_add(1, Ordering::SeqCst));
1122
1123 use crate::protocol::AstEncoder;
1125 let mut sql_buf = bytes::BytesMut::with_capacity(256);
1126 let mut params: Vec<Option<Vec<u8>>> = Vec::new();
1127 AstEncoder::encode_select_sql(cmd, &mut sql_buf, &mut params);
1128 let sql = String::from_utf8_lossy(&sql_buf).to_string();
1129
1130 self.connection.begin_transaction().await?;
1132
1133 self.connection.declare_cursor(&cursor_name, &sql).await?;
1135
1136 let mut all_batches = Vec::new();
1138 while let Some(rows) = self
1139 .connection
1140 .fetch_cursor(&cursor_name, batch_size)
1141 .await?
1142 {
1143 let pg_rows: Vec<PgRow> = rows
1144 .into_iter()
1145 .map(|cols| PgRow {
1146 columns: cols,
1147 column_info: None,
1148 })
1149 .collect();
1150 all_batches.push(pg_rows);
1151 }
1152
1153 self.connection.close_cursor(&cursor_name).await?;
1154 self.connection.commit().await?;
1155
1156 Ok(all_batches)
1157 }
1158}
1159
1160#[derive(Default)]
1177pub struct PgDriverBuilder {
1178 host: Option<String>,
1179 port: Option<u16>,
1180 user: Option<String>,
1181 database: Option<String>,
1182 password: Option<String>,
1183 timeout: Option<std::time::Duration>,
1184}
1185
1186impl PgDriverBuilder {
1187 pub fn new() -> Self {
1189 Self::default()
1190 }
1191
1192 pub fn host(mut self, host: impl Into<String>) -> Self {
1194 self.host = Some(host.into());
1195 self
1196 }
1197
1198 pub fn port(mut self, port: u16) -> Self {
1200 self.port = Some(port);
1201 self
1202 }
1203
1204 pub fn user(mut self, user: impl Into<String>) -> Self {
1206 self.user = Some(user.into());
1207 self
1208 }
1209
1210 pub fn database(mut self, database: impl Into<String>) -> Self {
1212 self.database = Some(database.into());
1213 self
1214 }
1215
1216 pub fn password(mut self, password: impl Into<String>) -> Self {
1218 self.password = Some(password.into());
1219 self
1220 }
1221
1222 pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
1224 self.timeout = Some(timeout);
1225 self
1226 }
1227
1228 pub async fn connect(self) -> PgResult<PgDriver> {
1230 let host = self.host.as_deref().unwrap_or("127.0.0.1");
1231 let port = self.port.unwrap_or(5432);
1232 let user = self.user.as_deref().ok_or_else(|| {
1233 PgError::Connection("User is required".to_string())
1234 })?;
1235 let database = self.database.as_deref().ok_or_else(|| {
1236 PgError::Connection("Database is required".to_string())
1237 })?;
1238
1239 match (self.password.as_deref(), self.timeout) {
1240 (Some(password), Some(timeout)) => {
1241 PgDriver::connect_with_timeout(host, port, user, database, password, timeout).await
1242 }
1243 (Some(password), None) => {
1244 PgDriver::connect_with_password(host, port, user, database, password).await
1245 }
1246 (None, Some(timeout)) => {
1247 tokio::time::timeout(
1248 timeout,
1249 PgDriver::connect(host, port, user, database),
1250 )
1251 .await
1252 .map_err(|_| PgError::Connection(format!("Connection timeout after {:?}", timeout)))?
1253 }
1254 (None, None) => {
1255 PgDriver::connect(host, port, user, database).await
1256 }
1257 }
1258 }
1259}