1use std::convert::{TryFrom, TryInto};
6use std::ops::Deref;
7use std::sync::Arc;
8
9use bytes::{BufMut, BytesMut};
10use snafu::OptionExt;
11use uuid::Uuid;
12
13use gel_errors::ParameterTypeMismatchError;
14use gel_errors::{ClientEncodingError, DescriptorMismatch, ProtocolError};
15use gel_errors::{Error, ErrorKind, InvalidReferenceError};
16
17use crate::codec::{self, build_codec, Codec};
18use crate::descriptors::TypePos;
19use crate::descriptors::{Descriptor, EnumerationTypeDescriptor};
20use crate::errors;
21use crate::features::ProtocolVersion;
22use crate::model::range;
23use crate::value::Value;
24
25pub struct Encoder<'a> {
26 pub ctx: &'a DescriptorContext<'a>,
27 pub buf: &'a mut BytesMut,
28}
29
30pub trait QueryArg: Send + Sync + Sized {
32 fn encode_slot(&self, encoder: &mut Encoder) -> Result<(), Error>;
33 fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error>;
34 fn to_value(&self) -> Result<Value, Error>;
35}
36
37pub trait ScalarArg: Send + Sync + Sized {
38 fn encode(&self, encoder: &mut Encoder) -> Result<(), Error>;
39 fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error>;
40 fn to_value(&self) -> Result<Value, Error>;
41}
42
43pub trait QueryArgs: Send + Sync {
49 fn encode(&self, encoder: &mut Encoder) -> Result<(), Error>;
50}
51
52pub struct DescriptorContext<'a> {
53 #[allow(dead_code)]
54 pub(crate) proto: &'a ProtocolVersion,
55 pub(crate) root_pos: Option<TypePos>,
56 pub(crate) descriptors: &'a [Descriptor],
57}
58
59impl<'a> Encoder<'a> {
60 pub fn new(ctx: &'a DescriptorContext<'a>, buf: &'a mut BytesMut) -> Encoder<'a> {
61 Encoder { ctx, buf }
62 }
63 pub fn length_prefixed(
64 &mut self,
65 f: impl FnOnce(&mut Encoder) -> Result<(), Error>,
66 ) -> Result<(), Error> {
67 self.buf.reserve(4);
68 let pos = self.buf.len();
69 self.buf.put_u32(0); f(self)?;
72
73 let len = self.buf.len() - pos - 4;
74 self.buf[pos..pos + 4].copy_from_slice(
75 &u32::try_from(len)
76 .map_err(|_| ClientEncodingError::with_message("alias is too long"))?
77 .to_be_bytes(),
78 );
79
80 Ok(())
81 }
82}
83
84impl DescriptorContext<'_> {
85 pub fn get(&self, type_pos: TypePos) -> Result<&Descriptor, Error> {
86 self.descriptors
87 .get(type_pos.0 as usize)
88 .ok_or_else(|| ProtocolError::with_message("invalid type descriptor"))
89 }
90 pub fn build_codec(&self) -> Result<Arc<dyn Codec>, Error> {
91 build_codec(self.root_pos, self.descriptors)
92 .map_err(|e| ProtocolError::with_source(e).context("error decoding input codec"))
93 }
94 pub fn wrong_type(&self, descriptor: &Descriptor, expected: &str) -> Error {
95 DescriptorMismatch::with_message(format!(
96 "server returned unexpected type {descriptor:?} when client expected {expected}"
97 ))
98 }
99 pub fn field_number(&self, expected: usize, unexpected: usize) -> Error {
100 DescriptorMismatch::with_message(format!("expected {expected} fields, got {unexpected}"))
101 }
102}
103
104impl<T: ScalarArg> ScalarArg for &T {
105 fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> {
106 (*self).encode(encoder)
107 }
108
109 fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
110 T::check_descriptor(ctx, pos)
111 }
112
113 fn to_value(&self) -> Result<Value, Error> {
114 (*self).to_value()
115 }
116}
117
118impl QueryArgs for () {
119 fn encode(&self, enc: &mut Encoder) -> Result<(), Error> {
120 if enc.ctx.root_pos.is_some() {
121 if enc.ctx.proto.is_at_most(0, 11) {
122 let root = enc.ctx.root_pos.and_then(|p| enc.ctx.get(p).ok());
123 match root {
124 Some(Descriptor::Tuple(t))
125 if t.id == Uuid::from_u128(0xFF) && t.element_types.is_empty() => {}
126 _ => {
127 return Err(ParameterTypeMismatchError::with_message(
128 "query arguments expected",
129 ))
130 }
131 };
132 } else {
133 return Err(ParameterTypeMismatchError::with_message(
134 "query arguments expected",
135 ));
136 }
137 }
138 if enc.ctx.proto.is_at_most(0, 11) {
139 enc.buf.reserve(4);
140 enc.buf.put_u32(0);
141 }
142 Ok(())
143 }
144}
145
146impl QueryArg for Value {
147 fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> {
148 use Value::*;
149 match self {
150 Nothing => {
151 enc.buf.reserve(4);
152 enc.buf.put_i32(-1);
153 }
154 Uuid(v) => v.encode_slot(enc)?,
155 Str(v) => v.encode_slot(enc)?,
156 Bytes(v) => v.encode_slot(enc)?,
157 Int16(v) => v.encode_slot(enc)?,
158 Int32(v) => v.encode_slot(enc)?,
159 Int64(v) => v.encode_slot(enc)?,
160 Float32(v) => v.encode_slot(enc)?,
161 Float64(v) => v.encode_slot(enc)?,
162 BigInt(v) => v.encode_slot(enc)?,
163 ConfigMemory(v) => v.encode_slot(enc)?,
164 Decimal(v) => v.encode_slot(enc)?,
165 Bool(v) => v.encode_slot(enc)?,
166 Datetime(v) => v.encode_slot(enc)?,
167 LocalDatetime(v) => v.encode_slot(enc)?,
168 LocalDate(v) => v.encode_slot(enc)?,
169 LocalTime(v) => v.encode_slot(enc)?,
170 Duration(v) => v.encode_slot(enc)?,
171 RelativeDuration(v) => v.encode_slot(enc)?,
172 DateDuration(v) => v.encode_slot(enc)?,
173 Json(v) => v.encode_slot(enc)?,
174 Set(_) => {
175 return Err(ClientEncodingError::with_message(
176 "set cannot be query argument",
177 ))
178 }
179 Object { .. } => {
180 return Err(ClientEncodingError::with_message(
181 "object cannot be query argument",
182 ))
183 }
184 SparseObject(_) => {
185 return Err(ClientEncodingError::with_message(
186 "sparse object cannot be query argument",
187 ))
188 }
189 Tuple(_) => {
190 return Err(ClientEncodingError::with_message(
191 "tuple object cannot be query argument",
192 ))
193 }
194 NamedTuple { .. } => {
195 return Err(ClientEncodingError::with_message(
196 "named tuple object cannot be query argument",
197 ))
198 }
199 Array(_) => {
200 return Err(ClientEncodingError::with_message(
201 "array cannot be query argument",
202 ))
203 }
204 Enum(v) => v.encode_slot(enc)?,
205 Range(v) => v.encode_slot(enc)?,
206 Vector(v) => crate::model::VectorRef(v).encode_slot(enc)?,
207 PostGisGeometry(v) => v.encode_slot(enc)?,
208 PostGisGeography(v) => v.encode_slot(enc)?,
209 PostGisBox2d(v) => v.encode_slot(enc)?,
210 PostGisBox3d(v) => v.encode_slot(enc)?,
211 SQLRow { .. } => {
212 return Err(ClientEncodingError::with_message(
213 "SQL row cannot be query argument",
214 ))
215 }
216 }
217
218 Ok(())
219 }
220 fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
221 use Descriptor::*;
222 use Value::*;
223 let desc = ctx.get(pos)?.normalize_to_base(ctx)?;
224
225 match (self, desc) {
226 (Nothing, _) => Ok(()), (BigInt(_), BaseScalar(d)) if d.id == codec::STD_BIGINT => Ok(()),
228 (Bool(_), BaseScalar(d)) if d.id == codec::STD_BOOL => Ok(()),
229 (Bytes(_), BaseScalar(d)) if d.id == codec::STD_BYTES => Ok(()),
230 (ConfigMemory(_), BaseScalar(d)) if d.id == codec::CFG_MEMORY => Ok(()),
231 (DateDuration(_), BaseScalar(d)) if d.id == codec::CAL_DATE_DURATION => Ok(()),
232 (Datetime(_), BaseScalar(d)) if d.id == codec::STD_DATETIME => Ok(()),
233 (Decimal(_), BaseScalar(d)) if d.id == codec::STD_DECIMAL => Ok(()),
234 (Duration(_), BaseScalar(d)) if d.id == codec::STD_DURATION => Ok(()),
235 (Float32(_), BaseScalar(d)) if d.id == codec::STD_FLOAT32 => Ok(()),
236 (Float64(_), BaseScalar(d)) if d.id == codec::STD_FLOAT64 => Ok(()),
237 (Int16(_), BaseScalar(d)) if d.id == codec::STD_INT16 => Ok(()),
238 (Int32(_), BaseScalar(d)) if d.id == codec::STD_INT32 => Ok(()),
239 (Int64(_), BaseScalar(d)) if d.id == codec::STD_INT64 => Ok(()),
240 (Json(_), BaseScalar(d)) if d.id == codec::STD_JSON => Ok(()),
241 (LocalDate(_), BaseScalar(d)) if d.id == codec::CAL_LOCAL_DATE => Ok(()),
242 (LocalDatetime(_), BaseScalar(d)) if d.id == codec::CAL_LOCAL_DATETIME => Ok(()),
243 (LocalTime(_), BaseScalar(d)) if d.id == codec::CAL_LOCAL_TIME => Ok(()),
244 (RelativeDuration(_), BaseScalar(d)) if d.id == codec::CAL_RELATIVE_DURATION => Ok(()),
245 (Str(_), BaseScalar(d)) if d.id == codec::STD_STR => Ok(()),
246 (Uuid(_), BaseScalar(d)) if d.id == codec::STD_UUID => Ok(()),
247 (Enum(val), Enumeration(EnumerationTypeDescriptor { members, .. })) => {
248 let val = val.deref();
249 check_enum(val, &members)
250 }
251 (Vector(_), BaseScalar(d)) if d.id == codec::PGVECTOR_VECTOR => Ok(()),
252 (PostGisGeometry(_), BaseScalar(d)) if d.id == codec::POSTGIS_GEOMETRY => Ok(()),
253 (PostGisGeography(_), BaseScalar(d)) if d.id == codec::POSTGIS_GEOGRAPHY => Ok(()),
254 (PostGisBox2d(_), BaseScalar(d)) if d.id == codec::POSTGIS_BOX_2D => Ok(()),
255 (PostGisBox3d(_), BaseScalar(d)) if d.id == codec::POSTGIS_BOX_3D => Ok(()),
256 (_, desc) => Err(ctx.wrong_type(&desc, self.kind())),
258 }
259 }
260 fn to_value(&self) -> Result<Value, Error> {
261 Ok(self.clone())
262 }
263}
264
265pub(crate) fn check_enum(variant_name: &str, expected_members: &[String]) -> Result<(), Error> {
266 if expected_members.iter().any(|c| c == variant_name) {
267 Ok(())
268 } else {
269 let mut members = expected_members
270 .iter()
271 .map(|c| format!("'{c}'"))
272 .collect::<Vec<_>>();
273 members.sort_unstable();
274 let members = members.join(", ");
275 Err(InvalidReferenceError::with_message(format!(
276 "Expected one of: {members}, while enum value '{variant_name}' was provided"
277 )))
278 }
279}
280
281impl QueryArgs for Value {
282 fn encode(&self, enc: &mut Encoder) -> Result<(), Error> {
283 let codec = enc.ctx.build_codec()?;
284 codec
285 .encode(enc.buf, self)
286 .map_err(ClientEncodingError::with_source)
287 }
288}
289
290impl<T: ScalarArg> QueryArg for T {
291 fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> {
292 enc.buf.reserve(4);
293 let pos = enc.buf.len();
294 enc.buf.put_u32(0); ScalarArg::encode(self, enc)?;
296 let len = enc.buf.len() - pos - 4;
297 enc.buf[pos..pos + 4].copy_from_slice(
298 &i32::try_from(len)
299 .ok()
300 .context(errors::ElementTooLong)
301 .map_err(ClientEncodingError::with_source)?
302 .to_be_bytes(),
303 );
304 Ok(())
305 }
306 fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
307 T::check_descriptor(ctx, pos)
308 }
309 fn to_value(&self) -> Result<Value, Error> {
310 ScalarArg::to_value(self)
311 }
312}
313
314impl<T: ScalarArg> QueryArg for Option<T> {
315 fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> {
316 if let Some(val) = self {
317 QueryArg::encode_slot(val, enc)
318 } else {
319 enc.buf.reserve(4);
320 enc.buf.put_i32(-1);
321 Ok(())
322 }
323 }
324 fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
325 T::check_descriptor(ctx, pos)
326 }
327 fn to_value(&self) -> Result<Value, Error> {
328 match self.as_ref() {
329 Some(v) => ScalarArg::to_value(v),
330 None => Ok(Value::Nothing),
331 }
332 }
333}
334
335impl<T: ScalarArg> QueryArg for Vec<T> {
336 fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> {
337 enc.buf.reserve(8);
338 enc.length_prefixed(|enc| {
339 if self.is_empty() {
340 enc.buf.reserve(12);
341 enc.buf.put_u32(0); enc.buf.put_u32(0); enc.buf.put_u32(0); return Ok(());
345 }
346 enc.buf.reserve(20);
347 enc.buf.put_u32(1); enc.buf.put_u32(0); enc.buf.put_u32(0); enc.buf.put_u32(
351 self.len()
352 .try_into()
353 .map_err(|_| ClientEncodingError::with_message("array is too long"))?,
354 );
355 enc.buf.put_u32(1); for item in self {
357 enc.length_prefixed(|enc| item.encode(enc))?;
358 }
359 Ok(())
360 })
361 }
362 fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
363 let desc = ctx.get(pos)?;
364 if let Descriptor::Array(arr) = desc {
365 T::check_descriptor(ctx, arr.type_pos)
366 } else {
367 Err(ctx.wrong_type(desc, "array"))
368 }
369 }
370 fn to_value(&self) -> Result<Value, Error> {
371 Ok(Value::Array(
372 self.iter()
373 .map(|v| v.to_value())
374 .collect::<Result<_, _>>()?,
375 ))
376 }
377}
378
379impl QueryArg for Vec<Value> {
380 fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> {
381 enc.buf.reserve(8);
382 enc.length_prefixed(|enc| {
383 if self.is_empty() {
384 enc.buf.reserve(12);
385 enc.buf.put_u32(0); enc.buf.put_u32(0); enc.buf.put_u32(0); return Ok(());
389 }
390 enc.buf.reserve(20);
391 enc.buf.put_u32(1); enc.buf.put_u32(0); enc.buf.put_u32(0); enc.buf.put_u32(
395 self.len()
396 .try_into()
397 .map_err(|_| ClientEncodingError::with_message("array is too long"))?,
398 );
399 enc.buf.put_u32(1); for item in self {
401 enc.length_prefixed(|enc| item.encode(enc))?;
402 }
403 Ok(())
404 })
405 }
406 fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
407 let desc = ctx.get(pos)?;
408 if let Descriptor::Array(arr) = desc {
409 for val in self {
410 val.check_descriptor(ctx, arr.type_pos)?
411 }
412 Ok(())
413 } else {
414 Err(ctx.wrong_type(desc, "array"))
415 }
416 }
417 fn to_value(&self) -> Result<Value, Error> {
418 Ok(Value::Array(
419 self.iter()
420 .map(|v| v.to_value())
421 .collect::<Result<_, _>>()?,
422 ))
423 }
424}
425
426impl QueryArg for range::Range<Box<Value>> {
427 fn encode_slot(&self, encoder: &mut Encoder) -> Result<(), Error> {
428 encoder.length_prefixed(|encoder| {
429 let flags = if self.empty {
430 range::EMPTY
431 } else {
432 (if self.inc_lower { range::LB_INC } else { 0 })
433 | (if self.inc_upper { range::UB_INC } else { 0 })
434 | (if self.lower.is_none() {
435 range::LB_INF
436 } else {
437 0
438 })
439 | (if self.upper.is_none() {
440 range::UB_INF
441 } else {
442 0
443 })
444 };
445 encoder.buf.reserve(1);
446 encoder.buf.put_u8(flags as u8);
447
448 if let Some(lower) = &self.lower {
449 encoder.length_prefixed(|encoder| lower.encode(encoder))?
450 }
451
452 if let Some(upper) = &self.upper {
453 encoder.length_prefixed(|encoder| upper.encode(encoder))?;
454 }
455 Ok(())
456 })
457 }
458 fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> {
459 let desc = ctx.get(pos)?;
460 if let Descriptor::Range(rng) = desc {
461 self.lower
462 .as_ref()
463 .map(|v| v.check_descriptor(ctx, rng.type_pos))
464 .transpose()?;
465 self.upper
466 .as_ref()
467 .map(|v| v.check_descriptor(ctx, rng.type_pos))
468 .transpose()?;
469 Ok(())
470 } else {
471 Err(ctx.wrong_type(desc, "range"))
472 }
473 }
474 fn to_value(&self) -> Result<Value, Error> {
475 Ok(Value::Range(self.clone()))
476 }
477}
478
479macro_rules! implement_tuple {
480 ( $count:expr, $($name:ident,)+ ) => {
481 impl<$($name:QueryArg),+> QueryArgs for ($($name,)+) {
482 fn encode(&self, enc: &mut Encoder)
483 -> Result<(), Error>
484 {
485 #![allow(non_snake_case)]
486 let root_pos = enc.ctx.root_pos
487 .ok_or_else(|| DescriptorMismatch::with_message(
488 format!(
489 "provided {} positional arguments, \
490 but no arguments expected by the server",
491 $count)))?;
492 let desc = enc.ctx.get(root_pos)?;
493 match desc {
494 Descriptor::ObjectShape(desc)
495 if enc.ctx.proto.is_at_least(0, 12)
496 => {
497 if desc.elements.len() != $count {
498 return Err(enc.ctx.field_number(
499 desc.elements.len(), $count));
500 }
501 let mut els = desc.elements.iter().enumerate();
502 let ($(ref $name,)+) = self;
503 $(
504 let (idx, el) = els.next().unwrap();
505 if el.name.parse() != Ok(idx) {
506 return Err(DescriptorMismatch::with_message(
507 format!("expected positional arguments, \
508 got {} instead of {}",
509 el.name, idx)));
510 }
511 $name.check_descriptor(enc.ctx, el.type_pos)?;
512 )+
513 }
514 Descriptor::Tuple(desc) if enc.ctx.proto.is_at_most(0, 11)
515 => {
516 if desc.element_types.len() != $count {
517 return Err(enc.ctx.field_number(
518 desc.element_types.len(), $count));
519 }
520 let mut els = desc.element_types.iter();
521 let ($(ref $name,)+) = self;
522 $(
523 let type_pos = els.next().unwrap();
524 $name.check_descriptor(enc.ctx, *type_pos)?;
525 )+
526 }
527 _ => return Err(enc.ctx.wrong_type(desc,
528 if enc.ctx.proto.is_at_least(0, 12) { "object" }
529 else { "tuple" }))
530 }
531
532 enc.buf.reserve(4 + 8*$count);
533 enc.buf.put_u32($count);
534 let ($(ref $name,)+) = self;
535 $(
536 enc.buf.reserve(8);
537 enc.buf.put_u32(0);
538 QueryArg::encode_slot($name, enc)?;
539 )*
540 Ok(())
541 }
542 }
543 }
544}
545
546implement_tuple! {1, T0, }
547implement_tuple! {2, T0, T1, }
548implement_tuple! {3, T0, T1, T2, }
549implement_tuple! {4, T0, T1, T2, T3, }
550implement_tuple! {5, T0, T1, T2, T3, T4, }
551implement_tuple! {6, T0, T1, T2, T3, T4, T5, }
552implement_tuple! {7, T0, T1, T2, T3, T4, T5, T6, }
553implement_tuple! {8, T0, T1, T2, T3, T4, T5, T6, T7, }
554implement_tuple! {9, T0, T1, T2, T3, T4, T5, T6, T7, T8, }
555implement_tuple! {10, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, }
556implement_tuple! {11, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, }
557implement_tuple! {12, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, }