Skip to main content

qail_pg/driver/
copy.rs

1//! COPY protocol methods for PostgreSQL bulk operations.
2//!
3
4use super::{PgConnection, PgError, PgResult, parse_affected_rows};
5use crate::protocol::{AstEncoder, BackendMessage, PgEncoder};
6use bytes::BytesMut;
7use qail_core::ast::{Action, Qail};
8use tokio::io::AsyncWriteExt;
9
10/// Quote a SQL identifier to prevent injection.
11/// Wraps in double-quotes, escapes embedded double-quotes, and strips NUL bytes.
12fn quote_ident(ident: &str) -> String {
13    format!("\"{}\"" , ident.replace('\0', "").replace('"', "\"\""))
14}
15
16impl PgConnection {
17    /// **Fast** bulk insert using COPY protocol with zero-allocation encoding.
18    /// Encodes all rows into a single buffer and writes with one syscall.
19    /// ~2x faster than `copy_in_internal` due to batched I/O.
20    pub(crate) async fn copy_in_fast(
21        &mut self,
22        table: &str,
23        columns: &[String],
24        rows: &[Vec<qail_core::ast::Value>],
25    ) -> PgResult<u64> {
26        use crate::protocol::encode_copy_batch;
27
28        let cols: Vec<String> = columns.iter().map(|c| quote_ident(c)).collect();
29        let sql = format!("COPY {} ({}) FROM STDIN", quote_ident(table), cols.join(", "));
30
31        // Send COPY command
32        let bytes = PgEncoder::encode_query_string(&sql);
33        self.stream.write_all(&bytes).await?;
34
35        // Wait for CopyInResponse
36        loop {
37            let msg = self.recv().await?;
38            match msg {
39                BackendMessage::CopyInResponse { .. } => break,
40                BackendMessage::ErrorResponse(err) => {
41                    return Err(PgError::Query(err.message));
42                }
43                _ => {}
44            }
45        }
46
47        // Encode ALL rows into a single buffer (zero-allocation per value)
48        let batch_data = encode_copy_batch(rows);
49
50        // Single write for entire batch!
51        self.send_copy_data(&batch_data).await?;
52
53        // Send CopyDone
54        self.send_copy_done().await?;
55
56        // Wait for CommandComplete
57        let mut affected = 0u64;
58        loop {
59            let msg = self.recv().await?;
60            match msg {
61                BackendMessage::CommandComplete(tag) => {
62                    affected = parse_affected_rows(&tag);
63                }
64                BackendMessage::ReadyForQuery(_) => {
65                    return Ok(affected);
66                }
67                BackendMessage::ErrorResponse(err) => {
68                    return Err(PgError::Query(err.message));
69                }
70                _ => {}
71            }
72        }
73    }
74
75    /// **Fastest** bulk insert using COPY protocol with pre-encoded data.
76    /// Accepts raw COPY text format bytes, no encoding needed.
77    /// Use when caller has already encoded rows to COPY format.
78    /// # Format
79    /// Data should be tab-separated rows with newlines:
80    /// `1\thello\t3.14\n2\tworld\t2.71\n`
81    pub async fn copy_in_raw(
82        &mut self,
83        table: &str,
84        columns: &[String],
85        data: &[u8],
86    ) -> PgResult<u64> {
87        let cols: Vec<String> = columns.iter().map(|c| quote_ident(c)).collect();
88        let sql = format!("COPY {} ({}) FROM STDIN", quote_ident(table), cols.join(", "));
89
90        // Send COPY command
91        let bytes = PgEncoder::encode_query_string(&sql);
92        self.stream.write_all(&bytes).await?;
93
94        // Wait for CopyInResponse
95        loop {
96            let msg = self.recv().await?;
97            match msg {
98                BackendMessage::CopyInResponse { .. } => break,
99                BackendMessage::ErrorResponse(err) => {
100                    return Err(PgError::Query(err.message));
101                }
102                _ => {}
103            }
104        }
105
106        // Single write - data is already encoded!
107        self.send_copy_data(data).await?;
108
109        // Send CopyDone
110        self.send_copy_done().await?;
111
112        // Wait for CommandComplete
113        let mut affected = 0u64;
114        loop {
115            let msg = self.recv().await?;
116            match msg {
117                BackendMessage::CommandComplete(tag) => {
118                    affected = parse_affected_rows(&tag);
119                }
120                BackendMessage::ReadyForQuery(_) => {
121                    return Ok(affected);
122                }
123                BackendMessage::ErrorResponse(err) => {
124                    return Err(PgError::Query(err.message));
125                }
126                _ => {}
127            }
128        }
129    }
130
131    /// Send CopyData message (raw bytes).
132    async fn send_copy_data(&mut self, data: &[u8]) -> PgResult<()> {
133        // CopyData: 'd' + length + data
134        let len = (data.len() + 4) as i32;
135        let mut buf = BytesMut::with_capacity(1 + 4 + data.len());
136        buf.extend_from_slice(b"d");
137        buf.extend_from_slice(&len.to_be_bytes());
138        buf.extend_from_slice(data);
139        self.stream.write_all(&buf).await?;
140        Ok(())
141    }
142
143    async fn send_copy_done(&mut self) -> PgResult<()> {
144        // CopyDone: 'c' + length (4)
145        self.stream.write_all(&[b'c', 0, 0, 0, 4]).await?;
146        Ok(())
147    }
148
149    /// Export data using COPY TO STDOUT (AST-native).
150    /// Takes a Qail::Export and returns rows as Vec<Vec<String>>.
151    /// # Example
152    /// ```ignore
153    /// let cmd = Qail::export("users")
154    ///     .columns(["id", "name"])
155    ///     .filter("active", true);
156    /// let rows = conn.copy_export(&cmd).await?;
157    /// ```
158    pub async fn copy_export(&mut self, cmd: &Qail) -> PgResult<Vec<Vec<String>>> {
159
160        if cmd.action != Action::Export {
161            return Err(PgError::Query(
162                "copy_export requires Qail::Export action".to_string(),
163            ));
164        }
165
166        // Encode command to SQL using AST encoder
167        let (sql, _params) = AstEncoder::encode_cmd_sql(cmd).map_err(|e| PgError::Encode(e.to_string()))?;
168
169        // Send COPY command
170        let bytes = PgEncoder::encode_query_string(&sql);
171        self.stream.write_all(&bytes).await?;
172
173        // Wait for CopyOutResponse
174        loop {
175            let msg = self.recv().await?;
176            match msg {
177                BackendMessage::CopyOutResponse { .. } => break,
178                BackendMessage::ErrorResponse(err) => {
179                    return Err(PgError::Query(err.message));
180                }
181                _ => {}
182            }
183        }
184
185        // Receive CopyData messages until CopyDone
186        let mut rows = Vec::new();
187        loop {
188            let msg = self.recv().await?;
189            match msg {
190                BackendMessage::CopyData(data) => {
191                    let line = String::from_utf8_lossy(&data);
192                    let line = line.trim_end_matches('\n');
193                    let cols: Vec<String> = line.split('\t').map(|s| s.to_string()).collect();
194                    rows.push(cols);
195                }
196                BackendMessage::CopyDone => {}
197                BackendMessage::CommandComplete(_) => {}
198                BackendMessage::ReadyForQuery(_) => {
199                    return Ok(rows);
200                }
201                BackendMessage::ErrorResponse(err) => {
202                    return Err(PgError::Query(err.message));
203                }
204                _ => {}
205            }
206        }
207    }
208
209    /// Export data using raw COPY TO STDOUT, returning raw bytes.
210    /// Format: tab-separated values, newline-terminated rows.
211    /// Suitable for direct re-import via copy_in_raw.
212    ///
213    /// # Safety
214    /// `pub(crate)` — not exposed externally because callers pass raw SQL.
215    /// External code should use `copy_export()` with the AST encoder instead.
216    pub(crate) async fn copy_out_raw(&mut self, sql: &str) -> PgResult<Vec<u8>> {
217        // Send COPY command
218        let bytes = PgEncoder::encode_query_string(sql);
219        self.stream.write_all(&bytes).await?;
220
221        // Wait for CopyOutResponse
222        loop {
223            let msg = self.recv().await?;
224            match msg {
225                BackendMessage::CopyOutResponse { .. } => break,
226                BackendMessage::ErrorResponse(err) => {
227                    return Err(PgError::Query(err.message));
228                }
229                _ => {}
230            }
231        }
232
233        // Receive CopyData messages until CopyDone
234        let mut data = Vec::new();
235        loop {
236            let msg = self.recv().await?;
237            match msg {
238                BackendMessage::CopyData(chunk) => {
239                    data.extend_from_slice(&chunk);
240                }
241                BackendMessage::CopyDone => {}
242                BackendMessage::CommandComplete(_) => {}
243                BackendMessage::ReadyForQuery(_) => {
244                    return Ok(data);
245                }
246                BackendMessage::ErrorResponse(err) => {
247                    return Err(PgError::Query(err.message));
248                }
249                _ => {}
250            }
251        }
252    }
253}