1use {
2 chrono::FixedOffset,
3 quote::{
4 quote,
5 format_ident,
6 ToTokens,
7 },
8 samevariant::samevariant,
9 syn::Path,
10 std::{
11 collections::HashMap,
12 fmt::Display,
13 rc::Rc,
14 },
15 crate::{
16 pg::{
17 types::{
18 Type,
19 SimpleSimpleType,
20 SimpleType,
21 to_rust_types,
22 },
23 query::utils::QueryBody,
24 schema::{
25 field::{
26 Field,
27 },
28 },
29 QueryResCount,
30 },
31 utils::{
32 Tokens,
33 Errs,
34 sanitize_ident,
35 },
36 },
37 super::{
38 utils::PgQueryCtx,
39 select::Select,
40 },
41};
42#[cfg(feature = "chrono")]
43use chrono::{
44 DateTime,
45 Utc,
46};
47#[cfg(feature = "jiff")]
48use jiff::{
49 Timestamp,
50};
51
52#[derive(Clone)]
55pub struct ComputeType(Rc<dyn Fn(&mut PgQueryCtx, &rpds::Vector<String>, Vec<ExprType>) -> Option<Type>>);
56
57impl ComputeType {
58 pub fn new(
59 f: impl Fn(&mut PgQueryCtx, &rpds::Vector<String>, Vec<ExprType>) -> Option<Type> + 'static,
60 ) -> ComputeType {
61 return ComputeType(Rc::new(f));
62 }
63}
64
65impl std::fmt::Debug for ComputeType {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 return f.write_str("ComputeType");
68 }
69}
70
71#[derive(Clone, Debug)]
72pub enum Expr {
73 LitArray(Vec<Expr>),
74 LitNull(SimpleType),
77 LitBool(bool),
78 LitAuto(i64),
79 LitI32(i32),
80 LitI64(i64),
81 LitF32(f32),
82 LitF64(f64),
83 LitString(String),
84 LitBytes(Vec<u8>),
85 #[cfg(feature = "chrono")]
86 LitUtcTimeChrono(DateTime<Utc>),
87 #[cfg(feature = "chrono")]
88 LitFixedOffsetTimeChrono(DateTime<FixedOffset>),
89 #[cfg(feature = "jiff")]
90 LitUtcTimeJiff(Timestamp),
91 Param {
94 name: String,
95 type_: Type,
96 },
97 Field(Field),
102 BinOp {
103 left: Box<Expr>,
104 op: BinOp,
105 right: Box<Expr>,
106 },
107 BinOpChain {
110 op: BinOp,
111 exprs: Vec<Expr>,
112 },
113 PrefixOp {
114 op: PrefixOp,
115 right: Box<Expr>,
116 },
117 Call {
121 func: String,
122 args: Vec<Expr>,
123 compute_type: ComputeType,
124 },
125 Select(Box<Select>),
127 Cast(Box<Expr>, Type),
131}
132
133#[derive(Clone, Hash, PartialEq, Eq, Debug)]
134pub struct ExprValName {
135 pub table_id: String,
136 pub id: String,
137}
138
139impl ExprValName {
140 pub(crate) fn local(name: String) -> Self {
141 ExprValName {
142 table_id: "".into(),
143 id: name,
144 }
145 }
146
147 pub(crate) fn empty() -> Self {
148 ExprValName {
149 table_id: "".into(),
150 id: "".into(),
151 }
152 }
153
154 pub(crate) fn field(f: &Field) -> Self {
155 ExprValName {
156 table_id: f.table.id.clone(),
157 id: f.id.clone(),
158 }
159 }
160
161 pub(crate) fn with_alias(&self, s: &str) -> ExprValName {
162 ExprValName {
163 table_id: s.into(),
164 id: self.id.clone(),
165 }
166 }
167}
168
169impl Display for ExprValName {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 Display::fmt(&format!("{}.{}", self.table_id, self.id), f)
172 }
173}
174
175pub struct ExprType(pub Vec<(ExprValName, Type)>);
176
177impl ExprType {
178 pub fn assert_scalar(&self, errs: &mut Errs, path: &rpds::Vector<String>) -> Option<(ExprValName, Type)> {
179 if self.0.len() != 1 {
180 errs.err(
181 path,
182 format!("Select outputs must be scalars, but got result with more than one field: {}", self.0.len()),
183 );
184 return None;
185 }
186 Some(self.0[0].clone())
187 }
188}
189
190#[derive(Debug)]
191#[samevariant(GeneralTypePairs)]
192pub(crate) enum GeneralType {
193 Bool,
194 Numeric,
195 Blob,
196}
197
198pub(crate) fn general_type(t: &Type) -> GeneralType {
199 match t.type_.type_ {
200 SimpleSimpleType::Auto => GeneralType::Numeric,
201 SimpleSimpleType::I32 => GeneralType::Numeric,
202 SimpleSimpleType::I64 => GeneralType::Numeric,
203 SimpleSimpleType::F32 => GeneralType::Numeric,
204 SimpleSimpleType::F64 => GeneralType::Numeric,
205 SimpleSimpleType::Bool => GeneralType::Bool,
206 SimpleSimpleType::String => GeneralType::Blob,
207 SimpleSimpleType::Bytes => GeneralType::Blob,
208 #[cfg(feature = "chrono")]
209 SimpleSimpleType::UtcTimeChrono => GeneralType::Numeric,
210 #[cfg(feature = "chrono")]
211 SimpleSimpleType::FixedOffsetTimeChrono => GeneralType::Numeric,
212 #[cfg(feature = "jiff")]
213 SimpleSimpleType::UtcTimeJiff => GeneralType::Numeric,
214 }
215}
216
217pub fn check_general_same_type(ctx: &mut PgQueryCtx, path: &rpds::Vector<String>, left: &Type, right: &Type) {
218 if left.opt != right.opt {
219 ctx.errs.err(path, format!("Operator arms have differing optionality"));
220 }
221 match GeneralTypePairs::pairs(&general_type(left), &general_type(right)) {
222 GeneralTypePairs::Nonmatching(left, right) => {
223 ctx.errs.err(path, format!("Operator arms have incompatible types: {:?} and {:?}", left, right));
224 },
225 _ => { },
226 }
227}
228
229pub(crate) fn check_general_same(
230 ctx: &mut PgQueryCtx,
231 path: &rpds::Vector<String>,
232 left: &ExprType,
233 right: &ExprType,
234) {
235 if left.0.len() != right.0.len() {
236 ctx
237 .errs
238 .err(
239 path,
240 format!(
241 "Operator arms record type lengths don't match: left has {} fields and right has {}",
242 left.0.len(),
243 right.0.len()
244 ),
245 );
246 } else if left.0.len() == 1 && right.0.len() == 1 {
247 check_general_same_type(ctx, path, &left.0[0].1, &left.0[0].1);
248 } else {
249 for (i, (left, right)) in left.0.iter().zip(right.0.iter()).enumerate() {
250 check_general_same_type(ctx, &path.push_back(format!("Record pair {}", i)), &left.1, &right.1);
251 }
252 }
253}
254
255pub(crate) fn check_same(
256 errs: &mut Errs,
257 path: &rpds::Vector<String>,
258 left: &ExprType,
259 right: &ExprType,
260) -> Option<Type> {
261 let left = match left.assert_scalar(errs, &path.push_back("Left".into())) {
262 Some(t) => t,
263 None => {
264 return None;
265 },
266 };
267 let right = match right.assert_scalar(errs, &path.push_back("Right".into())) {
268 Some(t) => t,
269 None => {
270 return None;
271 },
272 };
273 if left.1.opt != right.1.opt {
274 errs.err(
275 path,
276 format!(
277 "Expected same types, but left nullability is {} but right nullability is {}",
278 left.1.opt,
279 right.1.opt
280 ),
281 );
282 }
283 if left.1.type_.custom != right.1.type_.custom {
284 errs.err(
285 path,
286 format!(
287 "Expected same types, but left rust type is {:?} while right rust type is {:?}",
288 left.1.type_.custom,
289 right.1.type_.custom
290 ),
291 );
292 }
293 if left.1.type_.type_ != right.1.type_.type_ {
294 errs.err(
295 path,
296 format!(
297 "Expected same types, but left base type is {:?} while right base type is {:?}",
298 left.1.type_.type_,
299 right.1.type_.type_
300 ),
301 );
302 }
303 Some(left.1.clone())
304}
305
306pub(crate) fn check_bool(ctx: &mut PgQueryCtx, path: &rpds::Vector<String>, a: &ExprType) {
307 let t = match a.assert_scalar(&mut ctx.errs, path) {
308 Some(t) => t,
309 None => {
310 return;
311 },
312 };
313 if t.1.opt {
314 ctx.errs.err(path, format!("Expected bool type but is nullable: got {:?}", t));
315 }
316 if !matches!(t.1.type_.type_, SimpleSimpleType::Bool) {
317 ctx.errs.err(path, format!("Expected bool but type is non-bool: got {:?}", t.1.type_.type_));
318 }
319}
320
321pub(crate) fn check_assignable(errs: &mut Errs, path: &rpds::Vector<String>, a: &Type, b: &ExprType) {
322 check_same(errs, path, &ExprType(vec![(ExprValName::empty(), a.clone())]), b);
323}
324
325impl Expr {
326 pub(crate) fn build(
327 &self,
328 ctx: &mut PgQueryCtx,
329 path: &rpds::Vector<String>,
330 scope: &HashMap<ExprValName, Type>,
331 ) -> (ExprType, Tokens) {
332 macro_rules! empty_type{
333 ($o: expr, $t: expr) => {
334 (ExprType(vec![(ExprValName::empty(), Type {
335 type_: SimpleType {
336 type_: $t,
337 custom: None,
338 },
339 opt: false,
340 })]), $o)
341 };
342 }
343
344 fn do_bin_op(
345 ctx: &mut PgQueryCtx,
346 path: &rpds::Vector<String>,
347 scope: &HashMap<ExprValName, Type>,
348 op: &BinOp,
349 exprs: &Vec<Expr>,
350 ) -> (ExprType, Tokens) {
351 if exprs.len() < 2 {
352 ctx.errs.err(path, format!("Binary ops must have at least two operands, but got {}", exprs.len()));
353 }
354 let mut res = vec![];
355 for (i, e) in exprs.iter().enumerate() {
356 res.push(e.build(ctx, &path.push_back(format!("Operand {}", i)), scope));
357 }
358 let t = match op {
359 BinOp::Plus | BinOp::Minus | BinOp::Multiply | BinOp::Divide => {
360 let base = res.get(0).unwrap();
361 let t =
362 match check_same(
363 &mut ctx.errs,
364 &path.push_back(format!("Operands 0, 1")),
365 &base.0,
366 &res.get(0).unwrap().0,
367 ) {
368 Some(t) => t,
369 None => {
370 return (ExprType(vec![]), Tokens::new());
371 },
372 };
373 for (i, res) in res.iter().enumerate().skip(2) {
374 match check_same(
375 &mut ctx.errs,
376 &path.push_back(format!("Operands 0, {}", i)),
377 &base.0,
378 &res.0,
379 ) {
380 Some(_) => { },
381 None => {
382 return (ExprType(vec![]), Tokens::new());
383 },
384 };
385 }
386 t
387 },
388 BinOp::And | BinOp::Or => {
389 for (i, res) in res.iter().enumerate() {
390 check_bool(ctx, &path.push_back(format!("Operand {}", i)), &res.0);
391 }
392 Type {
393 type_: SimpleType {
394 type_: SimpleSimpleType::Bool,
395 custom: None,
396 },
397 opt: false,
398 }
399 },
400 BinOp::Equals |
401 BinOp::NotEquals |
402 BinOp::Is |
403 BinOp::IsNot |
404 BinOp::LessThan |
405 BinOp::LessThanEqualTo |
406 BinOp::GreaterThan |
407 BinOp::GreaterThanEqualTo => {
408 let base = res.get(0).unwrap();
409 check_general_same(
410 ctx,
411 &path.push_back(format!("Operands 0, 1")),
412 &base.0,
413 &res.get(1).unwrap().0,
414 );
415 for (i, res) in res.iter().enumerate().skip(2) {
416 check_general_same(ctx, &path.push_back(format!("Operands 0, {}", i)), &base.0, &res.0);
417 }
418 Type {
419 type_: SimpleType {
420 type_: SimpleSimpleType::Bool,
421 custom: None,
422 },
423 opt: false,
424 }
425 },
426 };
427 let token = match op {
428 BinOp::Plus => "+",
429 BinOp::Minus => "-",
430 BinOp::Multiply => "*",
431 BinOp::Divide => "/",
432 BinOp::And => "and",
433 BinOp::Or => "or",
434 BinOp::Equals => "=",
435 BinOp::NotEquals => "!=",
436 BinOp::Is => "is",
437 BinOp::IsNot => "is not",
438 BinOp::LessThan => "<",
439 BinOp::LessThanEqualTo => "<=",
440 BinOp::GreaterThan => ">",
441 BinOp::GreaterThanEqualTo => ">=",
442 };
443 let mut out = Tokens::new();
444 out.s("(");
445 for (i, res) in res.iter().enumerate() {
446 if i > 0 {
447 out.s(token);
448 }
449 out.s(&res.1.to_string());
450 }
451 out.s(")");
452 (ExprType(vec![(ExprValName::empty(), t)]), out)
453 }
454
455 match self {
456 Expr::LitArray(t) => {
457 let mut out = Tokens::new();
458 let mut child_types = vec![];
459 out.s("(");
460 for (i, child) in t.iter().enumerate() {
461 if i > 0 {
462 out.s(", ");
463 }
464 let (child_type, child_tokens) = child.build(ctx, path, scope);
465 out.s(&child_tokens.to_string());
466 child_types.extend(child_type.0);
467 }
468 out.s(")");
469 return (ExprType(child_types), out);
470 },
471 Expr::LitNull(t) => {
472 let mut out = Tokens::new();
473 out.s("null");
474 return (ExprType(vec![(ExprValName::empty(), Type {
475 type_: t.clone(),
476 opt: true,
477 })]), out);
478 },
479 Expr::LitBool(x) => {
480 let mut out = Tokens::new();
481 out.s(if *x {
482 "true"
483 } else {
484 "false"
485 });
486 return empty_type!(out, SimpleSimpleType::Bool);
487 },
488 Expr::LitAuto(x) => {
489 let mut out = Tokens::new();
490 out.s(&x.to_string());
491 return empty_type!(out, SimpleSimpleType::Auto);
492 },
493 Expr::LitI32(x) => {
494 let mut out = Tokens::new();
495 out.s(&x.to_string());
496 return empty_type!(out, SimpleSimpleType::I32);
497 },
498 Expr::LitI64(x) => {
499 let mut out = Tokens::new();
500 out.s(&x.to_string());
501 return empty_type!(out, SimpleSimpleType::I64);
502 },
503 Expr::LitF32(x) => {
504 let mut out = Tokens::new();
505 out.s(&x.to_string());
506 return empty_type!(out, SimpleSimpleType::F32);
507 },
508 Expr::LitF64(x) => {
509 let mut out = Tokens::new();
510 out.s(&x.to_string());
511 return empty_type!(out, SimpleSimpleType::F64);
512 },
513 Expr::LitString(x) => {
514 let mut out = Tokens::new();
515 out.s(&format!("'{}'", x.replace("'", "''")));
516 return empty_type!(out, SimpleSimpleType::String);
517 },
518 Expr::LitBytes(x) => {
519 let mut out = Tokens::new();
520 let h = hex::encode(&x);
521 out.s(&format!("x'{}'", h));
522 return empty_type!(out, SimpleSimpleType::Bytes);
523 },
524 #[cfg(feature = "chrono")]
525 Expr::LitUtcTimeChrono(d) => {
526 let mut out = Tokens::new();
527 let d = d.to_rfc3339();
528 out.s(&format!("'{}'", d));
529 return empty_type!(out, SimpleSimpleType::UtcTimeChrono);
530 },
531 #[cfg(feature = "chrono")]
532 Expr::LitFixedOffsetTimeChrono(d) => {
533 let mut out = Tokens::new();
534 let d = d.to_rfc3339();
535 out.s(&format!("'{}'", d));
536 return empty_type!(out, SimpleSimpleType::FixedOffsetTimeChrono);
537 },
538 #[cfg(feature = "jiff")]
539 Expr::LitUtcTimeJiff(d) => {
540 let mut out = Tokens::new();
541 let d = d.to_string();
542 out.s(&format!("'{}'", d));
543 return empty_type!(out, SimpleSimpleType::UtcTimeChrono);
544 },
545 Expr::Param { name: x, type_: t } => {
546 let path = path.push_back(format!("Param ({})", x));
547 let mut out = Tokens::new();
548 let mut errs = vec![];
549 let i = match ctx.rust_arg_lookup.entry(x.clone()) {
550 std::collections::hash_map::Entry::Occupied(e) => {
551 let (i, prev_t) = e.get();
552 if t != prev_t {
553 errs.push(
554 format!("Parameter {} specified with multiple types: {:?}, {:?}", x, t, prev_t),
555 );
556 }
557 *i
558 },
559 std::collections::hash_map::Entry::Vacant(e) => {
560 let i = ctx.query_args.len();
561 e.insert((i, t.clone()));
562 let rust_types = to_rust_types(&t.type_.type_);
563 let custom_trait_ident = rust_types.custom_trait;
564 let rust_type = rust_types.arg_type;
565 let ident = format_ident!("{}", sanitize_ident(x).1);
566 let (mut rust_type, mut rust_forward) = if let Some(custom) = &t.type_.custom {
567 let custom_ident = match syn::parse_str::<Path>(custom.as_str()) {
568 Ok(p) => p,
569 Err(e) => {
570 ctx.errs.err(&path, format!("Couldn't parse custom type {}: {:?}", custom, e));
571 return (ExprType(vec![]), Tokens::new());
572 },
573 }.to_token_stream();
574 let forward =
575 quote!(< #custom_ident as #custom_trait_ident < #custom_ident >>:: to_sql(& #ident));
576 (quote!(& #custom_ident), forward)
577 } else {
578 (rust_type, quote!(#ident))
579 };
580 if t.opt {
581 rust_type = quote!(Option < #rust_type >);
582 rust_forward = quote!(#ident.map(| #ident | #rust_forward));
583 }
584 ctx.rust_args.push(quote!(#ident: #rust_type));
585 ctx.query_args.push(quote!(#rust_forward));
586 i
587 },
588 };
589 for e in errs {
590 ctx.errs.err(&path, e);
591 }
592 out.s(&format!("${}", i + 1));
593 return (ExprType(vec![(ExprValName::local(x.clone()), t.clone())]), out);
594 },
595 Expr::Field(x) => {
596 let name = ExprValName::field(x);
597 let t = match scope.get(&name) {
598 Some(t) => t.clone(),
599 None => {
600 ctx
601 .errs
602 .err(
603 path,
604 format!(
605 "Expression references {} but this field isn't available here (available fields: {:?})",
606 x,
607 scope.iter().map(|e| e.0.to_string()).collect::<Vec<String>>()
608 ),
609 );
610 return (ExprType(vec![]), Tokens::new());
611 },
612 };
613 let mut out = Tokens::new();
614 out.id(&x.table.id).s(".").id(&x.id);
615 return (ExprType(vec![(name, t.clone())]), out);
616 },
617 Expr::BinOp { left, op, right } => {
618 return do_bin_op(
619 ctx,
620 &path.push_back(format!("Bin op {:?}", op)),
621 scope,
622 op,
623 &vec![left.as_ref().clone(), right.as_ref().clone()],
624 );
625 },
626 Expr::BinOpChain { op, exprs } => {
627 return do_bin_op(ctx, &path.push_back(format!("Chain bin op {:?}", op)), scope, op, exprs);
628 },
629 Expr::PrefixOp { op, right } => {
630 let path = path.push_back(format!("Prefix op {:?}", op));
631 let mut out = Tokens::new();
632 let res = right.build(ctx, &path, scope);
633 let (op_text, op_type) = match op {
634 PrefixOp::Not => {
635 check_bool(ctx, &path, &res.0);
636 ("not", SimpleSimpleType::Bool)
637 },
638 };
639 out.s(op_text).s(&res.1.to_string());
640 return empty_type!(out, op_type);
641 },
642 Expr::Call { func, args, compute_type } => {
643 let mut types = vec![];
644 let mut out = Tokens::new();
645 out.s(func);
646 out.s("(");
647 for (i, arg) in args.iter().enumerate() {
648 if i > 0 {
649 out.s(",");
650 }
651 let (arg_type, tokens) =
652 arg.build(ctx, &path.push_back(format!("Call [{}] arg {}", func, i)), scope);
653 types.push(arg_type);
654 out.s(&tokens.to_string());
655 }
656 out.s(")");
657 let type_ = match (compute_type.0)(ctx, path, types) {
658 Some(t) => t,
659 None => {
660 return (ExprType(vec![]), Tokens::new());
661 },
662 };
663 return (ExprType(vec![(ExprValName::empty(), type_)]), out);
664 },
665 Expr::Select(s) => {
666 let path = path.push_back(format!("Subselect"));
667 return s.build(ctx, &path, QueryResCount::Many);
668 },
669 Expr::Cast(e, t) => {
670 let path = path.push_back(format!("Cast"));
671 let out = e.build(ctx, &path, scope);
672 let got_t = match out.0.assert_scalar(&mut ctx.errs, &path) {
673 Some(t) => t,
674 None => {
675 return (ExprType(vec![]), Tokens::new());
676 },
677 };
678 check_general_same_type(ctx, &path, t, &got_t.1);
679 return (ExprType(vec![(got_t.0, t.clone())]), out.1);
680 },
681 };
682 }
683}
684
685#[derive(Clone, Debug)]
686pub enum BinOp {
687 Plus,
688 Minus,
689 Multiply,
690 Divide,
691 And,
692 Or,
693 Equals,
694 NotEquals,
695 Is,
696 IsNot,
697 LessThan,
698 LessThanEqualTo,
699 GreaterThan,
700 GreaterThanEqualTo,
701}
702
703#[derive(Clone, Debug)]
704pub enum PrefixOp {
705 Not,
706}