Skip to main content

sentinel_driver/copy/
mod.rs

1pub mod binary;
2pub mod text;
3
4use crate::connection::stream::PgConnection;
5use crate::error::{Error, Result};
6use crate::protocol::backend::{BackendMessage, CopyFormat};
7use crate::protocol::frontend;
8use crate::row::parse_command_tag;
9
10/// A COPY IN operation — streaming data to the server.
11///
12/// Created by sending a `COPY ... FROM STDIN` query.
13/// Write rows via `write_raw()` or the format-specific helpers,
14/// then call `finish()` to complete.
15pub struct CopyIn<'a> {
16    conn: &'a mut PgConnection,
17    format: CopyFormat,
18    column_count: usize,
19    finished: bool,
20}
21
22impl<'a> CopyIn<'a> {
23    pub(crate) fn new(conn: &'a mut PgConnection, format: CopyFormat, column_count: usize) -> Self {
24        Self {
25            conn,
26            format,
27            column_count,
28            finished: false,
29        }
30    }
31
32    /// The COPY format (Text or Binary).
33    pub fn format(&self) -> CopyFormat {
34        self.format
35    }
36
37    /// Number of columns expected per row.
38    pub fn column_count(&self) -> usize {
39        self.column_count
40    }
41
42    /// Write raw COPY data. The data must be in the correct format
43    /// (text or binary) as negotiated with the server.
44    pub async fn write_raw(&mut self, data: &[u8]) -> Result<()> {
45        frontend::copy_data(self.conn.write_buf(), data);
46        self.conn.send().await
47    }
48
49    /// Finish the COPY operation and return the number of rows inserted.
50    pub async fn finish(mut self) -> Result<u64> {
51        self.finished = true;
52
53        frontend::copy_done(self.conn.write_buf());
54        self.conn.send().await?;
55
56        // Expect CommandComplete then ReadyForQuery
57        let rows = loop {
58            match self.conn.recv().await? {
59                BackendMessage::CommandComplete { tag } => {
60                    break parse_command_tag(&tag).rows_affected;
61                }
62                BackendMessage::ErrorResponse { fields } => {
63                    // Drain until ReadyForQuery
64                    drain_until_ready(self.conn).await.ok();
65                    return Err(Error::server(
66                        fields.severity,
67                        fields.code,
68                        fields.message,
69                        fields.detail,
70                        fields.hint,
71                        fields.position,
72                    ));
73                }
74                _ => {}
75            }
76        };
77
78        // Wait for ReadyForQuery
79        drain_until_ready(self.conn).await?;
80
81        Ok(rows)
82    }
83
84    /// Abort the COPY operation with an error message.
85    pub async fn abort(mut self, message: &str) -> Result<()> {
86        self.finished = true;
87
88        frontend::copy_fail(self.conn.write_buf(), message);
89        self.conn.send().await?;
90
91        // Server will send ErrorResponse + ReadyForQuery
92        drain_until_ready(self.conn).await.ok();
93
94        Ok(())
95    }
96}
97
98impl Drop for CopyIn<'_> {
99    fn drop(&mut self) {
100        if !self.finished {
101            // Can't do async in drop — just write CopyFail to buffer.
102            // The next operation on the connection will flush it.
103            frontend::copy_fail(
104                self.conn.write_buf(),
105                "COPY IN aborted: dropped without finish",
106            );
107        }
108    }
109}
110
111/// A COPY OUT operation — streaming data from the server.
112///
113/// Created by sending a `COPY ... TO STDOUT` query.
114/// Read rows via `read_raw()` until it returns `None`.
115pub struct CopyOut<'a> {
116    conn: &'a mut PgConnection,
117    format: CopyFormat,
118    done: bool,
119}
120
121impl<'a> CopyOut<'a> {
122    pub(crate) fn new(conn: &'a mut PgConnection, format: CopyFormat) -> Self {
123        Self {
124            conn,
125            format,
126            done: false,
127        }
128    }
129
130    /// The COPY format (Text or Binary).
131    pub fn format(&self) -> CopyFormat {
132        self.format
133    }
134
135    /// Read the next chunk of COPY data.
136    ///
137    /// Returns `None` when the COPY operation is complete.
138    pub async fn read_raw(&mut self) -> Result<Option<bytes::Bytes>> {
139        if self.done {
140            return Ok(None);
141        }
142
143        loop {
144            match self.conn.recv().await? {
145                BackendMessage::CopyData { data } => {
146                    return Ok(Some(data));
147                }
148                BackendMessage::CopyDone => {
149                    self.done = true;
150                    // Expect CommandComplete + ReadyForQuery
151                    drain_until_ready(self.conn).await?;
152                    return Ok(None);
153                }
154                BackendMessage::ErrorResponse { fields } => {
155                    self.done = true;
156                    drain_until_ready(self.conn).await.ok();
157                    return Err(Error::server(
158                        fields.severity,
159                        fields.code,
160                        fields.message,
161                        fields.detail,
162                        fields.hint,
163                        fields.position,
164                    ));
165                }
166                _ => {}
167            }
168        }
169    }
170}
171
172/// Start a COPY IN operation by sending the COPY query.
173pub(crate) async fn start_copy_in(
174    conn: &mut PgConnection,
175    sql: &str,
176) -> Result<(CopyFormat, usize)> {
177    frontend::query(conn.write_buf(), sql);
178    conn.send().await?;
179
180    loop {
181        match conn.recv().await? {
182            BackendMessage::CopyInResponse {
183                format,
184                column_formats,
185            } => {
186                return Ok((format, column_formats.len()));
187            }
188            BackendMessage::ErrorResponse { fields } => {
189                drain_until_ready(conn).await.ok();
190                return Err(Error::server(
191                    fields.severity,
192                    fields.code,
193                    fields.message,
194                    fields.detail,
195                    fields.hint,
196                    fields.position,
197                ));
198            }
199            _ => {}
200        }
201    }
202}
203
204/// Start a COPY OUT operation by sending the COPY query.
205pub(crate) async fn start_copy_out(conn: &mut PgConnection, sql: &str) -> Result<CopyFormat> {
206    frontend::query(conn.write_buf(), sql);
207    conn.send().await?;
208
209    loop {
210        match conn.recv().await? {
211            BackendMessage::CopyOutResponse { format, .. } => {
212                return Ok(format);
213            }
214            BackendMessage::ErrorResponse { fields } => {
215                drain_until_ready(conn).await.ok();
216                return Err(Error::server(
217                    fields.severity,
218                    fields.code,
219                    fields.message,
220                    fields.detail,
221                    fields.hint,
222                    fields.position,
223                ));
224            }
225            _ => {}
226        }
227    }
228}
229
230/// Drain messages until ReadyForQuery.
231async fn drain_until_ready(conn: &mut PgConnection) -> Result<()> {
232    loop {
233        if let BackendMessage::ReadyForQuery { .. } = conn.recv().await? {
234            return Ok(());
235        }
236    }
237}