1use crate::connection::Authenticated;
2use crate::errors::ClientError;
3use crate::query::QueryFailed;
4use crate::{Client, Connection, DatabaseStream, Query, Result};
5use std::borrow::BorrowMut;
6use std::io::Read;
7
8pub struct Response<T, HasInfo>
38where
39 T: DatabaseStream,
40{
41 query: Query<T, HasInfo>,
42 info_prefix: Option<Vec<u8>>,
43 info_complete: bool,
44 is_ok: bool,
45 result_complete: bool,
46}
47
48impl<T, HasInfo> Response<T, HasInfo>
49where
50 T: DatabaseStream,
51{
52 pub(crate) fn new(query: Query<T, HasInfo>) -> Self {
53 Self {
54 query,
55 info_prefix: None,
56 info_complete: false,
57 is_ok: false,
58 result_complete: false,
59 }
60 }
61
62 pub fn close(mut self) -> Result<Query<T, HasInfo>> {
81 let mut buf = [0u8; 4096];
82
83 while !self.result_complete && self.read(&mut buf)? > 0 {}
84
85 if !self.result_complete {
86 panic!("Unexpected end of stream.");
87 }
88
89 match self.is_ok {
90 true => Ok(self.query),
91 false => {
92 let info_suffix = if !self.info_complete {
93 Some(self.connection().read_string()?)
94 } else {
95 None
96 };
97
98 let mut info = String::from_utf8(self.info_prefix.unwrap_or_default())?;
99
100 if let Some(info_suffix) = info_suffix {
101 info.push_str(info_suffix.as_str());
102 }
103
104 Err(ClientError::QueryFailed(QueryFailed::new(info)))
105 }
106 }
107 }
108
109 fn connection(&mut self) -> &mut Connection<T, Authenticated> {
110 let client: &mut Client<T> = self.query.borrow_mut();
111 client.borrow_mut()
112 }
113}
114
115impl<T, HasInfo> Read for Response<T, HasInfo>
116where
117 T: DatabaseStream,
118{
119 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
120 if self.result_complete {
121 return Ok(0);
122 }
123
124 let size = self.connection().read(buf)?;
125 let mut escape = false;
126 let mut shift = 0usize;
127 let mut position: Option<usize> = None;
128
129 for i in 0..size {
130 if buf[i] == 0xFF && !escape {
131 escape = true;
132 shift += 1;
133 continue;
134 }
135 if buf[i] == 0 && !escape {
136 position = Some(i);
137 break;
138 }
139
140 escape = false;
141 buf[i - shift] = buf[i];
142 }
143
144 if let Some(position) = position {
145 if size > position + 1 {
146 self.result_complete = true;
147 self.is_ok = match buf[..size][position + 1] {
148 0 => true,
149 1 => false,
150 other => panic!("Invalid status byte \"{}\"", other),
151 };
152 if self.is_ok {
153 self.info_complete = true;
154 } else {
155 self.info_prefix = match buf[position + 2..size].iter().position(|&b| b == 0) {
156 Some(length) => {
157 self.info_complete = true;
158 Some(buf[position + 2..position + 2 + length].to_vec())
159 }
160 None => Some(buf[position + 2..size].to_vec()),
161 };
162 }
163 }
164
165 return Ok(position - shift);
166 }
167
168 Ok(size - shift)
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use crate::ClientError;
176
177 #[test]
178 fn test_reading_result_from_response() {
179 let connection = Connection::from_str("result\0".to_owned());
180 let client = Client::new(connection);
181
182 let query = Query::without_info("1".to_owned(), client);
183 let mut response = Response::new(query);
184 let mut actual_response = String::new();
185 response.read_to_string(&mut actual_response).unwrap();
186 let expected_response = "result".to_owned();
187
188 assert_eq!(expected_response, actual_response);
189 }
190
191 #[test]
192 fn test_reading_result_from_response_on_multiple_read_calls() {
193 let connection = Connection::from_str("result".repeat(10) + "\0");
194 let client = Client::new(connection);
195
196 let query = Query::without_info("1".to_owned(), client);
197 let mut response = Response::new(query);
198 let mut actual_response = String::new();
199 response.read_to_string(&mut actual_response).unwrap();
200 let expected_response = "result".repeat(10).to_owned();
201
202 assert_eq!(expected_response, actual_response);
203
204 response.close().expect("Operation must succeed.");
205 }
206
207 #[test]
208 fn test_reading_result_from_response_with_some_escape_bytes() {
209 let connection = Connection::from_bytes(&[0xFFu8, 0, 1, 6, 9, 0xFF, 0xFF, 3, 0]);
210 let client = Client::new(connection);
211
212 let query = Query::without_info("1".to_owned(), client);
213 let mut response = Response::new(query);
214 let mut actual_response: Vec<u8> = vec![];
215 response.read_to_end(&mut actual_response).unwrap();
216 let expected_response = vec![0u8, 1, 6, 9, 0xFF, 3];
217
218 assert_eq!(expected_response, actual_response);
219
220 response.close().expect("Operation must succeed.");
221 }
222
223 #[test]
224 fn test_reading_result_from_response_with_only_escape_bytes() {
225 let mut bytes = [0xFFu8, 0].repeat(10);
226 bytes.extend([0]);
227 let connection = Connection::from_bytes(&bytes);
228 let client = Client::new(connection);
229
230 let query = Query::without_info("1".to_owned(), client);
231 let mut response = Response::new(query);
232 let mut actual_response: Vec<u8> = vec![];
233 response.read_to_end(&mut actual_response).unwrap();
234 let expected_response = [0u8].repeat(10).to_vec();
235
236 assert_eq!(expected_response, actual_response);
237
238 response.close().expect("Operation must succeed.");
239 }
240
241 #[test]
242 fn test_reading_error_from_response() {
243 let expected_error = "Stopped at ., 1/1:\n[XPST0008] Undeclared variable: $x.";
244 let connection = Connection::from_str(format!("partial_result\0\u{1}{}\0", expected_error));
245 let client = Client::new(connection);
246
247 let query = Query::without_info("1".to_owned(), client);
248 let response = Response::new(query);
249 let actual_error = response.close().err().unwrap();
250
251 assert!(matches!(
252 actual_error,
253 ClientError::QueryFailed(q) if q.raw() == expected_error
254 ));
255 }
256
257 #[test]
258 fn test_reading_error_from_response_on_multiple_read_calls() {
259 let expected_error = "Stopped at ., 1/1:\n[XPST0008] ".to_owned() + &"error".repeat(5000);
260 let connection = Connection::from_str(format!("partial_result\0\u{1}{}\0", expected_error));
261 let client = Client::new(connection);
262 let query = Query::without_info("1".to_owned(), client);
263 let response = Response::new(query);
264 let actual_error = response.close().err().unwrap();
265
266 assert!(matches!(
267 actual_error,
268 ClientError::QueryFailed(q) if q.raw() == expected_error
269 ));
270 }
271
272 #[test]
273 #[should_panic]
274 fn test_reading_panics_on_invalid_status_byte() {
275 let connection = Connection::from_str("partial_result\0\u{2}test_error\0".to_owned());
276 let client = Client::new(connection);
277 let query = Query::without_info("1".to_owned(), client);
278
279 let _ = Response::new(query).read(&mut [0u8; 27]);
280 }
281
282 #[test]
283 #[should_panic]
284 fn test_reading_panics_on_incomplete_result() {
285 let connection = Connection::from_str("partial_result".to_owned());
286 let client = Client::new(connection);
287 let query = Query::without_info("1".to_owned(), client);
288
289 let _ = Response::new(query).close();
290 }
291}