postgres_protocol/message/
frontend.rs

1//! Frontend message serialization.
2#![allow(missing_docs)]
3
4use byteorder::{BigEndian, ByteOrder};
5use bytes::{Buf, BufMut, BytesMut};
6use std::error::Error;
7use std::io;
8use std::marker;
9
10use crate::{write_nullable, FromUsize, IsNull, Oid};
11
12#[inline]
13fn write_body<F, E>(buf: &mut BytesMut, f: F) -> Result<(), E>
14where
15    F: FnOnce(&mut BytesMut) -> Result<(), E>,
16    E: From<io::Error>,
17{
18    let base = buf.len();
19    buf.extend_from_slice(&[0; 4]);
20
21    f(buf)?;
22
23    let size = i32::from_usize(buf.len() - base)?;
24    BigEndian::write_i32(&mut buf[base..], size);
25    Ok(())
26}
27
28pub enum BindError {
29    Conversion(Box<dyn Error + marker::Sync + Send>),
30    Serialization(io::Error),
31}
32
33impl From<Box<dyn Error + marker::Sync + Send>> for BindError {
34    #[inline]
35    fn from(e: Box<dyn Error + marker::Sync + Send>) -> BindError {
36        BindError::Conversion(e)
37    }
38}
39
40impl From<io::Error> for BindError {
41    #[inline]
42    fn from(e: io::Error) -> BindError {
43        BindError::Serialization(e)
44    }
45}
46
47#[inline]
48pub fn bind<I, J, F, T, K>(
49    portal: &str,
50    statement: &str,
51    formats: I,
52    values: J,
53    mut serializer: F,
54    result_formats: K,
55    buf: &mut BytesMut,
56) -> Result<(), BindError>
57where
58    I: IntoIterator<Item = i16>,
59    J: IntoIterator<Item = T>,
60    F: FnMut(T, &mut BytesMut) -> Result<IsNull, Box<dyn Error + marker::Sync + Send>>,
61    K: IntoIterator<Item = i16>,
62{
63    buf.put_u8(b'B');
64
65    write_body(buf, |buf| {
66        write_cstr(portal.as_bytes(), buf)?;
67        write_cstr(statement.as_bytes(), buf)?;
68        write_counted(
69            formats,
70            |f, buf| {
71                buf.put_i16(f);
72                Ok::<_, io::Error>(())
73            },
74            buf,
75        )?;
76        write_counted(
77            values,
78            |v, buf| write_nullable(|buf| serializer(v, buf), buf),
79            buf,
80        )?;
81        write_counted(
82            result_formats,
83            |f, buf| {
84                buf.put_i16(f);
85                Ok::<_, io::Error>(())
86            },
87            buf,
88        )?;
89
90        Ok(())
91    })
92}
93
94#[inline]
95fn write_counted<I, T, F, E>(items: I, mut serializer: F, buf: &mut BytesMut) -> Result<(), E>
96where
97    I: IntoIterator<Item = T>,
98    F: FnMut(T, &mut BytesMut) -> Result<(), E>,
99    E: From<io::Error>,
100{
101    let base = buf.len();
102    buf.extend_from_slice(&[0; 2]);
103    let mut count = 0;
104    for item in items {
105        serializer(item, buf)?;
106        count += 1;
107    }
108    let count = i16::from_usize(count)?;
109    BigEndian::write_i16(&mut buf[base..], count);
110
111    Ok(())
112}
113
114#[inline]
115pub fn cancel_request(process_id: i32, secret_key: i32, buf: &mut BytesMut) {
116    write_body(buf, |buf| {
117        buf.put_i32(80_877_102);
118        buf.put_i32(process_id);
119        buf.put_i32(secret_key);
120        Ok::<_, io::Error>(())
121    })
122    .unwrap();
123}
124
125#[inline]
126pub fn close(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
127    buf.put_u8(b'C');
128    write_body(buf, |buf| {
129        buf.put_u8(variant);
130        write_cstr(name.as_bytes(), buf)
131    })
132}
133
134pub struct CopyData<T> {
135    buf: T,
136    len: i32,
137}
138
139impl<T> CopyData<T>
140where
141    T: Buf,
142{
143    pub fn new(buf: T) -> io::Result<CopyData<T>> {
144        let len = buf
145            .remaining()
146            .checked_add(4)
147            .and_then(|l| i32::try_from(l).ok())
148            .ok_or_else(|| {
149                io::Error::new(io::ErrorKind::InvalidInput, "message length overflow")
150            })?;
151
152        Ok(CopyData { buf, len })
153    }
154
155    pub fn write(self, out: &mut BytesMut) {
156        out.put_u8(b'd');
157        out.put_i32(self.len);
158        out.put(self.buf);
159    }
160}
161
162#[inline]
163pub fn copy_done(buf: &mut BytesMut) {
164    buf.put_u8(b'c');
165    write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
166}
167
168#[inline]
169pub fn copy_fail(message: &str, buf: &mut BytesMut) -> io::Result<()> {
170    buf.put_u8(b'f');
171    write_body(buf, |buf| write_cstr(message.as_bytes(), buf))
172}
173
174#[inline]
175pub fn describe(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
176    buf.put_u8(b'D');
177    write_body(buf, |buf| {
178        buf.put_u8(variant);
179        write_cstr(name.as_bytes(), buf)
180    })
181}
182
183#[inline]
184pub fn execute(portal: &str, max_rows: i32, buf: &mut BytesMut) -> io::Result<()> {
185    buf.put_u8(b'E');
186    write_body(buf, |buf| {
187        write_cstr(portal.as_bytes(), buf)?;
188        buf.put_i32(max_rows);
189        Ok(())
190    })
191}
192
193#[inline]
194pub fn parse<I>(name: &str, query: &str, param_types: I, buf: &mut BytesMut) -> io::Result<()>
195where
196    I: IntoIterator<Item = Oid>,
197{
198    buf.put_u8(b'P');
199    write_body(buf, |buf| {
200        write_cstr(name.as_bytes(), buf)?;
201        write_cstr(query.as_bytes(), buf)?;
202        write_counted(
203            param_types,
204            |t, buf| {
205                buf.put_u32(t);
206                Ok::<_, io::Error>(())
207            },
208            buf,
209        )?;
210        Ok(())
211    })
212}
213
214#[inline]
215pub fn password_message(password: &[u8], buf: &mut BytesMut) -> io::Result<()> {
216    buf.put_u8(b'p');
217    write_body(buf, |buf| write_cstr(password, buf))
218}
219
220#[inline]
221pub fn query(query: &str, buf: &mut BytesMut) -> io::Result<()> {
222    buf.put_u8(b'Q');
223    write_body(buf, |buf| write_cstr(query.as_bytes(), buf))
224}
225
226#[inline]
227pub fn sasl_initial_response(mechanism: &str, data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
228    buf.put_u8(b'p');
229    write_body(buf, |buf| {
230        write_cstr(mechanism.as_bytes(), buf)?;
231        let len = i32::from_usize(data.len())?;
232        buf.put_i32(len);
233        buf.put_slice(data);
234        Ok(())
235    })
236}
237
238#[inline]
239pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
240    buf.put_u8(b'p');
241    write_body(buf, |buf| {
242        buf.put_slice(data);
243        Ok(())
244    })
245}
246
247#[inline]
248pub fn ssl_request(buf: &mut BytesMut) {
249    write_body(buf, |buf| {
250        buf.put_i32(80_877_103);
251        Ok::<_, io::Error>(())
252    })
253    .unwrap();
254}
255
256#[inline]
257pub fn startup_message<'a, I>(parameters: I, buf: &mut BytesMut) -> io::Result<()>
258where
259    I: IntoIterator<Item = (&'a str, &'a str)>,
260{
261    write_body(buf, |buf| {
262        // postgres protocol version 3.0(196608) in bigger-endian
263        buf.put_i32(0x00_03_00_00);
264        for (key, value) in parameters {
265            write_cstr(key.as_bytes(), buf)?;
266            write_cstr(value.as_bytes(), buf)?;
267        }
268        buf.put_u8(0);
269        Ok(())
270    })
271}
272
273#[inline]
274pub fn flush(buf: &mut BytesMut) {
275    buf.put_u8(b'H');
276    write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
277}
278
279#[inline]
280pub fn sync(buf: &mut BytesMut) {
281    buf.put_u8(b'S');
282    write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
283}
284
285#[inline]
286pub fn terminate(buf: &mut BytesMut) {
287    buf.put_u8(b'X');
288    write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
289}
290
291#[inline]
292fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
293    if s.contains(&0) {
294        return Err(io::Error::new(
295            io::ErrorKind::InvalidInput,
296            "string contains embedded null",
297        ));
298    }
299    buf.put_slice(s);
300    buf.put_u8(0);
301    Ok(())
302}