1mod cancel;
19mod connection;
20mod copy;
21mod cursor;
22mod io;
23pub mod io_backend;
24mod pipeline;
25mod pool;
26mod prepared;
27mod query;
28pub mod rls;
29pub mod branch_sql;
30mod row;
31mod stream;
32mod transaction;
33
34pub use connection::PgConnection;
35pub use connection::TlsConfig;
36pub(crate) use connection::{CANCEL_REQUEST_CODE, parse_affected_rows};
37pub use cancel::CancelToken;
38pub use io_backend::{IoBackend, backend_name, detect as detect_io_backend};
39pub use pool::{PgPool, PoolConfig, PoolStats, PooledConnection};
40pub use prepared::PreparedStatement;
41pub use rls::RlsContext;
42pub use row::QailRow;
43
44use qail_core::ast::Qail;
45use std::collections::HashMap;
46use std::sync::Arc;
47
48#[derive(Debug, Clone)]
49pub struct ColumnInfo {
50 pub name_to_index: HashMap<String, usize>,
51 pub oids: Vec<u32>,
52 pub formats: Vec<i16>,
53}
54
55impl ColumnInfo {
56 pub fn from_fields(fields: &[crate::protocol::FieldDescription]) -> Self {
57 let mut name_to_index = HashMap::with_capacity(fields.len());
58 let mut oids = Vec::with_capacity(fields.len());
59 let mut formats = Vec::with_capacity(fields.len());
60
61 for (i, field) in fields.iter().enumerate() {
62 name_to_index.insert(field.name.clone(), i);
63 oids.push(field.type_oid);
64 formats.push(field.format);
65 }
66
67 Self {
68 name_to_index,
69 oids,
70 formats,
71 }
72 }
73}
74
75pub struct PgRow {
77 pub columns: Vec<Option<Vec<u8>>>,
78 pub column_info: Option<Arc<ColumnInfo>>,
79}
80
81#[derive(Debug)]
83pub enum PgError {
84 Connection(String),
85 Protocol(String),
86 Auth(String),
87 Query(String),
88 NoRows,
89 Io(std::io::Error),
91 Encode(String),
93}
94
95impl std::fmt::Display for PgError {
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 match self {
98 PgError::Connection(e) => write!(f, "Connection error: {}", e),
99 PgError::Protocol(e) => write!(f, "Protocol error: {}", e),
100 PgError::Auth(e) => write!(f, "Auth error: {}", e),
101 PgError::Query(e) => write!(f, "Query error: {}", e),
102 PgError::NoRows => write!(f, "No rows returned"),
103 PgError::Io(e) => write!(f, "I/O error: {}", e),
104 PgError::Encode(e) => write!(f, "Encode error: {}", e),
105 }
106 }
107}
108
109impl std::error::Error for PgError {}
110
111impl From<std::io::Error> for PgError {
112 fn from(e: std::io::Error) -> Self {
113 PgError::Io(e)
114 }
115}
116
117pub type PgResult<T> = Result<T, PgError>;
119
120#[derive(Debug, Clone)]
122pub struct QueryResult {
123 pub columns: Vec<String>,
125 pub rows: Vec<Vec<Option<String>>>,
127}
128
129pub struct PgDriver {
131 #[allow(dead_code)]
132 connection: PgConnection,
133 rls_context: Option<RlsContext>,
135}
136
137impl PgDriver {
138 pub fn new(connection: PgConnection) -> Self {
140 Self { connection, rls_context: None }
141 }
142
143 pub fn builder() -> PgDriverBuilder {
156 PgDriverBuilder::new()
157 }
158
159 pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
161 let connection = PgConnection::connect(host, port, user, database).await?;
162 Ok(Self::new(connection))
163 }
164
165 pub async fn connect_with_password(
167 host: &str,
168 port: u16,
169 user: &str,
170 database: &str,
171 password: &str,
172 ) -> PgResult<Self> {
173 let connection =
174 PgConnection::connect_with_password(host, port, user, database, Some(password)).await?;
175 Ok(Self::new(connection))
176 }
177
178 pub async fn connect_env() -> PgResult<Self> {
189 let url = std::env::var("DATABASE_URL")
190 .map_err(|_| PgError::Connection("DATABASE_URL environment variable not set".to_string()))?;
191 Self::connect_url(&url).await
192 }
193
194 pub async fn connect_url(url: &str) -> PgResult<Self> {
204 let (host, port, user, database, password) = Self::parse_database_url(url)?;
205
206 if let Some(pwd) = password {
207 Self::connect_with_password(&host, port, &user, &database, &pwd).await
208 } else {
209 Self::connect(&host, port, &user, &database).await
210 }
211 }
212
213 fn parse_database_url(url: &str) -> PgResult<(String, u16, String, String, Option<String>)> {
220 let after_scheme = url.split("://").nth(1)
222 .ok_or_else(|| PgError::Connection("Invalid DATABASE_URL: missing scheme".to_string()))?;
223
224 let (auth_part, host_db_part) = if let Some(at_pos) = after_scheme.rfind('@') {
226 (Some(&after_scheme[..at_pos]), &after_scheme[at_pos + 1..])
227 } else {
228 (None, after_scheme)
229 };
230
231 let (user, password) = if let Some(auth) = auth_part {
233 let parts: Vec<&str> = auth.splitn(2, ':').collect();
234 if parts.len() == 2 {
235 (
237 Self::percent_decode(parts[0]),
238 Some(Self::percent_decode(parts[1])),
239 )
240 } else {
241 (Self::percent_decode(parts[0]), None)
242 }
243 } else {
244 return Err(PgError::Connection("Invalid DATABASE_URL: missing user".to_string()));
245 };
246
247 let (host_port, database) = if let Some(slash_pos) = host_db_part.find('/') {
249 (&host_db_part[..slash_pos], host_db_part[slash_pos + 1..].to_string())
250 } else {
251 return Err(PgError::Connection("Invalid DATABASE_URL: missing database name".to_string()));
252 };
253
254 let (host, port) = if let Some(colon_pos) = host_port.rfind(':') {
256 let port_str = &host_port[colon_pos + 1..];
257 let port = port_str.parse::<u16>()
258 .map_err(|_| PgError::Connection(format!("Invalid port: {}", port_str)))?;
259 (host_port[..colon_pos].to_string(), port)
260 } else {
261 (host_port.to_string(), 5432) };
263
264 Ok((host, port, user, database, password))
265 }
266
267 fn percent_decode(s: &str) -> String {
270 let mut result = String::with_capacity(s.len());
271 let mut chars = s.chars().peekable();
272
273 while let Some(c) = chars.next() {
274 if c == '%' {
275 let hex: String = chars.by_ref().take(2).collect();
277 if hex.len() == 2
278 && let Ok(byte) = u8::from_str_radix(&hex, 16)
279 {
280 result.push(byte as char);
281 continue;
282 }
283 result.push('%');
285 result.push_str(&hex);
286 } else if c == '+' {
287 result.push('+');
290 } else {
291 result.push(c);
292 }
293 }
294
295 result
296 }
297
298 pub async fn connect_with_timeout(
309 host: &str,
310 port: u16,
311 user: &str,
312 database: &str,
313 password: &str,
314 timeout: std::time::Duration,
315 ) -> PgResult<Self> {
316 tokio::time::timeout(
317 timeout,
318 Self::connect_with_password(host, port, user, database, password),
319 )
320 .await
321 .map_err(|_| PgError::Connection(format!("Connection timeout after {:?}", timeout)))?
322 }
323 pub fn clear_cache(&mut self) {
327 self.connection.stmt_cache.clear();
328 self.connection.prepared_statements.clear();
329 }
330
331 pub fn cache_stats(&self) -> (usize, usize) {
334 (self.connection.stmt_cache.len(), self.connection.stmt_cache.cap().get())
335 }
336
337 pub async fn fetch_all(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
343 self.fetch_all_cached(cmd).await
345 }
346
347 pub async fn fetch_typed<T: row::QailRow>(&mut self, cmd: &Qail) -> PgResult<Vec<T>> {
355 let rows = self.fetch_all(cmd).await?;
356 Ok(rows.iter().map(T::from_row).collect())
357 }
358
359 pub async fn fetch_one_typed<T: row::QailRow>(&mut self, cmd: &Qail) -> PgResult<Option<T>> {
362 let rows = self.fetch_all(cmd).await?;
363 Ok(rows.first().map(T::from_row))
364 }
365
366 pub async fn fetch_all_uncached(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
372 use crate::protocol::AstEncoder;
373
374 AstEncoder::encode_cmd_reuse_into(
375 cmd,
376 &mut self.connection.sql_buf,
377 &mut self.connection.params_buf,
378 &mut self.connection.write_buf,
379 );
380
381 self.connection.flush_write_buf().await?;
382
383 let mut rows: Vec<PgRow> = Vec::with_capacity(32);
384 let mut column_info: Option<Arc<ColumnInfo>> = None;
385
386 let mut error: Option<PgError> = None;
387
388 loop {
389 let msg = self.connection.recv().await?;
390 match msg {
391 crate::protocol::BackendMessage::ParseComplete
392 | crate::protocol::BackendMessage::BindComplete => {}
393 crate::protocol::BackendMessage::RowDescription(fields) => {
394 column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
395 }
396 crate::protocol::BackendMessage::DataRow(data) => {
397 if error.is_none() {
398 rows.push(PgRow {
399 columns: data,
400 column_info: column_info.clone(),
401 });
402 }
403 }
404 crate::protocol::BackendMessage::CommandComplete(_) => {}
405 crate::protocol::BackendMessage::ReadyForQuery(_) => {
406 if let Some(err) = error {
407 return Err(err);
408 }
409 return Ok(rows);
410 }
411 crate::protocol::BackendMessage::ErrorResponse(err) => {
412 if error.is_none() {
413 error = Some(PgError::Query(err.message));
414 }
415 }
416 _ => {}
417 }
418 }
419 }
420
421 pub async fn fetch_all_fast(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
425 use crate::protocol::AstEncoder;
426
427 AstEncoder::encode_cmd_reuse_into(
428 cmd,
429 &mut self.connection.sql_buf,
430 &mut self.connection.params_buf,
431 &mut self.connection.write_buf,
432 );
433
434 self.connection.flush_write_buf().await?;
435
436 let mut rows: Vec<PgRow> = Vec::with_capacity(32);
438 let mut error: Option<PgError> = None;
439
440 loop {
441 let res = self.connection.recv_with_data_fast().await;
442 match res {
443 Ok((msg_type, data)) => {
444 match msg_type {
445 b'D' => {
446 if error.is_none() && let Some(columns) = data {
448 rows.push(PgRow {
449 columns,
450 column_info: None, });
452 }
453 }
454 b'Z' => {
455 if let Some(err) = error {
457 return Err(err);
458 }
459 return Ok(rows);
460 }
461 _ => {} }
463 }
464 Err(e) => {
465 if error.is_none() {
474 error = Some(e);
475 }
476 }
481 }
482 }
483 }
484
485 pub async fn fetch_one(&mut self, cmd: &Qail) -> PgResult<PgRow> {
487 let rows = self.fetch_all(cmd).await?;
488 rows.into_iter().next().ok_or(PgError::NoRows)
489 }
490
491 pub async fn fetch_all_cached(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
500 use crate::protocol::AstEncoder;
501 use std::collections::hash_map::DefaultHasher;
502 use std::hash::{Hash, Hasher};
503
504 self.connection.sql_buf.clear();
505 self.connection.params_buf.clear();
506
507 match cmd.action {
509 qail_core::ast::Action::Get | qail_core::ast::Action::With => {
510 crate::protocol::ast_encoder::dml::encode_select(cmd, &mut self.connection.sql_buf, &mut self.connection.params_buf).ok();
511 }
512 qail_core::ast::Action::Add => {
513 crate::protocol::ast_encoder::dml::encode_insert(cmd, &mut self.connection.sql_buf, &mut self.connection.params_buf).ok();
514 }
515 qail_core::ast::Action::Set => {
516 crate::protocol::ast_encoder::dml::encode_update(cmd, &mut self.connection.sql_buf, &mut self.connection.params_buf).ok();
517 }
518 qail_core::ast::Action::Del => {
519 crate::protocol::ast_encoder::dml::encode_delete(cmd, &mut self.connection.sql_buf, &mut self.connection.params_buf).ok();
520 }
521 _ => {
522 let (sql, params) = AstEncoder::encode_cmd_sql(cmd);
524 let raw_rows = self.connection.query_cached(&sql, ¶ms).await?;
525 return Ok(raw_rows.into_iter().map(|data| PgRow { columns: data, column_info: None }).collect());
526 }
527 }
528
529 let mut hasher = DefaultHasher::new();
530 self.connection.sql_buf.hash(&mut hasher);
531 let sql_hash = hasher.finish();
532
533 let is_cache_miss = !self.connection.stmt_cache.contains(&sql_hash);
534
535 self.connection.write_buf.clear();
537
538 let stmt_name = if let Some(name) = self.connection.stmt_cache.get(&sql_hash) {
539 name.clone()
540 } else {
541 let name = format!("qail_{:x}", sql_hash);
542
543 let sql_str = std::str::from_utf8(&self.connection.sql_buf).unwrap_or("");
544
545 use crate::protocol::PgEncoder;
547 let parse_msg = PgEncoder::encode_parse(&name, sql_str, &[]);
548 let describe_msg = PgEncoder::encode_describe(false, &name);
549 self.connection.write_buf.extend_from_slice(&parse_msg);
550 self.connection.write_buf.extend_from_slice(&describe_msg);
551
552 self.connection.stmt_cache.put(sql_hash, name.clone());
553 self.connection.prepared_statements.insert(name.clone(), sql_str.to_string());
554
555 name
556 };
557
558 use crate::protocol::PgEncoder;
560 PgEncoder::encode_bind_to(&mut self.connection.write_buf, &stmt_name, &self.connection.params_buf)
561 .map_err(|e| PgError::Encode(e.to_string()))?;
562 PgEncoder::encode_execute_to(&mut self.connection.write_buf);
563 PgEncoder::encode_sync_to(&mut self.connection.write_buf);
564
565 self.connection.flush_write_buf().await?;
567
568 let cached_column_info = self.connection.column_info_cache.get(&sql_hash).cloned();
570
571 let mut rows: Vec<PgRow> = Vec::with_capacity(32);
572 let mut column_info: Option<Arc<ColumnInfo>> = cached_column_info;
573 let mut error: Option<PgError> = None;
574
575 loop {
576 let msg = self.connection.recv().await?;
577 match msg {
578 crate::protocol::BackendMessage::ParseComplete
579 | crate::protocol::BackendMessage::BindComplete => {}
580 crate::protocol::BackendMessage::ParameterDescription(_) => {
581 }
583 crate::protocol::BackendMessage::RowDescription(fields) => {
584 let info = Arc::new(ColumnInfo::from_fields(&fields));
586 if is_cache_miss {
587 self.connection.column_info_cache.insert(sql_hash, info.clone());
588 }
589 column_info = Some(info);
590 }
591 crate::protocol::BackendMessage::DataRow(data) => {
592 if error.is_none() {
593 rows.push(PgRow {
594 columns: data,
595 column_info: column_info.clone(),
596 });
597 }
598 }
599 crate::protocol::BackendMessage::CommandComplete(_) => {}
600 crate::protocol::BackendMessage::NoData => {
601 }
603 crate::protocol::BackendMessage::ReadyForQuery(_) => {
604 if let Some(err) = error {
605 return Err(err);
606 }
607 return Ok(rows);
608 }
609 crate::protocol::BackendMessage::ErrorResponse(err) => {
610 if error.is_none() {
611 error = Some(PgError::Query(err.message));
612 self.connection.stmt_cache.clear();
615 self.connection.prepared_statements.clear();
616 self.connection.column_info_cache.clear();
617 }
618 }
619 _ => {}
620 }
621 }
622 }
623
624 pub async fn execute(&mut self, cmd: &Qail) -> PgResult<u64> {
626 use crate::protocol::AstEncoder;
627
628 let wire_bytes = AstEncoder::encode_cmd_reuse(
629 cmd,
630 &mut self.connection.sql_buf,
631 &mut self.connection.params_buf,
632 );
633
634 self.connection.send_bytes(&wire_bytes).await?;
635
636 let mut affected = 0u64;
637 let mut error: Option<PgError> = None;
638
639 loop {
640 let msg = self.connection.recv().await?;
641 match msg {
642 crate::protocol::BackendMessage::ParseComplete
643 | crate::protocol::BackendMessage::BindComplete => {}
644 crate::protocol::BackendMessage::RowDescription(_) => {}
645 crate::protocol::BackendMessage::DataRow(_) => {}
646 crate::protocol::BackendMessage::CommandComplete(tag) => {
647 if error.is_none() && let Some(n) = tag.split_whitespace().last() {
648 affected = n.parse().unwrap_or(0);
649 }
650 }
651 crate::protocol::BackendMessage::ReadyForQuery(_) => {
652 if let Some(err) = error {
653 return Err(err);
654 }
655 return Ok(affected);
656 }
657 crate::protocol::BackendMessage::ErrorResponse(err) => {
658 if error.is_none() {
659 error = Some(PgError::Query(err.message));
660 }
661 }
662 _ => {}
663 }
664 }
665 }
666
667 pub async fn query_ast(&mut self, cmd: &Qail) -> PgResult<QueryResult> {
671 use crate::protocol::AstEncoder;
672
673 let wire_bytes = AstEncoder::encode_cmd_reuse(
674 cmd,
675 &mut self.connection.sql_buf,
676 &mut self.connection.params_buf,
677 );
678
679 self.connection.send_bytes(&wire_bytes).await?;
680
681 let mut columns: Vec<String> = Vec::new();
682 let mut rows: Vec<Vec<Option<String>>> = Vec::new();
683 let mut error: Option<PgError> = None;
684
685 loop {
686 let msg = self.connection.recv().await?;
687 match msg {
688 crate::protocol::BackendMessage::ParseComplete
689 | crate::protocol::BackendMessage::BindComplete => {}
690 crate::protocol::BackendMessage::RowDescription(fields) => {
691 columns = fields.into_iter().map(|f| f.name).collect();
692 }
693 crate::protocol::BackendMessage::DataRow(data) => {
694 if error.is_none() {
695 let row: Vec<Option<String>> = data
696 .into_iter()
697 .map(|col| col.map(|bytes| String::from_utf8_lossy(&bytes).to_string()))
698 .collect();
699 rows.push(row);
700 }
701 }
702 crate::protocol::BackendMessage::CommandComplete(_) => {}
703 crate::protocol::BackendMessage::NoData => {}
704 crate::protocol::BackendMessage::ReadyForQuery(_) => {
705 if let Some(err) = error {
706 return Err(err);
707 }
708 return Ok(QueryResult { columns, rows });
709 }
710 crate::protocol::BackendMessage::ErrorResponse(err) => {
711 if error.is_none() {
712 error = Some(PgError::Query(err.message));
713 }
714 }
715 _ => {}
716 }
717 }
718 }
719
720 pub async fn begin(&mut self) -> PgResult<()> {
724 self.connection.begin_transaction().await
725 }
726
727 pub async fn commit(&mut self) -> PgResult<()> {
729 self.connection.commit().await
730 }
731
732 pub async fn rollback(&mut self) -> PgResult<()> {
734 self.connection.rollback().await
735 }
736
737 pub async fn savepoint(&mut self, name: &str) -> PgResult<()> {
750 self.connection.savepoint(name).await
751 }
752
753 pub async fn rollback_to(&mut self, name: &str) -> PgResult<()> {
757 self.connection.rollback_to(name).await
758 }
759
760 pub async fn release_savepoint(&mut self, name: &str) -> PgResult<()> {
763 self.connection.release_savepoint(name).await
764 }
765
766 pub async fn execute_batch(&mut self, cmds: &[Qail]) -> PgResult<Vec<u64>> {
780 self.begin().await?;
781 let mut results = Vec::with_capacity(cmds.len());
782 for cmd in cmds {
783 match self.execute(cmd).await {
784 Ok(n) => results.push(n),
785 Err(e) => {
786 self.rollback().await?;
787 return Err(e);
788 }
789 }
790 }
791 self.commit().await?;
792 Ok(results)
793 }
794
795 pub async fn set_statement_timeout(&mut self, ms: u32) -> PgResult<()> {
803 self.execute_raw(&format!("SET statement_timeout = {}", ms))
804 .await
805 }
806
807 pub async fn reset_statement_timeout(&mut self) -> PgResult<()> {
809 self.execute_raw("RESET statement_timeout").await
810 }
811
812 pub async fn set_rls_context(&mut self, ctx: rls::RlsContext) -> PgResult<()> {
830 let sql = rls::context_to_sql(&ctx);
831 self.execute_raw(&sql).await?;
832 self.rls_context = Some(ctx);
833 Ok(())
834 }
835
836 pub async fn clear_rls_context(&mut self) -> PgResult<()> {
841 self.execute_raw(rls::reset_sql()).await?;
842 self.rls_context = None;
843 Ok(())
844 }
845
846 pub fn rls_context(&self) -> Option<&rls::RlsContext> {
848 self.rls_context.as_ref()
849 }
850
851 pub async fn pipeline_batch(&mut self, cmds: &[Qail]) -> PgResult<usize> {
863 self.connection.pipeline_ast_fast(cmds).await
864 }
865
866 pub async fn pipeline_fetch(&mut self, cmds: &[Qail]) -> PgResult<Vec<Vec<PgRow>>> {
868 let raw_results = self.connection.pipeline_ast(cmds).await?;
869
870 let results: Vec<Vec<PgRow>> = raw_results
871 .into_iter()
872 .map(|rows| {
873 rows.into_iter()
874 .map(|columns| PgRow {
875 columns,
876 column_info: None,
877 })
878 .collect()
879 })
880 .collect();
881
882 Ok(results)
883 }
884
885 pub async fn prepare(&mut self, sql: &str) -> PgResult<PreparedStatement> {
887 self.connection.prepare(sql).await
888 }
889
890 pub async fn pipeline_prepared_fast(
892 &mut self,
893 stmt: &PreparedStatement,
894 params_batch: &[Vec<Option<Vec<u8>>>],
895 ) -> PgResult<usize> {
896 self.connection
897 .pipeline_prepared_fast(stmt, params_batch)
898 .await
899 }
900
901 pub async fn execute_raw(&mut self, sql: &str) -> PgResult<()> {
908 if sql.as_bytes().contains(&0) {
910 return Err(crate::PgError::Protocol(
911 "SQL contains NULL byte (0x00) which is invalid in PostgreSQL".to_string(),
912 ));
913 }
914 self.connection.execute_simple(sql).await
915 }
916
917 pub async fn fetch_raw(&mut self, sql: &str) -> PgResult<Vec<PgRow>> {
921 if sql.as_bytes().contains(&0) {
922 return Err(crate::PgError::Protocol(
923 "SQL contains NULL byte (0x00) which is invalid in PostgreSQL".to_string(),
924 ));
925 }
926
927 use tokio::io::AsyncWriteExt;
928 use crate::protocol::PgEncoder;
929
930 let msg = PgEncoder::encode_query_string(sql);
932 self.connection.stream.write_all(&msg).await?;
933
934 let mut rows: Vec<PgRow> = Vec::new();
935 let mut column_info: Option<std::sync::Arc<ColumnInfo>> = None;
936
937
938 let mut error: Option<PgError> = None;
939
940 loop {
941 let msg = self.connection.recv().await?;
942 match msg {
943 crate::protocol::BackendMessage::RowDescription(fields) => {
944 column_info = Some(std::sync::Arc::new(ColumnInfo::from_fields(&fields)));
945 }
946 crate::protocol::BackendMessage::DataRow(data) => {
947 if error.is_none() {
948 rows.push(PgRow {
949 columns: data,
950 column_info: column_info.clone(),
951 });
952 }
953 }
954 crate::protocol::BackendMessage::CommandComplete(_) => {}
955 crate::protocol::BackendMessage::ReadyForQuery(_) => {
956 if let Some(err) = error {
957 return Err(err);
958 }
959 return Ok(rows);
960 }
961 crate::protocol::BackendMessage::ErrorResponse(err) => {
962 if error.is_none() {
963 error = Some(PgError::Query(err.message));
964 }
965 }
966 _ => {}
967 }
968 }
969 }
970
971 pub async fn copy_bulk(
987 &mut self,
988 cmd: &Qail,
989 rows: &[Vec<qail_core::ast::Value>],
990 ) -> PgResult<u64> {
991 use qail_core::ast::Action;
992
993
994 if cmd.action != Action::Add {
995 return Err(PgError::Query(
996 "copy_bulk requires Qail::Add action".to_string(),
997 ));
998 }
999
1000 let table = &cmd.table;
1001
1002 let columns: Vec<String> = cmd
1003 .columns
1004 .iter()
1005 .filter_map(|expr| {
1006 use qail_core::ast::Expr;
1007 match expr {
1008 Expr::Named(name) => Some(name.clone()),
1009 Expr::Aliased { name, .. } => Some(name.clone()),
1010 Expr::Star => None, _ => None,
1012 }
1013 })
1014 .collect();
1015
1016 if columns.is_empty() {
1017 return Err(PgError::Query(
1018 "copy_bulk requires columns in Qail".to_string(),
1019 ));
1020 }
1021
1022 self.connection.copy_in_fast(table, &columns, rows).await
1024 }
1025
1026 pub async fn copy_bulk_bytes(&mut self, cmd: &Qail, data: &[u8]) -> PgResult<u64> {
1039 use qail_core::ast::Action;
1040
1041 if cmd.action != Action::Add {
1042 return Err(PgError::Query(
1043 "copy_bulk_bytes requires Qail::Add action".to_string(),
1044 ));
1045 }
1046
1047 let table = &cmd.table;
1048 let columns: Vec<String> = cmd
1049 .columns
1050 .iter()
1051 .filter_map(|expr| {
1052 use qail_core::ast::Expr;
1053 match expr {
1054 Expr::Named(name) => Some(name.clone()),
1055 Expr::Aliased { name, .. } => Some(name.clone()),
1056 _ => None,
1057 }
1058 })
1059 .collect();
1060
1061 if columns.is_empty() {
1062 return Err(PgError::Query(
1063 "copy_bulk_bytes requires columns in Qail".to_string(),
1064 ));
1065 }
1066
1067 self.connection.copy_in_raw(table, &columns, data).await
1069 }
1070
1071 pub async fn copy_export_table(
1079 &mut self,
1080 table: &str,
1081 columns: &[String],
1082 ) -> PgResult<Vec<u8>> {
1083 let cols = columns.join(", ");
1084 let sql = format!("COPY {} ({}) TO STDOUT", table, cols);
1085
1086 self.connection.copy_out_raw(&sql).await
1087 }
1088
1089 pub async fn stream_cmd(
1103 &mut self,
1104 cmd: &Qail,
1105 batch_size: usize,
1106 ) -> PgResult<Vec<Vec<PgRow>>> {
1107 use std::sync::atomic::{AtomicU64, Ordering};
1108 static CURSOR_ID: AtomicU64 = AtomicU64::new(0);
1109
1110 let cursor_name = format!("qail_cursor_{}", CURSOR_ID.fetch_add(1, Ordering::SeqCst));
1111
1112 use crate::protocol::AstEncoder;
1114 let mut sql_buf = bytes::BytesMut::with_capacity(256);
1115 let mut params: Vec<Option<Vec<u8>>> = Vec::new();
1116 AstEncoder::encode_select_sql(cmd, &mut sql_buf, &mut params);
1117 let sql = String::from_utf8_lossy(&sql_buf).to_string();
1118
1119 self.connection.begin_transaction().await?;
1121
1122 self.connection.declare_cursor(&cursor_name, &sql).await?;
1124
1125 let mut all_batches = Vec::new();
1127 while let Some(rows) = self
1128 .connection
1129 .fetch_cursor(&cursor_name, batch_size)
1130 .await?
1131 {
1132 let pg_rows: Vec<PgRow> = rows
1133 .into_iter()
1134 .map(|cols| PgRow {
1135 columns: cols,
1136 column_info: None,
1137 })
1138 .collect();
1139 all_batches.push(pg_rows);
1140 }
1141
1142 self.connection.close_cursor(&cursor_name).await?;
1143 self.connection.commit().await?;
1144
1145 Ok(all_batches)
1146 }
1147}
1148
1149#[derive(Default)]
1166pub struct PgDriverBuilder {
1167 host: Option<String>,
1168 port: Option<u16>,
1169 user: Option<String>,
1170 database: Option<String>,
1171 password: Option<String>,
1172 timeout: Option<std::time::Duration>,
1173}
1174
1175impl PgDriverBuilder {
1176 pub fn new() -> Self {
1178 Self::default()
1179 }
1180
1181 pub fn host(mut self, host: impl Into<String>) -> Self {
1183 self.host = Some(host.into());
1184 self
1185 }
1186
1187 pub fn port(mut self, port: u16) -> Self {
1189 self.port = Some(port);
1190 self
1191 }
1192
1193 pub fn user(mut self, user: impl Into<String>) -> Self {
1195 self.user = Some(user.into());
1196 self
1197 }
1198
1199 pub fn database(mut self, database: impl Into<String>) -> Self {
1201 self.database = Some(database.into());
1202 self
1203 }
1204
1205 pub fn password(mut self, password: impl Into<String>) -> Self {
1207 self.password = Some(password.into());
1208 self
1209 }
1210
1211 pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
1213 self.timeout = Some(timeout);
1214 self
1215 }
1216
1217 pub async fn connect(self) -> PgResult<PgDriver> {
1219 let host = self.host.as_deref().unwrap_or("127.0.0.1");
1220 let port = self.port.unwrap_or(5432);
1221 let user = self.user.as_deref().ok_or_else(|| {
1222 PgError::Connection("User is required".to_string())
1223 })?;
1224 let database = self.database.as_deref().ok_or_else(|| {
1225 PgError::Connection("Database is required".to_string())
1226 })?;
1227
1228 match (self.password.as_deref(), self.timeout) {
1229 (Some(password), Some(timeout)) => {
1230 PgDriver::connect_with_timeout(host, port, user, database, password, timeout).await
1231 }
1232 (Some(password), None) => {
1233 PgDriver::connect_with_password(host, port, user, database, password).await
1234 }
1235 (None, Some(timeout)) => {
1236 tokio::time::timeout(
1237 timeout,
1238 PgDriver::connect(host, port, user, database),
1239 )
1240 .await
1241 .map_err(|_| PgError::Connection(format!("Connection timeout after {:?}", timeout)))?
1242 }
1243 (None, None) => {
1244 PgDriver::connect(host, port, user, database).await
1245 }
1246 }
1247 }
1248}