cdbc_pg/copy.rs
1use cdbc::error::{Error, Result};
2use cdbc::pool::{Pool, PoolConnection};
3use crate::connection::PgConnection;
4use crate::message::{
5 CommandComplete, CopyData, CopyDone, CopyFail, CopyResponse, MessageFormat, Query,
6};
7use crate::Postgres;
8use bytes::{BufMut, Bytes};
9use smallvec::alloc::borrow::Cow;
10use std::convert::TryFrom;
11use std::ops::{Deref, DerefMut};
12use cdbc::io::chan_stream::ChanStream;
13use std::io::Write;
14
15impl PgConnection {
16 /// Issue a `COPY FROM STDIN` statement and transition the connection to streaming data
17 /// to Postgres. This is a more efficient way to import data into Postgres as compared to
18 /// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
19 ///
20 /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
21 /// returned.
22 ///
23 /// Command examples and accepted formats for `COPY` data are shown here:
24 /// https://www.postgresql.org/docs/current/sql-copy.html
25 ///
26 /// ### Note
27 /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
28 /// will return an error the next time it is used.
29 pub fn copy_in_raw(&mut self, statement: &str) -> Result<PgCopyIn<&mut Self>> {
30 PgCopyIn::begin(self, statement)
31 }
32
33 /// Issue a `COPY TO STDOUT` statement and transition the connection to streaming data
34 /// from Postgres. This is a more efficient way to export data from Postgres but
35 /// arrives in chunks of one of a few data formats (text/CSV/binary).
36 ///
37 /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
38 /// an error is returned.
39 ///
40 /// Note that once this process has begun, unless you read the stream to completion,
41 /// it can only be canceled in two ways:
42 ///
43 /// 1. by closing the connection, or:
44 /// 2. by using another connection to kill the server process that is sending the data as shown
45 /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
46 ///
47 /// If you don't read the stream to completion, the next time the connection is used it will
48 /// need to read and discard all the remaining queued data, which could take some time.
49 ///
50 /// Command examples and accepted formats for `COPY` data are shown here:
51 /// https://www.postgresql.org/docs/current/sql-copy.html
52 #[allow(clippy::needless_lifetimes)]
53 pub fn copy_out_raw<'c>(
54 &'c mut self,
55 statement: &str,
56 ) -> Result<ChanStream<Bytes>> {
57 pg_begin_copy_out(self, statement)
58 }
59}
60
61pub trait CopyRaw{
62 fn copy_in_raw(&self, statement: &str) -> Result<PgCopyIn<PoolConnection<Postgres>>>;
63 fn copy_out_raw(&self, statement: &str) -> Result<ChanStream<Bytes>>;
64}
65
66impl CopyRaw for Pool<Postgres> {
67 /// Issue a `COPY FROM STDIN` statement and begin streaming data to Postgres.
68 /// This is a more efficient way to import data into Postgres as compared to
69 /// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
70 ///
71 /// A single connection will be checked out for the duration.
72 ///
73 /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
74 /// returned.
75 ///
76 /// Command examples and accepted formats for `COPY` data are shown here:
77 /// https://www.postgresql.org/docs/current/sql-copy.html
78 ///
79 /// ### Note
80 /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
81 /// will return an error the next time it is used.
82 fn copy_in_raw(&self, statement: &str) -> Result<PgCopyIn<PoolConnection<Postgres>>> {
83 PgCopyIn::begin(self.acquire()?, statement)
84 }
85
86 /// Issue a `COPY TO STDOUT` statement and begin streaming data
87 /// from Postgres. This is a more efficient way to export data from Postgres but
88 /// arrives in chunks of one of a few data formats (text/CSV/binary).
89 ///
90 /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
91 /// an error is returned.
92 ///
93 /// Note that once this process has begun, unless you read the stream to completion,
94 /// it can only be canceled in two ways:
95 ///
96 /// 1. by closing the connection, or:
97 /// 2. by using another connection to kill the server process that is sending the data as shown
98 /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
99 ///
100 /// If you don't read the stream to completion, the next time the connection is used it will
101 /// need to read and discard all the remaining queued data, which could take some time.
102 ///
103 /// Command examples and accepted formats for `COPY` data are shown here:
104 /// https://www.postgresql.org/docs/current/sql-copy.html
105 fn copy_out_raw(&self, statement: &str) -> Result<ChanStream<Bytes>> {
106 pg_begin_copy_out(self.acquire()?, statement)
107 }
108}
109
110/// A connection in streaming `COPY FROM STDIN` mode.
111///
112/// Created by [PgConnection::copy_in_raw] or [Pool::copy_out_raw].
113///
114/// ### Note
115/// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
116/// will return an error the next time it is used.
117#[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"]
118pub struct PgCopyIn<C: DerefMut<Target = PgConnection>> {
119 conn: Option<C>,
120 response: CopyResponse,
121}
122
123impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
124 fn begin(mut conn: C, statement: &str) -> Result<Self> {
125 conn.wait_until_ready()?;
126 conn.stream.send(Query(statement))?;
127
128 let response: CopyResponse = conn
129 .stream
130 .recv_expect(MessageFormat::CopyInResponse)
131 ?;
132
133 Ok(PgCopyIn {
134 conn: Some(conn),
135 response,
136 })
137 }
138
139 /// Returns `true` if Postgres is expecting data in text or CSV format.
140 pub fn is_textual(&self) -> bool {
141 self.response.format == 0
142 }
143
144 /// Returns the number of columns expected in the input.
145 pub fn num_columns(&self) -> usize {
146 assert_eq!(
147 self.response.num_columns as usize,
148 self.response.format_codes.len(),
149 "num_columns does not match format_codes.len()"
150 );
151 self.response.format_codes.len()
152 }
153
154 /// Check if a column is expecting data in text format (`true`) or binary format (`false`).
155 ///
156 /// ### Panics
157 /// If `column` is out of range according to [`.num_columns()`][Self::num_columns].
158 pub fn column_is_textual(&self, column: usize) -> bool {
159 self.response.format_codes[column] == 0
160 }
161
162 /// Send a chunk of `COPY` data.
163 ///
164 /// If you're copying data from an `AsyncRead`, maybe consider [Self::read_from] instead.
165 pub fn send(&mut self, data: impl Deref<Target = [u8]>) -> Result<&mut Self> {
166 self.conn
167 .as_deref_mut()
168 .expect("send_data: conn taken")
169 .stream
170 .send(CopyData(data))
171 ?;
172
173 Ok(self)
174 }
175
176 /// Copy data directly from `source` to the database without requiring an intermediate buffer.
177 ///
178 /// `source` will be read to the end.
179 ///
180 /// ### Note
181 /// You must still call either [Self::finish] or [Self::abort] to complete the process.
182 pub fn read_from(&mut self, mut source: impl std::io::Read) -> Result<&mut Self> {
183 // this is a separate guard from WriteAndFlush so we can reuse the buffer without zeroing
184 struct BufGuard<'s>(&'s mut Vec<u8>);
185
186 impl Drop for BufGuard<'_> {
187 fn drop(&mut self) {
188 self.0.clear()
189 }
190 }
191
192 let conn: &mut PgConnection = self.conn.as_deref_mut().expect("copy_from: conn taken");
193
194 // flush any existing messages in the buffer and clear it
195 conn.stream.flush()?;
196
197 {
198 let buf_stream = &mut *conn.stream;
199 let stream = &mut buf_stream.stream;
200
201 // ensures the buffer isn't left in an inconsistent state
202 let mut guard = BufGuard(&mut buf_stream.wbuf);
203
204 let buf: &mut Vec<u8> = &mut guard.0;
205 buf.push(b'd'); // CopyData format code
206 buf.resize(5, 0); // reserve space for the length
207
208 loop {
209 let read = match () {
210 _ => {
211 // should be a no-op unless len != capacity
212 buf.resize(buf.capacity(), 0);
213 source.read(&mut buf[5..])?
214 }
215 };
216
217 if read == 0 {
218 break;
219 }
220
221 let read32 = u32::try_from(read)
222 .map_err(|_| err_protocol!("number of bytes read exceeds 2^32: {}", read))?;
223
224 (&mut buf[1..]).put_u32(read32 + 4);
225
226 stream.write_all(&buf[..read + 5])?;
227 stream.flush()?;
228 }
229 }
230
231 Ok(self)
232 }
233
234 /// Signal that the `COPY` process should be aborted and any data received should be discarded.
235 ///
236 /// The given message can be used for indicating the reason for the abort in the database logs.
237 ///
238 /// The server is expected to respond with an error, so only _unexpected_ errors are returned.
239 pub fn abort(mut self, msg: impl Into<String>) -> Result<()> {
240 let mut conn = self
241 .conn
242 .take()
243 .expect("PgCopyIn::fail_with: conn taken illegally");
244
245 conn.stream.send(CopyFail::new(msg))?;
246
247 match conn.stream.recv() {
248 Ok(msg) => Err(err_protocol!(
249 "fail_with: expected ErrorResponse, got: {:?}",
250 msg.format
251 )),
252 Err(Error::Database(e)) => {
253 match e.code() {
254 Some(Cow::Borrowed("57014")) => {
255 // postgres abort received error code
256 conn.stream
257 .recv_expect(MessageFormat::ReadyForQuery)
258 ?;
259 Ok(())
260 }
261 _ => Err(Error::Database(e)),
262 }
263 }
264 Err(e) => Err(e),
265 }
266 }
267
268 /// Signal that the `COPY` process is complete.
269 ///
270 /// The number of rows affected is returned.
271 pub fn finish(mut self) -> Result<u64> {
272 let mut conn = self
273 .conn
274 .take()
275 .expect("CopyWriter::finish: conn taken illegally");
276
277 conn.stream.send(CopyDone)?;
278 let cc: CommandComplete = conn
279 .stream
280 .recv_expect(MessageFormat::CommandComplete)
281 ?;
282
283 conn.stream
284 .recv_expect(MessageFormat::ReadyForQuery)
285 ?;
286
287 Ok(cc.rows_affected())
288 }
289}
290
291impl<C: DerefMut<Target = PgConnection>> Drop for PgCopyIn<C> {
292 fn drop(&mut self) {
293 if let Some(mut conn) = self.conn.take() {
294 conn.stream.write(CopyFail::new(
295 "PgCopyIn dropped without calling finish() or fail()",
296 ));
297 }
298 }
299}
300
301fn pg_begin_copy_out<'c, C: DerefMut<Target = PgConnection> + Send + 'c>(
302 mut conn: C,
303 statement: &str,
304) -> Result<ChanStream<Bytes>> {
305 conn.wait_until_ready()?;
306 conn.stream.send(Query(statement))?;
307
308 let _: CopyResponse = conn
309 .stream
310 .recv_expect(MessageFormat::CopyOutResponse)
311 ?;
312
313 let stream = chan_stream! {
314 loop {
315 let msg = conn.stream.recv()?;
316 match msg.format {
317 MessageFormat::CopyData => r#yield!(msg.decode::<CopyData<Bytes>>()?.0),
318 MessageFormat::CopyDone => {
319 let _ = msg.decode::<CopyDone>()?;
320 conn.stream.recv_expect(MessageFormat::CommandComplete)?;
321 conn.stream.recv_expect(MessageFormat::ReadyForQuery)?;
322 return Ok(())
323 },
324 _ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format))
325 }
326 }
327 };
328
329 Ok(stream)
330}