1use super::{PgConnection, PgError, PgResult};
6use crate::protocol::{BackendMessage, FrontendMessage};
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8
9const MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
14
15impl PgConnection {
16 pub async fn send(&mut self, msg: FrontendMessage) -> PgResult<()> {
18 let bytes = msg.encode();
19 self.stream.write_all(&bytes).await?;
20 Ok(())
21 }
22
23 pub async fn recv(&mut self) -> PgResult<BackendMessage> {
26 loop {
27 if self.buffer.len() >= 5 {
29 let msg_len = u32::from_be_bytes([
30 self.buffer[1],
31 self.buffer[2],
32 self.buffer[3],
33 self.buffer[4],
34 ]) as usize;
35
36 if msg_len > MAX_MESSAGE_SIZE {
37 return Err(PgError::Protocol(format!(
38 "Message too large: {} bytes (max {})",
39 msg_len, MAX_MESSAGE_SIZE
40 )));
41 }
42
43 if self.buffer.len() > msg_len {
44 let msg_bytes = self.buffer.split_to(msg_len + 1);
46 let (msg, _) = BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
47
48 if let BackendMessage::NotificationResponse { process_id, channel, payload } = msg {
50 self.notifications.push_back(
51 super::notification::Notification { process_id, channel, payload }
52 );
53 continue; }
55
56 return Ok(msg);
57 }
58 }
59
60
61 let n = self.read_with_timeout().await?;
62 if n == 0 {
63 return Err(PgError::Connection("Connection closed".to_string()));
64 }
65 }
66 }
67
68 #[inline]
73 pub(crate) async fn read_with_timeout(&mut self) -> PgResult<usize> {
74 if self.buffer.capacity() - self.buffer.len() < 65536 {
75 self.buffer.reserve(131072);
76 }
77
78 match tokio::time::timeout(
79 DEFAULT_READ_TIMEOUT,
80 self.stream.read_buf(&mut self.buffer),
81 ).await {
82 Ok(Ok(n)) => Ok(n),
83 Ok(Err(e)) => Err(PgError::Connection(format!("Read error: {}", e))),
84 Err(_) => Err(PgError::Connection(format!(
85 "Read timeout after {:?} — possible Slowloris attack or dead connection",
86 DEFAULT_READ_TIMEOUT
87 ))),
88 }
89 }
90
91 pub async fn send_bytes(&mut self, bytes: &[u8]) -> PgResult<()> {
95 self.stream.write_all(bytes).await?;
96 self.stream.flush().await?;
97 Ok(())
98 }
99
100 #[inline]
105 pub fn buffer_bytes(&mut self, bytes: &[u8]) {
106 self.write_buf.extend_from_slice(bytes);
107 }
108
109 pub async fn flush_write_buf(&mut self) -> PgResult<()> {
112 if !self.write_buf.is_empty() {
113 self.stream.write_all(&self.write_buf).await?;
114 self.write_buf.clear();
115 self.stream.flush().await?;
116 }
117 Ok(())
118 }
119
120 #[inline]
124 pub(crate) async fn recv_msg_type_fast(&mut self) -> PgResult<u8> {
125 loop {
126 if self.buffer.len() >= 5 {
127 let msg_len = u32::from_be_bytes([
128 self.buffer[1],
129 self.buffer[2],
130 self.buffer[3],
131 self.buffer[4],
132 ]) as usize;
133
134 if msg_len > MAX_MESSAGE_SIZE {
135 return Err(PgError::Protocol(format!(
136 "Message too large: {} bytes (max {})",
137 msg_len, MAX_MESSAGE_SIZE
138 )));
139 }
140
141 if self.buffer.len() > msg_len {
142 let msg_type = self.buffer[0];
143
144 if msg_type == b'E' {
145 let msg_bytes = self.buffer.split_to(msg_len + 1);
146 let (msg, _) =
147 BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
148 if let BackendMessage::ErrorResponse(err) = msg {
149 return Err(PgError::Query(err.message));
150 }
151 }
152
153 let _ = self.buffer.split_to(msg_len + 1);
154 return Ok(msg_type);
155 }
156 }
157
158
159 let n = self.read_with_timeout().await?;
160 if n == 0 {
161 return Err(PgError::Connection("Connection closed".to_string()));
162 }
163 }
164 }
165
166 #[inline]
172 pub(crate) async fn recv_with_data_fast(
173 &mut self,
174 ) -> PgResult<(u8, Option<Vec<Option<Vec<u8>>>>)> {
175 loop {
176 if self.buffer.len() >= 5 {
177 let msg_len = u32::from_be_bytes([
178 self.buffer[1],
179 self.buffer[2],
180 self.buffer[3],
181 self.buffer[4],
182 ]) as usize;
183
184 if msg_len > MAX_MESSAGE_SIZE {
185 return Err(PgError::Protocol(format!(
186 "Message too large: {} bytes (max {})",
187 msg_len, MAX_MESSAGE_SIZE
188 )));
189 }
190
191 if self.buffer.len() > msg_len {
192 let msg_type = self.buffer[0];
193
194 if msg_type == b'E' {
195 let msg_bytes = self.buffer.split_to(msg_len + 1);
196 let (msg, _) =
197 BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
198 if let BackendMessage::ErrorResponse(err) = msg {
199 return Err(PgError::Query(err.message));
200 }
201 }
202
203 if msg_type == b'D' {
205 let payload = &self.buffer[5..msg_len + 1];
206
207 if payload.len() >= 2 {
208 let column_count =
209 u16::from_be_bytes([payload[0], payload[1]]) as usize;
210 let mut columns = Vec::with_capacity(column_count);
211 let mut pos = 2;
212
213 for _ in 0..column_count {
214 if pos + 4 > payload.len() {
215 let _ = self.buffer.split_to(msg_len + 1);
216 return Err(PgError::Protocol("DataRow truncated: missing column length".into()));
217 }
218
219 let len = i32::from_be_bytes([
220 payload[pos],
221 payload[pos + 1],
222 payload[pos + 2],
223 payload[pos + 3],
224 ]);
225 pos += 4;
226
227 if len == -1 {
228 columns.push(None);
229 } else {
230 let len = len as usize;
231 if pos + len > payload.len() {
232 let _ = self.buffer.split_to(msg_len + 1);
233 return Err(PgError::Protocol("DataRow truncated: column data exceeds payload".into()));
234 }
235 columns.push(Some(payload[pos..pos + len].to_vec()));
236 pos += len;
237 }
238 }
239
240 let _ = self.buffer.split_to(msg_len + 1);
241 return Ok((msg_type, Some(columns)));
242 }
243 }
244
245 let _ = self.buffer.split_to(msg_len + 1);
247 return Ok((msg_type, None));
248 }
249 }
250
251
252 let n = self.read_with_timeout().await?;
253 if n == 0 {
254 return Err(PgError::Connection("Connection closed".to_string()));
255 }
256 }
257 }
258
259 #[inline]
265 pub(crate) async fn recv_data_zerocopy(
266 &mut self,
267 ) -> PgResult<(u8, Option<Vec<Option<bytes::Bytes>>>)> {
268 use bytes::Buf;
269
270 loop {
271 if self.buffer.len() >= 5 {
272 let msg_len = u32::from_be_bytes([
273 self.buffer[1],
274 self.buffer[2],
275 self.buffer[3],
276 self.buffer[4],
277 ]) as usize;
278
279 if msg_len > MAX_MESSAGE_SIZE {
280 return Err(PgError::Protocol(format!(
281 "Message too large: {} bytes (max {})",
282 msg_len, MAX_MESSAGE_SIZE
283 )));
284 }
285
286 if self.buffer.len() > msg_len {
287 let msg_type = self.buffer[0];
288
289 if msg_type == b'E' {
290 let msg_bytes = self.buffer.split_to(msg_len + 1);
291 let (msg, _) =
292 BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
293 if let BackendMessage::ErrorResponse(err) = msg {
294 return Err(PgError::Query(err.message));
295 }
296 }
297
298 if msg_type == b'D' {
300 let mut msg_bytes = self.buffer.split_to(msg_len + 1);
302
303 msg_bytes.advance(5);
305
306 if msg_bytes.len() >= 2 {
307 let column_count = msg_bytes.get_u16() as usize;
308 let mut columns = Vec::with_capacity(column_count);
309
310 for _ in 0..column_count {
311 if msg_bytes.remaining() < 4 {
312 return Err(PgError::Protocol("DataRow truncated: missing column length".into()));
313 }
314
315 let len = msg_bytes.get_i32();
316
317 if len == -1 {
318 columns.push(None);
319 } else {
320 let len = len as usize;
321 if msg_bytes.remaining() < len {
322 return Err(PgError::Protocol("DataRow truncated: column data exceeds payload".into()));
323 }
324 let col_data = msg_bytes.split_to(len).freeze();
325 columns.push(Some(col_data));
326 }
327 }
328
329 return Ok((msg_type, Some(columns)));
330 }
331 return Ok((msg_type, None));
332 }
333
334 let _ = self.buffer.split_to(msg_len + 1);
336 return Ok((msg_type, None));
337 }
338 }
339
340
341 let n = self.read_with_timeout().await?;
342 if n == 0 {
343 return Err(PgError::Connection("Connection closed".to_string()));
344 }
345 }
346 }
347
348 #[inline(always)]
352 pub(crate) async fn recv_data_ultra(
353 &mut self,
354 ) -> PgResult<(u8, Option<(bytes::Bytes, bytes::Bytes)>)> {
355 use bytes::Buf;
356
357 loop {
358 if self.buffer.len() >= 5 {
359 let msg_len = u32::from_be_bytes([
360 self.buffer[1],
361 self.buffer[2],
362 self.buffer[3],
363 self.buffer[4],
364 ]) as usize;
365
366 if msg_len > MAX_MESSAGE_SIZE {
367 return Err(PgError::Protocol(format!(
368 "Message too large: {} bytes (max {})",
369 msg_len, MAX_MESSAGE_SIZE
370 )));
371 }
372
373 if self.buffer.len() > msg_len {
374 let msg_type = self.buffer[0];
375
376 if msg_type == b'E' {
378 let msg_bytes = self.buffer.split_to(msg_len + 1);
379 let (msg, _) =
380 BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
381 if let BackendMessage::ErrorResponse(err) = msg {
382 return Err(PgError::Query(err.message));
383 }
384 }
385
386 if msg_type == b'D' {
387 let mut msg_bytes = self.buffer.split_to(msg_len + 1);
388 msg_bytes.advance(5); if msg_bytes.remaining() < 2 {
392 return Err(PgError::Protocol("DataRow ultra: too short for column count".into()));
393 }
394
395 let _col_count = msg_bytes.get_u16();
397
398 if msg_bytes.remaining() < 4 {
399 return Err(PgError::Protocol("DataRow ultra: truncated before col0 length".into()));
400 }
401 let len0 = msg_bytes.get_i32();
402 let col0 = if len0 > 0 {
403 let len0 = len0 as usize;
404 if msg_bytes.remaining() < len0 {
405 return Err(PgError::Protocol("DataRow ultra: col0 data exceeds payload".into()));
406 }
407 msg_bytes.split_to(len0).freeze()
408 } else {
409 bytes::Bytes::new()
410 };
411
412 if msg_bytes.remaining() < 4 {
413 return Err(PgError::Protocol("DataRow ultra: truncated before col1 length".into()));
414 }
415 let len1 = msg_bytes.get_i32();
416 let col1 = if len1 > 0 {
417 let len1 = len1 as usize;
418 if msg_bytes.remaining() < len1 {
419 return Err(PgError::Protocol("DataRow ultra: col1 data exceeds payload".into()));
420 }
421 msg_bytes.split_to(len1).freeze()
422 } else {
423 bytes::Bytes::new()
424 };
425
426 return Ok((msg_type, Some((col0, col1))));
427 }
428
429 let _ = self.buffer.split_to(msg_len + 1);
431 return Ok((msg_type, None));
432 }
433 }
434
435
436 let n = self.read_with_timeout().await?;
437 if n == 0 {
438 return Err(PgError::Connection("Connection closed".to_string()));
439 }
440 }
441 }
442}