1mod cancel;
20mod connection;
21mod copy;
22mod cursor;
23mod io;
24pub mod io_backend;
25mod pipeline;
26mod pool;
27mod prepared;
28mod query;
29pub mod rls;
30pub mod explain;
31pub mod branch_sql;
32mod row;
33mod stream;
34mod transaction;
35pub mod notification;
36
37pub use connection::PgConnection;
38pub use connection::TlsConfig;
39pub(crate) use connection::{CANCEL_REQUEST_CODE, parse_affected_rows};
40pub use cancel::CancelToken;
41pub use io_backend::{IoBackend, backend_name, detect as detect_io_backend};
42pub use pool::{PgPool, PoolConfig, PoolStats, PooledConnection};
43pub use prepared::PreparedStatement;
44pub use rls::RlsContext;
45pub use row::QailRow;
46pub use notification::Notification;
47
48use qail_core::ast::Qail;
49use std::collections::HashMap;
50use std::sync::Arc;
51
52#[derive(Debug, Clone)]
57pub struct ColumnInfo {
58 pub name_to_index: HashMap<String, usize>,
60 pub oids: Vec<u32>,
62 pub formats: Vec<i16>,
64}
65
66impl ColumnInfo {
67 pub fn from_fields(fields: &[crate::protocol::FieldDescription]) -> Self {
70 let mut name_to_index = HashMap::with_capacity(fields.len());
71 let mut oids = Vec::with_capacity(fields.len());
72 let mut formats = Vec::with_capacity(fields.len());
73
74 for (i, field) in fields.iter().enumerate() {
75 name_to_index.insert(field.name.clone(), i);
76 oids.push(field.type_oid);
77 formats.push(field.format);
78 }
79
80 Self {
81 name_to_index,
82 oids,
83 formats,
84 }
85 }
86}
87
88pub struct PgRow {
90 pub columns: Vec<Option<Vec<u8>>>,
92 pub column_info: Option<Arc<ColumnInfo>>,
94}
95
96#[derive(Debug)]
98pub enum PgError {
99 Connection(String),
101 Protocol(String),
103 Auth(String),
105 Query(String),
107 NoRows,
109 Io(std::io::Error),
111 Encode(String),
113 Timeout(String),
115 PoolExhausted {
117 max: usize,
119 },
120 PoolClosed,
122}
123
124impl std::fmt::Display for PgError {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 match self {
127 PgError::Connection(e) => write!(f, "Connection error: {}", e),
128 PgError::Protocol(e) => write!(f, "Protocol error: {}", e),
129 PgError::Auth(e) => write!(f, "Auth error: {}", e),
130 PgError::Query(e) => write!(f, "Query error: {}", e),
131 PgError::NoRows => write!(f, "No rows returned"),
132 PgError::Io(e) => write!(f, "I/O error: {}", e),
133 PgError::Encode(e) => write!(f, "Encode error: {}", e),
134 PgError::Timeout(ctx) => write!(f, "Timeout: {}", ctx),
135 PgError::PoolExhausted { max } => write!(f, "Pool exhausted ({} max connections)", max),
136 PgError::PoolClosed => write!(f, "Connection pool is closed"),
137 }
138 }
139}
140
141impl std::error::Error for PgError {
142 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
143 match self {
144 PgError::Io(e) => Some(e),
145 _ => None,
146 }
147 }
148}
149
150impl From<std::io::Error> for PgError {
151 fn from(e: std::io::Error) -> Self {
152 PgError::Io(e)
153 }
154}
155
156pub type PgResult<T> = Result<T, PgError>;
158
159#[derive(Debug, Clone)]
161pub struct QueryResult {
162 pub columns: Vec<String>,
164 pub rows: Vec<Vec<Option<String>>>,
166}
167
168pub struct PgDriver {
170 #[allow(dead_code)]
171 connection: PgConnection,
172 rls_context: Option<RlsContext>,
174}
175
176impl PgDriver {
177 pub fn new(connection: PgConnection) -> Self {
179 Self { connection, rls_context: None }
180 }
181
182 pub fn builder() -> PgDriverBuilder {
195 PgDriverBuilder::new()
196 }
197
198 pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
207 let connection = PgConnection::connect(host, port, user, database).await?;
208 Ok(Self::new(connection))
209 }
210
211 pub async fn connect_with_password(
213 host: &str,
214 port: u16,
215 user: &str,
216 database: &str,
217 password: &str,
218 ) -> PgResult<Self> {
219 let connection =
220 PgConnection::connect_with_password(host, port, user, database, Some(password)).await?;
221 Ok(Self::new(connection))
222 }
223
224 pub async fn connect_env() -> PgResult<Self> {
235 let url = std::env::var("DATABASE_URL")
236 .map_err(|_| PgError::Connection("DATABASE_URL environment variable not set".to_string()))?;
237 Self::connect_url(&url).await
238 }
239
240 pub async fn connect_url(url: &str) -> PgResult<Self> {
250 let (host, port, user, database, password) = Self::parse_database_url(url)?;
251
252 if let Some(pwd) = password {
253 Self::connect_with_password(&host, port, &user, &database, &pwd).await
254 } else {
255 Self::connect(&host, port, &user, &database).await
256 }
257 }
258
259 fn parse_database_url(url: &str) -> PgResult<(String, u16, String, String, Option<String>)> {
266 let after_scheme = url.split("://").nth(1)
268 .ok_or_else(|| PgError::Connection("Invalid DATABASE_URL: missing scheme".to_string()))?;
269
270 let (auth_part, host_db_part) = if let Some(at_pos) = after_scheme.rfind('@') {
272 (Some(&after_scheme[..at_pos]), &after_scheme[at_pos + 1..])
273 } else {
274 (None, after_scheme)
275 };
276
277 let (user, password) = if let Some(auth) = auth_part {
279 let parts: Vec<&str> = auth.splitn(2, ':').collect();
280 if parts.len() == 2 {
281 (
283 Self::percent_decode(parts[0]),
284 Some(Self::percent_decode(parts[1])),
285 )
286 } else {
287 (Self::percent_decode(parts[0]), None)
288 }
289 } else {
290 return Err(PgError::Connection("Invalid DATABASE_URL: missing user".to_string()));
291 };
292
293 let (host_port, database) = if let Some(slash_pos) = host_db_part.find('/') {
295 (&host_db_part[..slash_pos], host_db_part[slash_pos + 1..].to_string())
296 } else {
297 return Err(PgError::Connection("Invalid DATABASE_URL: missing database name".to_string()));
298 };
299
300 let (host, port) = if let Some(colon_pos) = host_port.rfind(':') {
302 let port_str = &host_port[colon_pos + 1..];
303 let port = port_str.parse::<u16>()
304 .map_err(|_| PgError::Connection(format!("Invalid port: {}", port_str)))?;
305 (host_port[..colon_pos].to_string(), port)
306 } else {
307 (host_port.to_string(), 5432) };
309
310 Ok((host, port, user, database, password))
311 }
312
313 fn percent_decode(s: &str) -> String {
316 let mut result = String::with_capacity(s.len());
317 let mut chars = s.chars().peekable();
318
319 while let Some(c) = chars.next() {
320 if c == '%' {
321 let hex: String = chars.by_ref().take(2).collect();
323 if hex.len() == 2
324 && let Ok(byte) = u8::from_str_radix(&hex, 16)
325 {
326 result.push(byte as char);
327 continue;
328 }
329 result.push('%');
331 result.push_str(&hex);
332 } else if c == '+' {
333 result.push('+');
336 } else {
337 result.push(c);
338 }
339 }
340
341 result
342 }
343
344 pub async fn connect_with_timeout(
355 host: &str,
356 port: u16,
357 user: &str,
358 database: &str,
359 password: &str,
360 timeout: std::time::Duration,
361 ) -> PgResult<Self> {
362 tokio::time::timeout(
363 timeout,
364 Self::connect_with_password(host, port, user, database, password),
365 )
366 .await
367 .map_err(|_| PgError::Timeout(format!("connection after {:?}", timeout)))?
368 }
369 pub fn clear_cache(&mut self) {
373 self.connection.stmt_cache.clear();
374 self.connection.prepared_statements.clear();
375 }
376
377 pub fn cache_stats(&self) -> (usize, usize) {
380 (self.connection.stmt_cache.len(), self.connection.stmt_cache.cap().get())
381 }
382
383 pub async fn fetch_all(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
389 self.fetch_all_cached(cmd).await
391 }
392
393 pub async fn fetch_typed<T: row::QailRow>(&mut self, cmd: &Qail) -> PgResult<Vec<T>> {
401 let rows = self.fetch_all(cmd).await?;
402 Ok(rows.iter().map(T::from_row).collect())
403 }
404
405 pub async fn fetch_one_typed<T: row::QailRow>(&mut self, cmd: &Qail) -> PgResult<Option<T>> {
408 let rows = self.fetch_all(cmd).await?;
409 Ok(rows.first().map(T::from_row))
410 }
411
412 pub async fn fetch_all_uncached(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
418 use crate::protocol::AstEncoder;
419
420 AstEncoder::encode_cmd_reuse_into(
421 cmd,
422 &mut self.connection.sql_buf,
423 &mut self.connection.params_buf,
424 &mut self.connection.write_buf,
425 )
426 .map_err(|e| PgError::Encode(e.to_string()))?;
427
428 self.connection.flush_write_buf().await?;
429
430 let mut rows: Vec<PgRow> = Vec::with_capacity(32);
431 let mut column_info: Option<Arc<ColumnInfo>> = None;
432
433 let mut error: Option<PgError> = None;
434
435 loop {
436 let msg = self.connection.recv().await?;
437 match msg {
438 crate::protocol::BackendMessage::ParseComplete
439 | crate::protocol::BackendMessage::BindComplete => {}
440 crate::protocol::BackendMessage::RowDescription(fields) => {
441 column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
442 }
443 crate::protocol::BackendMessage::DataRow(data) => {
444 if error.is_none() {
445 rows.push(PgRow {
446 columns: data,
447 column_info: column_info.clone(),
448 });
449 }
450 }
451 crate::protocol::BackendMessage::CommandComplete(_) => {}
452 crate::protocol::BackendMessage::ReadyForQuery(_) => {
453 if let Some(err) = error {
454 return Err(err);
455 }
456 return Ok(rows);
457 }
458 crate::protocol::BackendMessage::ErrorResponse(err) => {
459 if error.is_none() {
460 error = Some(PgError::Query(err.message));
461 }
462 }
463 _ => {}
464 }
465 }
466 }
467
468 pub async fn fetch_all_fast(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
472 use crate::protocol::AstEncoder;
473
474 AstEncoder::encode_cmd_reuse_into(
475 cmd,
476 &mut self.connection.sql_buf,
477 &mut self.connection.params_buf,
478 &mut self.connection.write_buf,
479 )
480 .map_err(|e| PgError::Encode(e.to_string()))?;
481
482 self.connection.flush_write_buf().await?;
483
484 let mut rows: Vec<PgRow> = Vec::with_capacity(32);
486 let mut error: Option<PgError> = None;
487
488 loop {
489 let res = self.connection.recv_with_data_fast().await;
490 match res {
491 Ok((msg_type, data)) => {
492 match msg_type {
493 b'D' => {
494 if error.is_none() && let Some(columns) = data {
496 rows.push(PgRow {
497 columns,
498 column_info: None, });
500 }
501 }
502 b'Z' => {
503 if let Some(err) = error {
505 return Err(err);
506 }
507 return Ok(rows);
508 }
509 _ => {} }
511 }
512 Err(e) => {
513 if error.is_none() {
522 error = Some(e);
523 }
524 }
529 }
530 }
531 }
532
533 pub async fn fetch_one(&mut self, cmd: &Qail) -> PgResult<PgRow> {
535 let rows = self.fetch_all(cmd).await?;
536 rows.into_iter().next().ok_or(PgError::NoRows)
537 }
538
539 pub async fn fetch_all_cached(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
548 use crate::protocol::AstEncoder;
549 use std::collections::hash_map::DefaultHasher;
550 use std::hash::{Hash, Hasher};
551
552 self.connection.sql_buf.clear();
553 self.connection.params_buf.clear();
554
555 match cmd.action {
557 qail_core::ast::Action::Get | qail_core::ast::Action::With => {
558 crate::protocol::ast_encoder::dml::encode_select(cmd, &mut self.connection.sql_buf, &mut self.connection.params_buf).ok();
559 }
560 qail_core::ast::Action::Add => {
561 crate::protocol::ast_encoder::dml::encode_insert(cmd, &mut self.connection.sql_buf, &mut self.connection.params_buf).ok();
562 }
563 qail_core::ast::Action::Set => {
564 crate::protocol::ast_encoder::dml::encode_update(cmd, &mut self.connection.sql_buf, &mut self.connection.params_buf).ok();
565 }
566 qail_core::ast::Action::Del => {
567 crate::protocol::ast_encoder::dml::encode_delete(cmd, &mut self.connection.sql_buf, &mut self.connection.params_buf).ok();
568 }
569 _ => {
570 let (sql, params) = AstEncoder::encode_cmd_sql(cmd).map_err(|e| PgError::Encode(e.to_string()))?;
572 let raw_rows = self.connection.query_cached(&sql, ¶ms).await?;
573 return Ok(raw_rows.into_iter().map(|data| PgRow { columns: data, column_info: None }).collect());
574 }
575 }
576
577 let mut hasher = DefaultHasher::new();
578 self.connection.sql_buf.hash(&mut hasher);
579 let sql_hash = hasher.finish();
580
581 let is_cache_miss = !self.connection.stmt_cache.contains(&sql_hash);
582
583 self.connection.write_buf.clear();
585
586 let stmt_name = if let Some(name) = self.connection.stmt_cache.get(&sql_hash) {
587 name.clone()
588 } else {
589 let name = format!("qail_{:x}", sql_hash);
590
591 self.connection.evict_prepared_if_full();
593
594 let sql_str = std::str::from_utf8(&self.connection.sql_buf).unwrap_or("");
595
596 use crate::protocol::PgEncoder;
598 let parse_msg = PgEncoder::encode_parse(&name, sql_str, &[]);
599 let describe_msg = PgEncoder::encode_describe(false, &name);
600 self.connection.write_buf.extend_from_slice(&parse_msg);
601 self.connection.write_buf.extend_from_slice(&describe_msg);
602
603 self.connection.stmt_cache.put(sql_hash, name.clone());
604 self.connection.prepared_statements.insert(name.clone(), sql_str.to_string());
605
606 name
607 };
608
609 use crate::protocol::PgEncoder;
611 PgEncoder::encode_bind_to(&mut self.connection.write_buf, &stmt_name, &self.connection.params_buf)
612 .map_err(|e| PgError::Encode(e.to_string()))?;
613 PgEncoder::encode_execute_to(&mut self.connection.write_buf);
614 PgEncoder::encode_sync_to(&mut self.connection.write_buf);
615
616 self.connection.flush_write_buf().await?;
618
619 let cached_column_info = self.connection.column_info_cache.get(&sql_hash).cloned();
621
622 let mut rows: Vec<PgRow> = Vec::with_capacity(32);
623 let mut column_info: Option<Arc<ColumnInfo>> = cached_column_info;
624 let mut error: Option<PgError> = None;
625
626 loop {
627 let msg = self.connection.recv().await?;
628 match msg {
629 crate::protocol::BackendMessage::ParseComplete
630 | crate::protocol::BackendMessage::BindComplete => {}
631 crate::protocol::BackendMessage::ParameterDescription(_) => {
632 }
634 crate::protocol::BackendMessage::RowDescription(fields) => {
635 let info = Arc::new(ColumnInfo::from_fields(&fields));
637 if is_cache_miss {
638 self.connection.column_info_cache.insert(sql_hash, info.clone());
639 }
640 column_info = Some(info);
641 }
642 crate::protocol::BackendMessage::DataRow(data) => {
643 if error.is_none() {
644 rows.push(PgRow {
645 columns: data,
646 column_info: column_info.clone(),
647 });
648 }
649 }
650 crate::protocol::BackendMessage::CommandComplete(_) => {}
651 crate::protocol::BackendMessage::NoData => {
652 }
654 crate::protocol::BackendMessage::ReadyForQuery(_) => {
655 if let Some(err) = error {
656 return Err(err);
657 }
658 return Ok(rows);
659 }
660 crate::protocol::BackendMessage::ErrorResponse(err) => {
661 if error.is_none() {
662 error = Some(PgError::Query(err.message));
663 self.connection.stmt_cache.clear();
666 self.connection.prepared_statements.clear();
667 self.connection.column_info_cache.clear();
668 }
669 }
670 _ => {}
671 }
672 }
673 }
674
675 pub async fn execute(&mut self, cmd: &Qail) -> PgResult<u64> {
677 use crate::protocol::AstEncoder;
678
679 let wire_bytes = AstEncoder::encode_cmd_reuse(
680 cmd,
681 &mut self.connection.sql_buf,
682 &mut self.connection.params_buf,
683 )
684 .map_err(|e| PgError::Encode(e.to_string()))?;
685
686 self.connection.send_bytes(&wire_bytes).await?;
687
688 let mut affected = 0u64;
689 let mut error: Option<PgError> = None;
690
691 loop {
692 let msg = self.connection.recv().await?;
693 match msg {
694 crate::protocol::BackendMessage::ParseComplete
695 | crate::protocol::BackendMessage::BindComplete => {}
696 crate::protocol::BackendMessage::RowDescription(_) => {}
697 crate::protocol::BackendMessage::DataRow(_) => {}
698 crate::protocol::BackendMessage::CommandComplete(tag) => {
699 if error.is_none() && let Some(n) = tag.split_whitespace().last() {
700 affected = n.parse().unwrap_or(0);
701 }
702 }
703 crate::protocol::BackendMessage::ReadyForQuery(_) => {
704 if let Some(err) = error {
705 return Err(err);
706 }
707 return Ok(affected);
708 }
709 crate::protocol::BackendMessage::ErrorResponse(err) => {
710 if error.is_none() {
711 error = Some(PgError::Query(err.message));
712 }
713 }
714 _ => {}
715 }
716 }
717 }
718
719 pub async fn query_ast(&mut self, cmd: &Qail) -> PgResult<QueryResult> {
723 use crate::protocol::AstEncoder;
724
725 let wire_bytes = AstEncoder::encode_cmd_reuse(
726 cmd,
727 &mut self.connection.sql_buf,
728 &mut self.connection.params_buf,
729 )
730 .map_err(|e| PgError::Encode(e.to_string()))?;
731
732 self.connection.send_bytes(&wire_bytes).await?;
733
734 let mut columns: Vec<String> = Vec::new();
735 let mut rows: Vec<Vec<Option<String>>> = Vec::new();
736 let mut error: Option<PgError> = None;
737
738 loop {
739 let msg = self.connection.recv().await?;
740 match msg {
741 crate::protocol::BackendMessage::ParseComplete
742 | crate::protocol::BackendMessage::BindComplete => {}
743 crate::protocol::BackendMessage::RowDescription(fields) => {
744 columns = fields.into_iter().map(|f| f.name).collect();
745 }
746 crate::protocol::BackendMessage::DataRow(data) => {
747 if error.is_none() {
748 let row: Vec<Option<String>> = data
749 .into_iter()
750 .map(|col| col.map(|bytes| String::from_utf8_lossy(&bytes).to_string()))
751 .collect();
752 rows.push(row);
753 }
754 }
755 crate::protocol::BackendMessage::CommandComplete(_) => {}
756 crate::protocol::BackendMessage::NoData => {}
757 crate::protocol::BackendMessage::ReadyForQuery(_) => {
758 if let Some(err) = error {
759 return Err(err);
760 }
761 return Ok(QueryResult { columns, rows });
762 }
763 crate::protocol::BackendMessage::ErrorResponse(err) => {
764 if error.is_none() {
765 error = Some(PgError::Query(err.message));
766 }
767 }
768 _ => {}
769 }
770 }
771 }
772
773 pub async fn begin(&mut self) -> PgResult<()> {
777 self.connection.begin_transaction().await
778 }
779
780 pub async fn commit(&mut self) -> PgResult<()> {
782 self.connection.commit().await
783 }
784
785 pub async fn rollback(&mut self) -> PgResult<()> {
787 self.connection.rollback().await
788 }
789
790 pub async fn savepoint(&mut self, name: &str) -> PgResult<()> {
803 self.connection.savepoint(name).await
804 }
805
806 pub async fn rollback_to(&mut self, name: &str) -> PgResult<()> {
810 self.connection.rollback_to(name).await
811 }
812
813 pub async fn release_savepoint(&mut self, name: &str) -> PgResult<()> {
816 self.connection.release_savepoint(name).await
817 }
818
819 pub async fn execute_batch(&mut self, cmds: &[Qail]) -> PgResult<Vec<u64>> {
833 self.begin().await?;
834 let mut results = Vec::with_capacity(cmds.len());
835 for cmd in cmds {
836 match self.execute(cmd).await {
837 Ok(n) => results.push(n),
838 Err(e) => {
839 self.rollback().await?;
840 return Err(e);
841 }
842 }
843 }
844 self.commit().await?;
845 Ok(results)
846 }
847
848 pub async fn set_statement_timeout(&mut self, ms: u32) -> PgResult<()> {
856 self.execute_raw(&format!("SET statement_timeout = {}", ms))
857 .await
858 }
859
860 pub async fn reset_statement_timeout(&mut self) -> PgResult<()> {
862 self.execute_raw("RESET statement_timeout").await
863 }
864
865 pub async fn set_rls_context(&mut self, ctx: rls::RlsContext) -> PgResult<()> {
883 let sql = rls::context_to_sql(&ctx);
884 self.execute_raw(&sql).await?;
885 self.rls_context = Some(ctx);
886 Ok(())
887 }
888
889 pub async fn clear_rls_context(&mut self) -> PgResult<()> {
894 self.execute_raw(rls::reset_sql()).await?;
895 self.rls_context = None;
896 Ok(())
897 }
898
899 pub fn rls_context(&self) -> Option<&rls::RlsContext> {
901 self.rls_context.as_ref()
902 }
903
904 pub async fn pipeline_batch(&mut self, cmds: &[Qail]) -> PgResult<usize> {
916 self.connection.pipeline_ast_fast(cmds).await
917 }
918
919 pub async fn pipeline_fetch(&mut self, cmds: &[Qail]) -> PgResult<Vec<Vec<PgRow>>> {
921 let raw_results = self.connection.pipeline_ast(cmds).await?;
922
923 let results: Vec<Vec<PgRow>> = raw_results
924 .into_iter()
925 .map(|rows| {
926 rows.into_iter()
927 .map(|columns| PgRow {
928 columns,
929 column_info: None,
930 })
931 .collect()
932 })
933 .collect();
934
935 Ok(results)
936 }
937
938 pub async fn prepare(&mut self, sql: &str) -> PgResult<PreparedStatement> {
940 self.connection.prepare(sql).await
941 }
942
943 pub async fn pipeline_prepared_fast(
945 &mut self,
946 stmt: &PreparedStatement,
947 params_batch: &[Vec<Option<Vec<u8>>>],
948 ) -> PgResult<usize> {
949 self.connection
950 .pipeline_prepared_fast(stmt, params_batch)
951 .await
952 }
953
954 pub async fn execute_raw(&mut self, sql: &str) -> PgResult<()> {
961 if sql.as_bytes().contains(&0) {
963 return Err(crate::PgError::Protocol(
964 "SQL contains NULL byte (0x00) which is invalid in PostgreSQL".to_string(),
965 ));
966 }
967 self.connection.execute_simple(sql).await
968 }
969
970 pub async fn fetch_raw(&mut self, sql: &str) -> PgResult<Vec<PgRow>> {
974 if sql.as_bytes().contains(&0) {
975 return Err(crate::PgError::Protocol(
976 "SQL contains NULL byte (0x00) which is invalid in PostgreSQL".to_string(),
977 ));
978 }
979
980 use tokio::io::AsyncWriteExt;
981 use crate::protocol::PgEncoder;
982
983 let msg = PgEncoder::encode_query_string(sql);
985 self.connection.stream.write_all(&msg).await?;
986
987 let mut rows: Vec<PgRow> = Vec::new();
988 let mut column_info: Option<std::sync::Arc<ColumnInfo>> = None;
989
990
991 let mut error: Option<PgError> = None;
992
993 loop {
994 let msg = self.connection.recv().await?;
995 match msg {
996 crate::protocol::BackendMessage::RowDescription(fields) => {
997 column_info = Some(std::sync::Arc::new(ColumnInfo::from_fields(&fields)));
998 }
999 crate::protocol::BackendMessage::DataRow(data) => {
1000 if error.is_none() {
1001 rows.push(PgRow {
1002 columns: data,
1003 column_info: column_info.clone(),
1004 });
1005 }
1006 }
1007 crate::protocol::BackendMessage::CommandComplete(_) => {}
1008 crate::protocol::BackendMessage::ReadyForQuery(_) => {
1009 if let Some(err) = error {
1010 return Err(err);
1011 }
1012 return Ok(rows);
1013 }
1014 crate::protocol::BackendMessage::ErrorResponse(err) => {
1015 if error.is_none() {
1016 error = Some(PgError::Query(err.message));
1017 }
1018 }
1019 _ => {}
1020 }
1021 }
1022 }
1023
1024 pub async fn copy_bulk(
1040 &mut self,
1041 cmd: &Qail,
1042 rows: &[Vec<qail_core::ast::Value>],
1043 ) -> PgResult<u64> {
1044 use qail_core::ast::Action;
1045
1046
1047 if cmd.action != Action::Add {
1048 return Err(PgError::Query(
1049 "copy_bulk requires Qail::Add action".to_string(),
1050 ));
1051 }
1052
1053 let table = &cmd.table;
1054
1055 let columns: Vec<String> = cmd
1056 .columns
1057 .iter()
1058 .filter_map(|expr| {
1059 use qail_core::ast::Expr;
1060 match expr {
1061 Expr::Named(name) => Some(name.clone()),
1062 Expr::Aliased { name, .. } => Some(name.clone()),
1063 Expr::Star => None, _ => None,
1065 }
1066 })
1067 .collect();
1068
1069 if columns.is_empty() {
1070 return Err(PgError::Query(
1071 "copy_bulk requires columns in Qail".to_string(),
1072 ));
1073 }
1074
1075 self.connection.copy_in_fast(table, &columns, rows).await
1077 }
1078
1079 pub async fn copy_bulk_bytes(&mut self, cmd: &Qail, data: &[u8]) -> PgResult<u64> {
1092 use qail_core::ast::Action;
1093
1094 if cmd.action != Action::Add {
1095 return Err(PgError::Query(
1096 "copy_bulk_bytes requires Qail::Add action".to_string(),
1097 ));
1098 }
1099
1100 let table = &cmd.table;
1101 let columns: Vec<String> = cmd
1102 .columns
1103 .iter()
1104 .filter_map(|expr| {
1105 use qail_core::ast::Expr;
1106 match expr {
1107 Expr::Named(name) => Some(name.clone()),
1108 Expr::Aliased { name, .. } => Some(name.clone()),
1109 _ => None,
1110 }
1111 })
1112 .collect();
1113
1114 if columns.is_empty() {
1115 return Err(PgError::Query(
1116 "copy_bulk_bytes requires columns in Qail".to_string(),
1117 ));
1118 }
1119
1120 self.connection.copy_in_raw(table, &columns, data).await
1122 }
1123
1124 pub async fn copy_export_table(
1132 &mut self,
1133 table: &str,
1134 columns: &[String],
1135 ) -> PgResult<Vec<u8>> {
1136 let cols = columns.join(", ");
1137 let sql = format!("COPY {} ({}) TO STDOUT", table, cols);
1138
1139 self.connection.copy_out_raw(&sql).await
1140 }
1141
1142 pub async fn stream_cmd(
1156 &mut self,
1157 cmd: &Qail,
1158 batch_size: usize,
1159 ) -> PgResult<Vec<Vec<PgRow>>> {
1160 use std::sync::atomic::{AtomicU64, Ordering};
1161 static CURSOR_ID: AtomicU64 = AtomicU64::new(0);
1162
1163 let cursor_name = format!("qail_cursor_{}", CURSOR_ID.fetch_add(1, Ordering::SeqCst));
1164
1165 use crate::protocol::AstEncoder;
1167 let mut sql_buf = bytes::BytesMut::with_capacity(256);
1168 let mut params: Vec<Option<Vec<u8>>> = Vec::new();
1169 AstEncoder::encode_select_sql(cmd, &mut sql_buf, &mut params);
1170 let sql = String::from_utf8_lossy(&sql_buf).to_string();
1171
1172 self.connection.begin_transaction().await?;
1174
1175 self.connection.declare_cursor(&cursor_name, &sql).await?;
1177
1178 let mut all_batches = Vec::new();
1180 while let Some(rows) = self
1181 .connection
1182 .fetch_cursor(&cursor_name, batch_size)
1183 .await?
1184 {
1185 let pg_rows: Vec<PgRow> = rows
1186 .into_iter()
1187 .map(|cols| PgRow {
1188 columns: cols,
1189 column_info: None,
1190 })
1191 .collect();
1192 all_batches.push(pg_rows);
1193 }
1194
1195 self.connection.close_cursor(&cursor_name).await?;
1196 self.connection.commit().await?;
1197
1198 Ok(all_batches)
1199 }
1200}
1201
1202#[derive(Default)]
1219pub struct PgDriverBuilder {
1220 host: Option<String>,
1221 port: Option<u16>,
1222 user: Option<String>,
1223 database: Option<String>,
1224 password: Option<String>,
1225 timeout: Option<std::time::Duration>,
1226}
1227
1228impl PgDriverBuilder {
1229 pub fn new() -> Self {
1231 Self::default()
1232 }
1233
1234 pub fn host(mut self, host: impl Into<String>) -> Self {
1236 self.host = Some(host.into());
1237 self
1238 }
1239
1240 pub fn port(mut self, port: u16) -> Self {
1242 self.port = Some(port);
1243 self
1244 }
1245
1246 pub fn user(mut self, user: impl Into<String>) -> Self {
1248 self.user = Some(user.into());
1249 self
1250 }
1251
1252 pub fn database(mut self, database: impl Into<String>) -> Self {
1254 self.database = Some(database.into());
1255 self
1256 }
1257
1258 pub fn password(mut self, password: impl Into<String>) -> Self {
1260 self.password = Some(password.into());
1261 self
1262 }
1263
1264 pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
1266 self.timeout = Some(timeout);
1267 self
1268 }
1269
1270 pub async fn connect(self) -> PgResult<PgDriver> {
1272 let host = self.host.as_deref().unwrap_or("127.0.0.1");
1273 let port = self.port.unwrap_or(5432);
1274 let user = self.user.as_deref().ok_or_else(|| {
1275 PgError::Connection("User is required".to_string())
1276 })?;
1277 let database = self.database.as_deref().ok_or_else(|| {
1278 PgError::Connection("Database is required".to_string())
1279 })?;
1280
1281 match (self.password.as_deref(), self.timeout) {
1282 (Some(password), Some(timeout)) => {
1283 PgDriver::connect_with_timeout(host, port, user, database, password, timeout).await
1284 }
1285 (Some(password), None) => {
1286 PgDriver::connect_with_password(host, port, user, database, password).await
1287 }
1288 (None, Some(timeout)) => {
1289 tokio::time::timeout(
1290 timeout,
1291 PgDriver::connect(host, port, user, database),
1292 )
1293 .await
1294 .map_err(|_| PgError::Timeout(format!("connection after {:?}", timeout)))?
1295 }
1296 (None, None) => {
1297 PgDriver::connect(host, port, user, database).await
1298 }
1299 }
1300 }
1301}