1mod cancel;
19mod connection;
20mod copy;
21mod cursor;
22mod io;
23pub mod io_backend;
24mod pipeline;
25mod pool;
26mod prepared;
27mod query;
28pub mod rls;
29mod row;
30mod stream;
31mod transaction;
32
33pub use connection::PgConnection;
34pub use connection::TlsConfig;
35pub(crate) use connection::{CANCEL_REQUEST_CODE, parse_affected_rows};
36pub use cancel::CancelToken;
37pub use io_backend::{IoBackend, backend_name, detect as detect_io_backend};
38pub use pool::{PgPool, PoolConfig, PoolStats, PooledConnection};
39pub use prepared::PreparedStatement;
40pub use rls::RlsContext;
41pub use row::QailRow;
42
43use qail_core::ast::Qail;
44use std::collections::HashMap;
45use std::sync::Arc;
46
47#[derive(Debug, Clone)]
48pub struct ColumnInfo {
49 pub name_to_index: HashMap<String, usize>,
50 pub oids: Vec<u32>,
51 pub formats: Vec<i16>,
52}
53
54impl ColumnInfo {
55 pub fn from_fields(fields: &[crate::protocol::FieldDescription]) -> Self {
56 let mut name_to_index = HashMap::with_capacity(fields.len());
57 let mut oids = Vec::with_capacity(fields.len());
58 let mut formats = Vec::with_capacity(fields.len());
59
60 for (i, field) in fields.iter().enumerate() {
61 name_to_index.insert(field.name.clone(), i);
62 oids.push(field.type_oid);
63 formats.push(field.format);
64 }
65
66 Self {
67 name_to_index,
68 oids,
69 formats,
70 }
71 }
72}
73
74pub struct PgRow {
76 pub columns: Vec<Option<Vec<u8>>>,
77 pub column_info: Option<Arc<ColumnInfo>>,
78}
79
80#[derive(Debug)]
82pub enum PgError {
83 Connection(String),
84 Protocol(String),
85 Auth(String),
86 Query(String),
87 NoRows,
88 Io(std::io::Error),
90 Encode(String),
92}
93
94impl std::fmt::Display for PgError {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 match self {
97 PgError::Connection(e) => write!(f, "Connection error: {}", e),
98 PgError::Protocol(e) => write!(f, "Protocol error: {}", e),
99 PgError::Auth(e) => write!(f, "Auth error: {}", e),
100 PgError::Query(e) => write!(f, "Query error: {}", e),
101 PgError::NoRows => write!(f, "No rows returned"),
102 PgError::Io(e) => write!(f, "I/O error: {}", e),
103 PgError::Encode(e) => write!(f, "Encode error: {}", e),
104 }
105 }
106}
107
108impl std::error::Error for PgError {}
109
110impl From<std::io::Error> for PgError {
111 fn from(e: std::io::Error) -> Self {
112 PgError::Io(e)
113 }
114}
115
116pub type PgResult<T> = Result<T, PgError>;
118
119#[derive(Debug, Clone)]
121pub struct QueryResult {
122 pub columns: Vec<String>,
124 pub rows: Vec<Vec<Option<String>>>,
126}
127
128pub struct PgDriver {
130 #[allow(dead_code)]
131 connection: PgConnection,
132 rls_context: Option<RlsContext>,
134}
135
136impl PgDriver {
137 pub fn new(connection: PgConnection) -> Self {
139 Self { connection, rls_context: None }
140 }
141
142 pub fn builder() -> PgDriverBuilder {
155 PgDriverBuilder::new()
156 }
157
158 pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
160 let connection = PgConnection::connect(host, port, user, database).await?;
161 Ok(Self::new(connection))
162 }
163
164 pub async fn connect_with_password(
166 host: &str,
167 port: u16,
168 user: &str,
169 database: &str,
170 password: &str,
171 ) -> PgResult<Self> {
172 let connection =
173 PgConnection::connect_with_password(host, port, user, database, Some(password)).await?;
174 Ok(Self::new(connection))
175 }
176
177 pub async fn connect_env() -> PgResult<Self> {
188 let url = std::env::var("DATABASE_URL")
189 .map_err(|_| PgError::Connection("DATABASE_URL environment variable not set".to_string()))?;
190 Self::connect_url(&url).await
191 }
192
193 pub async fn connect_url(url: &str) -> PgResult<Self> {
203 let (host, port, user, database, password) = Self::parse_database_url(url)?;
204
205 if let Some(pwd) = password {
206 Self::connect_with_password(&host, port, &user, &database, &pwd).await
207 } else {
208 Self::connect(&host, port, &user, &database).await
209 }
210 }
211
212 fn parse_database_url(url: &str) -> PgResult<(String, u16, String, String, Option<String>)> {
219 let after_scheme = url.split("://").nth(1)
221 .ok_or_else(|| PgError::Connection("Invalid DATABASE_URL: missing scheme".to_string()))?;
222
223 let (auth_part, host_db_part) = if let Some(at_pos) = after_scheme.rfind('@') {
225 (Some(&after_scheme[..at_pos]), &after_scheme[at_pos + 1..])
226 } else {
227 (None, after_scheme)
228 };
229
230 let (user, password) = if let Some(auth) = auth_part {
232 let parts: Vec<&str> = auth.splitn(2, ':').collect();
233 if parts.len() == 2 {
234 (
236 Self::percent_decode(parts[0]),
237 Some(Self::percent_decode(parts[1])),
238 )
239 } else {
240 (Self::percent_decode(parts[0]), None)
241 }
242 } else {
243 return Err(PgError::Connection("Invalid DATABASE_URL: missing user".to_string()));
244 };
245
246 let (host_port, database) = if let Some(slash_pos) = host_db_part.find('/') {
248 (&host_db_part[..slash_pos], host_db_part[slash_pos + 1..].to_string())
249 } else {
250 return Err(PgError::Connection("Invalid DATABASE_URL: missing database name".to_string()));
251 };
252
253 let (host, port) = if let Some(colon_pos) = host_port.rfind(':') {
255 let port_str = &host_port[colon_pos + 1..];
256 let port = port_str.parse::<u16>()
257 .map_err(|_| PgError::Connection(format!("Invalid port: {}", port_str)))?;
258 (host_port[..colon_pos].to_string(), port)
259 } else {
260 (host_port.to_string(), 5432) };
262
263 Ok((host, port, user, database, password))
264 }
265
266 fn percent_decode(s: &str) -> String {
269 let mut result = String::with_capacity(s.len());
270 let mut chars = s.chars().peekable();
271
272 while let Some(c) = chars.next() {
273 if c == '%' {
274 let hex: String = chars.by_ref().take(2).collect();
276 if hex.len() == 2
277 && let Ok(byte) = u8::from_str_radix(&hex, 16)
278 {
279 result.push(byte as char);
280 continue;
281 }
282 result.push('%');
284 result.push_str(&hex);
285 } else if c == '+' {
286 result.push('+');
289 } else {
290 result.push(c);
291 }
292 }
293
294 result
295 }
296
297 pub async fn connect_with_timeout(
308 host: &str,
309 port: u16,
310 user: &str,
311 database: &str,
312 password: &str,
313 timeout: std::time::Duration,
314 ) -> PgResult<Self> {
315 tokio::time::timeout(
316 timeout,
317 Self::connect_with_password(host, port, user, database, password),
318 )
319 .await
320 .map_err(|_| PgError::Connection(format!("Connection timeout after {:?}", timeout)))?
321 }
322 pub fn clear_cache(&mut self) {
326 self.connection.stmt_cache.clear();
327 self.connection.prepared_statements.clear();
328 }
329
330 pub fn cache_stats(&self) -> (usize, usize) {
333 (self.connection.stmt_cache.len(), self.connection.stmt_cache.cap().get())
334 }
335
336 pub async fn fetch_all(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
342 self.fetch_all_cached(cmd).await
344 }
345
346 pub async fn fetch_typed<T: row::QailRow>(&mut self, cmd: &Qail) -> PgResult<Vec<T>> {
354 let rows = self.fetch_all(cmd).await?;
355 Ok(rows.iter().map(T::from_row).collect())
356 }
357
358 pub async fn fetch_one_typed<T: row::QailRow>(&mut self, cmd: &Qail) -> PgResult<Option<T>> {
361 let rows = self.fetch_all(cmd).await?;
362 Ok(rows.first().map(T::from_row))
363 }
364
365 pub async fn fetch_all_uncached(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
369 use crate::protocol::AstEncoder;
370
371 let wire_bytes = AstEncoder::encode_cmd_reuse(
372 cmd,
373 &mut self.connection.sql_buf,
374 &mut self.connection.params_buf,
375 );
376
377 self.connection.send_bytes(&wire_bytes).await?;
378
379 let mut rows: Vec<PgRow> = Vec::new();
380 let mut column_info: Option<Arc<ColumnInfo>> = None;
381
382 let mut error: Option<PgError> = None;
383
384 loop {
385 let msg = self.connection.recv().await?;
386 match msg {
387 crate::protocol::BackendMessage::ParseComplete
388 | crate::protocol::BackendMessage::BindComplete => {}
389 crate::protocol::BackendMessage::RowDescription(fields) => {
390 column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
391 }
392 crate::protocol::BackendMessage::DataRow(data) => {
393 if error.is_none() {
394 rows.push(PgRow {
395 columns: data,
396 column_info: column_info.clone(),
397 });
398 }
399 }
400 crate::protocol::BackendMessage::CommandComplete(_) => {}
401 crate::protocol::BackendMessage::ReadyForQuery(_) => {
402 if let Some(err) = error {
403 return Err(err);
404 }
405 return Ok(rows);
406 }
407 crate::protocol::BackendMessage::ErrorResponse(err) => {
408 if error.is_none() {
409 error = Some(PgError::Query(err.message));
410 }
411 }
412 _ => {}
413 }
414 }
415 }
416
417 pub async fn fetch_all_fast(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
421 use crate::protocol::AstEncoder;
422
423 let wire_bytes = AstEncoder::encode_cmd_reuse(
424 cmd,
425 &mut self.connection.sql_buf,
426 &mut self.connection.params_buf,
427 );
428
429 self.connection.send_bytes(&wire_bytes).await?;
430
431 let mut rows: Vec<PgRow> = Vec::new();
433 let mut error: Option<PgError> = None;
434
435 loop {
436 let res = self.connection.recv_with_data_fast().await;
437 match res {
438 Ok((msg_type, data)) => {
439 match msg_type {
440 b'D' => {
441 if error.is_none() && let Some(columns) = data {
443 rows.push(PgRow {
444 columns,
445 column_info: None, });
447 }
448 }
449 b'Z' => {
450 if let Some(err) = error {
452 return Err(err);
453 }
454 return Ok(rows);
455 }
456 _ => {} }
458 }
459 Err(e) => {
460 if error.is_none() {
469 error = Some(e);
470 }
471 }
476 }
477 }
478 }
479
480 pub async fn fetch_one(&mut self, cmd: &Qail) -> PgResult<PgRow> {
482 let rows = self.fetch_all(cmd).await?;
483 rows.into_iter().next().ok_or(PgError::NoRows)
484 }
485
486 pub async fn fetch_all_cached(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
493 use crate::protocol::AstEncoder;
494 use std::collections::hash_map::DefaultHasher;
495 use std::hash::{Hash, Hasher};
496
497 self.connection.sql_buf.clear();
498 self.connection.params_buf.clear();
499
500 match cmd.action {
502 qail_core::ast::Action::Get | qail_core::ast::Action::With => {
503 crate::protocol::ast_encoder::dml::encode_select(cmd, &mut self.connection.sql_buf, &mut self.connection.params_buf).ok();
504 }
505 qail_core::ast::Action::Add => {
506 crate::protocol::ast_encoder::dml::encode_insert(cmd, &mut self.connection.sql_buf, &mut self.connection.params_buf).ok();
507 }
508 qail_core::ast::Action::Set => {
509 crate::protocol::ast_encoder::dml::encode_update(cmd, &mut self.connection.sql_buf, &mut self.connection.params_buf).ok();
510 }
511 qail_core::ast::Action::Del => {
512 crate::protocol::ast_encoder::dml::encode_delete(cmd, &mut self.connection.sql_buf, &mut self.connection.params_buf).ok();
513 }
514 _ => {
515 let (sql, params) = AstEncoder::encode_cmd_sql(cmd);
517 let raw_rows = self.connection.query_cached(&sql, ¶ms).await?;
518 return Ok(raw_rows.into_iter().map(|data| PgRow { columns: data, column_info: None }).collect());
519 }
520 }
521
522 let mut hasher = DefaultHasher::new();
523 self.connection.sql_buf.hash(&mut hasher);
524 let sql_hash = hasher.finish();
525
526 let is_cache_miss = !self.connection.stmt_cache.contains(&sql_hash);
527
528 let stmt_name = if let Some(name) = self.connection.stmt_cache.get(&sql_hash) {
529 name.clone()
530 } else {
531 let name = format!("qail_{:x}", sql_hash);
532
533 use crate::protocol::PgEncoder;
534 use tokio::io::AsyncWriteExt;
535
536 let sql_str = std::str::from_utf8(&self.connection.sql_buf).unwrap_or("");
537
538 let parse_msg = PgEncoder::encode_parse(&name, sql_str, &[]);
540 let describe_msg = PgEncoder::encode_describe(false, &name);
541 self.connection.stream.write_all(&parse_msg).await?;
542 self.connection.stream.write_all(&describe_msg).await?;
543
544 self.connection.stmt_cache.put(sql_hash, name.clone());
545 self.connection.prepared_statements.insert(name.clone(), sql_str.to_string());
546
547 name
548 };
549
550 use crate::protocol::PgEncoder;
552 use tokio::io::AsyncWriteExt;
553
554 let mut buf = bytes::BytesMut::with_capacity(128);
555 PgEncoder::encode_bind_to(&mut buf, &stmt_name, &self.connection.params_buf)
556 .map_err(|e| PgError::Encode(e.to_string()))?;
557 PgEncoder::encode_execute_to(&mut buf);
558 PgEncoder::encode_sync_to(&mut buf);
559 self.connection.stream.write_all(&buf).await?;
560
561 let cached_column_info = self.connection.column_info_cache.get(&sql_hash).cloned();
563
564 let mut rows: Vec<PgRow> = Vec::new();
565 let mut column_info: Option<Arc<ColumnInfo>> = cached_column_info;
566 let mut error: Option<PgError> = None;
567
568 loop {
569 let msg = self.connection.recv().await?;
570 match msg {
571 crate::protocol::BackendMessage::ParseComplete
572 | crate::protocol::BackendMessage::BindComplete => {}
573 crate::protocol::BackendMessage::ParameterDescription(_) => {
574 }
576 crate::protocol::BackendMessage::RowDescription(fields) => {
577 let info = Arc::new(ColumnInfo::from_fields(&fields));
579 if is_cache_miss {
580 self.connection.column_info_cache.insert(sql_hash, info.clone());
581 }
582 column_info = Some(info);
583 }
584 crate::protocol::BackendMessage::DataRow(data) => {
585 if error.is_none() {
586 rows.push(PgRow {
587 columns: data,
588 column_info: column_info.clone(),
589 });
590 }
591 }
592 crate::protocol::BackendMessage::CommandComplete(_) => {}
593 crate::protocol::BackendMessage::NoData => {
594 }
596 crate::protocol::BackendMessage::ReadyForQuery(_) => {
597 if let Some(err) = error {
598 return Err(err);
599 }
600 return Ok(rows);
601 }
602 crate::protocol::BackendMessage::ErrorResponse(err) => {
603 if error.is_none() {
604 error = Some(PgError::Query(err.message));
605 self.connection.stmt_cache.clear();
608 self.connection.prepared_statements.clear();
609 self.connection.column_info_cache.clear();
610 }
611 }
612 _ => {}
613 }
614 }
615 }
616
617 pub async fn execute(&mut self, cmd: &Qail) -> PgResult<u64> {
619 use crate::protocol::AstEncoder;
620
621 let wire_bytes = AstEncoder::encode_cmd_reuse(
622 cmd,
623 &mut self.connection.sql_buf,
624 &mut self.connection.params_buf,
625 );
626
627 self.connection.send_bytes(&wire_bytes).await?;
628
629 let mut affected = 0u64;
630 let mut error: Option<PgError> = None;
631
632 loop {
633 let msg = self.connection.recv().await?;
634 match msg {
635 crate::protocol::BackendMessage::ParseComplete
636 | crate::protocol::BackendMessage::BindComplete => {}
637 crate::protocol::BackendMessage::RowDescription(_) => {}
638 crate::protocol::BackendMessage::DataRow(_) => {}
639 crate::protocol::BackendMessage::CommandComplete(tag) => {
640 if error.is_none() && let Some(n) = tag.split_whitespace().last() {
641 affected = n.parse().unwrap_or(0);
642 }
643 }
644 crate::protocol::BackendMessage::ReadyForQuery(_) => {
645 if let Some(err) = error {
646 return Err(err);
647 }
648 return Ok(affected);
649 }
650 crate::protocol::BackendMessage::ErrorResponse(err) => {
651 if error.is_none() {
652 error = Some(PgError::Query(err.message));
653 }
654 }
655 _ => {}
656 }
657 }
658 }
659
660 pub async fn query_ast(&mut self, cmd: &Qail) -> PgResult<QueryResult> {
664 use crate::protocol::AstEncoder;
665
666 let wire_bytes = AstEncoder::encode_cmd_reuse(
667 cmd,
668 &mut self.connection.sql_buf,
669 &mut self.connection.params_buf,
670 );
671
672 self.connection.send_bytes(&wire_bytes).await?;
673
674 let mut columns: Vec<String> = Vec::new();
675 let mut rows: Vec<Vec<Option<String>>> = Vec::new();
676 let mut error: Option<PgError> = None;
677
678 loop {
679 let msg = self.connection.recv().await?;
680 match msg {
681 crate::protocol::BackendMessage::ParseComplete
682 | crate::protocol::BackendMessage::BindComplete => {}
683 crate::protocol::BackendMessage::RowDescription(fields) => {
684 columns = fields.into_iter().map(|f| f.name).collect();
685 }
686 crate::protocol::BackendMessage::DataRow(data) => {
687 if error.is_none() {
688 let row: Vec<Option<String>> = data
689 .into_iter()
690 .map(|col| col.map(|bytes| String::from_utf8_lossy(&bytes).to_string()))
691 .collect();
692 rows.push(row);
693 }
694 }
695 crate::protocol::BackendMessage::CommandComplete(_) => {}
696 crate::protocol::BackendMessage::NoData => {}
697 crate::protocol::BackendMessage::ReadyForQuery(_) => {
698 if let Some(err) = error {
699 return Err(err);
700 }
701 return Ok(QueryResult { columns, rows });
702 }
703 crate::protocol::BackendMessage::ErrorResponse(err) => {
704 if error.is_none() {
705 error = Some(PgError::Query(err.message));
706 }
707 }
708 _ => {}
709 }
710 }
711 }
712
713 pub async fn begin(&mut self) -> PgResult<()> {
717 self.connection.begin_transaction().await
718 }
719
720 pub async fn commit(&mut self) -> PgResult<()> {
722 self.connection.commit().await
723 }
724
725 pub async fn rollback(&mut self) -> PgResult<()> {
727 self.connection.rollback().await
728 }
729
730 pub async fn savepoint(&mut self, name: &str) -> PgResult<()> {
743 self.connection.savepoint(name).await
744 }
745
746 pub async fn rollback_to(&mut self, name: &str) -> PgResult<()> {
750 self.connection.rollback_to(name).await
751 }
752
753 pub async fn release_savepoint(&mut self, name: &str) -> PgResult<()> {
756 self.connection.release_savepoint(name).await
757 }
758
759 pub async fn execute_batch(&mut self, cmds: &[Qail]) -> PgResult<Vec<u64>> {
773 self.begin().await?;
774 let mut results = Vec::with_capacity(cmds.len());
775 for cmd in cmds {
776 match self.execute(cmd).await {
777 Ok(n) => results.push(n),
778 Err(e) => {
779 self.rollback().await?;
780 return Err(e);
781 }
782 }
783 }
784 self.commit().await?;
785 Ok(results)
786 }
787
788 pub async fn set_statement_timeout(&mut self, ms: u32) -> PgResult<()> {
796 self.execute_raw(&format!("SET statement_timeout = {}", ms))
797 .await
798 }
799
800 pub async fn reset_statement_timeout(&mut self) -> PgResult<()> {
802 self.execute_raw("RESET statement_timeout").await
803 }
804
805 pub async fn set_rls_context(&mut self, ctx: rls::RlsContext) -> PgResult<()> {
823 let sql = rls::context_to_sql(&ctx);
824 self.execute_raw(&sql).await?;
825 self.rls_context = Some(ctx);
826 Ok(())
827 }
828
829 pub async fn clear_rls_context(&mut self) -> PgResult<()> {
834 self.execute_raw(rls::reset_sql()).await?;
835 self.rls_context = None;
836 Ok(())
837 }
838
839 pub fn rls_context(&self) -> Option<&rls::RlsContext> {
841 self.rls_context.as_ref()
842 }
843
844 pub async fn pipeline_batch(&mut self, cmds: &[Qail]) -> PgResult<usize> {
856 self.connection.pipeline_ast_fast(cmds).await
857 }
858
859 pub async fn pipeline_fetch(&mut self, cmds: &[Qail]) -> PgResult<Vec<Vec<PgRow>>> {
861 let raw_results = self.connection.pipeline_ast(cmds).await?;
862
863 let results: Vec<Vec<PgRow>> = raw_results
864 .into_iter()
865 .map(|rows| {
866 rows.into_iter()
867 .map(|columns| PgRow {
868 columns,
869 column_info: None,
870 })
871 .collect()
872 })
873 .collect();
874
875 Ok(results)
876 }
877
878 pub async fn prepare(&mut self, sql: &str) -> PgResult<PreparedStatement> {
880 self.connection.prepare(sql).await
881 }
882
883 pub async fn pipeline_prepared_fast(
885 &mut self,
886 stmt: &PreparedStatement,
887 params_batch: &[Vec<Option<Vec<u8>>>],
888 ) -> PgResult<usize> {
889 self.connection
890 .pipeline_prepared_fast(stmt, params_batch)
891 .await
892 }
893
894 pub async fn execute_raw(&mut self, sql: &str) -> PgResult<()> {
901 if sql.as_bytes().contains(&0) {
903 return Err(crate::PgError::Protocol(
904 "SQL contains NULL byte (0x00) which is invalid in PostgreSQL".to_string(),
905 ));
906 }
907 self.connection.execute_simple(sql).await
908 }
909
910 pub async fn fetch_raw(&mut self, sql: &str) -> PgResult<Vec<PgRow>> {
914 if sql.as_bytes().contains(&0) {
915 return Err(crate::PgError::Protocol(
916 "SQL contains NULL byte (0x00) which is invalid in PostgreSQL".to_string(),
917 ));
918 }
919
920 use tokio::io::AsyncWriteExt;
921 use crate::protocol::PgEncoder;
922
923 let msg = PgEncoder::encode_query_string(sql);
925 self.connection.stream.write_all(&msg).await?;
926
927 let mut rows: Vec<PgRow> = Vec::new();
928 let mut column_info: Option<std::sync::Arc<ColumnInfo>> = None;
929
930
931 let mut error: Option<PgError> = None;
932
933 loop {
934 let msg = self.connection.recv().await?;
935 match msg {
936 crate::protocol::BackendMessage::RowDescription(fields) => {
937 column_info = Some(std::sync::Arc::new(ColumnInfo::from_fields(&fields)));
938 }
939 crate::protocol::BackendMessage::DataRow(data) => {
940 if error.is_none() {
941 rows.push(PgRow {
942 columns: data,
943 column_info: column_info.clone(),
944 });
945 }
946 }
947 crate::protocol::BackendMessage::CommandComplete(_) => {}
948 crate::protocol::BackendMessage::ReadyForQuery(_) => {
949 if let Some(err) = error {
950 return Err(err);
951 }
952 return Ok(rows);
953 }
954 crate::protocol::BackendMessage::ErrorResponse(err) => {
955 if error.is_none() {
956 error = Some(PgError::Query(err.message));
957 }
958 }
959 _ => {}
960 }
961 }
962 }
963
964 pub async fn copy_bulk(
980 &mut self,
981 cmd: &Qail,
982 rows: &[Vec<qail_core::ast::Value>],
983 ) -> PgResult<u64> {
984 use qail_core::ast::Action;
985
986
987 if cmd.action != Action::Add {
988 return Err(PgError::Query(
989 "copy_bulk requires Qail::Add action".to_string(),
990 ));
991 }
992
993 let table = &cmd.table;
994
995 let columns: Vec<String> = cmd
996 .columns
997 .iter()
998 .filter_map(|expr| {
999 use qail_core::ast::Expr;
1000 match expr {
1001 Expr::Named(name) => Some(name.clone()),
1002 Expr::Aliased { name, .. } => Some(name.clone()),
1003 Expr::Star => None, _ => None,
1005 }
1006 })
1007 .collect();
1008
1009 if columns.is_empty() {
1010 return Err(PgError::Query(
1011 "copy_bulk requires columns in Qail".to_string(),
1012 ));
1013 }
1014
1015 self.connection.copy_in_fast(table, &columns, rows).await
1017 }
1018
1019 pub async fn copy_bulk_bytes(&mut self, cmd: &Qail, data: &[u8]) -> PgResult<u64> {
1032 use qail_core::ast::Action;
1033
1034 if cmd.action != Action::Add {
1035 return Err(PgError::Query(
1036 "copy_bulk_bytes requires Qail::Add action".to_string(),
1037 ));
1038 }
1039
1040 let table = &cmd.table;
1041 let columns: Vec<String> = cmd
1042 .columns
1043 .iter()
1044 .filter_map(|expr| {
1045 use qail_core::ast::Expr;
1046 match expr {
1047 Expr::Named(name) => Some(name.clone()),
1048 Expr::Aliased { name, .. } => Some(name.clone()),
1049 _ => None,
1050 }
1051 })
1052 .collect();
1053
1054 if columns.is_empty() {
1055 return Err(PgError::Query(
1056 "copy_bulk_bytes requires columns in Qail".to_string(),
1057 ));
1058 }
1059
1060 self.connection.copy_in_raw(table, &columns, data).await
1062 }
1063
1064 pub async fn copy_export_table(
1072 &mut self,
1073 table: &str,
1074 columns: &[String],
1075 ) -> PgResult<Vec<u8>> {
1076 let cols = columns.join(", ");
1077 let sql = format!("COPY {} ({}) TO STDOUT", table, cols);
1078
1079 self.connection.copy_out_raw(&sql).await
1080 }
1081
1082 pub async fn stream_cmd(
1096 &mut self,
1097 cmd: &Qail,
1098 batch_size: usize,
1099 ) -> PgResult<Vec<Vec<PgRow>>> {
1100 use std::sync::atomic::{AtomicU64, Ordering};
1101 static CURSOR_ID: AtomicU64 = AtomicU64::new(0);
1102
1103 let cursor_name = format!("qail_cursor_{}", CURSOR_ID.fetch_add(1, Ordering::SeqCst));
1104
1105 use crate::protocol::AstEncoder;
1107 let mut sql_buf = bytes::BytesMut::with_capacity(256);
1108 let mut params: Vec<Option<Vec<u8>>> = Vec::new();
1109 AstEncoder::encode_select_sql(cmd, &mut sql_buf, &mut params);
1110 let sql = String::from_utf8_lossy(&sql_buf).to_string();
1111
1112 self.connection.begin_transaction().await?;
1114
1115 self.connection.declare_cursor(&cursor_name, &sql).await?;
1117
1118 let mut all_batches = Vec::new();
1120 while let Some(rows) = self
1121 .connection
1122 .fetch_cursor(&cursor_name, batch_size)
1123 .await?
1124 {
1125 let pg_rows: Vec<PgRow> = rows
1126 .into_iter()
1127 .map(|cols| PgRow {
1128 columns: cols,
1129 column_info: None,
1130 })
1131 .collect();
1132 all_batches.push(pg_rows);
1133 }
1134
1135 self.connection.close_cursor(&cursor_name).await?;
1136 self.connection.commit().await?;
1137
1138 Ok(all_batches)
1139 }
1140}
1141
1142#[derive(Default)]
1159pub struct PgDriverBuilder {
1160 host: Option<String>,
1161 port: Option<u16>,
1162 user: Option<String>,
1163 database: Option<String>,
1164 password: Option<String>,
1165 timeout: Option<std::time::Duration>,
1166}
1167
1168impl PgDriverBuilder {
1169 pub fn new() -> Self {
1171 Self::default()
1172 }
1173
1174 pub fn host(mut self, host: impl Into<String>) -> Self {
1176 self.host = Some(host.into());
1177 self
1178 }
1179
1180 pub fn port(mut self, port: u16) -> Self {
1182 self.port = Some(port);
1183 self
1184 }
1185
1186 pub fn user(mut self, user: impl Into<String>) -> Self {
1188 self.user = Some(user.into());
1189 self
1190 }
1191
1192 pub fn database(mut self, database: impl Into<String>) -> Self {
1194 self.database = Some(database.into());
1195 self
1196 }
1197
1198 pub fn password(mut self, password: impl Into<String>) -> Self {
1200 self.password = Some(password.into());
1201 self
1202 }
1203
1204 pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
1206 self.timeout = Some(timeout);
1207 self
1208 }
1209
1210 pub async fn connect(self) -> PgResult<PgDriver> {
1212 let host = self.host.as_deref().unwrap_or("127.0.0.1");
1213 let port = self.port.unwrap_or(5432);
1214 let user = self.user.as_deref().ok_or_else(|| {
1215 PgError::Connection("User is required".to_string())
1216 })?;
1217 let database = self.database.as_deref().ok_or_else(|| {
1218 PgError::Connection("Database is required".to_string())
1219 })?;
1220
1221 match (self.password.as_deref(), self.timeout) {
1222 (Some(password), Some(timeout)) => {
1223 PgDriver::connect_with_timeout(host, port, user, database, password, timeout).await
1224 }
1225 (Some(password), None) => {
1226 PgDriver::connect_with_password(host, port, user, database, password).await
1227 }
1228 (None, Some(timeout)) => {
1229 tokio::time::timeout(
1230 timeout,
1231 PgDriver::connect(host, port, user, database),
1232 )
1233 .await
1234 .map_err(|_| PgError::Connection(format!("Connection timeout after {:?}", timeout)))?
1235 }
1236 (None, None) => {
1237 PgDriver::connect(host, port, user, database).await
1238 }
1239 }
1240 }
1241}