1use super::{PgConnection, PgError, PgResult};
6use crate::protocol::{BackendMessage, FrontendMessage};
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8
9const MAX_MESSAGE_SIZE: usize = 1024 * 1024 * 1024; impl PgConnection {
12 pub async fn send(&mut self, msg: FrontendMessage) -> PgResult<()> {
14 let bytes = msg.encode();
15 self.stream.write_all(&bytes).await?;
16 Ok(())
17 }
18
19 pub async fn recv(&mut self) -> PgResult<BackendMessage> {
21 loop {
22 if self.buffer.len() >= 5 {
24 let msg_len = u32::from_be_bytes([
25 self.buffer[1],
26 self.buffer[2],
27 self.buffer[3],
28 self.buffer[4],
29 ]) as usize;
30
31 if msg_len > MAX_MESSAGE_SIZE {
32 return Err(PgError::Protocol(format!(
33 "Message too large: {} bytes (max {})",
34 msg_len, MAX_MESSAGE_SIZE
35 )));
36 }
37
38 if self.buffer.len() > msg_len {
39 let msg_bytes = self.buffer.split_to(msg_len + 1);
41 let (msg, _) = BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
42 return Ok(msg);
43 }
44 }
45
46 if self.buffer.capacity() - self.buffer.len() < 65536 {
47 self.buffer.reserve(131072); }
49
50 let n = self.stream.read_buf(&mut self.buffer).await?;
51 if n == 0 {
52 return Err(PgError::Connection("Connection closed".to_string()));
53 }
54 }
55 }
56
57 pub async fn send_bytes(&mut self, bytes: &[u8]) -> PgResult<()> {
59 self.stream.write_all(bytes).await?;
60 self.stream.flush().await?; Ok(())
62 }
63
64 #[inline]
69 pub fn buffer_bytes(&mut self, bytes: &[u8]) {
70 self.write_buf.extend_from_slice(bytes);
71 }
72
73 pub async fn flush_write_buf(&mut self) -> PgResult<()> {
76 if !self.write_buf.is_empty() {
77 self.stream.write_all(&self.write_buf).await?;
78 self.write_buf.clear();
79 }
80 Ok(())
81 }
82
83 #[inline]
87 pub(crate) async fn recv_msg_type_fast(&mut self) -> PgResult<u8> {
88 loop {
89 if self.buffer.len() >= 5 {
90 let msg_len = u32::from_be_bytes([
91 self.buffer[1],
92 self.buffer[2],
93 self.buffer[3],
94 self.buffer[4],
95 ]) as usize;
96
97 if msg_len > MAX_MESSAGE_SIZE {
98 return Err(PgError::Protocol(format!(
99 "Message too large: {} bytes (max {})",
100 msg_len, MAX_MESSAGE_SIZE
101 )));
102 }
103
104 if self.buffer.len() > msg_len {
105 let msg_type = self.buffer[0];
106
107 if msg_type == b'E' {
108 let msg_bytes = self.buffer.split_to(msg_len + 1);
109 let (msg, _) =
110 BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
111 if let BackendMessage::ErrorResponse(err) = msg {
112 return Err(PgError::Query(err.message));
113 }
114 }
115
116 let _ = self.buffer.split_to(msg_len + 1);
117 return Ok(msg_type);
118 }
119 }
120
121 if self.buffer.capacity() - self.buffer.len() < 65536 {
122 self.buffer.reserve(131072); }
124
125 let n = self.stream.read_buf(&mut self.buffer).await?;
126 if n == 0 {
127 return Err(PgError::Connection("Connection closed".to_string()));
128 }
129 }
130 }
131
132 #[inline]
138 pub(crate) async fn recv_with_data_fast(
139 &mut self,
140 ) -> PgResult<(u8, Option<Vec<Option<Vec<u8>>>>)> {
141 loop {
142 if self.buffer.len() >= 5 {
143 let msg_len = u32::from_be_bytes([
144 self.buffer[1],
145 self.buffer[2],
146 self.buffer[3],
147 self.buffer[4],
148 ]) as usize;
149
150 if msg_len > MAX_MESSAGE_SIZE {
151 return Err(PgError::Protocol(format!(
152 "Message too large: {} bytes (max {})",
153 msg_len, MAX_MESSAGE_SIZE
154 )));
155 }
156
157 if self.buffer.len() > msg_len {
158 let msg_type = self.buffer[0];
159
160 if msg_type == b'E' {
161 let msg_bytes = self.buffer.split_to(msg_len + 1);
162 let (msg, _) =
163 BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
164 if let BackendMessage::ErrorResponse(err) = msg {
165 return Err(PgError::Query(err.message));
166 }
167 }
168
169 if msg_type == b'D' {
171 let payload = &self.buffer[5..msg_len + 1];
172
173 if payload.len() >= 2 {
174 let column_count =
175 u16::from_be_bytes([payload[0], payload[1]]) as usize;
176 let mut columns = Vec::with_capacity(column_count);
177 let mut pos = 2;
178
179 for _ in 0..column_count {
180 if pos + 4 > payload.len() {
181 break;
182 }
183
184 let len = i32::from_be_bytes([
185 payload[pos],
186 payload[pos + 1],
187 payload[pos + 2],
188 payload[pos + 3],
189 ]);
190 pos += 4;
191
192 if len == -1 {
193 columns.push(None);
194 } else {
195 let len = len as usize;
196 if pos + len <= payload.len() {
197 columns.push(Some(payload[pos..pos + len].to_vec()));
198 pos += len;
199 }
200 }
201 }
202
203 let _ = self.buffer.split_to(msg_len + 1);
204 return Ok((msg_type, Some(columns)));
205 }
206 }
207
208 let _ = self.buffer.split_to(msg_len + 1);
210 return Ok((msg_type, None));
211 }
212 }
213
214 if self.buffer.capacity() - self.buffer.len() < 65536 {
215 self.buffer.reserve(131072);
216 }
217
218 let n = self.stream.read_buf(&mut self.buffer).await?;
219 if n == 0 {
220 return Err(PgError::Connection("Connection closed".to_string()));
221 }
222 }
223 }
224
225 #[inline]
231 pub(crate) async fn recv_data_zerocopy(
232 &mut self,
233 ) -> PgResult<(u8, Option<Vec<Option<bytes::Bytes>>>)> {
234 use bytes::Buf;
235
236 loop {
237 if self.buffer.len() >= 5 {
238 let msg_len = u32::from_be_bytes([
239 self.buffer[1],
240 self.buffer[2],
241 self.buffer[3],
242 self.buffer[4],
243 ]) as usize;
244
245 if msg_len > MAX_MESSAGE_SIZE {
246 return Err(PgError::Protocol(format!(
247 "Message too large: {} bytes (max {})",
248 msg_len, MAX_MESSAGE_SIZE
249 )));
250 }
251
252 if self.buffer.len() > msg_len {
253 let msg_type = self.buffer[0];
254
255 if msg_type == b'E' {
256 let msg_bytes = self.buffer.split_to(msg_len + 1);
257 let (msg, _) =
258 BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
259 if let BackendMessage::ErrorResponse(err) = msg {
260 return Err(PgError::Query(err.message));
261 }
262 }
263
264 if msg_type == b'D' {
266 let mut msg_bytes = self.buffer.split_to(msg_len + 1);
268
269 msg_bytes.advance(5);
271
272 if msg_bytes.len() >= 2 {
273 let column_count = msg_bytes.get_u16() as usize;
274 let mut columns = Vec::with_capacity(column_count);
275
276 for _ in 0..column_count {
277 if msg_bytes.remaining() < 4 {
278 break;
279 }
280
281 let len = msg_bytes.get_i32();
282
283 if len == -1 {
284 columns.push(None);
285 } else {
286 let len = len as usize;
287 if msg_bytes.remaining() >= len {
288 let col_data = msg_bytes.split_to(len).freeze();
289 columns.push(Some(col_data));
290 }
291 }
292 }
293
294 return Ok((msg_type, Some(columns)));
295 }
296 return Ok((msg_type, None));
297 }
298
299 let _ = self.buffer.split_to(msg_len + 1);
301 return Ok((msg_type, None));
302 }
303 }
304
305 if self.buffer.capacity() - self.buffer.len() < 65536 {
306 self.buffer.reserve(131072);
307 }
308
309 let n = self.stream.read_buf(&mut self.buffer).await?;
310 if n == 0 {
311 return Err(PgError::Connection("Connection closed".to_string()));
312 }
313 }
314 }
315
316 #[inline(always)]
320 pub(crate) async fn recv_data_ultra(
321 &mut self,
322 ) -> PgResult<(u8, Option<(bytes::Bytes, bytes::Bytes)>)> {
323 use bytes::Buf;
324
325 loop {
326 if self.buffer.len() >= 5 {
327 let msg_len = u32::from_be_bytes([
328 self.buffer[1],
329 self.buffer[2],
330 self.buffer[3],
331 self.buffer[4],
332 ]) as usize;
333
334 if msg_len > MAX_MESSAGE_SIZE {
335 return Err(PgError::Protocol(format!(
336 "Message too large: {} bytes (max {})",
337 msg_len, MAX_MESSAGE_SIZE
338 )));
339 }
340
341 if self.buffer.len() > msg_len {
342 let msg_type = self.buffer[0];
343
344 if msg_type == b'E' {
346 let msg_bytes = self.buffer.split_to(msg_len + 1);
347 let (msg, _) =
348 BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
349 if let BackendMessage::ErrorResponse(err) = msg {
350 return Err(PgError::Query(err.message));
351 }
352 }
353
354 if msg_type == b'D' {
355 let mut msg_bytes = self.buffer.split_to(msg_len + 1);
356 msg_bytes.advance(5); let _col_count = msg_bytes.get_u16();
360
361 let len0 = msg_bytes.get_i32();
362 let col0 = if len0 > 0 {
363 msg_bytes.split_to(len0 as usize).freeze()
364 } else {
365 bytes::Bytes::new()
366 };
367
368 let len1 = msg_bytes.get_i32();
369 let col1 = if len1 > 0 {
370 msg_bytes.split_to(len1 as usize).freeze()
371 } else {
372 bytes::Bytes::new()
373 };
374
375 return Ok((msg_type, Some((col0, col1))));
376 }
377
378 let _ = self.buffer.split_to(msg_len + 1);
380 return Ok((msg_type, None));
381 }
382 }
383
384 if self.buffer.capacity() - self.buffer.len() < 65536 {
385 self.buffer.reserve(131072);
386 }
387
388 let n = self.stream.read_buf(&mut self.buffer).await?;
389 if n == 0 {
390 return Err(PgError::Connection("Connection closed".to_string()));
391 }
392 }
393 }
394}