good_ormning/pg/query/
expr.rs

1use {
2    chrono::FixedOffset,
3    quote::{
4        quote,
5        format_ident,
6        ToTokens,
7    },
8    samevariant::samevariant,
9    syn::Path,
10    std::{
11        collections::HashMap,
12        fmt::Display,
13        rc::Rc,
14    },
15    crate::{
16        pg::{
17            types::{
18                Type,
19                SimpleSimpleType,
20                SimpleType,
21                to_rust_types,
22            },
23            query::utils::QueryBody,
24            schema::{
25                field::{
26                    Field,
27                },
28            },
29            QueryResCount,
30        },
31        utils::{
32            Tokens,
33            Errs,
34            sanitize_ident,
35        },
36    },
37    super::{
38        utils::PgQueryCtx,
39        select::Select,
40    },
41};
42#[cfg(feature = "chrono")]
43use chrono::{
44    DateTime,
45    Utc,
46};
47#[cfg(feature = "jiff")]
48use jiff::{
49    Timestamp,
50};
51
52/// This is used for function expressions, to check the argument types and compute
53/// a result type from them.  See readme for details.
54#[derive(Clone)]
55pub struct ComputeType(Rc<dyn Fn(&mut PgQueryCtx, &rpds::Vector<String>, Vec<ExprType>) -> Option<Type>>);
56
57impl ComputeType {
58    pub fn new(
59        f: impl Fn(&mut PgQueryCtx, &rpds::Vector<String>, Vec<ExprType>) -> Option<Type> + 'static,
60    ) -> ComputeType {
61        return ComputeType(Rc::new(f));
62    }
63}
64
65impl std::fmt::Debug for ComputeType {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        return f.write_str("ComputeType");
68    }
69}
70
71#[derive(Clone, Debug)]
72pub enum Expr {
73    LitArray(Vec<Expr>),
74    // A null value needs a type for type checking purposes. It will always be trated
75    // as an optional value.
76    LitNull(SimpleType),
77    LitBool(bool),
78    LitAuto(i64),
79    LitI32(i32),
80    LitI64(i64),
81    LitF32(f32),
82    LitF64(f64),
83    LitString(String),
84    LitBytes(Vec<u8>),
85    #[cfg(feature = "chrono")]
86    LitUtcTimeChrono(DateTime<Utc>),
87    #[cfg(feature = "chrono")]
88    LitFixedOffsetTimeChrono(DateTime<FixedOffset>),
89    #[cfg(feature = "jiff")]
90    LitUtcTimeJiff(Timestamp),
91    /// A query parameter. This will become a parameter to the generated Rust function
92    /// with the specified `name` and `type_`.
93    Param {
94        name: String,
95        type_: Type,
96    },
97    /// This evaluates to the value of a field in the query main or joined tables. If
98    /// you've aliased tables or field names, you'll have to instantiate `FieldId`
99    /// yourself with the appropriate values. For synthetic values like function
100    /// results you may need a `FieldId` with an empty `TableId` (`""`).
101    Field(Field),
102    BinOp {
103        left: Box<Expr>,
104        op: BinOp,
105        right: Box<Expr>,
106    },
107    /// This is the same as `BinOp` but allows chaining multiple expressions with the
108    /// same operator. This can be useful if you have many successive `AND`s or similar.
109    BinOpChain {
110        op: BinOp,
111        exprs: Vec<Expr>,
112    },
113    PrefixOp {
114        op: PrefixOp,
115        right: Box<Expr>,
116    },
117    /// Represents a call to an SQL function, like `collate()`. You must provide a
118    /// helper to check and determine type of the result since we don't have a table of
119    /// functions and their return types at present.
120    Call {
121        func: String,
122        args: Vec<Expr>,
123        compute_type: ComputeType,
124    },
125    /// A sub SELECT query.
126    Select(Box<Select>),
127    /// This is a synthetic expression, saying to treat the result of the expression as
128    /// having the specified type. Use this for casting between primitive types and
129    /// Rust new-types for instance.
130    Cast(Box<Expr>, Type),
131}
132
133#[derive(Clone, Hash, PartialEq, Eq, Debug)]
134pub struct ExprValName {
135    pub table_id: String,
136    pub id: String,
137}
138
139impl ExprValName {
140    pub(crate) fn local(name: String) -> Self {
141        ExprValName {
142            table_id: "".into(),
143            id: name,
144        }
145    }
146
147    pub(crate) fn empty() -> Self {
148        ExprValName {
149            table_id: "".into(),
150            id: "".into(),
151        }
152    }
153
154    pub(crate) fn field(f: &Field) -> Self {
155        ExprValName {
156            table_id: f.table.id.clone(),
157            id: f.id.clone(),
158        }
159    }
160
161    pub(crate) fn with_alias(&self, s: &str) -> ExprValName {
162        ExprValName {
163            table_id: s.into(),
164            id: self.id.clone(),
165        }
166    }
167}
168
169impl Display for ExprValName {
170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        Display::fmt(&format!("{}.{}", self.table_id, self.id), f)
172    }
173}
174
175pub struct ExprType(pub Vec<(ExprValName, Type)>);
176
177impl ExprType {
178    pub fn assert_scalar(&self, errs: &mut Errs, path: &rpds::Vector<String>) -> Option<(ExprValName, Type)> {
179        if self.0.len() != 1 {
180            errs.err(
181                path,
182                format!("Select outputs must be scalars, but got result with more than one field: {}", self.0.len()),
183            );
184            return None;
185        }
186        Some(self.0[0].clone())
187    }
188}
189
190#[derive(Debug)]
191#[samevariant(GeneralTypePairs)]
192pub(crate) enum GeneralType {
193    Bool,
194    Numeric,
195    Blob,
196}
197
198pub(crate) fn general_type(t: &Type) -> GeneralType {
199    match t.type_.type_ {
200        SimpleSimpleType::Auto => GeneralType::Numeric,
201        SimpleSimpleType::I32 => GeneralType::Numeric,
202        SimpleSimpleType::I64 => GeneralType::Numeric,
203        SimpleSimpleType::F32 => GeneralType::Numeric,
204        SimpleSimpleType::F64 => GeneralType::Numeric,
205        SimpleSimpleType::Bool => GeneralType::Bool,
206        SimpleSimpleType::String => GeneralType::Blob,
207        SimpleSimpleType::Bytes => GeneralType::Blob,
208        #[cfg(feature = "chrono")]
209        SimpleSimpleType::UtcTimeChrono => GeneralType::Numeric,
210        #[cfg(feature = "chrono")]
211        SimpleSimpleType::FixedOffsetTimeChrono => GeneralType::Numeric,
212        #[cfg(feature = "jiff")]
213        SimpleSimpleType::UtcTimeJiff => GeneralType::Numeric,
214    }
215}
216
217pub fn check_general_same_type(ctx: &mut PgQueryCtx, path: &rpds::Vector<String>, left: &Type, right: &Type) {
218    if left.opt != right.opt {
219        ctx.errs.err(path, format!("Operator arms have differing optionality"));
220    }
221    match GeneralTypePairs::pairs(&general_type(left), &general_type(right)) {
222        GeneralTypePairs::Nonmatching(left, right) => {
223            ctx.errs.err(path, format!("Operator arms have incompatible types: {:?} and {:?}", left, right));
224        },
225        _ => { },
226    }
227}
228
229pub(crate) fn check_general_same(
230    ctx: &mut PgQueryCtx,
231    path: &rpds::Vector<String>,
232    left: &ExprType,
233    right: &ExprType,
234) {
235    if left.0.len() != right.0.len() {
236        ctx
237            .errs
238            .err(
239                path,
240                format!(
241                    "Operator arms record type lengths don't match: left has {} fields and right has {}",
242                    left.0.len(),
243                    right.0.len()
244                ),
245            );
246    } else if left.0.len() == 1 && right.0.len() == 1 {
247        check_general_same_type(ctx, path, &left.0[0].1, &left.0[0].1);
248    } else {
249        for (i, (left, right)) in left.0.iter().zip(right.0.iter()).enumerate() {
250            check_general_same_type(ctx, &path.push_back(format!("Record pair {}", i)), &left.1, &right.1);
251        }
252    }
253}
254
255pub(crate) fn check_same(
256    errs: &mut Errs,
257    path: &rpds::Vector<String>,
258    left: &ExprType,
259    right: &ExprType,
260) -> Option<Type> {
261    let left = match left.assert_scalar(errs, &path.push_back("Left".into())) {
262        Some(t) => t,
263        None => {
264            return None;
265        },
266    };
267    let right = match right.assert_scalar(errs, &path.push_back("Right".into())) {
268        Some(t) => t,
269        None => {
270            return None;
271        },
272    };
273    if left.1.opt != right.1.opt {
274        errs.err(
275            path,
276            format!(
277                "Expected same types, but left nullability is {} but right nullability is {}",
278                left.1.opt,
279                right.1.opt
280            ),
281        );
282    }
283    if left.1.type_.custom != right.1.type_.custom {
284        errs.err(
285            path,
286            format!(
287                "Expected same types, but left rust type is {:?} while right rust type is {:?}",
288                left.1.type_.custom,
289                right.1.type_.custom
290            ),
291        );
292    }
293    if left.1.type_.type_ != right.1.type_.type_ {
294        errs.err(
295            path,
296            format!(
297                "Expected same types, but left base type is {:?} while right base type is {:?}",
298                left.1.type_.type_,
299                right.1.type_.type_
300            ),
301        );
302    }
303    Some(left.1.clone())
304}
305
306pub(crate) fn check_bool(ctx: &mut PgQueryCtx, path: &rpds::Vector<String>, a: &ExprType) {
307    let t = match a.assert_scalar(&mut ctx.errs, path) {
308        Some(t) => t,
309        None => {
310            return;
311        },
312    };
313    if t.1.opt {
314        ctx.errs.err(path, format!("Expected bool type but is nullable: got {:?}", t));
315    }
316    if !matches!(t.1.type_.type_, SimpleSimpleType::Bool) {
317        ctx.errs.err(path, format!("Expected bool but type is non-bool: got {:?}", t.1.type_.type_));
318    }
319}
320
321pub(crate) fn check_assignable(errs: &mut Errs, path: &rpds::Vector<String>, a: &Type, b: &ExprType) {
322    check_same(errs, path, &ExprType(vec![(ExprValName::empty(), a.clone())]), b);
323}
324
325impl Expr {
326    pub(crate) fn build(
327        &self,
328        ctx: &mut PgQueryCtx,
329        path: &rpds::Vector<String>,
330        scope: &HashMap<ExprValName, Type>,
331    ) -> (ExprType, Tokens) {
332        macro_rules! empty_type{
333            ($o: expr, $t: expr) => {
334                (ExprType(vec![(ExprValName::empty(), Type {
335                    type_: SimpleType {
336                        type_: $t,
337                        custom: None,
338                    },
339                    opt: false,
340                })]), $o)
341            };
342        }
343
344        fn do_bin_op(
345            ctx: &mut PgQueryCtx,
346            path: &rpds::Vector<String>,
347            scope: &HashMap<ExprValName, Type>,
348            op: &BinOp,
349            exprs: &Vec<Expr>,
350        ) -> (ExprType, Tokens) {
351            if exprs.len() < 2 {
352                ctx.errs.err(path, format!("Binary ops must have at least two operands, but got {}", exprs.len()));
353            }
354            let mut res = vec![];
355            for (i, e) in exprs.iter().enumerate() {
356                res.push(e.build(ctx, &path.push_back(format!("Operand {}", i)), scope));
357            }
358            let t = match op {
359                BinOp::Plus | BinOp::Minus | BinOp::Multiply | BinOp::Divide => {
360                    let base = res.get(0).unwrap();
361                    let t =
362                        match check_same(
363                            &mut ctx.errs,
364                            &path.push_back(format!("Operands 0, 1")),
365                            &base.0,
366                            &res.get(0).unwrap().0,
367                        ) {
368                            Some(t) => t,
369                            None => {
370                                return (ExprType(vec![]), Tokens::new());
371                            },
372                        };
373                    for (i, res) in res.iter().enumerate().skip(2) {
374                        match check_same(
375                            &mut ctx.errs,
376                            &path.push_back(format!("Operands 0, {}", i)),
377                            &base.0,
378                            &res.0,
379                        ) {
380                            Some(_) => { },
381                            None => {
382                                return (ExprType(vec![]), Tokens::new());
383                            },
384                        };
385                    }
386                    t
387                },
388                BinOp::And | BinOp::Or => {
389                    for (i, res) in res.iter().enumerate() {
390                        check_bool(ctx, &path.push_back(format!("Operand {}", i)), &res.0);
391                    }
392                    Type {
393                        type_: SimpleType {
394                            type_: SimpleSimpleType::Bool,
395                            custom: None,
396                        },
397                        opt: false,
398                    }
399                },
400                BinOp::Equals |
401                BinOp::NotEquals |
402                BinOp::Is |
403                BinOp::IsNot |
404                BinOp::LessThan |
405                BinOp::LessThanEqualTo |
406                BinOp::GreaterThan |
407                BinOp::GreaterThanEqualTo => {
408                    let base = res.get(0).unwrap();
409                    check_general_same(
410                        ctx,
411                        &path.push_back(format!("Operands 0, 1")),
412                        &base.0,
413                        &res.get(1).unwrap().0,
414                    );
415                    for (i, res) in res.iter().enumerate().skip(2) {
416                        check_general_same(ctx, &path.push_back(format!("Operands 0, {}", i)), &base.0, &res.0);
417                    }
418                    Type {
419                        type_: SimpleType {
420                            type_: SimpleSimpleType::Bool,
421                            custom: None,
422                        },
423                        opt: false,
424                    }
425                },
426            };
427            let token = match op {
428                BinOp::Plus => "+",
429                BinOp::Minus => "-",
430                BinOp::Multiply => "*",
431                BinOp::Divide => "/",
432                BinOp::And => "and",
433                BinOp::Or => "or",
434                BinOp::Equals => "=",
435                BinOp::NotEquals => "!=",
436                BinOp::Is => "is",
437                BinOp::IsNot => "is not",
438                BinOp::LessThan => "<",
439                BinOp::LessThanEqualTo => "<=",
440                BinOp::GreaterThan => ">",
441                BinOp::GreaterThanEqualTo => ">=",
442            };
443            let mut out = Tokens::new();
444            out.s("(");
445            for (i, res) in res.iter().enumerate() {
446                if i > 0 {
447                    out.s(token);
448                }
449                out.s(&res.1.to_string());
450            }
451            out.s(")");
452            (ExprType(vec![(ExprValName::empty(), t)]), out)
453        }
454
455        match self {
456            Expr::LitArray(t) => {
457                let mut out = Tokens::new();
458                let mut child_types = vec![];
459                out.s("(");
460                for (i, child) in t.iter().enumerate() {
461                    if i > 0 {
462                        out.s(", ");
463                    }
464                    let (child_type, child_tokens) = child.build(ctx, path, scope);
465                    out.s(&child_tokens.to_string());
466                    child_types.extend(child_type.0);
467                }
468                out.s(")");
469                return (ExprType(child_types), out);
470            },
471            Expr::LitNull(t) => {
472                let mut out = Tokens::new();
473                out.s("null");
474                return (ExprType(vec![(ExprValName::empty(), Type {
475                    type_: t.clone(),
476                    opt: true,
477                })]), out);
478            },
479            Expr::LitBool(x) => {
480                let mut out = Tokens::new();
481                out.s(if *x {
482                    "true"
483                } else {
484                    "false"
485                });
486                return empty_type!(out, SimpleSimpleType::Bool);
487            },
488            Expr::LitAuto(x) => {
489                let mut out = Tokens::new();
490                out.s(&x.to_string());
491                return empty_type!(out, SimpleSimpleType::Auto);
492            },
493            Expr::LitI32(x) => {
494                let mut out = Tokens::new();
495                out.s(&x.to_string());
496                return empty_type!(out, SimpleSimpleType::I32);
497            },
498            Expr::LitI64(x) => {
499                let mut out = Tokens::new();
500                out.s(&x.to_string());
501                return empty_type!(out, SimpleSimpleType::I64);
502            },
503            Expr::LitF32(x) => {
504                let mut out = Tokens::new();
505                out.s(&x.to_string());
506                return empty_type!(out, SimpleSimpleType::F32);
507            },
508            Expr::LitF64(x) => {
509                let mut out = Tokens::new();
510                out.s(&x.to_string());
511                return empty_type!(out, SimpleSimpleType::F64);
512            },
513            Expr::LitString(x) => {
514                let mut out = Tokens::new();
515                out.s(&format!("'{}'", x.replace("'", "''")));
516                return empty_type!(out, SimpleSimpleType::String);
517            },
518            Expr::LitBytes(x) => {
519                let mut out = Tokens::new();
520                let h = hex::encode(&x);
521                out.s(&format!("x'{}'", h));
522                return empty_type!(out, SimpleSimpleType::Bytes);
523            },
524            #[cfg(feature = "chrono")]
525            Expr::LitUtcTimeChrono(d) => {
526                let mut out = Tokens::new();
527                let d = d.to_rfc3339();
528                out.s(&format!("'{}'", d));
529                return empty_type!(out, SimpleSimpleType::UtcTimeChrono);
530            },
531            #[cfg(feature = "chrono")]
532            Expr::LitFixedOffsetTimeChrono(d) => {
533                let mut out = Tokens::new();
534                let d = d.to_rfc3339();
535                out.s(&format!("'{}'", d));
536                return empty_type!(out, SimpleSimpleType::FixedOffsetTimeChrono);
537            },
538            #[cfg(feature = "jiff")]
539            Expr::LitUtcTimeJiff(d) => {
540                let mut out = Tokens::new();
541                let d = d.to_string();
542                out.s(&format!("'{}'", d));
543                return empty_type!(out, SimpleSimpleType::UtcTimeChrono);
544            },
545            Expr::Param { name: x, type_: t } => {
546                let path = path.push_back(format!("Param ({})", x));
547                let mut out = Tokens::new();
548                let mut errs = vec![];
549                let i = match ctx.rust_arg_lookup.entry(x.clone()) {
550                    std::collections::hash_map::Entry::Occupied(e) => {
551                        let (i, prev_t) = e.get();
552                        if t != prev_t {
553                            errs.push(
554                                format!("Parameter {} specified with multiple types: {:?}, {:?}", x, t, prev_t),
555                            );
556                        }
557                        *i
558                    },
559                    std::collections::hash_map::Entry::Vacant(e) => {
560                        let i = ctx.query_args.len();
561                        e.insert((i, t.clone()));
562                        let rust_types = to_rust_types(&t.type_.type_);
563                        let custom_trait_ident = rust_types.custom_trait;
564                        let rust_type = rust_types.arg_type;
565                        let ident = format_ident!("{}", sanitize_ident(x).1);
566                        let (mut rust_type, mut rust_forward) = if let Some(custom) = &t.type_.custom {
567                            let custom_ident = match syn::parse_str::<Path>(custom.as_str()) {
568                                Ok(p) => p,
569                                Err(e) => {
570                                    ctx.errs.err(&path, format!("Couldn't parse custom type {}: {:?}", custom, e));
571                                    return (ExprType(vec![]), Tokens::new());
572                                },
573                            }.to_token_stream();
574                            let forward =
575                                quote!(< #custom_ident as #custom_trait_ident < #custom_ident >>:: to_sql(& #ident));
576                            (quote!(& #custom_ident), forward)
577                        } else {
578                            (rust_type, quote!(#ident))
579                        };
580                        if t.opt {
581                            rust_type = quote!(Option < #rust_type >);
582                            rust_forward = quote!(#ident.map(| #ident | #rust_forward));
583                        }
584                        ctx.rust_args.push(quote!(#ident: #rust_type));
585                        ctx.query_args.push(quote!(#rust_forward));
586                        i
587                    },
588                };
589                for e in errs {
590                    ctx.errs.err(&path, e);
591                }
592                out.s(&format!("${}", i + 1));
593                return (ExprType(vec![(ExprValName::local(x.clone()), t.clone())]), out);
594            },
595            Expr::Field(x) => {
596                let name = ExprValName::field(x);
597                let t = match scope.get(&name) {
598                    Some(t) => t.clone(),
599                    None => {
600                        ctx
601                            .errs
602                            .err(
603                                path,
604                                format!(
605                                    "Expression references {} but this field isn't available here (available fields: {:?})",
606                                    x,
607                                    scope.iter().map(|e| e.0.to_string()).collect::<Vec<String>>()
608                                ),
609                            );
610                        return (ExprType(vec![]), Tokens::new());
611                    },
612                };
613                let mut out = Tokens::new();
614                out.id(&x.table.id).s(".").id(&x.id);
615                return (ExprType(vec![(name, t.clone())]), out);
616            },
617            Expr::BinOp { left, op, right } => {
618                return do_bin_op(
619                    ctx,
620                    &path.push_back(format!("Bin op {:?}", op)),
621                    scope,
622                    op,
623                    &vec![left.as_ref().clone(), right.as_ref().clone()],
624                );
625            },
626            Expr::BinOpChain { op, exprs } => {
627                return do_bin_op(ctx, &path.push_back(format!("Chain bin op {:?}", op)), scope, op, exprs);
628            },
629            Expr::PrefixOp { op, right } => {
630                let path = path.push_back(format!("Prefix op {:?}", op));
631                let mut out = Tokens::new();
632                let res = right.build(ctx, &path, scope);
633                let (op_text, op_type) = match op {
634                    PrefixOp::Not => {
635                        check_bool(ctx, &path, &res.0);
636                        ("not", SimpleSimpleType::Bool)
637                    },
638                };
639                out.s(op_text).s(&res.1.to_string());
640                return empty_type!(out, op_type);
641            },
642            Expr::Call { func, args, compute_type } => {
643                let mut types = vec![];
644                let mut out = Tokens::new();
645                out.s(func);
646                out.s("(");
647                for (i, arg) in args.iter().enumerate() {
648                    if i > 0 {
649                        out.s(",");
650                    }
651                    let (arg_type, tokens) =
652                        arg.build(ctx, &path.push_back(format!("Call [{}] arg {}", func, i)), scope);
653                    types.push(arg_type);
654                    out.s(&tokens.to_string());
655                }
656                out.s(")");
657                let type_ = match (compute_type.0)(ctx, path, types) {
658                    Some(t) => t,
659                    None => {
660                        return (ExprType(vec![]), Tokens::new());
661                    },
662                };
663                return (ExprType(vec![(ExprValName::empty(), type_)]), out);
664            },
665            Expr::Select(s) => {
666                let path = path.push_back(format!("Subselect"));
667                return s.build(ctx, &path, QueryResCount::Many);
668            },
669            Expr::Cast(e, t) => {
670                let path = path.push_back(format!("Cast"));
671                let out = e.build(ctx, &path, scope);
672                let got_t = match out.0.assert_scalar(&mut ctx.errs, &path) {
673                    Some(t) => t,
674                    None => {
675                        return (ExprType(vec![]), Tokens::new());
676                    },
677                };
678                check_general_same_type(ctx, &path, t, &got_t.1);
679                return (ExprType(vec![(got_t.0, t.clone())]), out.1);
680            },
681        };
682    }
683}
684
685#[derive(Clone, Debug)]
686pub enum BinOp {
687    Plus,
688    Minus,
689    Multiply,
690    Divide,
691    And,
692    Or,
693    Equals,
694    NotEquals,
695    Is,
696    IsNot,
697    LessThan,
698    LessThanEqualTo,
699    GreaterThan,
700    GreaterThanEqualTo,
701}
702
703#[derive(Clone, Debug)]
704pub enum PrefixOp {
705    Not,
706}