sqlx_postgres/
arguments.rs1use std::fmt::{self, Write};
2use std::ops::{Deref, DerefMut};
3use std::sync::Arc;
4
5use crate::encode::{Encode, IsNull};
6use crate::error::Error;
7use crate::types::Type;
8use crate::{PgConnection, PgTypeInfo, Postgres};
9
10pub(crate) use sqlx_core::arguments::Arguments;
11use sqlx_core::error::BoxDynError;
12
13#[derive(Default, Debug, Clone)]
24pub struct PgArgumentBuffer {
25 buffer: Vec<u8>,
26
27 count: usize,
29
30 patches: Vec<Patch>,
36
37 hole_offsets: Vec<usize>,
45 hole_types: Vec<PgTypeInfo>,
48}
49
50#[derive(Clone)]
51struct Patch {
52 buf_offset: usize,
53 arg_index: usize,
54 #[allow(clippy::type_complexity)]
55 callback: Arc<dyn Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync>,
56}
57
58impl fmt::Debug for Patch {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 f.debug_struct("Patch")
61 .field("buf_offset", &self.buf_offset)
62 .field("arg_index", &self.arg_index)
63 .field("callback", &"<callback>")
64 .finish()
65 }
66}
67
68#[derive(Default, Debug, Clone)]
70pub struct PgArguments {
71 pub(crate) types: Vec<PgTypeInfo>,
73
74 pub(crate) buffer: PgArgumentBuffer,
76}
77
78impl PgArguments {
79 pub(crate) fn add<'q, T>(&mut self, value: T) -> Result<(), BoxDynError>
80 where
81 T: Encode<'q, Postgres> + Type<Postgres>,
82 {
83 let type_info = value.produces().unwrap_or_else(T::type_info);
84
85 let buffer_snapshot = self.buffer.snapshot();
86
87 if let Err(error) = self.buffer.encode(value) {
89 self.buffer.reset_to_snapshot(buffer_snapshot);
92 return Err(error);
93 };
94
95 self.types.push(type_info);
97 self.buffer.count += 1;
99
100 Ok(())
101 }
102
103 pub(crate) async fn apply_patches(
106 &mut self,
107 conn: &mut PgConnection,
108 parameters: &[PgTypeInfo],
109 ) -> Result<(), Error> {
110 let PgArgumentBuffer {
111 ref patches,
112 ref hole_types,
113 ref hole_offsets,
114 ref mut buffer,
115 ..
116 } = self.buffer;
117
118 for patch in patches {
119 let buf = &mut buffer[patch.buf_offset..];
120 let ty = ¶meters[patch.arg_index];
121
122 (patch.callback)(buf, ty);
123 }
124
125 let resolved_holes = conn.resolve_types(hole_types).await?;
126
127 for (&offset, oid) in hole_offsets.iter().zip(resolved_holes) {
128 buffer[offset..][..4].copy_from_slice(&oid.0.to_be_bytes());
129 }
130
131 Ok(())
132 }
133}
134
135impl Arguments for PgArguments {
136 type Database = Postgres;
137
138 fn reserve(&mut self, additional: usize, size: usize) {
139 self.types.reserve(additional);
140 self.buffer.reserve(size);
141 }
142
143 fn add<'t, T>(&mut self, value: T) -> Result<(), BoxDynError>
144 where
145 T: Encode<'t, Self::Database> + Type<Self::Database>,
146 {
147 self.add(value)
148 }
149
150 fn format_placeholder<W: Write>(&self, writer: &mut W) -> fmt::Result {
151 write!(writer, "${}", self.buffer.count)
152 }
153
154 #[inline(always)]
155 fn len(&self) -> usize {
156 self.buffer.count
157 }
158}
159
160impl PgArgumentBuffer {
161 pub(crate) fn encode<'q, T>(&mut self, value: T) -> Result<(), BoxDynError>
162 where
163 T: Encode<'q, Postgres>,
164 {
165 value_size_int4_checked(value.size_hint())?;
167
168 let offset = self.len();
170
171 self.extend(&[0; 4]);
172
173 let len = if let IsNull::No = value.encode(self)? {
175 value_size_int4_checked(self.len() - offset - 4)?
177 } else {
178 debug_assert_eq!(self.len(), offset + 4);
181 -1_i32
182 };
183
184 self[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes());
187
188 Ok(())
189 }
190
191 #[cfg_attr(not(feature = "json"), expect(dead_code))]
193 pub(crate) fn patch_with<F>(&mut self, callback: F)
194 where
195 F: Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync,
196 {
197 let offset = self.len();
198 let arg_index = self.count;
199
200 self.patches.push(Patch {
201 buf_offset: offset,
202 arg_index,
203 callback: Arc::new(callback),
204 });
205 }
206
207 pub(crate) fn push_hole(&mut self, type_info: PgTypeInfo) {
210 let offset = self.len();
211
212 self.extend_from_slice(&0_u32.to_be_bytes());
213 self.hole_offsets.push(offset);
214 self.hole_types.push(type_info);
215 }
216
217 fn snapshot(&self) -> PgArgumentBufferSnapshot {
218 let Self {
219 buffer,
220 count,
221 patches,
222 hole_offsets,
223 ..
224 } = self;
225
226 PgArgumentBufferSnapshot {
227 buffer_length: buffer.len(),
228 count: *count,
229 patches_length: patches.len(),
230 type_holes_length: hole_offsets.len(),
231 }
232 }
233
234 fn reset_to_snapshot(
235 &mut self,
236 PgArgumentBufferSnapshot {
237 buffer_length,
238 count,
239 patches_length,
240 type_holes_length,
241 }: PgArgumentBufferSnapshot,
242 ) {
243 self.buffer.truncate(buffer_length);
244 self.count = count;
245 self.patches.truncate(patches_length);
246 self.hole_offsets.truncate(type_holes_length);
247 self.hole_types.truncate(type_holes_length);
248 }
249}
250
251struct PgArgumentBufferSnapshot {
252 buffer_length: usize,
253 count: usize,
254 patches_length: usize,
255 type_holes_length: usize,
256}
257
258impl Deref for PgArgumentBuffer {
259 type Target = Vec<u8>;
260
261 #[inline]
262 fn deref(&self) -> &Self::Target {
263 &self.buffer
264 }
265}
266
267impl DerefMut for PgArgumentBuffer {
268 #[inline]
269 fn deref_mut(&mut self) -> &mut Self::Target {
270 &mut self.buffer
271 }
272}
273
274pub(crate) fn value_size_int4_checked(size: usize) -> Result<i32, String> {
275 i32::try_from(size).map_err(|_| {
276 format!(
277 "value size would overflow in the binary protocol encoding: {size} > {}",
278 i32::MAX
279 )
280 })
281}