1use super::{PgConnection, PgError, PgResult};
6use crate::protocol::{BackendMessage, FrontendMessage};
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8
9const MAX_MESSAGE_SIZE: usize = 1024 * 1024 * 1024; impl PgConnection {
12 pub async fn send(&mut self, msg: FrontendMessage) -> PgResult<()> {
14 let bytes = msg.encode();
15 self.stream.write_all(&bytes).await?;
16 Ok(())
17 }
18
19 pub async fn recv(&mut self) -> PgResult<BackendMessage> {
21 loop {
22 if self.buffer.len() >= 5 {
24 let msg_len = u32::from_be_bytes([
25 self.buffer[1],
26 self.buffer[2],
27 self.buffer[3],
28 self.buffer[4],
29 ]) as usize;
30
31 if msg_len > MAX_MESSAGE_SIZE {
32 return Err(PgError::Protocol(format!(
33 "Message too large: {} bytes (max {})",
34 msg_len, MAX_MESSAGE_SIZE
35 )));
36 }
37
38 if self.buffer.len() > msg_len {
39 let msg_bytes = self.buffer.split_to(msg_len + 1);
41 let (msg, _) = BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
42 return Ok(msg);
43 }
44 }
45
46 if self.buffer.capacity() - self.buffer.len() < 65536 {
47 self.buffer.reserve(131072); }
49
50 let n = self.stream.read_buf(&mut self.buffer).await?;
51 if n == 0 {
52 return Err(PgError::Connection("Connection closed".to_string()));
53 }
54 }
55 }
56
57 pub async fn send_bytes(&mut self, bytes: &[u8]) -> PgResult<()> {
61 self.stream.write_all(bytes).await?;
62 self.stream.flush().await?;
63 Ok(())
64 }
65
66 #[inline]
71 pub fn buffer_bytes(&mut self, bytes: &[u8]) {
72 self.write_buf.extend_from_slice(bytes);
73 }
74
75 pub async fn flush_write_buf(&mut self) -> PgResult<()> {
78 if !self.write_buf.is_empty() {
79 self.stream.write_all(&self.write_buf).await?;
80 self.write_buf.clear();
81 self.stream.flush().await?;
82 }
83 Ok(())
84 }
85
86 #[inline]
90 pub(crate) async fn recv_msg_type_fast(&mut self) -> PgResult<u8> {
91 loop {
92 if self.buffer.len() >= 5 {
93 let msg_len = u32::from_be_bytes([
94 self.buffer[1],
95 self.buffer[2],
96 self.buffer[3],
97 self.buffer[4],
98 ]) as usize;
99
100 if msg_len > MAX_MESSAGE_SIZE {
101 return Err(PgError::Protocol(format!(
102 "Message too large: {} bytes (max {})",
103 msg_len, MAX_MESSAGE_SIZE
104 )));
105 }
106
107 if self.buffer.len() > msg_len {
108 let msg_type = self.buffer[0];
109
110 if msg_type == b'E' {
111 let msg_bytes = self.buffer.split_to(msg_len + 1);
112 let (msg, _) =
113 BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
114 if let BackendMessage::ErrorResponse(err) = msg {
115 return Err(PgError::Query(err.message));
116 }
117 }
118
119 let _ = self.buffer.split_to(msg_len + 1);
120 return Ok(msg_type);
121 }
122 }
123
124 if self.buffer.capacity() - self.buffer.len() < 65536 {
125 self.buffer.reserve(131072); }
127
128 let n = self.stream.read_buf(&mut self.buffer).await?;
129 if n == 0 {
130 return Err(PgError::Connection("Connection closed".to_string()));
131 }
132 }
133 }
134
135 #[inline]
141 pub(crate) async fn recv_with_data_fast(
142 &mut self,
143 ) -> PgResult<(u8, Option<Vec<Option<Vec<u8>>>>)> {
144 loop {
145 if self.buffer.len() >= 5 {
146 let msg_len = u32::from_be_bytes([
147 self.buffer[1],
148 self.buffer[2],
149 self.buffer[3],
150 self.buffer[4],
151 ]) as usize;
152
153 if msg_len > MAX_MESSAGE_SIZE {
154 return Err(PgError::Protocol(format!(
155 "Message too large: {} bytes (max {})",
156 msg_len, MAX_MESSAGE_SIZE
157 )));
158 }
159
160 if self.buffer.len() > msg_len {
161 let msg_type = self.buffer[0];
162
163 if msg_type == b'E' {
164 let msg_bytes = self.buffer.split_to(msg_len + 1);
165 let (msg, _) =
166 BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
167 if let BackendMessage::ErrorResponse(err) = msg {
168 return Err(PgError::Query(err.message));
169 }
170 }
171
172 if msg_type == b'D' {
174 let payload = &self.buffer[5..msg_len + 1];
175
176 if payload.len() >= 2 {
177 let column_count =
178 u16::from_be_bytes([payload[0], payload[1]]) as usize;
179 let mut columns = Vec::with_capacity(column_count);
180 let mut pos = 2;
181
182 for _ in 0..column_count {
183 if pos + 4 > payload.len() {
184 break;
185 }
186
187 let len = i32::from_be_bytes([
188 payload[pos],
189 payload[pos + 1],
190 payload[pos + 2],
191 payload[pos + 3],
192 ]);
193 pos += 4;
194
195 if len == -1 {
196 columns.push(None);
197 } else {
198 let len = len as usize;
199 if pos + len <= payload.len() {
200 columns.push(Some(payload[pos..pos + len].to_vec()));
201 pos += len;
202 }
203 }
204 }
205
206 let _ = self.buffer.split_to(msg_len + 1);
207 return Ok((msg_type, Some(columns)));
208 }
209 }
210
211 let _ = self.buffer.split_to(msg_len + 1);
213 return Ok((msg_type, None));
214 }
215 }
216
217 if self.buffer.capacity() - self.buffer.len() < 65536 {
218 self.buffer.reserve(131072);
219 }
220
221 let n = self.stream.read_buf(&mut self.buffer).await?;
222 if n == 0 {
223 return Err(PgError::Connection("Connection closed".to_string()));
224 }
225 }
226 }
227
228 #[inline]
234 pub(crate) async fn recv_data_zerocopy(
235 &mut self,
236 ) -> PgResult<(u8, Option<Vec<Option<bytes::Bytes>>>)> {
237 use bytes::Buf;
238
239 loop {
240 if self.buffer.len() >= 5 {
241 let msg_len = u32::from_be_bytes([
242 self.buffer[1],
243 self.buffer[2],
244 self.buffer[3],
245 self.buffer[4],
246 ]) as usize;
247
248 if msg_len > MAX_MESSAGE_SIZE {
249 return Err(PgError::Protocol(format!(
250 "Message too large: {} bytes (max {})",
251 msg_len, MAX_MESSAGE_SIZE
252 )));
253 }
254
255 if self.buffer.len() > msg_len {
256 let msg_type = self.buffer[0];
257
258 if msg_type == b'E' {
259 let msg_bytes = self.buffer.split_to(msg_len + 1);
260 let (msg, _) =
261 BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
262 if let BackendMessage::ErrorResponse(err) = msg {
263 return Err(PgError::Query(err.message));
264 }
265 }
266
267 if msg_type == b'D' {
269 let mut msg_bytes = self.buffer.split_to(msg_len + 1);
271
272 msg_bytes.advance(5);
274
275 if msg_bytes.len() >= 2 {
276 let column_count = msg_bytes.get_u16() as usize;
277 let mut columns = Vec::with_capacity(column_count);
278
279 for _ in 0..column_count {
280 if msg_bytes.remaining() < 4 {
281 break;
282 }
283
284 let len = msg_bytes.get_i32();
285
286 if len == -1 {
287 columns.push(None);
288 } else {
289 let len = len as usize;
290 if msg_bytes.remaining() >= len {
291 let col_data = msg_bytes.split_to(len).freeze();
292 columns.push(Some(col_data));
293 }
294 }
295 }
296
297 return Ok((msg_type, Some(columns)));
298 }
299 return Ok((msg_type, None));
300 }
301
302 let _ = self.buffer.split_to(msg_len + 1);
304 return Ok((msg_type, None));
305 }
306 }
307
308 if self.buffer.capacity() - self.buffer.len() < 65536 {
309 self.buffer.reserve(131072);
310 }
311
312 let n = self.stream.read_buf(&mut self.buffer).await?;
313 if n == 0 {
314 return Err(PgError::Connection("Connection closed".to_string()));
315 }
316 }
317 }
318
319 #[inline(always)]
323 pub(crate) async fn recv_data_ultra(
324 &mut self,
325 ) -> PgResult<(u8, Option<(bytes::Bytes, bytes::Bytes)>)> {
326 use bytes::Buf;
327
328 loop {
329 if self.buffer.len() >= 5 {
330 let msg_len = u32::from_be_bytes([
331 self.buffer[1],
332 self.buffer[2],
333 self.buffer[3],
334 self.buffer[4],
335 ]) as usize;
336
337 if msg_len > MAX_MESSAGE_SIZE {
338 return Err(PgError::Protocol(format!(
339 "Message too large: {} bytes (max {})",
340 msg_len, MAX_MESSAGE_SIZE
341 )));
342 }
343
344 if self.buffer.len() > msg_len {
345 let msg_type = self.buffer[0];
346
347 if msg_type == b'E' {
349 let msg_bytes = self.buffer.split_to(msg_len + 1);
350 let (msg, _) =
351 BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
352 if let BackendMessage::ErrorResponse(err) = msg {
353 return Err(PgError::Query(err.message));
354 }
355 }
356
357 if msg_type == b'D' {
358 let mut msg_bytes = self.buffer.split_to(msg_len + 1);
359 msg_bytes.advance(5); let _col_count = msg_bytes.get_u16();
363
364 let len0 = msg_bytes.get_i32();
365 let col0 = if len0 > 0 {
366 msg_bytes.split_to(len0 as usize).freeze()
367 } else {
368 bytes::Bytes::new()
369 };
370
371 let len1 = msg_bytes.get_i32();
372 let col1 = if len1 > 0 {
373 msg_bytes.split_to(len1 as usize).freeze()
374 } else {
375 bytes::Bytes::new()
376 };
377
378 return Ok((msg_type, Some((col0, col1))));
379 }
380
381 let _ = self.buffer.split_to(msg_len + 1);
383 return Ok((msg_type, None));
384 }
385 }
386
387 if self.buffer.capacity() - self.buffer.len() < 65536 {
388 self.buffer.reserve(131072);
389 }
390
391 let n = self.stream.read_buf(&mut self.buffer).await?;
392 if n == 0 {
393 return Err(PgError::Connection("Connection closed".to_string()));
394 }
395 }
396 }
397}