cdbc_pg/
arguments.rs

1use std::ops::{Deref, DerefMut};
2
3use cdbc::arguments::Arguments;
4use cdbc::encode::{Encode, IsNull};
5use cdbc::error::Error;
6use cdbc::utils::ustr::UStr;
7use crate::{PgConnection, PgTypeInfo, Postgres};
8use cdbc::types::Type;
9
10// TODO: buf.patch(|| ...) is a poor name, can we think of a better name? Maybe `buf.lazy(||)` ?
11// TODO: Extend the patch system to support dynamic lengths
12//       Considerations:
13//          - The prefixed-len offset needs to be back-tracked and updated
14//          - message::Bind needs to take a &PgArguments and use a `write` method instead of
15//            referencing a buffer directly
16//          - The basic idea is that we write bytes for the buffer until we get somewhere
17//            that has a patch, we then apply the patch which should write to &mut Vec<u8>,
18//            backtrack and update the prefixed-len, then write until the next patch offset
19
20#[derive(Default)]
21pub struct PgArgumentBuffer {
22    buffer: Vec<u8>,
23
24    // Number of arguments
25    count: usize,
26
27    // Whenever an `Encode` impl needs to defer some work until after we resolve parameter types
28    // it can use `patch`.
29    //
30    // This currently is only setup to be useful if there is a *fixed-size* slot that needs to be
31    // tweaked from the input type. However, that's the only use case we currently have.
32    //
33    patches: Vec<(
34        usize, // offset
35        usize, // argument index
36        Box<dyn Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync>,
37    )>,
38
39    // Whenever an `Encode` impl encounters a `PgTypeInfo` object that does not have an OID
40    // It pushes a "hole" that must be patched later.
41    //
42    // The hole is a `usize` offset into the buffer with the type name that should be resolved
43    // This is done for Records and Arrays as the OID is needed well before we are in an async
44    // function and can just ask postgres.
45    //
46    type_holes: Vec<(usize, UStr)>, // Vec<{ offset, type_name }>
47}
48
49/// Implementation of [`Arguments`] for PostgreSQL.
50#[derive(Default)]
51pub struct PgArguments {
52    // Types of each bind parameter
53    pub(crate) types: Vec<PgTypeInfo>,
54
55    // Buffer of encoded bind parameters
56    pub(crate) buffer: PgArgumentBuffer,
57}
58
59impl PgArguments {
60    pub(crate) fn add<'q, T>(&mut self, value: T)
61    where
62        T: Encode<'q, Postgres> + Type<Postgres>,
63    {
64        // remember the type information for this value
65        self.types
66            .push(value.produces().unwrap_or_else(T::type_info));
67
68        // encode the value into our buffer
69        self.buffer.encode(value);
70
71        // increment the number of arguments we are tracking
72        self.buffer.count += 1;
73    }
74
75    // Apply patches
76    // This should only go out and ask postgres if we have not seen the type name yet
77    pub(crate) fn apply_patches(
78        &mut self,
79        conn: &mut PgConnection,
80        parameters: &[PgTypeInfo],
81    ) -> Result<(), Error> {
82        let PgArgumentBuffer {
83            ref patches,
84            ref type_holes,
85            ref mut buffer,
86            ..
87        } = self.buffer;
88
89        for (offset, ty, callback) in patches {
90            let buf = &mut buffer[*offset..];
91            let ty = &parameters[*ty];
92
93            callback(buf, ty);
94        }
95
96        for (offset, name) in type_holes {
97            let oid = conn.fetch_type_id_by_name(&*name)?;
98            buffer[*offset..(*offset + 4)].copy_from_slice(&oid.to_be_bytes());
99        }
100
101        Ok(())
102    }
103}
104
105impl<'q> Arguments<'q> for PgArguments {
106    type Database = Postgres;
107
108    fn reserve(&mut self, additional: usize, size: usize) {
109        self.types.reserve(additional);
110        self.buffer.reserve(size);
111    }
112
113    fn add<T>(&mut self, value: T)
114    where
115        T: Encode<'q, Self::Database> + Type<Self::Database>,
116    {
117        self.add(value)
118    }
119}
120
121impl PgArgumentBuffer {
122    pub(crate) fn encode<'q, T>(&mut self, value: T)
123    where
124        T: Encode<'q, Postgres>,
125    {
126        // reserve space to write the prefixed length of the value
127        let offset = self.len();
128        self.extend(&[0; 4]);
129
130        // encode the value into our buffer
131        let len = if let IsNull::No = value.encode(self) {
132            (self.len() - offset - 4) as i32
133        } else {
134            // Write a -1 to indicate NULL
135            // NOTE: It is illegal for [encode] to write any data
136            debug_assert_eq!(self.len(), offset + 4);
137            -1_i32
138        };
139
140        // write the len to the beginning of the value
141        self[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes());
142    }
143
144    // Adds a callback to be invoked later when we know the parameter type
145    #[allow(dead_code)]
146    pub(crate) fn patch<F>(&mut self, callback: F)
147    where
148        F: Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync,
149    {
150        let offset = self.len();
151        let index = self.count;
152
153        self.patches.push((offset, index, Box::new(callback)));
154    }
155
156    // Extends the inner buffer by enough space to have an OID
157    // Remembers where the OID goes and type name for the OID
158    pub(crate) fn patch_type_by_name(&mut self, type_name: &UStr) {
159        let offset = self.len();
160
161        self.extend_from_slice(&0_u32.to_be_bytes());
162        self.type_holes.push((offset, type_name.clone()));
163    }
164}
165
166impl Deref for PgArgumentBuffer {
167    type Target = Vec<u8>;
168
169    #[inline]
170    fn deref(&self) -> &Self::Target {
171        &self.buffer
172    }
173}
174
175impl DerefMut for PgArgumentBuffer {
176    #[inline]
177    fn deref_mut(&mut self) -> &mut Self::Target {
178        &mut self.buffer
179    }
180}