gel_protocol/
query_arg.rs

1/*!
2Contains the [QueryArg] and [QueryArgs] traits.
3*/
4
5use std::convert::{TryFrom, TryInto};
6use std::ops::Deref;
7use std::sync::Arc;
8
9use bytes::{BufMut, BytesMut};
10use snafu::OptionExt;
11use uuid::Uuid;
12
13use gel_errors::ParameterTypeMismatchError;
14use gel_errors::{ClientEncodingError, DescriptorMismatch, ProtocolError};
15use gel_errors::{Error, ErrorKind, InvalidReferenceError};
16
17use crate::codec::{self, build_codec, Codec};
18use crate::descriptors::TypePos;
19use crate::descriptors::{Descriptor, EnumerationTypeDescriptor};
20use crate::errors;
21use crate::features::ProtocolVersion;
22use crate::model::range;
23use crate::value::Value;
24
25pub struct Encoder<'a> {
26    pub ctx: &'a DescriptorContext<'a>,
27    pub buf: &'a mut BytesMut,
28}
29
30/// A single argument for a query.
31pub trait QueryArg: Send + Sync + Sized {
32    fn encode_slot(&self, encoder: &mut Encoder) -> Result<(), Error>;
33    fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error>;
34    fn to_value(&self) -> Result<Value, Error>;
35}
36
37pub trait ScalarArg: Send + Sync + Sized {
38    fn encode(&self, encoder: &mut Encoder) -> Result<(), Error>;
39    fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error>;
40    fn to_value(&self) -> Result<Value, Error>;
41}
42
43/// A tuple of query arguments.
44///
45/// This trait is implemented for tuples of sizes up to twelve. You can derive
46/// it for a structure in this case it's treated as a named tuple (i.e. query
47/// should include named arguments rather than numeric ones).
48pub trait QueryArgs: Send + Sync {
49    fn encode(&self, encoder: &mut Encoder) -> Result<(), Error>;
50}
51
52pub struct DescriptorContext<'a> {
53    #[allow(dead_code)]
54    pub(crate) proto: &'a ProtocolVersion,
55    pub(crate) root_pos: Option<TypePos>,
56    pub(crate) descriptors: &'a [Descriptor],
57}
58
59impl<'a> Encoder<'a> {
60    pub fn new(ctx: &'a DescriptorContext<'a>, buf: &'a mut BytesMut) -> Encoder<'a> {
61        Encoder { ctx, buf }
62    }
63    pub fn length_prefixed(
64        &mut self,
65        f: impl FnOnce(&mut Encoder) -> Result<(), Error>,
66    ) -> Result<(), Error> {
67        self.buf.reserve(4);
68        let pos = self.buf.len();
69        self.buf.put_u32(0); // replaced after serializing a value
70                             //
71        f(self)?;
72
73        let len = self.buf.len() - pos - 4;
74        self.buf[pos..pos + 4].copy_from_slice(
75            &u32::try_from(len)
76                .map_err(|_| ClientEncodingError::with_message("alias is too long"))?
77                .to_be_bytes(),
78        );
79
80        Ok(())
81    }
82}
83
84impl DescriptorContext<'_> {
85    pub fn get(&self, type_pos: TypePos) -> Result<&Descriptor, Error> {
86        self.descriptors
87            .get(type_pos.0 as usize)
88            .ok_or_else(|| ProtocolError::with_message("invalid type descriptor"))
89    }
90    pub fn build_codec(&self) -> Result<Arc<dyn Codec>, Error> {
91        build_codec(self.root_pos, self.descriptors)
92            .map_err(|e| ProtocolError::with_source(e).context("error decoding input codec"))
93    }
94    pub fn wrong_type(&self, descriptor: &Descriptor, expected: &str) -> Error {
95        DescriptorMismatch::with_message(format!(
96            "server returned unexpected type {descriptor:?} when client expected {expected}"
97        ))
98    }
99    pub fn field_number(&self, expected: usize, unexpected: usize) -> Error {
100        DescriptorMismatch::with_message(format!("expected {expected} fields, got {unexpected}"))
101    }
102}
103
104impl<T: ScalarArg> ScalarArg for &T {
105    fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> {
106        (*self).encode(encoder)
107    }
108
109    fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
110        T::check_descriptor(ctx, pos)
111    }
112
113    fn to_value(&self) -> Result<Value, Error> {
114        (*self).to_value()
115    }
116}
117
118impl QueryArgs for () {
119    fn encode(&self, enc: &mut Encoder) -> Result<(), Error> {
120        if enc.ctx.root_pos.is_some() {
121            if enc.ctx.proto.is_at_most(0, 11) {
122                let root = enc.ctx.root_pos.and_then(|p| enc.ctx.get(p).ok());
123                match root {
124                    Some(Descriptor::Tuple(t))
125                        if t.id == Uuid::from_u128(0xFF) && t.element_types.is_empty() => {}
126                    _ => {
127                        return Err(ParameterTypeMismatchError::with_message(
128                            "query arguments expected",
129                        ))
130                    }
131                };
132            } else {
133                return Err(ParameterTypeMismatchError::with_message(
134                    "query arguments expected",
135                ));
136            }
137        }
138        if enc.ctx.proto.is_at_most(0, 11) {
139            enc.buf.reserve(4);
140            enc.buf.put_u32(0);
141        }
142        Ok(())
143    }
144}
145
146impl QueryArg for Value {
147    fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> {
148        use Value::*;
149        match self {
150            Nothing => {
151                enc.buf.reserve(4);
152                enc.buf.put_i32(-1);
153            }
154            Uuid(v) => v.encode_slot(enc)?,
155            Str(v) => v.encode_slot(enc)?,
156            Bytes(v) => v.encode_slot(enc)?,
157            Int16(v) => v.encode_slot(enc)?,
158            Int32(v) => v.encode_slot(enc)?,
159            Int64(v) => v.encode_slot(enc)?,
160            Float32(v) => v.encode_slot(enc)?,
161            Float64(v) => v.encode_slot(enc)?,
162            BigInt(v) => v.encode_slot(enc)?,
163            ConfigMemory(v) => v.encode_slot(enc)?,
164            Decimal(v) => v.encode_slot(enc)?,
165            Bool(v) => v.encode_slot(enc)?,
166            Datetime(v) => v.encode_slot(enc)?,
167            LocalDatetime(v) => v.encode_slot(enc)?,
168            LocalDate(v) => v.encode_slot(enc)?,
169            LocalTime(v) => v.encode_slot(enc)?,
170            Duration(v) => v.encode_slot(enc)?,
171            RelativeDuration(v) => v.encode_slot(enc)?,
172            DateDuration(v) => v.encode_slot(enc)?,
173            Json(v) => v.encode_slot(enc)?,
174            Set(_) => {
175                return Err(ClientEncodingError::with_message(
176                    "set cannot be query argument",
177                ))
178            }
179            Object { .. } => {
180                return Err(ClientEncodingError::with_message(
181                    "object cannot be query argument",
182                ))
183            }
184            SparseObject(_) => {
185                return Err(ClientEncodingError::with_message(
186                    "sparse object cannot be query argument",
187                ))
188            }
189            Tuple(_) => {
190                return Err(ClientEncodingError::with_message(
191                    "tuple object cannot be query argument",
192                ))
193            }
194            NamedTuple { .. } => {
195                return Err(ClientEncodingError::with_message(
196                    "named tuple object cannot be query argument",
197                ))
198            }
199            Array(_) => {
200                return Err(ClientEncodingError::with_message(
201                    "array cannot be query argument",
202                ))
203            }
204            Enum(v) => v.encode_slot(enc)?,
205            Range(v) => v.encode_slot(enc)?,
206            Vector(v) => crate::model::VectorRef(v).encode_slot(enc)?,
207            PostGisGeometry(v) => v.encode_slot(enc)?,
208            PostGisGeography(v) => v.encode_slot(enc)?,
209            PostGisBox2d(v) => v.encode_slot(enc)?,
210            PostGisBox3d(v) => v.encode_slot(enc)?,
211            SQLRow { .. } => {
212                return Err(ClientEncodingError::with_message(
213                    "SQL row cannot be query argument",
214                ))
215            }
216        }
217
218        Ok(())
219    }
220    fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
221        use Descriptor::*;
222        use Value::*;
223        let desc = ctx.get(pos)?.normalize_to_base(ctx)?;
224
225        match (self, desc) {
226            (Nothing, _) => Ok(()), // any descriptor works
227            (BigInt(_), BaseScalar(d)) if d.id == codec::STD_BIGINT => Ok(()),
228            (Bool(_), BaseScalar(d)) if d.id == codec::STD_BOOL => Ok(()),
229            (Bytes(_), BaseScalar(d)) if d.id == codec::STD_BYTES => Ok(()),
230            (ConfigMemory(_), BaseScalar(d)) if d.id == codec::CFG_MEMORY => Ok(()),
231            (DateDuration(_), BaseScalar(d)) if d.id == codec::CAL_DATE_DURATION => Ok(()),
232            (Datetime(_), BaseScalar(d)) if d.id == codec::STD_DATETIME => Ok(()),
233            (Decimal(_), BaseScalar(d)) if d.id == codec::STD_DECIMAL => Ok(()),
234            (Duration(_), BaseScalar(d)) if d.id == codec::STD_DURATION => Ok(()),
235            (Float32(_), BaseScalar(d)) if d.id == codec::STD_FLOAT32 => Ok(()),
236            (Float64(_), BaseScalar(d)) if d.id == codec::STD_FLOAT64 => Ok(()),
237            (Int16(_), BaseScalar(d)) if d.id == codec::STD_INT16 => Ok(()),
238            (Int32(_), BaseScalar(d)) if d.id == codec::STD_INT32 => Ok(()),
239            (Int64(_), BaseScalar(d)) if d.id == codec::STD_INT64 => Ok(()),
240            (Json(_), BaseScalar(d)) if d.id == codec::STD_JSON => Ok(()),
241            (LocalDate(_), BaseScalar(d)) if d.id == codec::CAL_LOCAL_DATE => Ok(()),
242            (LocalDatetime(_), BaseScalar(d)) if d.id == codec::CAL_LOCAL_DATETIME => Ok(()),
243            (LocalTime(_), BaseScalar(d)) if d.id == codec::CAL_LOCAL_TIME => Ok(()),
244            (RelativeDuration(_), BaseScalar(d)) if d.id == codec::CAL_RELATIVE_DURATION => Ok(()),
245            (Str(_), BaseScalar(d)) if d.id == codec::STD_STR => Ok(()),
246            (Uuid(_), BaseScalar(d)) if d.id == codec::STD_UUID => Ok(()),
247            (Enum(val), Enumeration(EnumerationTypeDescriptor { members, .. })) => {
248                let val = val.deref();
249                check_enum(val, &members)
250            }
251            (Vector(_), BaseScalar(d)) if d.id == codec::PGVECTOR_VECTOR => Ok(()),
252            (PostGisGeometry(_), BaseScalar(d)) if d.id == codec::POSTGIS_GEOMETRY => Ok(()),
253            (PostGisGeography(_), BaseScalar(d)) if d.id == codec::POSTGIS_GEOGRAPHY => Ok(()),
254            (PostGisBox2d(_), BaseScalar(d)) if d.id == codec::POSTGIS_BOX_2D => Ok(()),
255            (PostGisBox3d(_), BaseScalar(d)) if d.id == codec::POSTGIS_BOX_3D => Ok(()),
256            // TODO(tailhook) all types
257            (_, desc) => Err(ctx.wrong_type(&desc, self.kind())),
258        }
259    }
260    fn to_value(&self) -> Result<Value, Error> {
261        Ok(self.clone())
262    }
263}
264
265pub(crate) fn check_enum(variant_name: &str, expected_members: &[String]) -> Result<(), Error> {
266    if expected_members.iter().any(|c| c == variant_name) {
267        Ok(())
268    } else {
269        let mut members = expected_members
270            .iter()
271            .map(|c| format!("'{c}'"))
272            .collect::<Vec<_>>();
273        members.sort_unstable();
274        let members = members.join(", ");
275        Err(InvalidReferenceError::with_message(format!(
276            "Expected one of: {members}, while enum value '{variant_name}' was provided"
277        )))
278    }
279}
280
281impl QueryArgs for Value {
282    fn encode(&self, enc: &mut Encoder) -> Result<(), Error> {
283        let codec = enc.ctx.build_codec()?;
284        codec
285            .encode(enc.buf, self)
286            .map_err(ClientEncodingError::with_source)
287    }
288}
289
290impl<T: ScalarArg> QueryArg for T {
291    fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> {
292        enc.buf.reserve(4);
293        let pos = enc.buf.len();
294        enc.buf.put_u32(0); // will fill after encoding
295        ScalarArg::encode(self, enc)?;
296        let len = enc.buf.len() - pos - 4;
297        enc.buf[pos..pos + 4].copy_from_slice(
298            &i32::try_from(len)
299                .ok()
300                .context(errors::ElementTooLong)
301                .map_err(ClientEncodingError::with_source)?
302                .to_be_bytes(),
303        );
304        Ok(())
305    }
306    fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
307        T::check_descriptor(ctx, pos)
308    }
309    fn to_value(&self) -> Result<Value, Error> {
310        ScalarArg::to_value(self)
311    }
312}
313
314impl<T: ScalarArg> QueryArg for Option<T> {
315    fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> {
316        if let Some(val) = self {
317            QueryArg::encode_slot(val, enc)
318        } else {
319            enc.buf.reserve(4);
320            enc.buf.put_i32(-1);
321            Ok(())
322        }
323    }
324    fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
325        T::check_descriptor(ctx, pos)
326    }
327    fn to_value(&self) -> Result<Value, Error> {
328        match self.as_ref() {
329            Some(v) => ScalarArg::to_value(v),
330            None => Ok(Value::Nothing),
331        }
332    }
333}
334
335impl<T: ScalarArg> QueryArg for Vec<T> {
336    fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> {
337        enc.buf.reserve(8);
338        enc.length_prefixed(|enc| {
339            if self.is_empty() {
340                enc.buf.reserve(12);
341                enc.buf.put_u32(0); // ndims
342                enc.buf.put_u32(0); // reserved0
343                enc.buf.put_u32(0); // reserved1
344                return Ok(());
345            }
346            enc.buf.reserve(20);
347            enc.buf.put_u32(1); // ndims
348            enc.buf.put_u32(0); // reserved0
349            enc.buf.put_u32(0); // reserved1
350            enc.buf.put_u32(
351                self.len()
352                    .try_into()
353                    .map_err(|_| ClientEncodingError::with_message("array is too long"))?,
354            );
355            enc.buf.put_u32(1); // lower
356            for item in self {
357                enc.length_prefixed(|enc| item.encode(enc))?;
358            }
359            Ok(())
360        })
361    }
362    fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
363        let desc = ctx.get(pos)?;
364        if let Descriptor::Array(arr) = desc {
365            T::check_descriptor(ctx, arr.type_pos)
366        } else {
367            Err(ctx.wrong_type(desc, "array"))
368        }
369    }
370    fn to_value(&self) -> Result<Value, Error> {
371        Ok(Value::Array(
372            self.iter()
373                .map(|v| v.to_value())
374                .collect::<Result<_, _>>()?,
375        ))
376    }
377}
378
379impl QueryArg for Vec<Value> {
380    fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> {
381        enc.buf.reserve(8);
382        enc.length_prefixed(|enc| {
383            if self.is_empty() {
384                enc.buf.reserve(12);
385                enc.buf.put_u32(0); // ndims
386                enc.buf.put_u32(0); // reserved0
387                enc.buf.put_u32(0); // reserved1
388                return Ok(());
389            }
390            enc.buf.reserve(20);
391            enc.buf.put_u32(1); // ndims
392            enc.buf.put_u32(0); // reserved0
393            enc.buf.put_u32(0); // reserved1
394            enc.buf.put_u32(
395                self.len()
396                    .try_into()
397                    .map_err(|_| ClientEncodingError::with_message("array is too long"))?,
398            );
399            enc.buf.put_u32(1); // lower
400            for item in self {
401                enc.length_prefixed(|enc| item.encode(enc))?;
402            }
403            Ok(())
404        })
405    }
406    fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
407        let desc = ctx.get(pos)?;
408        if let Descriptor::Array(arr) = desc {
409            for val in self {
410                val.check_descriptor(ctx, arr.type_pos)?
411            }
412            Ok(())
413        } else {
414            Err(ctx.wrong_type(desc, "array"))
415        }
416    }
417    fn to_value(&self) -> Result<Value, Error> {
418        Ok(Value::Array(
419            self.iter()
420                .map(|v| v.to_value())
421                .collect::<Result<_, _>>()?,
422        ))
423    }
424}
425
426impl QueryArg for range::Range<Box<Value>> {
427    fn encode_slot(&self, encoder: &mut Encoder) -> Result<(), Error> {
428        encoder.length_prefixed(|encoder| {
429            let flags = if self.empty {
430                range::EMPTY
431            } else {
432                (if self.inc_lower { range::LB_INC } else { 0 })
433                    | (if self.inc_upper { range::UB_INC } else { 0 })
434                    | (if self.lower.is_none() {
435                        range::LB_INF
436                    } else {
437                        0
438                    })
439                    | (if self.upper.is_none() {
440                        range::UB_INF
441                    } else {
442                        0
443                    })
444            };
445            encoder.buf.reserve(1);
446            encoder.buf.put_u8(flags as u8);
447
448            if let Some(lower) = &self.lower {
449                encoder.length_prefixed(|encoder| lower.encode(encoder))?
450            }
451
452            if let Some(upper) = &self.upper {
453                encoder.length_prefixed(|encoder| upper.encode(encoder))?;
454            }
455            Ok(())
456        })
457    }
458    fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
459        let desc = ctx.get(pos)?;
460        if let Descriptor::Range(rng) = desc {
461            self.lower
462                .as_ref()
463                .map(|v| v.check_descriptor(ctx, rng.type_pos))
464                .transpose()?;
465            self.upper
466                .as_ref()
467                .map(|v| v.check_descriptor(ctx, rng.type_pos))
468                .transpose()?;
469            Ok(())
470        } else {
471            Err(ctx.wrong_type(desc, "range"))
472        }
473    }
474    fn to_value(&self) -> Result<Value, Error> {
475        Ok(Value::Range(self.clone()))
476    }
477}
478
479macro_rules! implement_tuple {
480    ( $count:expr, $($name:ident,)+ ) => {
481        impl<$($name:QueryArg),+> QueryArgs for ($($name,)+) {
482            fn encode(&self, enc: &mut Encoder)
483                -> Result<(), Error>
484            {
485                #![allow(non_snake_case)]
486                let root_pos = enc.ctx.root_pos
487                    .ok_or_else(|| DescriptorMismatch::with_message(
488                        format!(
489                            "provided {} positional arguments, \
490                             but no arguments expected by the server",
491                             $count)))?;
492                let desc = enc.ctx.get(root_pos)?;
493                match desc {
494                    Descriptor::ObjectShape(desc)
495                    if enc.ctx.proto.is_at_least(0, 12)
496                    => {
497                        if desc.elements.len() != $count {
498                            return Err(enc.ctx.field_number(
499                                desc.elements.len(), $count));
500                        }
501                        let mut els = desc.elements.iter().enumerate();
502                        let ($(ref $name,)+) = self;
503                        $(
504                            let (idx, el) = els.next().unwrap();
505                            if el.name.parse() != Ok(idx) {
506                                return Err(DescriptorMismatch::with_message(
507                                    format!("expected positional arguments, \
508                                             got {} instead of {}",
509                                             el.name, idx)));
510                            }
511                            $name.check_descriptor(enc.ctx, el.type_pos)?;
512                        )+
513                    }
514                    Descriptor::Tuple(desc) if enc.ctx.proto.is_at_most(0, 11)
515                    => {
516                        if desc.element_types.len() != $count {
517                            return Err(enc.ctx.field_number(
518                                desc.element_types.len(), $count));
519                        }
520                        let mut els = desc.element_types.iter();
521                        let ($(ref $name,)+) = self;
522                        $(
523                            let type_pos = els.next().unwrap();
524                            $name.check_descriptor(enc.ctx, *type_pos)?;
525                        )+
526                    }
527                    _ => return Err(enc.ctx.wrong_type(desc,
528                        if enc.ctx.proto.is_at_least(0, 12) { "object" }
529                        else { "tuple" }))
530                }
531
532                enc.buf.reserve(4 + 8*$count);
533                enc.buf.put_u32($count);
534                let ($(ref $name,)+) = self;
535                $(
536                    enc.buf.reserve(8);
537                    enc.buf.put_u32(0);
538                    QueryArg::encode_slot($name, enc)?;
539                )*
540                Ok(())
541            }
542        }
543    }
544}
545
546implement_tuple! {1, T0, }
547implement_tuple! {2, T0, T1, }
548implement_tuple! {3, T0, T1, T2, }
549implement_tuple! {4, T0, T1, T2, T3, }
550implement_tuple! {5, T0, T1, T2, T3, T4, }
551implement_tuple! {6, T0, T1, T2, T3, T4, T5, }
552implement_tuple! {7, T0, T1, T2, T3, T4, T5, T6, }
553implement_tuple! {8, T0, T1, T2, T3, T4, T5, T6, T7, }
554implement_tuple! {9, T0, T1, T2, T3, T4, T5, T6, T7, T8, }
555implement_tuple! {10, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, }
556implement_tuple! {11, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, }
557implement_tuple! {12, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, }