postgres_proto_rs/network/
tokio.rs1use bytes::{BufMut, BytesMut};
2use log::trace;
3use std::mem;
4use tokio::io::{AsyncReadExt, AsyncWriteExt};
5
6use crate::errors::Error;
7use crate::messages::{backend::*, frontend::*};
8
9pub async fn read_startup<S>(stream: &mut S) -> Result<StartupMessageType, Error>
10where
11 S: tokio::io::AsyncRead + std::marker::Unpin,
12{
13 let len = match stream.read_i32().await {
14 Ok(len) => len,
15 Err(_) => return Err(Error::SocketIOError),
16 };
17
18 let code = match stream.read_i32().await {
19 Ok(len) => len,
20 Err(_) => return Err(Error::SocketIOError),
21 };
22
23 let mut message_bytes = vec![0u8; len as usize - 8];
24 match stream.read_exact(&mut message_bytes).await {
25 Ok(_) => {}
26 Err(_) => return Err(Error::SocketIOError),
27 }
28
29 let mut bytes_mut = BytesMut::with_capacity(len as usize + mem::size_of::<i32>());
30
31 bytes_mut.put_i32(len);
32 bytes_mut.put_i32(code);
33 bytes_mut.put_slice(&message_bytes);
34
35 trace!(
36 "F: Startup message: {:?}",
37 String::from_utf8_lossy(&bytes_mut)
38 );
39
40 StartupMessageType::new_from_bytes(code, bytes_mut)
41}
42
43pub async fn send_startup_message<S>(
44 stream: &mut S,
45 message: &StartupMessageType,
46) -> Result<(), Error>
47where
48 S: tokio::io::AsyncWrite + std::marker::Unpin,
49{
50 match stream.write(message.get_bytes()).await {
51 Ok(_) => Ok(()),
52 Err(_) => return Err(Error::SocketIOError),
53 }
54}
55
56pub async fn read_frontend_message<S>(stream: &mut S) -> Result<FrontendMessageType, Error>
57where
58 S: tokio::io::AsyncRead + std::marker::Unpin,
59{
60 let (msg_type, message_bytes) = read_message_bytes(stream).await?;
61
62 trace!(
63 "F: Code: {}\n Message: {:?}",
64 msg_type as char,
65 String::from_utf8_lossy(&message_bytes)
66 );
67 FrontendMessageType::new_from_bytes(msg_type, message_bytes)
68}
69
70pub async fn send_frontend_message<S>(
71 stream: &mut S,
72 message: &FrontendMessageType,
73) -> Result<(), Error>
74where
75 S: tokio::io::AsyncWrite + std::marker::Unpin,
76{
77 match stream.write(message.get_bytes()).await {
78 Ok(_) => Ok(()),
79 Err(_) => return Err(Error::SocketIOError),
80 }
81}
82
83pub async fn read_backend_message<S>(stream: &mut S) -> Result<BackendMessageType, Error>
84where
85 S: tokio::io::AsyncRead + std::marker::Unpin,
86{
87 let (msg_type, message_bytes) = read_message_bytes(stream).await?;
88
89 trace!(
90 "B: Code: {}\n Message: {:?}",
91 msg_type as char,
92 String::from_utf8_lossy(&message_bytes)
93 );
94
95 BackendMessageType::new_from_bytes(msg_type, message_bytes)
96}
97
98pub async fn send_backend_message<S>(
99 stream: &mut S,
100 message: &BackendMessageType,
101) -> Result<(), Error>
102where
103 S: tokio::io::AsyncWrite + std::marker::Unpin,
104{
105 match stream.write(message.get_bytes()).await {
106 Ok(_) => Ok(()),
107 Err(_) => return Err(Error::SocketIOError),
108 }
109}
110
111pub async fn read_message_bytes<S>(stream: &mut S) -> Result<(u8, BytesMut), Error>
112where
113 S: tokio::io::AsyncRead + std::marker::Unpin,
114{
115 let msg_type = match stream.read_u8().await {
116 Ok(msg_type) => msg_type,
117 Err(_) => return Err(Error::SocketIOError),
118 };
119
120 let len = match stream.read_i32().await {
121 Ok(len) => len,
122 Err(_) => return Err(Error::SocketIOError),
123 };
124
125 let mut message_body = vec![0u8; len as usize - 4];
126 match stream.read_exact(&mut message_body).await {
127 Ok(_) => {}
128 Err(_) => return Err(Error::SocketIOError),
129 }
130
131 let mut message_bytes = BytesMut::with_capacity(mem::size_of::<u8>() + len as usize);
132
133 message_bytes.put_u8(msg_type);
134 message_bytes.put_i32(len);
135 message_bytes.put_slice(&message_body);
136
137 Ok((msg_type, message_bytes))
138}