1use std::convert::{TryFrom, TryInto};
6use std::ops::Deref;
7use std::sync::Arc;
8
9use bytes::{BufMut, BytesMut};
10use snafu::OptionExt;
11use uuid::Uuid;
12
13use gel_errors::ParameterTypeMismatchError;
14use gel_errors::{ClientEncodingError, DescriptorMismatch, ProtocolError};
15use gel_errors::{Error, ErrorKind, InvalidReferenceError};
16
17use crate::codec::{self, build_codec, Codec};
18use crate::descriptors::TypePos;
19use crate::descriptors::{Descriptor, EnumerationTypeDescriptor};
20use crate::errors;
21use crate::features::ProtocolVersion;
22use crate::model::range;
23use crate::value::Value;
24
25pub struct Encoder<'a> {
26 pub ctx: &'a DescriptorContext<'a>,
27 pub buf: &'a mut BytesMut,
28}
29
30pub trait QueryArg: Send + Sync + Sized {
32 fn encode_slot(&self, encoder: &mut Encoder) -> Result<(), Error>;
33 fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error>;
34 fn to_value(&self) -> Result<Value, Error>;
35}
36
37pub trait ScalarArg: Send + Sync + Sized {
38 fn encode(&self, encoder: &mut Encoder) -> Result<(), Error>;
39 fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error>;
40 fn to_value(&self) -> Result<Value, Error>;
41}
42
43pub trait QueryArgs: Send + Sync {
49 fn encode(&self, encoder: &mut Encoder) -> Result<(), Error>;
50}
51
52pub struct DescriptorContext<'a> {
53 #[allow(dead_code)]
54 pub(crate) proto: &'a ProtocolVersion,
55 pub(crate) root_pos: Option<TypePos>,
56 pub(crate) descriptors: &'a [Descriptor],
57}
58
59impl<'a> Encoder<'a> {
60 pub fn new(ctx: &'a DescriptorContext<'a>, buf: &'a mut BytesMut) -> Encoder<'a> {
61 Encoder { ctx, buf }
62 }
63 pub fn length_prefixed(
64 &mut self,
65 f: impl FnOnce(&mut Encoder) -> Result<(), Error>,
66 ) -> Result<(), Error> {
67 self.buf.reserve(4);
68 let pos = self.buf.len();
69 self.buf.put_u32(0); f(self)?;
72
73 let len = self.buf.len() - pos - 4;
74 self.buf[pos..pos + 4].copy_from_slice(
75 &u32::try_from(len)
76 .map_err(|_| ClientEncodingError::with_message("alias is too long"))?
77 .to_be_bytes(),
78 );
79
80 Ok(())
81 }
82}
83
84impl DescriptorContext<'_> {
85 pub fn get(&self, type_pos: TypePos) -> Result<&Descriptor, Error> {
86 self.descriptors
87 .get(type_pos.0 as usize)
88 .ok_or_else(|| ProtocolError::with_message("invalid type descriptor"))
89 }
90 pub fn build_codec(&self) -> Result<Arc<dyn Codec>, Error> {
91 build_codec(self.root_pos, self.descriptors)
92 .map_err(|e| ProtocolError::with_source(e).context("error decoding input codec"))
93 }
94 pub fn wrong_type(&self, descriptor: &Descriptor, expected: &str) -> Error {
95 DescriptorMismatch::with_message(format!(
96 "server returned unexpected type {descriptor:?} when client expected {expected}"
97 ))
98 }
99 pub fn field_number(&self, expected: usize, unexpected: usize) -> Error {
100 DescriptorMismatch::with_message(format!(
101 "expected {} fields, got {}",
102 expected, unexpected
103 ))
104 }
105}
106
107impl<T: ScalarArg> ScalarArg for &T {
108 fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> {
109 (*self).encode(encoder)
110 }
111
112 fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
113 T::check_descriptor(ctx, pos)
114 }
115
116 fn to_value(&self) -> Result<Value, Error> {
117 (*self).to_value()
118 }
119}
120
121impl QueryArgs for () {
122 fn encode(&self, enc: &mut Encoder) -> Result<(), Error> {
123 if enc.ctx.root_pos.is_some() {
124 if enc.ctx.proto.is_at_most(0, 11) {
125 let root = enc.ctx.root_pos.and_then(|p| enc.ctx.get(p).ok());
126 match root {
127 Some(Descriptor::Tuple(t))
128 if t.id == Uuid::from_u128(0xFF) && t.element_types.is_empty() => {}
129 _ => {
130 return Err(ParameterTypeMismatchError::with_message(
131 "query arguments expected",
132 ))
133 }
134 };
135 } else {
136 return Err(ParameterTypeMismatchError::with_message(
137 "query arguments expected",
138 ));
139 }
140 }
141 if enc.ctx.proto.is_at_most(0, 11) {
142 enc.buf.reserve(4);
143 enc.buf.put_u32(0);
144 }
145 Ok(())
146 }
147}
148
149impl QueryArg for Value {
150 fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> {
151 use Value::*;
152 match self {
153 Nothing => {
154 enc.buf.reserve(4);
155 enc.buf.put_i32(-1);
156 }
157 Uuid(v) => v.encode_slot(enc)?,
158 Str(v) => v.encode_slot(enc)?,
159 Bytes(v) => v.encode_slot(enc)?,
160 Int16(v) => v.encode_slot(enc)?,
161 Int32(v) => v.encode_slot(enc)?,
162 Int64(v) => v.encode_slot(enc)?,
163 Float32(v) => v.encode_slot(enc)?,
164 Float64(v) => v.encode_slot(enc)?,
165 BigInt(v) => v.encode_slot(enc)?,
166 ConfigMemory(v) => v.encode_slot(enc)?,
167 Decimal(v) => v.encode_slot(enc)?,
168 Bool(v) => v.encode_slot(enc)?,
169 Datetime(v) => v.encode_slot(enc)?,
170 LocalDatetime(v) => v.encode_slot(enc)?,
171 LocalDate(v) => v.encode_slot(enc)?,
172 LocalTime(v) => v.encode_slot(enc)?,
173 Duration(v) => v.encode_slot(enc)?,
174 RelativeDuration(v) => v.encode_slot(enc)?,
175 DateDuration(v) => v.encode_slot(enc)?,
176 Json(v) => v.encode_slot(enc)?,
177 Set(_) => {
178 return Err(ClientEncodingError::with_message(
179 "set cannot be query argument",
180 ))
181 }
182 Object { .. } => {
183 return Err(ClientEncodingError::with_message(
184 "object cannot be query argument",
185 ))
186 }
187 SparseObject(_) => {
188 return Err(ClientEncodingError::with_message(
189 "sparse object cannot be query argument",
190 ))
191 }
192 Tuple(_) => {
193 return Err(ClientEncodingError::with_message(
194 "tuple object cannot be query argument",
195 ))
196 }
197 NamedTuple { .. } => {
198 return Err(ClientEncodingError::with_message(
199 "named tuple object cannot be query argument",
200 ))
201 }
202 Array(v) => v.encode_slot(enc)?,
203 Enum(v) => v.encode_slot(enc)?,
204 Range(v) => v.encode_slot(enc)?,
205 Vector(v) => crate::model::VectorRef(v).encode_slot(enc)?,
206 PostGisGeometry(v) => v.encode_slot(enc)?,
207 PostGisGeography(v) => v.encode_slot(enc)?,
208 PostGisBox2d(v) => v.encode_slot(enc)?,
209 PostGisBox3d(v) => v.encode_slot(enc)?,
210 SQLRow { .. } => {
211 return Err(ClientEncodingError::with_message(
212 "SQL row cannot be query argument",
213 ))
214 }
215 }
216
217 Ok(())
218 }
219 fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
220 use Descriptor::*;
221 use Value::*;
222 let desc = ctx.get(pos)?.normalize_to_base(ctx)?;
223
224 match (self, desc) {
225 (Nothing, _) => Ok(()), (BigInt(_), BaseScalar(d)) if d.id == codec::STD_BIGINT => Ok(()),
227 (Bool(_), BaseScalar(d)) if d.id == codec::STD_BOOL => Ok(()),
228 (Bytes(_), BaseScalar(d)) if d.id == codec::STD_BYTES => Ok(()),
229 (ConfigMemory(_), BaseScalar(d)) if d.id == codec::CFG_MEMORY => Ok(()),
230 (DateDuration(_), BaseScalar(d)) if d.id == codec::CAL_DATE_DURATION => Ok(()),
231 (Datetime(_), BaseScalar(d)) if d.id == codec::STD_DATETIME => Ok(()),
232 (Decimal(_), BaseScalar(d)) if d.id == codec::STD_DECIMAL => Ok(()),
233 (Duration(_), BaseScalar(d)) if d.id == codec::STD_DURATION => Ok(()),
234 (Float32(_), BaseScalar(d)) if d.id == codec::STD_FLOAT32 => Ok(()),
235 (Float64(_), BaseScalar(d)) if d.id == codec::STD_FLOAT64 => Ok(()),
236 (Int16(_), BaseScalar(d)) if d.id == codec::STD_INT16 => Ok(()),
237 (Int32(_), BaseScalar(d)) if d.id == codec::STD_INT32 => Ok(()),
238 (Int64(_), BaseScalar(d)) if d.id == codec::STD_INT64 => Ok(()),
239 (Json(_), BaseScalar(d)) if d.id == codec::STD_JSON => Ok(()),
240 (LocalDate(_), BaseScalar(d)) if d.id == codec::CAL_LOCAL_DATE => Ok(()),
241 (LocalDatetime(_), BaseScalar(d)) if d.id == codec::CAL_LOCAL_DATETIME => Ok(()),
242 (LocalTime(_), BaseScalar(d)) if d.id == codec::CAL_LOCAL_TIME => Ok(()),
243 (RelativeDuration(_), BaseScalar(d)) if d.id == codec::CAL_RELATIVE_DURATION => Ok(()),
244 (Str(_), BaseScalar(d)) if d.id == codec::STD_STR => Ok(()),
245 (Uuid(_), BaseScalar(d)) if d.id == codec::STD_UUID => Ok(()),
246 (Enum(val), Enumeration(EnumerationTypeDescriptor { members, .. })) => {
247 let val = val.deref();
248 check_enum(val, &members)
249 }
250 (Vector(_), BaseScalar(d)) if d.id == codec::PGVECTOR_VECTOR => Ok(()),
251 (PostGisGeometry(_), BaseScalar(d)) if d.id == codec::POSTGIS_GEOMETRY => Ok(()),
252 (PostGisGeography(_), BaseScalar(d)) if d.id == codec::POSTGIS_GEOGRAPHY => Ok(()),
253 (PostGisBox2d(_), BaseScalar(d)) if d.id == codec::POSTGIS_BOX_2D => Ok(()),
254 (PostGisBox3d(_), BaseScalar(d)) if d.id == codec::POSTGIS_BOX_3D => Ok(()),
255 (_, desc) => Err(ctx.wrong_type(&desc, self.kind())),
257 }
258 }
259 fn to_value(&self) -> Result<Value, Error> {
260 Ok(self.clone())
261 }
262}
263
264pub(crate) fn check_enum(variant_name: &str, expected_members: &[String]) -> Result<(), Error> {
265 if expected_members.iter().any(|c| c == variant_name) {
266 Ok(())
267 } else {
268 let mut members = expected_members
269 .iter()
270 .map(|c| format!("'{c}'"))
271 .collect::<Vec<_>>();
272 members.sort_unstable();
273 let members = members.join(", ");
274 Err(InvalidReferenceError::with_message(format!(
275 "Expected one of: {members}, while enum value '{variant_name}' was provided"
276 )))
277 }
278}
279
280impl QueryArgs for Value {
281 fn encode(&self, enc: &mut Encoder) -> Result<(), Error> {
282 let codec = enc.ctx.build_codec()?;
283 codec
284 .encode(enc.buf, self)
285 .map_err(ClientEncodingError::with_source)
286 }
287}
288
289impl<T: ScalarArg> QueryArg for T {
290 fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> {
291 enc.buf.reserve(4);
292 let pos = enc.buf.len();
293 enc.buf.put_u32(0); ScalarArg::encode(self, enc)?;
295 let len = enc.buf.len() - pos - 4;
296 enc.buf[pos..pos + 4].copy_from_slice(
297 &i32::try_from(len)
298 .ok()
299 .context(errors::ElementTooLong)
300 .map_err(ClientEncodingError::with_source)?
301 .to_be_bytes(),
302 );
303 Ok(())
304 }
305 fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
306 T::check_descriptor(ctx, pos)
307 }
308 fn to_value(&self) -> Result<Value, Error> {
309 ScalarArg::to_value(self)
310 }
311}
312
313impl<T: ScalarArg> QueryArg for Option<T> {
314 fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> {
315 if let Some(val) = self {
316 QueryArg::encode_slot(val, enc)
317 } else {
318 enc.buf.reserve(4);
319 enc.buf.put_i32(-1);
320 Ok(())
321 }
322 }
323 fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
324 T::check_descriptor(ctx, pos)
325 }
326 fn to_value(&self) -> Result<Value, Error> {
327 match self.as_ref() {
328 Some(v) => ScalarArg::to_value(v),
329 None => Ok(Value::Nothing),
330 }
331 }
332}
333
334impl<T: ScalarArg> QueryArg for Vec<T> {
335 fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> {
336 enc.buf.reserve(8);
337 enc.length_prefixed(|enc| {
338 if self.is_empty() {
339 enc.buf.reserve(12);
340 enc.buf.put_u32(0); enc.buf.put_u32(0); enc.buf.put_u32(0); return Ok(());
344 }
345 enc.buf.reserve(20);
346 enc.buf.put_u32(1); enc.buf.put_u32(0); enc.buf.put_u32(0); enc.buf.put_u32(
350 self.len()
351 .try_into()
352 .map_err(|_| ClientEncodingError::with_message("array is too long"))?,
353 );
354 enc.buf.put_u32(1); for item in self {
356 enc.length_prefixed(|enc| item.encode(enc))?;
357 }
358 Ok(())
359 })
360 }
361 fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
362 let desc = ctx.get(pos)?;
363 if let Descriptor::Array(arr) = desc {
364 T::check_descriptor(ctx, arr.type_pos)
365 } else {
366 Err(ctx.wrong_type(desc, "array"))
367 }
368 }
369 fn to_value(&self) -> Result<Value, Error> {
370 Ok(Value::Array(
371 self.iter()
372 .map(|v| v.to_value())
373 .collect::<Result<_, _>>()?,
374 ))
375 }
376}
377
378impl QueryArg for Vec<Value> {
379 fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> {
380 enc.buf.reserve(8);
381 enc.length_prefixed(|enc| {
382 if self.is_empty() {
383 enc.buf.reserve(12);
384 enc.buf.put_u32(0); enc.buf.put_u32(0); enc.buf.put_u32(0); return Ok(());
388 }
389 enc.buf.reserve(20);
390 enc.buf.put_u32(1); enc.buf.put_u32(0); enc.buf.put_u32(0); enc.buf.put_u32(
394 self.len()
395 .try_into()
396 .map_err(|_| ClientEncodingError::with_message("array is too long"))?,
397 );
398 enc.buf.put_u32(1); for item in self {
400 enc.length_prefixed(|enc| item.encode(enc))?;
401 }
402 Ok(())
403 })
404 }
405 fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
406 let desc = ctx.get(pos)?;
407 if let Descriptor::Array(arr) = desc {
408 for val in self {
409 val.check_descriptor(ctx, arr.type_pos)?
410 }
411 Ok(())
412 } else {
413 Err(ctx.wrong_type(desc, "array"))
414 }
415 }
416 fn to_value(&self) -> Result<Value, Error> {
417 Ok(Value::Array(
418 self.iter()
419 .map(|v| v.to_value())
420 .collect::<Result<_, _>>()?,
421 ))
422 }
423}
424
425impl QueryArg for range::Range<Box<Value>> {
426 fn encode_slot(&self, encoder: &mut Encoder) -> Result<(), Error> {
427 encoder.length_prefixed(|encoder| {
428 let flags = if self.empty {
429 range::EMPTY
430 } else {
431 (if self.inc_lower { range::LB_INC } else { 0 })
432 | (if self.inc_upper { range::UB_INC } else { 0 })
433 | (if self.lower.is_none() {
434 range::LB_INF
435 } else {
436 0
437 })
438 | (if self.upper.is_none() {
439 range::UB_INF
440 } else {
441 0
442 })
443 };
444 encoder.buf.reserve(1);
445 encoder.buf.put_u8(flags as u8);
446
447 if let Some(lower) = &self.lower {
448 encoder.length_prefixed(|encoder| lower.encode(encoder))?
449 }
450
451 if let Some(upper) = &self.upper {
452 encoder.length_prefixed(|encoder| upper.encode(encoder))?;
453 }
454 Ok(())
455 })
456 }
457 fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
458 let desc = ctx.get(pos)?;
459 if let Descriptor::Range(rng) = desc {
460 self.lower
461 .as_ref()
462 .map(|v| v.check_descriptor(ctx, rng.type_pos))
463 .transpose()?;
464 self.upper
465 .as_ref()
466 .map(|v| v.check_descriptor(ctx, rng.type_pos))
467 .transpose()?;
468 Ok(())
469 } else {
470 Err(ctx.wrong_type(desc, "range"))
471 }
472 }
473 fn to_value(&self) -> Result<Value, Error> {
474 Ok(Value::Range(self.clone()))
475 }
476}
477
478macro_rules! implement_tuple {
479 ( $count:expr, $($name:ident,)+ ) => {
480 impl<$($name:QueryArg),+> QueryArgs for ($($name,)+) {
481 fn encode(&self, enc: &mut Encoder)
482 -> Result<(), Error>
483 {
484 #![allow(non_snake_case)]
485 let root_pos = enc.ctx.root_pos
486 .ok_or_else(|| DescriptorMismatch::with_message(
487 format!(
488 "provided {} positional arguments, \
489 but no arguments expected by the server",
490 $count)))?;
491 let desc = enc.ctx.get(root_pos)?;
492 match desc {
493 Descriptor::ObjectShape(desc)
494 if enc.ctx.proto.is_at_least(0, 12)
495 => {
496 if desc.elements.len() != $count {
497 return Err(enc.ctx.field_number(
498 desc.elements.len(), $count));
499 }
500 let mut els = desc.elements.iter().enumerate();
501 let ($(ref $name,)+) = self;
502 $(
503 let (idx, el) = els.next().unwrap();
504 if el.name.parse() != Ok(idx) {
505 return Err(DescriptorMismatch::with_message(
506 format!("expected positional arguments, \
507 got {} instead of {}",
508 el.name, idx)));
509 }
510 $name.check_descriptor(enc.ctx, el.type_pos)?;
511 )+
512 }
513 Descriptor::Tuple(desc) if enc.ctx.proto.is_at_most(0, 11)
514 => {
515 if desc.element_types.len() != $count {
516 return Err(enc.ctx.field_number(
517 desc.element_types.len(), $count));
518 }
519 let mut els = desc.element_types.iter();
520 let ($(ref $name,)+) = self;
521 $(
522 let type_pos = els.next().unwrap();
523 $name.check_descriptor(enc.ctx, *type_pos)?;
524 )+
525 }
526 _ => return Err(enc.ctx.wrong_type(desc,
527 if enc.ctx.proto.is_at_least(0, 12) { "object" }
528 else { "tuple" }))
529 }
530
531 enc.buf.reserve(4 + 8*$count);
532 enc.buf.put_u32($count);
533 let ($(ref $name,)+) = self;
534 $(
535 enc.buf.reserve(8);
536 enc.buf.put_u32(0);
537 QueryArg::encode_slot($name, enc)?;
538 )*
539 Ok(())
540 }
541 }
542 }
543}
544
545implement_tuple! {1, T0, }
546implement_tuple! {2, T0, T1, }
547implement_tuple! {3, T0, T1, T2, }
548implement_tuple! {4, T0, T1, T2, T3, }
549implement_tuple! {5, T0, T1, T2, T3, T4, }
550implement_tuple! {6, T0, T1, T2, T3, T4, T5, }
551implement_tuple! {7, T0, T1, T2, T3, T4, T5, T6, }
552implement_tuple! {8, T0, T1, T2, T3, T4, T5, T6, T7, }
553implement_tuple! {9, T0, T1, T2, T3, T4, T5, T6, T7, T8, }
554implement_tuple! {10, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, }
555implement_tuple! {11, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, }
556implement_tuple! {12, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, }