gel_protocol/
query_arg.rs

1/*!
2Contains the [QueryArg] and [QueryArgs] traits.
3*/
4
5use std::convert::{TryFrom, TryInto};
6use std::ops::Deref;
7use std::sync::Arc;
8
9use bytes::{BufMut, BytesMut};
10use snafu::OptionExt;
11use uuid::Uuid;
12
13use gel_errors::ParameterTypeMismatchError;
14use gel_errors::{ClientEncodingError, DescriptorMismatch, ProtocolError};
15use gel_errors::{Error, ErrorKind, InvalidReferenceError};
16
17use crate::codec::{self, build_codec, Codec};
18use crate::descriptors::TypePos;
19use crate::descriptors::{Descriptor, EnumerationTypeDescriptor};
20use crate::errors;
21use crate::features::ProtocolVersion;
22use crate::model::range;
23use crate::value::Value;
24
25pub struct Encoder<'a> {
26    pub ctx: &'a DescriptorContext<'a>,
27    pub buf: &'a mut BytesMut,
28}
29
30/// A single argument for a query.
31pub trait QueryArg: Send + Sync + Sized {
32    fn encode_slot(&self, encoder: &mut Encoder) -> Result<(), Error>;
33    fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error>;
34    fn to_value(&self) -> Result<Value, Error>;
35}
36
37pub trait ScalarArg: Send + Sync + Sized {
38    fn encode(&self, encoder: &mut Encoder) -> Result<(), Error>;
39    fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error>;
40    fn to_value(&self) -> Result<Value, Error>;
41}
42
43/// A tuple of query arguments.
44///
45/// This trait is implemented for tuples of sizes up to twelve. You can derive
46/// it for a structure in this case it's treated as a named tuple (i.e. query
47/// should include named arguments rather than numeric ones).
48pub trait QueryArgs: Send + Sync {
49    fn encode(&self, encoder: &mut Encoder) -> Result<(), Error>;
50}
51
52pub struct DescriptorContext<'a> {
53    #[allow(dead_code)]
54    pub(crate) proto: &'a ProtocolVersion,
55    pub(crate) root_pos: Option<TypePos>,
56    pub(crate) descriptors: &'a [Descriptor],
57}
58
59impl<'a> Encoder<'a> {
60    pub fn new(ctx: &'a DescriptorContext<'a>, buf: &'a mut BytesMut) -> Encoder<'a> {
61        Encoder { ctx, buf }
62    }
63    pub fn length_prefixed(
64        &mut self,
65        f: impl FnOnce(&mut Encoder) -> Result<(), Error>,
66    ) -> Result<(), Error> {
67        self.buf.reserve(4);
68        let pos = self.buf.len();
69        self.buf.put_u32(0); // replaced after serializing a value
70                             //
71        f(self)?;
72
73        let len = self.buf.len() - pos - 4;
74        self.buf[pos..pos + 4].copy_from_slice(
75            &u32::try_from(len)
76                .map_err(|_| ClientEncodingError::with_message("alias is too long"))?
77                .to_be_bytes(),
78        );
79
80        Ok(())
81    }
82}
83
84impl DescriptorContext<'_> {
85    pub fn get(&self, type_pos: TypePos) -> Result<&Descriptor, Error> {
86        self.descriptors
87            .get(type_pos.0 as usize)
88            .ok_or_else(|| ProtocolError::with_message("invalid type descriptor"))
89    }
90    pub fn build_codec(&self) -> Result<Arc<dyn Codec>, Error> {
91        build_codec(self.root_pos, self.descriptors)
92            .map_err(|e| ProtocolError::with_source(e).context("error decoding input codec"))
93    }
94    pub fn wrong_type(&self, descriptor: &Descriptor, expected: &str) -> Error {
95        DescriptorMismatch::with_message(format!(
96            "server returned unexpected type {descriptor:?} when client expected {expected}"
97        ))
98    }
99    pub fn field_number(&self, expected: usize, unexpected: usize) -> Error {
100        DescriptorMismatch::with_message(format!(
101            "expected {} fields, got {}",
102            expected, unexpected
103        ))
104    }
105}
106
107impl<T: ScalarArg> ScalarArg for &T {
108    fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> {
109        (*self).encode(encoder)
110    }
111
112    fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
113        T::check_descriptor(ctx, pos)
114    }
115
116    fn to_value(&self) -> Result<Value, Error> {
117        (*self).to_value()
118    }
119}
120
121impl QueryArgs for () {
122    fn encode(&self, enc: &mut Encoder) -> Result<(), Error> {
123        if enc.ctx.root_pos.is_some() {
124            if enc.ctx.proto.is_at_most(0, 11) {
125                let root = enc.ctx.root_pos.and_then(|p| enc.ctx.get(p).ok());
126                match root {
127                    Some(Descriptor::Tuple(t))
128                        if t.id == Uuid::from_u128(0xFF) && t.element_types.is_empty() => {}
129                    _ => {
130                        return Err(ParameterTypeMismatchError::with_message(
131                            "query arguments expected",
132                        ))
133                    }
134                };
135            } else {
136                return Err(ParameterTypeMismatchError::with_message(
137                    "query arguments expected",
138                ));
139            }
140        }
141        if enc.ctx.proto.is_at_most(0, 11) {
142            enc.buf.reserve(4);
143            enc.buf.put_u32(0);
144        }
145        Ok(())
146    }
147}
148
149impl QueryArg for Value {
150    fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> {
151        use Value::*;
152        match self {
153            Nothing => {
154                enc.buf.reserve(4);
155                enc.buf.put_i32(-1);
156            }
157            Uuid(v) => v.encode_slot(enc)?,
158            Str(v) => v.encode_slot(enc)?,
159            Bytes(v) => v.encode_slot(enc)?,
160            Int16(v) => v.encode_slot(enc)?,
161            Int32(v) => v.encode_slot(enc)?,
162            Int64(v) => v.encode_slot(enc)?,
163            Float32(v) => v.encode_slot(enc)?,
164            Float64(v) => v.encode_slot(enc)?,
165            BigInt(v) => v.encode_slot(enc)?,
166            ConfigMemory(v) => v.encode_slot(enc)?,
167            Decimal(v) => v.encode_slot(enc)?,
168            Bool(v) => v.encode_slot(enc)?,
169            Datetime(v) => v.encode_slot(enc)?,
170            LocalDatetime(v) => v.encode_slot(enc)?,
171            LocalDate(v) => v.encode_slot(enc)?,
172            LocalTime(v) => v.encode_slot(enc)?,
173            Duration(v) => v.encode_slot(enc)?,
174            RelativeDuration(v) => v.encode_slot(enc)?,
175            DateDuration(v) => v.encode_slot(enc)?,
176            Json(v) => v.encode_slot(enc)?,
177            Set(_) => {
178                return Err(ClientEncodingError::with_message(
179                    "set cannot be query argument",
180                ))
181            }
182            Object { .. } => {
183                return Err(ClientEncodingError::with_message(
184                    "object cannot be query argument",
185                ))
186            }
187            SparseObject(_) => {
188                return Err(ClientEncodingError::with_message(
189                    "sparse object cannot be query argument",
190                ))
191            }
192            Tuple(_) => {
193                return Err(ClientEncodingError::with_message(
194                    "tuple object cannot be query argument",
195                ))
196            }
197            NamedTuple { .. } => {
198                return Err(ClientEncodingError::with_message(
199                    "named tuple object cannot be query argument",
200                ))
201            }
202            Array(v) => v.encode_slot(enc)?,
203            Enum(v) => v.encode_slot(enc)?,
204            Range(v) => v.encode_slot(enc)?,
205            Vector(v) => crate::model::VectorRef(v).encode_slot(enc)?,
206            PostGisGeometry(v) => v.encode_slot(enc)?,
207            PostGisGeography(v) => v.encode_slot(enc)?,
208            PostGisBox2d(v) => v.encode_slot(enc)?,
209            PostGisBox3d(v) => v.encode_slot(enc)?,
210            SQLRow { .. } => {
211                return Err(ClientEncodingError::with_message(
212                    "SQL row cannot be query argument",
213                ))
214            }
215        }
216
217        Ok(())
218    }
219    fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
220        use Descriptor::*;
221        use Value::*;
222        let desc = ctx.get(pos)?.normalize_to_base(ctx)?;
223
224        match (self, desc) {
225            (Nothing, _) => Ok(()), // any descriptor works
226            (BigInt(_), BaseScalar(d)) if d.id == codec::STD_BIGINT => Ok(()),
227            (Bool(_), BaseScalar(d)) if d.id == codec::STD_BOOL => Ok(()),
228            (Bytes(_), BaseScalar(d)) if d.id == codec::STD_BYTES => Ok(()),
229            (ConfigMemory(_), BaseScalar(d)) if d.id == codec::CFG_MEMORY => Ok(()),
230            (DateDuration(_), BaseScalar(d)) if d.id == codec::CAL_DATE_DURATION => Ok(()),
231            (Datetime(_), BaseScalar(d)) if d.id == codec::STD_DATETIME => Ok(()),
232            (Decimal(_), BaseScalar(d)) if d.id == codec::STD_DECIMAL => Ok(()),
233            (Duration(_), BaseScalar(d)) if d.id == codec::STD_DURATION => Ok(()),
234            (Float32(_), BaseScalar(d)) if d.id == codec::STD_FLOAT32 => Ok(()),
235            (Float64(_), BaseScalar(d)) if d.id == codec::STD_FLOAT64 => Ok(()),
236            (Int16(_), BaseScalar(d)) if d.id == codec::STD_INT16 => Ok(()),
237            (Int32(_), BaseScalar(d)) if d.id == codec::STD_INT32 => Ok(()),
238            (Int64(_), BaseScalar(d)) if d.id == codec::STD_INT64 => Ok(()),
239            (Json(_), BaseScalar(d)) if d.id == codec::STD_JSON => Ok(()),
240            (LocalDate(_), BaseScalar(d)) if d.id == codec::CAL_LOCAL_DATE => Ok(()),
241            (LocalDatetime(_), BaseScalar(d)) if d.id == codec::CAL_LOCAL_DATETIME => Ok(()),
242            (LocalTime(_), BaseScalar(d)) if d.id == codec::CAL_LOCAL_TIME => Ok(()),
243            (RelativeDuration(_), BaseScalar(d)) if d.id == codec::CAL_RELATIVE_DURATION => Ok(()),
244            (Str(_), BaseScalar(d)) if d.id == codec::STD_STR => Ok(()),
245            (Uuid(_), BaseScalar(d)) if d.id == codec::STD_UUID => Ok(()),
246            (Enum(val), Enumeration(EnumerationTypeDescriptor { members, .. })) => {
247                let val = val.deref();
248                check_enum(val, &members)
249            }
250            (Vector(_), BaseScalar(d)) if d.id == codec::PGVECTOR_VECTOR => Ok(()),
251            (PostGisGeometry(_), BaseScalar(d)) if d.id == codec::POSTGIS_GEOMETRY => Ok(()),
252            (PostGisGeography(_), BaseScalar(d)) if d.id == codec::POSTGIS_GEOGRAPHY => Ok(()),
253            (PostGisBox2d(_), BaseScalar(d)) if d.id == codec::POSTGIS_BOX_2D => Ok(()),
254            (PostGisBox3d(_), BaseScalar(d)) if d.id == codec::POSTGIS_BOX_3D => Ok(()),
255            // TODO(tailhook) all types
256            (_, desc) => Err(ctx.wrong_type(&desc, self.kind())),
257        }
258    }
259    fn to_value(&self) -> Result<Value, Error> {
260        Ok(self.clone())
261    }
262}
263
264pub(crate) fn check_enum(variant_name: &str, expected_members: &[String]) -> Result<(), Error> {
265    if expected_members.iter().any(|c| c == variant_name) {
266        Ok(())
267    } else {
268        let mut members = expected_members
269            .iter()
270            .map(|c| format!("'{c}'"))
271            .collect::<Vec<_>>();
272        members.sort_unstable();
273        let members = members.join(", ");
274        Err(InvalidReferenceError::with_message(format!(
275            "Expected one of: {members}, while enum value '{variant_name}' was provided"
276        )))
277    }
278}
279
280impl QueryArgs for Value {
281    fn encode(&self, enc: &mut Encoder) -> Result<(), Error> {
282        let codec = enc.ctx.build_codec()?;
283        codec
284            .encode(enc.buf, self)
285            .map_err(ClientEncodingError::with_source)
286    }
287}
288
289impl<T: ScalarArg> QueryArg for T {
290    fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> {
291        enc.buf.reserve(4);
292        let pos = enc.buf.len();
293        enc.buf.put_u32(0); // will fill after encoding
294        ScalarArg::encode(self, enc)?;
295        let len = enc.buf.len() - pos - 4;
296        enc.buf[pos..pos + 4].copy_from_slice(
297            &i32::try_from(len)
298                .ok()
299                .context(errors::ElementTooLong)
300                .map_err(ClientEncodingError::with_source)?
301                .to_be_bytes(),
302        );
303        Ok(())
304    }
305    fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
306        T::check_descriptor(ctx, pos)
307    }
308    fn to_value(&self) -> Result<Value, Error> {
309        ScalarArg::to_value(self)
310    }
311}
312
313impl<T: ScalarArg> QueryArg for Option<T> {
314    fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> {
315        if let Some(val) = self {
316            QueryArg::encode_slot(val, enc)
317        } else {
318            enc.buf.reserve(4);
319            enc.buf.put_i32(-1);
320            Ok(())
321        }
322    }
323    fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
324        T::check_descriptor(ctx, pos)
325    }
326    fn to_value(&self) -> Result<Value, Error> {
327        match self.as_ref() {
328            Some(v) => ScalarArg::to_value(v),
329            None => Ok(Value::Nothing),
330        }
331    }
332}
333
334impl<T: ScalarArg> QueryArg for Vec<T> {
335    fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> {
336        enc.buf.reserve(8);
337        enc.length_prefixed(|enc| {
338            if self.is_empty() {
339                enc.buf.reserve(12);
340                enc.buf.put_u32(0); // ndims
341                enc.buf.put_u32(0); // reserved0
342                enc.buf.put_u32(0); // reserved1
343                return Ok(());
344            }
345            enc.buf.reserve(20);
346            enc.buf.put_u32(1); // ndims
347            enc.buf.put_u32(0); // reserved0
348            enc.buf.put_u32(0); // reserved1
349            enc.buf.put_u32(
350                self.len()
351                    .try_into()
352                    .map_err(|_| ClientEncodingError::with_message("array is too long"))?,
353            );
354            enc.buf.put_u32(1); // lower
355            for item in self {
356                enc.length_prefixed(|enc| item.encode(enc))?;
357            }
358            Ok(())
359        })
360    }
361    fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
362        let desc = ctx.get(pos)?;
363        if let Descriptor::Array(arr) = desc {
364            T::check_descriptor(ctx, arr.type_pos)
365        } else {
366            Err(ctx.wrong_type(desc, "array"))
367        }
368    }
369    fn to_value(&self) -> Result<Value, Error> {
370        Ok(Value::Array(
371            self.iter()
372                .map(|v| v.to_value())
373                .collect::<Result<_, _>>()?,
374        ))
375    }
376}
377
378impl QueryArg for Vec<Value> {
379    fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> {
380        enc.buf.reserve(8);
381        enc.length_prefixed(|enc| {
382            if self.is_empty() {
383                enc.buf.reserve(12);
384                enc.buf.put_u32(0); // ndims
385                enc.buf.put_u32(0); // reserved0
386                enc.buf.put_u32(0); // reserved1
387                return Ok(());
388            }
389            enc.buf.reserve(20);
390            enc.buf.put_u32(1); // ndims
391            enc.buf.put_u32(0); // reserved0
392            enc.buf.put_u32(0); // reserved1
393            enc.buf.put_u32(
394                self.len()
395                    .try_into()
396                    .map_err(|_| ClientEncodingError::with_message("array is too long"))?,
397            );
398            enc.buf.put_u32(1); // lower
399            for item in self {
400                enc.length_prefixed(|enc| item.encode(enc))?;
401            }
402            Ok(())
403        })
404    }
405    fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
406        let desc = ctx.get(pos)?;
407        if let Descriptor::Array(arr) = desc {
408            for val in self {
409                val.check_descriptor(ctx, arr.type_pos)?
410            }
411            Ok(())
412        } else {
413            Err(ctx.wrong_type(desc, "array"))
414        }
415    }
416    fn to_value(&self) -> Result<Value, Error> {
417        Ok(Value::Array(
418            self.iter()
419                .map(|v| v.to_value())
420                .collect::<Result<_, _>>()?,
421        ))
422    }
423}
424
425impl QueryArg for range::Range<Box<Value>> {
426    fn encode_slot(&self, encoder: &mut Encoder) -> Result<(), Error> {
427        encoder.length_prefixed(|encoder| {
428            let flags = if self.empty {
429                range::EMPTY
430            } else {
431                (if self.inc_lower { range::LB_INC } else { 0 })
432                    | (if self.inc_upper { range::UB_INC } else { 0 })
433                    | (if self.lower.is_none() {
434                        range::LB_INF
435                    } else {
436                        0
437                    })
438                    | (if self.upper.is_none() {
439                        range::UB_INF
440                    } else {
441                        0
442                    })
443            };
444            encoder.buf.reserve(1);
445            encoder.buf.put_u8(flags as u8);
446
447            if let Some(lower) = &self.lower {
448                encoder.length_prefixed(|encoder| lower.encode(encoder))?
449            }
450
451            if let Some(upper) = &self.upper {
452                encoder.length_prefixed(|encoder| upper.encode(encoder))?;
453            }
454            Ok(())
455        })
456    }
457    fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
458        let desc = ctx.get(pos)?;
459        if let Descriptor::Range(rng) = desc {
460            self.lower
461                .as_ref()
462                .map(|v| v.check_descriptor(ctx, rng.type_pos))
463                .transpose()?;
464            self.upper
465                .as_ref()
466                .map(|v| v.check_descriptor(ctx, rng.type_pos))
467                .transpose()?;
468            Ok(())
469        } else {
470            Err(ctx.wrong_type(desc, "range"))
471        }
472    }
473    fn to_value(&self) -> Result<Value, Error> {
474        Ok(Value::Range(self.clone()))
475    }
476}
477
478macro_rules! implement_tuple {
479    ( $count:expr, $($name:ident,)+ ) => {
480        impl<$($name:QueryArg),+> QueryArgs for ($($name,)+) {
481            fn encode(&self, enc: &mut Encoder)
482                -> Result<(), Error>
483            {
484                #![allow(non_snake_case)]
485                let root_pos = enc.ctx.root_pos
486                    .ok_or_else(|| DescriptorMismatch::with_message(
487                        format!(
488                            "provided {} positional arguments, \
489                             but no arguments expected by the server",
490                             $count)))?;
491                let desc = enc.ctx.get(root_pos)?;
492                match desc {
493                    Descriptor::ObjectShape(desc)
494                    if enc.ctx.proto.is_at_least(0, 12)
495                    => {
496                        if desc.elements.len() != $count {
497                            return Err(enc.ctx.field_number(
498                                desc.elements.len(), $count));
499                        }
500                        let mut els = desc.elements.iter().enumerate();
501                        let ($(ref $name,)+) = self;
502                        $(
503                            let (idx, el) = els.next().unwrap();
504                            if el.name.parse() != Ok(idx) {
505                                return Err(DescriptorMismatch::with_message(
506                                    format!("expected positional arguments, \
507                                             got {} instead of {}",
508                                             el.name, idx)));
509                            }
510                            $name.check_descriptor(enc.ctx, el.type_pos)?;
511                        )+
512                    }
513                    Descriptor::Tuple(desc) if enc.ctx.proto.is_at_most(0, 11)
514                    => {
515                        if desc.element_types.len() != $count {
516                            return Err(enc.ctx.field_number(
517                                desc.element_types.len(), $count));
518                        }
519                        let mut els = desc.element_types.iter();
520                        let ($(ref $name,)+) = self;
521                        $(
522                            let type_pos = els.next().unwrap();
523                            $name.check_descriptor(enc.ctx, *type_pos)?;
524                        )+
525                    }
526                    _ => return Err(enc.ctx.wrong_type(desc,
527                        if enc.ctx.proto.is_at_least(0, 12) { "object" }
528                        else { "tuple" }))
529                }
530
531                enc.buf.reserve(4 + 8*$count);
532                enc.buf.put_u32($count);
533                let ($(ref $name,)+) = self;
534                $(
535                    enc.buf.reserve(8);
536                    enc.buf.put_u32(0);
537                    QueryArg::encode_slot($name, enc)?;
538                )*
539                Ok(())
540            }
541        }
542    }
543}
544
545implement_tuple! {1, T0, }
546implement_tuple! {2, T0, T1, }
547implement_tuple! {3, T0, T1, T2, }
548implement_tuple! {4, T0, T1, T2, T3, }
549implement_tuple! {5, T0, T1, T2, T3, T4, }
550implement_tuple! {6, T0, T1, T2, T3, T4, T5, }
551implement_tuple! {7, T0, T1, T2, T3, T4, T5, T6, }
552implement_tuple! {8, T0, T1, T2, T3, T4, T5, T6, T7, }
553implement_tuple! {9, T0, T1, T2, T3, T4, T5, T6, T7, T8, }
554implement_tuple! {10, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, }
555implement_tuple! {11, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, }
556implement_tuple! {12, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, }