1use async_trait::async_trait;
2use futures::sink::{Sink, SinkExt};
3use futures::stream::StreamExt;
4use std::fmt::Debug;
5
6use crate::error::{ErrorInfo, PgWireError, PgWireResult};
7use crate::messages::PgWireBackendMessage;
8use crate::messages::copy::{
9 CopyBothResponse, CopyData, CopyDone, CopyFail, CopyInResponse, CopyOutResponse,
10};
11
12use super::ClientInfo;
13use super::results::{CopyResponse, Tag};
14
15#[async_trait]
17pub trait CopyHandler: Send + Sync {
18 async fn on_copy_data<C>(&self, _client: &mut C, _copy_data: CopyData) -> PgWireResult<()>
19 where
20 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
21 C::Error: Debug,
22 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>;
23
24 async fn on_copy_done<C>(&self, _client: &mut C, _done: CopyDone) -> PgWireResult<()>
25 where
26 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
27 C::Error: Debug,
28 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>;
29
30 async fn on_copy_fail<C>(&self, _client: &mut C, fail: CopyFail) -> PgWireError
31 where
32 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
33 C::Error: Debug,
34 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
35 {
36 PgWireError::UserError(Box::new(ErrorInfo::new(
37 "ERROR".to_owned(),
38 "XX000".to_owned(),
39 format!("COPY IN mode terminated by the user: {}", fail.message),
40 )))
41 }
42}
43
44pub async fn send_copy_in_response<C>(client: &mut C, resp: CopyResponse) -> PgWireResult<()>
45where
46 C: Sink<PgWireBackendMessage> + Unpin,
47 C::Error: Debug,
48 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
49{
50 let column_formats = resp.column_formats();
51 let resp = CopyInResponse::new(resp.format, resp.columns as i16, column_formats);
52 client
53 .send(PgWireBackendMessage::CopyInResponse(resp))
54 .await?;
55 Ok(())
56}
57
58pub async fn send_copy_out_response<C>(client: &mut C, resp: CopyResponse) -> PgWireResult<()>
59where
60 C: Sink<PgWireBackendMessage> + Unpin,
61 C::Error: Debug,
62 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
63{
64 let column_formats = resp.column_formats();
65 let CopyResponse {
66 format,
67 columns,
68 mut data_stream,
69 } = resp;
70 let copy_resp = CopyOutResponse::new(format, columns as i16, column_formats);
71 client
72 .send(PgWireBackendMessage::CopyOutResponse(copy_resp))
73 .await?;
74
75 let mut rows = 0;
76
77 while let Some(copy_data) = data_stream.next().await {
78 match copy_data {
79 Ok(data) => {
80 if !data.data.is_empty() {
81 if data.data.as_ref() != [0xFF, 0xFF] {
83 rows += 1;
84 }
85 client.feed(PgWireBackendMessage::CopyData(data)).await?;
86 }
87 }
88 Err(e) => {
89 let copy_fail = CopyFail::new(format!("{}", e));
90 client
91 .send(PgWireBackendMessage::CopyFail(copy_fail))
92 .await?;
93 return Err(e);
94 }
95 }
96 }
97
98 let copy_done = CopyDone::new();
99 client
100 .send(PgWireBackendMessage::CopyDone(copy_done))
101 .await?;
102
103 let tag = Tag::new("COPY").with_rows(rows);
104 client
105 .send(PgWireBackendMessage::CommandComplete(tag.into()))
106 .await?;
107
108 Ok(())
109}
110
111pub async fn send_copy_both_response<C>(client: &mut C, resp: CopyResponse) -> PgWireResult<()>
112where
113 C: Sink<PgWireBackendMessage> + Unpin,
114 C::Error: Debug,
115 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
116{
117 let column_formats = resp.column_formats();
118 let CopyResponse {
119 format,
120 columns,
121 mut data_stream,
122 } = resp;
123 let copy_resp = CopyBothResponse::new(format, columns as i16, column_formats);
124 client
125 .send(PgWireBackendMessage::CopyBothResponse(copy_resp))
126 .await?;
127
128 let mut rows = 0;
129
130 while let Some(copy_data) = data_stream.next().await {
131 match copy_data {
132 Ok(data) => {
133 if !data.data.is_empty() {
134 if data.data.as_ref() != [0xFF, 0xFF] {
136 rows += 1;
137 }
138 client.feed(PgWireBackendMessage::CopyData(data)).await?;
139 }
140 }
141 Err(e) => {
142 let copy_fail = CopyFail::new(format!("{}", e));
143 client
144 .send(PgWireBackendMessage::CopyFail(copy_fail))
145 .await?;
146 return Err(e);
147 }
148 }
149 }
150
151 let copy_done = CopyDone::new();
152 client
153 .send(PgWireBackendMessage::CopyDone(copy_done))
154 .await?;
155
156 let tag = Tag::new("COPY").with_rows(rows);
157 client
158 .send(PgWireBackendMessage::CommandComplete(tag.into()))
159 .await?;
160
161 Ok(())
162}
163
164#[async_trait]
165impl CopyHandler for super::NoopHandler {
166 async fn on_copy_data<C>(&self, _client: &mut C, _copy_data: CopyData) -> PgWireResult<()>
167 where
168 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
169 C::Error: Debug,
170 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
171 {
172 Err(PgWireError::UserError(Box::new(ErrorInfo::new(
173 "FATAL".to_owned(),
174 "08P01".to_owned(),
175 "This feature is not implemented.".to_string(),
176 ))))
177 }
178
179 async fn on_copy_done<C>(&self, _client: &mut C, _done: CopyDone) -> PgWireResult<()>
180 where
181 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
182 C::Error: Debug,
183 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
184 {
185 Err(PgWireError::UserError(Box::new(ErrorInfo::new(
186 "FATAL".to_owned(),
187 "08P01".to_owned(),
188 "This feature is not implemented.".to_string(),
189 ))))
190 }
191}