postgres_proto_rs/network/
standard.rs

1use bytes::{BufMut, BytesMut};
2use log::trace;
3use std::mem;
4
5use crate::errors::Error;
6use crate::messages::{backend::*, frontend::*};
7
8pub fn read_startup<S>(stream: &mut S) -> Result<StartupMessageType, Error>
9where
10    S: std::io::Read + std::marker::Unpin,
11{
12    let mut buf = [0; 4];
13    match stream.read_exact(&mut buf) {
14        Ok(_) => {}
15        Err(_) => return Err(Error::SocketIOError),
16    }
17    let len = i32::from_be_bytes(buf);
18
19    let mut buf = [0; 4];
20    match stream.read_exact(&mut buf) {
21        Ok(_) => {}
22        Err(_) => return Err(Error::SocketIOError),
23    }
24    let code = i32::from_be_bytes(buf);
25
26    let mut message_bytes = vec![0u8; len as usize - 8];
27    match stream.read_exact(&mut message_bytes) {
28        Ok(_) => {}
29        Err(_) => return Err(Error::SocketIOError),
30    }
31
32    let mut bytes_mut = BytesMut::with_capacity(len as usize + mem::size_of::<i32>());
33
34    bytes_mut.put_i32(len);
35
36    bytes_mut.put_i32(code);
37
38    bytes_mut.put_slice(&message_bytes);
39
40    trace!(
41        "F: Startup message: {:?}",
42        String::from_utf8_lossy(&bytes_mut)
43    );
44
45    StartupMessageType::new_from_bytes(code, bytes_mut)
46}
47
48pub fn send_startup_message<S>(
49    stream: &mut S,
50    message: &StartupMessageType,
51) -> Result<(), std::io::Error>
52where
53    S: std::io::Write + std::marker::Unpin,
54{
55    match stream.write(&message.get_bytes()) {
56        Ok(_) => Ok(()),
57        Err(err) => return Err(err),
58    }
59}
60
61pub fn read_frontend_message<S>(stream: &mut S) -> Result<FrontendMessageType, Error>
62where
63    S: std::io::Read + std::marker::Unpin,
64{
65    let (msg_type, message_bytes) = read_message_bytes(stream)?;
66
67    trace!(
68        "F: Code: {}\n Message: {:?}",
69        msg_type as char,
70        String::from_utf8_lossy(&message_bytes)
71    );
72    FrontendMessageType::new_from_bytes(msg_type, message_bytes)
73}
74
75pub fn send_frontend_message<S>(
76    stream: &mut S,
77    message: &FrontendMessageType,
78) -> Result<(), std::io::Error>
79where
80    S: std::io::Write + std::marker::Unpin,
81{
82    match stream.write(&message.get_bytes()) {
83        Ok(_) => Ok(()),
84        Err(err) => return Err(err),
85    }
86}
87
88pub fn read_backend_message<S>(stream: &mut S) -> Result<BackendMessageType, Error>
89where
90    S: std::io::Read + std::marker::Unpin,
91{
92    let (msg_type, message_bytes) = read_message_bytes(stream)?;
93
94    trace!(
95        "B: Code: {}\n Message: {:?}",
96        msg_type as char,
97        String::from_utf8_lossy(&message_bytes)
98    );
99
100    BackendMessageType::new_from_bytes(msg_type, message_bytes)
101}
102
103pub fn send_backend_message<S>(
104    stream: &mut S,
105    message: &BackendMessageType,
106) -> Result<(), std::io::Error>
107where
108    S: std::io::Write + std::marker::Unpin,
109{
110    match stream.write(&message.get_bytes()) {
111        Ok(_) => Ok(()),
112        Err(err) => Err(err),
113    }
114}
115
116pub fn read_message_bytes<S>(stream: &mut S) -> Result<(u8, BytesMut), Error>
117where
118    S: std::io::Read + std::marker::Unpin,
119{
120    let mut buf = [0; 1];
121    match stream.read_exact(&mut buf) {
122        Ok(_) => {}
123        Err(_) => return Err(Error::SocketIOError),
124    }
125    let msg_type = u8::from_be_bytes(buf);
126
127    let mut buf = [0; 4];
128    match stream.read_exact(&mut buf) {
129        Ok(_) => {}
130        Err(_) => return Err(Error::SocketIOError),
131    }
132    let len = i32::from_be_bytes(buf);
133
134    let mut message_body = vec![0u8; len as usize - 4];
135    match stream.read_exact(&mut message_body) {
136        Ok(_) => {}
137        Err(_) => return Err(Error::SocketIOError),
138    }
139
140    let mut message_bytes = BytesMut::with_capacity(mem::size_of::<u8>() + len as usize);
141
142    message_bytes.put_u8(msg_type);
143    message_bytes.put_i32(len);
144    message_bytes.put_slice(&message_body);
145
146    Ok((msg_type, message_bytes))
147}