Skip to main content

sqlx_postgres/
arguments.rs

1use std::fmt::{self, Write};
2use std::ops::{Deref, DerefMut};
3use std::sync::Arc;
4
5use crate::encode::{Encode, IsNull};
6use crate::error::Error;
7use crate::types::Type;
8use crate::{PgConnection, PgTypeInfo, Postgres};
9
10pub(crate) use sqlx_core::arguments::Arguments;
11use sqlx_core::error::BoxDynError;
12
13// TODO: buf.patch(|| ...) is a poor name, can we think of a better name? Maybe `buf.lazy(||)` ?
14// TODO: Extend the patch system to support dynamic lengths
15//       Considerations:
16//          - The prefixed-len offset needs to be back-tracked and updated
17//          - message::Bind needs to take a &PgArguments and use a `write` method instead of
18//            referencing a buffer directly
19//          - The basic idea is that we write bytes for the buffer until we get somewhere
20//            that has a patch, we then apply the patch which should write to &mut Vec<u8>,
21//            backtrack and update the prefixed-len, then write until the next patch offset
22
23#[derive(Default, Debug, Clone)]
24pub struct PgArgumentBuffer {
25    buffer: Vec<u8>,
26
27    // Number of arguments
28    count: usize,
29
30    // Whenever an `Encode` impl needs to defer some work until after we resolve parameter types
31    // it can use `patch`.
32    //
33    // This currently is only setup to be useful if there is a *fixed-size* slot that needs to be
34    // tweaked from the input type. However, that's the only use case we currently have.
35    patches: Vec<Patch>,
36
37    // Whenever an `Encode` impl encounters a `PgTypeInfo` object that does not have an OID
38    // It pushes a "hole" that must be patched later.
39    //
40    // The hole is a `usize` offset into the buffer with the type name that should be resolved
41    // This is done for Records and Arrays as the OID is needed well before we are in an async
42    // function and can just ask postgres.
43    //
44    hole_offsets: Vec<usize>,
45    // Separate vecator so that we don't have to generify or duplicate the logic in
46    // `PgConnection::resolve_types()`.
47    hole_types: Vec<PgTypeInfo>,
48}
49
50#[derive(Clone)]
51struct Patch {
52    buf_offset: usize,
53    arg_index: usize,
54    #[allow(clippy::type_complexity)]
55    callback: Arc<dyn Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync>,
56}
57
58impl fmt::Debug for Patch {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        f.debug_struct("Patch")
61            .field("buf_offset", &self.buf_offset)
62            .field("arg_index", &self.arg_index)
63            .field("callback", &"<callback>")
64            .finish()
65    }
66}
67
68/// Implementation of [`Arguments`] for PostgreSQL.
69#[derive(Default, Debug, Clone)]
70pub struct PgArguments {
71    // Types of each bind parameter
72    pub(crate) types: Vec<PgTypeInfo>,
73
74    // Buffer of encoded bind parameters
75    pub(crate) buffer: PgArgumentBuffer,
76}
77
78impl PgArguments {
79    pub(crate) fn add<'q, T>(&mut self, value: T) -> Result<(), BoxDynError>
80    where
81        T: Encode<'q, Postgres> + Type<Postgres>,
82    {
83        let type_info = value.produces().unwrap_or_else(T::type_info);
84
85        let buffer_snapshot = self.buffer.snapshot();
86
87        // encode the value into our buffer
88        if let Err(error) = self.buffer.encode(value) {
89            // reset the value buffer to its previous value if encoding failed,
90            // so we don't leave a half-encoded value behind
91            self.buffer.reset_to_snapshot(buffer_snapshot);
92            return Err(error);
93        };
94
95        // remember the type information for this value
96        self.types.push(type_info);
97        // increment the number of arguments we are tracking
98        self.buffer.count += 1;
99
100        Ok(())
101    }
102
103    // Apply patches
104    // This should only go out and ask postgres if we have not seen the type name yet
105    pub(crate) async fn apply_patches(
106        &mut self,
107        conn: &mut PgConnection,
108        parameters: &[PgTypeInfo],
109    ) -> Result<(), Error> {
110        let PgArgumentBuffer {
111            ref patches,
112            ref hole_types,
113            ref hole_offsets,
114            ref mut buffer,
115            ..
116        } = self.buffer;
117
118        for patch in patches {
119            let buf = &mut buffer[patch.buf_offset..];
120            let ty = &parameters[patch.arg_index];
121
122            (patch.callback)(buf, ty);
123        }
124
125        let resolved_holes = conn.resolve_types(hole_types).await?;
126
127        for (&offset, oid) in hole_offsets.iter().zip(resolved_holes) {
128            buffer[offset..][..4].copy_from_slice(&oid.0.to_be_bytes());
129        }
130
131        Ok(())
132    }
133}
134
135impl Arguments for PgArguments {
136    type Database = Postgres;
137
138    fn reserve(&mut self, additional: usize, size: usize) {
139        self.types.reserve(additional);
140        self.buffer.reserve(size);
141    }
142
143    fn add<'t, T>(&mut self, value: T) -> Result<(), BoxDynError>
144    where
145        T: Encode<'t, Self::Database> + Type<Self::Database>,
146    {
147        self.add(value)
148    }
149
150    fn format_placeholder<W: Write>(&self, writer: &mut W) -> fmt::Result {
151        write!(writer, "${}", self.buffer.count)
152    }
153
154    #[inline(always)]
155    fn len(&self) -> usize {
156        self.buffer.count
157    }
158}
159
160impl PgArgumentBuffer {
161    pub(crate) fn encode<'q, T>(&mut self, value: T) -> Result<(), BoxDynError>
162    where
163        T: Encode<'q, Postgres>,
164    {
165        // Won't catch everything but is a good sanity check
166        value_size_int4_checked(value.size_hint())?;
167
168        // reserve space to write the prefixed length of the value
169        let offset = self.len();
170
171        self.extend(&[0; 4]);
172
173        // encode the value into our buffer
174        let len = if let IsNull::No = value.encode(self)? {
175            // Ensure that the value size does not overflow i32
176            value_size_int4_checked(self.len() - offset - 4)?
177        } else {
178            // Write a -1 to indicate NULL
179            // NOTE: It is illegal for [encode] to write any data
180            debug_assert_eq!(self.len(), offset + 4);
181            -1_i32
182        };
183
184        // write the len to the beginning of the value
185        // (offset + 4) cannot overflow because it would have failed at `self.extend()`.
186        self[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes());
187
188        Ok(())
189    }
190
191    // Adds a callback to be invoked later when we know the parameter type
192    #[cfg_attr(not(feature = "json"), expect(dead_code))]
193    pub(crate) fn patch_with<F>(&mut self, callback: F)
194    where
195        F: Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync,
196    {
197        let offset = self.len();
198        let arg_index = self.count;
199
200        self.patches.push(Patch {
201            buf_offset: offset,
202            arg_index,
203            callback: Arc::new(callback),
204        });
205    }
206
207    // Extends the inner buffer by enough space to have an OID
208    // Remembers where the OID goes and type name for the OID
209    pub(crate) fn push_hole(&mut self, type_info: PgTypeInfo) {
210        let offset = self.len();
211
212        self.extend_from_slice(&0_u32.to_be_bytes());
213        self.hole_offsets.push(offset);
214        self.hole_types.push(type_info);
215    }
216
217    fn snapshot(&self) -> PgArgumentBufferSnapshot {
218        let Self {
219            buffer,
220            count,
221            patches,
222            hole_offsets,
223            ..
224        } = self;
225
226        PgArgumentBufferSnapshot {
227            buffer_length: buffer.len(),
228            count: *count,
229            patches_length: patches.len(),
230            type_holes_length: hole_offsets.len(),
231        }
232    }
233
234    fn reset_to_snapshot(
235        &mut self,
236        PgArgumentBufferSnapshot {
237            buffer_length,
238            count,
239            patches_length,
240            type_holes_length,
241        }: PgArgumentBufferSnapshot,
242    ) {
243        self.buffer.truncate(buffer_length);
244        self.count = count;
245        self.patches.truncate(patches_length);
246        self.hole_offsets.truncate(type_holes_length);
247        self.hole_types.truncate(type_holes_length);
248    }
249}
250
251struct PgArgumentBufferSnapshot {
252    buffer_length: usize,
253    count: usize,
254    patches_length: usize,
255    type_holes_length: usize,
256}
257
258impl Deref for PgArgumentBuffer {
259    type Target = Vec<u8>;
260
261    #[inline]
262    fn deref(&self) -> &Self::Target {
263        &self.buffer
264    }
265}
266
267impl DerefMut for PgArgumentBuffer {
268    #[inline]
269    fn deref_mut(&mut self) -> &mut Self::Target {
270        &mut self.buffer
271    }
272}
273
274pub(crate) fn value_size_int4_checked(size: usize) -> Result<i32, String> {
275    i32::try_from(size).map_err(|_| {
276        format!(
277            "value size would overflow in the binary protocol encoding: {size} > {}",
278            i32::MAX
279        )
280    })
281}