postgres_proto_rs/network/
tokio.rs

1use bytes::{BufMut, BytesMut};
2use log::trace;
3use std::mem;
4use tokio::io::{AsyncReadExt, AsyncWriteExt};
5
6use crate::errors::Error;
7use crate::messages::{backend::*, frontend::*};
8
9pub async fn read_startup<S>(stream: &mut S) -> Result<StartupMessageType, Error>
10where
11    S: tokio::io::AsyncRead + std::marker::Unpin,
12{
13    let len = match stream.read_i32().await {
14        Ok(len) => len,
15        Err(_) => return Err(Error::SocketIOError),
16    };
17
18    let code = match stream.read_i32().await {
19        Ok(len) => len,
20        Err(_) => return Err(Error::SocketIOError),
21    };
22
23    let mut message_bytes = vec![0u8; len as usize - 8];
24    match stream.read_exact(&mut message_bytes).await {
25        Ok(_) => {}
26        Err(_) => return Err(Error::SocketIOError),
27    }
28
29    let mut bytes_mut = BytesMut::with_capacity(len as usize + mem::size_of::<i32>());
30
31    bytes_mut.put_i32(len);
32    bytes_mut.put_i32(code);
33    bytes_mut.put_slice(&message_bytes);
34
35    trace!(
36        "F: Startup message: {:?}",
37        String::from_utf8_lossy(&bytes_mut)
38    );
39
40    StartupMessageType::new_from_bytes(code, bytes_mut)
41}
42
43pub async fn send_startup_message<S>(
44    stream: &mut S,
45    message: &StartupMessageType,
46) -> Result<(), Error>
47where
48    S: tokio::io::AsyncWrite + std::marker::Unpin,
49{
50    match stream.write(message.get_bytes()).await {
51        Ok(_) => Ok(()),
52        Err(_) => return Err(Error::SocketIOError),
53    }
54}
55
56pub async fn read_frontend_message<S>(stream: &mut S) -> Result<FrontendMessageType, Error>
57where
58    S: tokio::io::AsyncRead + std::marker::Unpin,
59{
60    let (msg_type, message_bytes) = read_message_bytes(stream).await?;
61
62    trace!(
63        "F: Code: {}\n Message: {:?}",
64        msg_type as char,
65        String::from_utf8_lossy(&message_bytes)
66    );
67    FrontendMessageType::new_from_bytes(msg_type, message_bytes)
68}
69
70pub async fn send_frontend_message<S>(
71    stream: &mut S,
72    message: &FrontendMessageType,
73) -> Result<(), Error>
74where
75    S: tokio::io::AsyncWrite + std::marker::Unpin,
76{
77    match stream.write(message.get_bytes()).await {
78        Ok(_) => Ok(()),
79        Err(_) => return Err(Error::SocketIOError),
80    }
81}
82
83pub async fn read_backend_message<S>(stream: &mut S) -> Result<BackendMessageType, Error>
84where
85    S: tokio::io::AsyncRead + std::marker::Unpin,
86{
87    let (msg_type, message_bytes) = read_message_bytes(stream).await?;
88
89    trace!(
90        "B: Code: {}\n Message: {:?}",
91        msg_type as char,
92        String::from_utf8_lossy(&message_bytes)
93    );
94
95    BackendMessageType::new_from_bytes(msg_type, message_bytes)
96}
97
98pub async fn send_backend_message<S>(
99    stream: &mut S,
100    message: &BackendMessageType,
101) -> Result<(), Error>
102where
103    S: tokio::io::AsyncWrite + std::marker::Unpin,
104{
105    match stream.write(message.get_bytes()).await {
106        Ok(_) => Ok(()),
107        Err(_) => return Err(Error::SocketIOError),
108    }
109}
110
111pub async fn read_message_bytes<S>(stream: &mut S) -> Result<(u8, BytesMut), Error>
112where
113    S: tokio::io::AsyncRead + std::marker::Unpin,
114{
115    let msg_type = match stream.read_u8().await {
116        Ok(msg_type) => msg_type,
117        Err(_) => return Err(Error::SocketIOError),
118    };
119
120    let len = match stream.read_i32().await {
121        Ok(len) => len,
122        Err(_) => return Err(Error::SocketIOError),
123    };
124
125    let mut message_body = vec![0u8; len as usize - 4];
126    match stream.read_exact(&mut message_body).await {
127        Ok(_) => {}
128        Err(_) => return Err(Error::SocketIOError),
129    }
130
131    let mut message_bytes = BytesMut::with_capacity(mem::size_of::<u8>() + len as usize);
132
133    message_bytes.put_u8(msg_type);
134    message_bytes.put_i32(len);
135    message_bytes.put_slice(&message_body);
136
137    Ok((msg_type, message_bytes))
138}