postgres_protocol/message/
frontend.rs1#![allow(missing_docs)]
3
4use byteorder::{BigEndian, ByteOrder};
5use bytes::{Buf, BufMut, BytesMut};
6use std::error::Error;
7use std::io;
8use std::marker;
9
10use crate::{write_nullable, FromUsize, IsNull, Oid};
11
12#[inline]
13fn write_body<F, E>(buf: &mut BytesMut, f: F) -> Result<(), E>
14where
15 F: FnOnce(&mut BytesMut) -> Result<(), E>,
16 E: From<io::Error>,
17{
18 let base = buf.len();
19 buf.extend_from_slice(&[0; 4]);
20
21 f(buf)?;
22
23 let size = i32::from_usize(buf.len() - base)?;
24 BigEndian::write_i32(&mut buf[base..], size);
25 Ok(())
26}
27
28pub enum BindError {
29 Conversion(Box<dyn Error + marker::Sync + Send>),
30 Serialization(io::Error),
31}
32
33impl From<Box<dyn Error + marker::Sync + Send>> for BindError {
34 #[inline]
35 fn from(e: Box<dyn Error + marker::Sync + Send>) -> BindError {
36 BindError::Conversion(e)
37 }
38}
39
40impl From<io::Error> for BindError {
41 #[inline]
42 fn from(e: io::Error) -> BindError {
43 BindError::Serialization(e)
44 }
45}
46
47#[inline]
48pub fn bind<I, J, F, T, K>(
49 portal: &str,
50 statement: &str,
51 formats: I,
52 values: J,
53 mut serializer: F,
54 result_formats: K,
55 buf: &mut BytesMut,
56) -> Result<(), BindError>
57where
58 I: IntoIterator<Item = i16>,
59 J: IntoIterator<Item = T>,
60 F: FnMut(T, &mut BytesMut) -> Result<IsNull, Box<dyn Error + marker::Sync + Send>>,
61 K: IntoIterator<Item = i16>,
62{
63 buf.put_u8(b'B');
64
65 write_body(buf, |buf| {
66 write_cstr(portal.as_bytes(), buf)?;
67 write_cstr(statement.as_bytes(), buf)?;
68 write_counted(
69 formats,
70 |f, buf| {
71 buf.put_i16(f);
72 Ok::<_, io::Error>(())
73 },
74 buf,
75 )?;
76 write_counted(
77 values,
78 |v, buf| write_nullable(|buf| serializer(v, buf), buf),
79 buf,
80 )?;
81 write_counted(
82 result_formats,
83 |f, buf| {
84 buf.put_i16(f);
85 Ok::<_, io::Error>(())
86 },
87 buf,
88 )?;
89
90 Ok(())
91 })
92}
93
94#[inline]
95fn write_counted<I, T, F, E>(items: I, mut serializer: F, buf: &mut BytesMut) -> Result<(), E>
96where
97 I: IntoIterator<Item = T>,
98 F: FnMut(T, &mut BytesMut) -> Result<(), E>,
99 E: From<io::Error>,
100{
101 let base = buf.len();
102 buf.extend_from_slice(&[0; 2]);
103 let mut count = 0;
104 for item in items {
105 serializer(item, buf)?;
106 count += 1;
107 }
108 let count = i16::from_usize(count)?;
109 BigEndian::write_i16(&mut buf[base..], count);
110
111 Ok(())
112}
113
114#[inline]
115pub fn cancel_request(process_id: i32, secret_key: i32, buf: &mut BytesMut) {
116 write_body(buf, |buf| {
117 buf.put_i32(80_877_102);
118 buf.put_i32(process_id);
119 buf.put_i32(secret_key);
120 Ok::<_, io::Error>(())
121 })
122 .unwrap();
123}
124
125#[inline]
126pub fn close(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
127 buf.put_u8(b'C');
128 write_body(buf, |buf| {
129 buf.put_u8(variant);
130 write_cstr(name.as_bytes(), buf)
131 })
132}
133
134pub struct CopyData<T> {
135 buf: T,
136 len: i32,
137}
138
139impl<T> CopyData<T>
140where
141 T: Buf,
142{
143 pub fn new(buf: T) -> io::Result<CopyData<T>> {
144 let len = buf
145 .remaining()
146 .checked_add(4)
147 .and_then(|l| i32::try_from(l).ok())
148 .ok_or_else(|| {
149 io::Error::new(io::ErrorKind::InvalidInput, "message length overflow")
150 })?;
151
152 Ok(CopyData { buf, len })
153 }
154
155 pub fn write(self, out: &mut BytesMut) {
156 out.put_u8(b'd');
157 out.put_i32(self.len);
158 out.put(self.buf);
159 }
160}
161
162#[inline]
163pub fn copy_done(buf: &mut BytesMut) {
164 buf.put_u8(b'c');
165 write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
166}
167
168#[inline]
169pub fn copy_fail(message: &str, buf: &mut BytesMut) -> io::Result<()> {
170 buf.put_u8(b'f');
171 write_body(buf, |buf| write_cstr(message.as_bytes(), buf))
172}
173
174#[inline]
175pub fn describe(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
176 buf.put_u8(b'D');
177 write_body(buf, |buf| {
178 buf.put_u8(variant);
179 write_cstr(name.as_bytes(), buf)
180 })
181}
182
183#[inline]
184pub fn execute(portal: &str, max_rows: i32, buf: &mut BytesMut) -> io::Result<()> {
185 buf.put_u8(b'E');
186 write_body(buf, |buf| {
187 write_cstr(portal.as_bytes(), buf)?;
188 buf.put_i32(max_rows);
189 Ok(())
190 })
191}
192
193#[inline]
194pub fn parse<I>(name: &str, query: &str, param_types: I, buf: &mut BytesMut) -> io::Result<()>
195where
196 I: IntoIterator<Item = Oid>,
197{
198 buf.put_u8(b'P');
199 write_body(buf, |buf| {
200 write_cstr(name.as_bytes(), buf)?;
201 write_cstr(query.as_bytes(), buf)?;
202 write_counted(
203 param_types,
204 |t, buf| {
205 buf.put_u32(t);
206 Ok::<_, io::Error>(())
207 },
208 buf,
209 )?;
210 Ok(())
211 })
212}
213
214#[inline]
215pub fn password_message(password: &[u8], buf: &mut BytesMut) -> io::Result<()> {
216 buf.put_u8(b'p');
217 write_body(buf, |buf| write_cstr(password, buf))
218}
219
220#[inline]
221pub fn query(query: &str, buf: &mut BytesMut) -> io::Result<()> {
222 buf.put_u8(b'Q');
223 write_body(buf, |buf| write_cstr(query.as_bytes(), buf))
224}
225
226#[inline]
227pub fn sasl_initial_response(mechanism: &str, data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
228 buf.put_u8(b'p');
229 write_body(buf, |buf| {
230 write_cstr(mechanism.as_bytes(), buf)?;
231 let len = i32::from_usize(data.len())?;
232 buf.put_i32(len);
233 buf.put_slice(data);
234 Ok(())
235 })
236}
237
238#[inline]
239pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
240 buf.put_u8(b'p');
241 write_body(buf, |buf| {
242 buf.put_slice(data);
243 Ok(())
244 })
245}
246
247#[inline]
248pub fn ssl_request(buf: &mut BytesMut) {
249 write_body(buf, |buf| {
250 buf.put_i32(80_877_103);
251 Ok::<_, io::Error>(())
252 })
253 .unwrap();
254}
255
256#[inline]
257pub fn startup_message<'a, I>(parameters: I, buf: &mut BytesMut) -> io::Result<()>
258where
259 I: IntoIterator<Item = (&'a str, &'a str)>,
260{
261 write_body(buf, |buf| {
262 buf.put_i32(0x00_03_00_00);
264 for (key, value) in parameters {
265 write_cstr(key.as_bytes(), buf)?;
266 write_cstr(value.as_bytes(), buf)?;
267 }
268 buf.put_u8(0);
269 Ok(())
270 })
271}
272
273#[inline]
274pub fn flush(buf: &mut BytesMut) {
275 buf.put_u8(b'H');
276 write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
277}
278
279#[inline]
280pub fn sync(buf: &mut BytesMut) {
281 buf.put_u8(b'S');
282 write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
283}
284
285#[inline]
286pub fn terminate(buf: &mut BytesMut) {
287 buf.put_u8(b'X');
288 write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
289}
290
291#[inline]
292fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
293 if s.contains(&0) {
294 return Err(io::Error::new(
295 io::ErrorKind::InvalidInput,
296 "string contains embedded null",
297 ));
298 }
299 buf.put_slice(s);
300 buf.put_u8(0);
301 Ok(())
302}