Skip to main content

pgwire/api/
copy.rs

1use async_trait::async_trait;
2use futures::sink::{Sink, SinkExt};
3use futures::stream::StreamExt;
4use std::fmt::Debug;
5
6use crate::error::{ErrorInfo, PgWireError, PgWireResult};
7use crate::messages::PgWireBackendMessage;
8use crate::messages::copy::{
9    CopyBothResponse, CopyData, CopyDone, CopyFail, CopyInResponse, CopyOutResponse,
10};
11
12use super::ClientInfo;
13use super::results::{CopyResponse, Tag};
14
15/// handler for copy messages
16#[async_trait]
17pub trait CopyHandler: Send + Sync {
18    async fn on_copy_data<C>(&self, _client: &mut C, _copy_data: CopyData) -> PgWireResult<()>
19    where
20        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
21        C::Error: Debug,
22        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>;
23
24    async fn on_copy_done<C>(&self, _client: &mut C, _done: CopyDone) -> PgWireResult<()>
25    where
26        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
27        C::Error: Debug,
28        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>;
29
30    async fn on_copy_fail<C>(&self, _client: &mut C, fail: CopyFail) -> PgWireError
31    where
32        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
33        C::Error: Debug,
34        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
35    {
36        PgWireError::UserError(Box::new(ErrorInfo::new(
37            "ERROR".to_owned(),
38            "XX000".to_owned(),
39            format!("COPY IN mode terminated by the user: {}", fail.message),
40        )))
41    }
42}
43
44pub async fn send_copy_in_response<C>(client: &mut C, resp: CopyResponse) -> PgWireResult<()>
45where
46    C: Sink<PgWireBackendMessage> + Unpin,
47    C::Error: Debug,
48    PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
49{
50    let column_formats = resp.column_formats();
51    let resp = CopyInResponse::new(resp.format, resp.columns as i16, column_formats);
52    client
53        .send(PgWireBackendMessage::CopyInResponse(resp))
54        .await?;
55    Ok(())
56}
57
58pub async fn send_copy_out_response<C>(client: &mut C, resp: CopyResponse) -> PgWireResult<()>
59where
60    C: Sink<PgWireBackendMessage> + Unpin,
61    C::Error: Debug,
62    PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
63{
64    let column_formats = resp.column_formats();
65    let CopyResponse {
66        format,
67        columns,
68        mut data_stream,
69    } = resp;
70    let copy_resp = CopyOutResponse::new(format, columns as i16, column_formats);
71    client
72        .send(PgWireBackendMessage::CopyOutResponse(copy_resp))
73        .await?;
74
75    let mut rows = 0;
76
77    while let Some(copy_data) = data_stream.next().await {
78        match copy_data {
79            Ok(data) => {
80                if !data.data.is_empty() {
81                    // do not count trailer
82                    if data.data.as_ref() != [0xFF, 0xFF] {
83                        rows += 1;
84                    }
85                    client.feed(PgWireBackendMessage::CopyData(data)).await?;
86                }
87            }
88            Err(e) => {
89                let copy_fail = CopyFail::new(format!("{}", e));
90                client
91                    .send(PgWireBackendMessage::CopyFail(copy_fail))
92                    .await?;
93                return Err(e);
94            }
95        }
96    }
97
98    let copy_done = CopyDone::new();
99    client
100        .send(PgWireBackendMessage::CopyDone(copy_done))
101        .await?;
102
103    let tag = Tag::new("COPY").with_rows(rows);
104    client
105        .send(PgWireBackendMessage::CommandComplete(tag.into()))
106        .await?;
107
108    Ok(())
109}
110
111pub async fn send_copy_both_response<C>(client: &mut C, resp: CopyResponse) -> PgWireResult<()>
112where
113    C: Sink<PgWireBackendMessage> + Unpin,
114    C::Error: Debug,
115    PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
116{
117    let column_formats = resp.column_formats();
118    let CopyResponse {
119        format,
120        columns,
121        mut data_stream,
122    } = resp;
123    let copy_resp = CopyBothResponse::new(format, columns as i16, column_formats);
124    client
125        .send(PgWireBackendMessage::CopyBothResponse(copy_resp))
126        .await?;
127
128    let mut rows = 0;
129
130    while let Some(copy_data) = data_stream.next().await {
131        match copy_data {
132            Ok(data) => {
133                if !data.data.is_empty() {
134                    // do not count trailer
135                    if data.data.as_ref() != [0xFF, 0xFF] {
136                        rows += 1;
137                    }
138                    client.feed(PgWireBackendMessage::CopyData(data)).await?;
139                }
140            }
141            Err(e) => {
142                let copy_fail = CopyFail::new(format!("{}", e));
143                client
144                    .send(PgWireBackendMessage::CopyFail(copy_fail))
145                    .await?;
146                return Err(e);
147            }
148        }
149    }
150
151    let copy_done = CopyDone::new();
152    client
153        .send(PgWireBackendMessage::CopyDone(copy_done))
154        .await?;
155
156    let tag = Tag::new("COPY").with_rows(rows);
157    client
158        .send(PgWireBackendMessage::CommandComplete(tag.into()))
159        .await?;
160
161    Ok(())
162}
163
164#[async_trait]
165impl CopyHandler for super::NoopHandler {
166    async fn on_copy_data<C>(&self, _client: &mut C, _copy_data: CopyData) -> PgWireResult<()>
167    where
168        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
169        C::Error: Debug,
170        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
171    {
172        Err(PgWireError::UserError(Box::new(ErrorInfo::new(
173            "FATAL".to_owned(),
174            "08P01".to_owned(),
175            "This feature is not implemented.".to_string(),
176        ))))
177    }
178
179    async fn on_copy_done<C>(&self, _client: &mut C, _done: CopyDone) -> PgWireResult<()>
180    where
181        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
182        C::Error: Debug,
183        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
184    {
185        Err(PgWireError::UserError(Box::new(ErrorInfo::new(
186            "FATAL".to_owned(),
187            "08P01".to_owned(),
188            "This feature is not implemented.".to_string(),
189        ))))
190    }
191}