1use crate::ast::{BinOp, Expr, FieldRef, UnaryOp};
32use reddb_types::cast_catalog::{can_implicit_cast, CastContext};
33use reddb_types::types::{DataType, TypeCategory, Value};
34
35#[derive(Debug, Clone)]
39pub enum TypeError {
40 UnknownColumn { table: String, column: String },
42 OperatorMismatch {
45 op: BinOp,
46 lhs: DataType,
47 rhs: DataType,
48 },
49 UnaryMismatch { op: UnaryOp, operand: DataType },
51 InvalidCast { src: DataType, target: DataType },
54 CaseBranchMismatch { first: DataType, other: DataType },
56 InListMismatch { target: DataType, element: DataType },
58}
59
60impl std::fmt::Display for TypeError {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 match self {
63 Self::UnknownColumn { table, column } => {
64 if table.is_empty() {
65 write!(f, "unknown column `{column}`")
66 } else {
67 write!(f, "unknown column `{table}.{column}`")
68 }
69 }
70 Self::OperatorMismatch { op, lhs, rhs } => {
71 write!(
72 f,
73 "operator `{op:?}` cannot apply to `{lhs:?}` and `{rhs:?}`"
74 )
75 }
76 Self::UnaryMismatch { op, operand } => {
77 write!(f, "unary `{op:?}` cannot apply to `{operand:?}`")
78 }
79 Self::InvalidCast { src, target } => {
80 write!(f, "no cast from `{src:?}` to `{target:?}`")
81 }
82 Self::CaseBranchMismatch { first, other } => {
83 write!(
84 f,
85 "CASE branches disagree on type: `{first:?}` vs `{other:?}`"
86 )
87 }
88 Self::InListMismatch { target, element } => {
89 write!(
90 f,
91 "IN list element `{element:?}` is incompatible with target `{target:?}`"
92 )
93 }
94 }
95 }
96}
97
98impl std::error::Error for TypeError {}
99
100#[derive(Debug, Clone)]
108pub struct TypedExpr {
109 pub kind: TypedExprKind,
110 pub ty: DataType,
111}
112
113#[derive(Debug, Clone)]
114pub enum TypedExprKind {
115 Literal(Value),
116 Column(FieldRef),
117 UnaryOp {
118 op: UnaryOp,
119 operand: Box<TypedExpr>,
120 },
121 BinaryOp {
122 op: BinOp,
123 lhs: Box<TypedExpr>,
124 rhs: Box<TypedExpr>,
125 },
126 Cast {
127 inner: Box<TypedExpr>,
128 },
129 FunctionCall {
130 name: String,
131 args: Vec<TypedExpr>,
132 },
133 Case {
134 branches: Vec<(TypedExpr, TypedExpr)>,
135 else_: Option<Box<TypedExpr>>,
136 },
137 IsNull {
138 operand: Box<TypedExpr>,
139 negated: bool,
140 },
141 InList {
142 target: Box<TypedExpr>,
143 values: Vec<TypedExpr>,
144 negated: bool,
145 },
146 Between {
147 target: Box<TypedExpr>,
148 low: Box<TypedExpr>,
149 high: Box<TypedExpr>,
150 negated: bool,
151 },
152}
153
154pub trait Scope {
159 fn lookup(&self, table: &str, column: &str) -> Option<DataType>;
160}
161
162impl<F> Scope for F
163where
164 F: Fn(&str, &str) -> Option<DataType>,
165{
166 fn lookup(&self, table: &str, column: &str) -> Option<DataType> {
167 self(table, column)
168 }
169}
170
171pub fn type_expr(expr: &Expr, scope: &dyn Scope) -> Result<TypedExpr, TypeError> {
173 match expr {
174 Expr::Literal { value, .. } => Ok(TypedExpr {
175 ty: literal_type(value),
176 kind: TypedExprKind::Literal(value.clone()),
177 }),
178 Expr::Column { field, .. } => {
179 let (table, column) = match field {
180 FieldRef::TableColumn { table, column } => (table.as_str(), column.as_str()),
181 FieldRef::NodeProperty { alias, property } => (alias.as_str(), property.as_str()),
182 FieldRef::EdgeProperty { alias, property } => (alias.as_str(), property.as_str()),
183 FieldRef::NodeId { .. } => ("", ""),
184 };
185 let ty = scope
186 .lookup(table, column)
187 .ok_or(TypeError::UnknownColumn {
188 table: table.to_string(),
189 column: column.to_string(),
190 })?;
191 Ok(TypedExpr {
192 ty,
193 kind: TypedExprKind::Column(field.clone()),
194 })
195 }
196 Expr::Parameter { .. } => {
197 Ok(TypedExpr {
201 ty: DataType::Nullable,
202 kind: TypedExprKind::Literal(Value::Null),
203 })
204 }
205 Expr::UnaryOp { op, operand, .. } => {
206 let inner = type_expr(operand, scope)?;
207 let ty = unary_result_type(*op, inner.ty)?;
208 Ok(TypedExpr {
209 ty,
210 kind: TypedExprKind::UnaryOp {
211 op: *op,
212 operand: Box::new(inner),
213 },
214 })
215 }
216 Expr::BinaryOp { op, lhs, rhs, .. } => {
217 let l = type_expr(lhs, scope)?;
218 let r = type_expr(rhs, scope)?;
219 let ty = binop_result_type(*op, l.ty, r.ty)?;
220 Ok(TypedExpr {
221 ty,
222 kind: TypedExprKind::BinaryOp {
223 op: *op,
224 lhs: Box::new(l),
225 rhs: Box::new(r),
226 },
227 })
228 }
229 Expr::Cast { inner, target, .. } => {
230 let inner_typed = type_expr(inner, scope)?;
231 if !reddb_types::cast_catalog::can_explicit_cast(inner_typed.ty, *target) {
234 return Err(TypeError::InvalidCast {
235 src: inner_typed.ty,
236 target: *target,
237 });
238 }
239 Ok(TypedExpr {
240 ty: *target,
241 kind: TypedExprKind::Cast {
242 inner: Box::new(inner_typed),
243 },
244 })
245 }
246 Expr::FunctionCall { name, args, .. } => {
247 let typed_args = args
248 .iter()
249 .map(|a| type_expr(a, scope))
250 .collect::<Result<Vec<_>, _>>()?;
251 let arg_dt: Vec<DataType> = typed_args.iter().map(|t| t.ty).collect();
259 let return_ty = resolve_function_return_type(name, &arg_dt);
260 Ok(TypedExpr {
261 ty: return_ty,
262 kind: TypedExprKind::FunctionCall {
263 name: name.clone(),
264 args: typed_args,
265 },
266 })
267 }
268 Expr::Case {
269 branches, else_, ..
270 } => {
271 let mut typed_branches = Vec::with_capacity(branches.len());
272 let mut result_ty: Option<DataType> = None;
273 for (cond, val) in branches {
274 let cond_typed = type_expr(cond, scope)?;
275 let val_typed = type_expr(val, scope)?;
276 let prev_ty = result_ty;
277 result_ty = merge_compatible_type(result_ty, val_typed.ty).map_err(|_| {
278 TypeError::CaseBranchMismatch {
279 first: prev_ty.unwrap_or(val_typed.ty),
280 other: val_typed.ty,
281 }
282 })?;
283 typed_branches.push((cond_typed, val_typed));
284 }
285 let typed_else = if let Some(else_expr) = else_ {
286 let e = type_expr(else_expr, scope)?;
287 let prev_ty = result_ty;
288 result_ty = merge_compatible_type(result_ty, e.ty).map_err(|_| {
289 TypeError::CaseBranchMismatch {
290 first: prev_ty.unwrap_or(e.ty),
291 other: e.ty,
292 }
293 })?;
294 Some(Box::new(e))
295 } else {
296 None
297 };
298 let ty = result_ty.unwrap_or(DataType::Nullable);
299 Ok(TypedExpr {
300 ty,
301 kind: TypedExprKind::Case {
302 branches: typed_branches,
303 else_: typed_else,
304 },
305 })
306 }
307 Expr::IsNull {
308 operand, negated, ..
309 } => {
310 let inner = type_expr(operand, scope)?;
311 Ok(TypedExpr {
312 ty: DataType::Boolean,
313 kind: TypedExprKind::IsNull {
314 operand: Box::new(inner),
315 negated: *negated,
316 },
317 })
318 }
319 Expr::InList {
320 target,
321 values,
322 negated,
323 ..
324 } => {
325 let target_typed = type_expr(target, scope)?;
326 let mut typed_values = Vec::with_capacity(values.len());
327 for v in values {
328 let vt = type_expr(v, scope)?;
329 if vt.ty != target_typed.ty && !can_implicit_cast(vt.ty, target_typed.ty) {
330 return Err(TypeError::InListMismatch {
331 target: target_typed.ty,
332 element: vt.ty,
333 });
334 }
335 typed_values.push(vt);
336 }
337 Ok(TypedExpr {
338 ty: DataType::Boolean,
339 kind: TypedExprKind::InList {
340 target: Box::new(target_typed),
341 values: typed_values,
342 negated: *negated,
343 },
344 })
345 }
346 Expr::Between {
347 target,
348 low,
349 high,
350 negated,
351 ..
352 } => {
353 let target_typed = type_expr(target, scope)?;
354 let low_typed = type_expr(low, scope)?;
355 let high_typed = type_expr(high, scope)?;
356 for bound in &[&low_typed, &high_typed] {
358 if bound.ty != target_typed.ty && !can_implicit_cast(bound.ty, target_typed.ty) {
359 return Err(TypeError::OperatorMismatch {
360 op: BinOp::Ge,
361 lhs: target_typed.ty,
362 rhs: bound.ty,
363 });
364 }
365 }
366 Ok(TypedExpr {
367 ty: DataType::Boolean,
368 kind: TypedExprKind::Between {
369 target: Box::new(target_typed),
370 low: Box::new(low_typed),
371 high: Box::new(high_typed),
372 negated: *negated,
373 },
374 })
375 }
376 Expr::Subquery { .. } => Ok(TypedExpr {
377 ty: DataType::Nullable,
378 kind: TypedExprKind::Literal(Value::Null),
379 }),
380 Expr::WindowFunctionCall { .. } => Ok(TypedExpr {
384 ty: DataType::Nullable,
385 kind: TypedExprKind::Literal(Value::Null),
386 }),
387 }
388}
389
390fn literal_type(v: &Value) -> DataType {
395 match v {
396 Value::Null => DataType::Nullable,
397 Value::Boolean(_) => DataType::Boolean,
398 Value::Integer(_) => DataType::Integer,
399 Value::UnsignedInteger(_) => DataType::UnsignedInteger,
400 Value::Float(_) => DataType::Float,
401 Value::BigInt(_) => DataType::BigInt,
402 Value::Decimal(_) => DataType::Decimal,
403 Value::Text(_) => DataType::Text,
404 Value::Blob(_) => DataType::Blob,
405 Value::Timestamp(_) => DataType::Timestamp,
406 Value::TimestampMs(_) => DataType::TimestampMs,
407 Value::Duration(_) => DataType::Duration,
408 Value::Date(_) => DataType::Date,
409 Value::Time(_) => DataType::Time,
410 Value::IpAddr(_) => DataType::IpAddr,
411 Value::Ipv4(_) => DataType::Ipv4,
412 Value::Ipv6(_) => DataType::Ipv6,
413 Value::Subnet(_, _) => DataType::Subnet,
414 Value::Cidr(_, _) => DataType::Cidr,
415 Value::MacAddr(_) => DataType::MacAddr,
416 Value::Port(_) => DataType::Port,
417 Value::Latitude(_) => DataType::Latitude,
418 Value::Longitude(_) => DataType::Longitude,
419 Value::GeoPoint(_, _) => DataType::GeoPoint,
420 Value::Country2(_) => DataType::Country2,
421 Value::Country3(_) => DataType::Country3,
422 Value::Lang2(_) => DataType::Lang2,
423 Value::Lang5(_) => DataType::Lang5,
424 Value::Currency(_) => DataType::Currency,
425 Value::AssetCode(_) => DataType::AssetCode,
426 Value::Money { .. } => DataType::Money,
427 Value::Color(_) => DataType::Color,
428 Value::ColorAlpha(_) => DataType::ColorAlpha,
429 Value::Email(_) => DataType::Email,
430 Value::Url(_) => DataType::Url,
431 Value::Phone(_) => DataType::Phone,
432 Value::Semver(_) => DataType::Semver,
433 Value::Uuid(_) => DataType::Uuid,
434 Value::Vector(_) => DataType::Vector,
435 Value::Array(_) => DataType::Array,
436 Value::Json(_) => DataType::Json,
437 Value::EnumValue(_) => DataType::Enum,
438 Value::NodeRef(_) => DataType::NodeRef,
439 Value::EdgeRef(_) => DataType::EdgeRef,
440 Value::VectorRef(_, _) => DataType::VectorRef,
441 Value::RowRef(_, _) => DataType::RowRef,
442 Value::KeyRef(_, _) => DataType::KeyRef,
443 Value::DocRef(_, _) => DataType::DocRef,
444 Value::TableRef(_) => DataType::TableRef,
445 Value::PageRef(_) => DataType::PageRef,
446 Value::Secret(_) => DataType::Secret,
447 Value::Password(_) => DataType::Password,
448 }
449}
450
451fn resolve_function_return_type(name: &str, arg_types: &[DataType]) -> DataType {
452 let upper = name.to_ascii_uppercase();
453 match upper.as_str() {
454 "CONCAT" | "CONCAT_WS" | "QUOTE_LITERAL" => DataType::Text,
458 "MONEY" => DataType::Money,
459 "MONEY_ASSET" => DataType::AssetCode,
460 "MONEY_MINOR" => DataType::BigInt,
461 "MONEY_SCALE" => DataType::Integer,
462 "COALESCE" => resolve_coalesce_return_type(arg_types),
465 _ => reddb_types::function_catalog::resolve(name, arg_types)
466 .map(|entry| entry.return_type)
467 .unwrap_or(DataType::Nullable),
468 }
469}
470
471fn resolve_coalesce_return_type(arg_types: &[DataType]) -> DataType {
472 let mut resolved: Option<DataType> = None;
473
474 for &arg_ty in arg_types {
475 match merge_compatible_type(resolved, arg_ty) {
476 Ok(next) => resolved = next,
477 Err(_) => return DataType::Nullable,
478 }
479 }
480
481 resolved.unwrap_or(DataType::Nullable)
482}
483
484fn merge_compatible_type(
485 current: Option<DataType>,
486 next: DataType,
487) -> Result<Option<DataType>, ()> {
488 if next == DataType::Nullable {
489 return Ok(current);
490 }
491
492 match current {
493 None => Ok(Some(next)),
494 Some(prev) if prev == next => Ok(Some(prev)),
495 Some(prev) if can_implicit_cast(next, prev) => Ok(Some(prev)),
496 Some(prev) if can_implicit_cast(prev, next) => Ok(Some(next)),
497 Some(_) => Err(()),
498 }
499}
500
501fn unary_result_type(op: UnaryOp, operand: DataType) -> Result<DataType, TypeError> {
504 match op {
505 UnaryOp::Neg if operand.category() == TypeCategory::Numeric => Ok(operand),
506 UnaryOp::Not if operand == DataType::Boolean => Ok(DataType::Boolean),
507 _ => Err(TypeError::UnaryMismatch { op, operand }),
508 }
509}
510
511fn binop_result_type(op: BinOp, lhs: DataType, rhs: DataType) -> Result<DataType, TypeError> {
523 use BinOp::*;
524 match op {
525 And | Or => {
526 if lhs == DataType::Boolean && rhs == DataType::Boolean {
527 Ok(DataType::Boolean)
528 } else {
529 Err(TypeError::OperatorMismatch { op, lhs, rhs })
530 }
531 }
532 Eq | Ne | Lt | Le | Gt | Ge => {
533 if lhs == rhs {
537 return Ok(DataType::Boolean);
538 }
539 if lhs.category() == rhs.category()
540 && (can_implicit_cast(lhs, rhs) || can_implicit_cast(rhs, lhs))
541 {
542 return Ok(DataType::Boolean);
543 }
544 Err(TypeError::OperatorMismatch { op, lhs, rhs })
545 }
546 Add | Sub | Mul | Div | Mod => {
547 if lhs.category() != TypeCategory::Numeric || rhs.category() != TypeCategory::Numeric {
548 return Err(TypeError::OperatorMismatch { op, lhs, rhs });
549 }
550 if lhs == DataType::Float || rhs == DataType::Float {
554 Ok(DataType::Float)
555 } else if lhs == DataType::Decimal || rhs == DataType::Decimal {
556 Ok(DataType::Decimal)
557 } else if lhs == DataType::BigInt || rhs == DataType::BigInt {
558 Ok(DataType::BigInt)
559 } else {
560 Ok(DataType::Integer)
561 }
562 }
563 Concat => {
564 if lhs == DataType::Text && rhs == DataType::Text {
565 Ok(DataType::Text)
566 } else {
567 Err(TypeError::OperatorMismatch { op, lhs, rhs })
568 }
569 }
570 }
571}
572
573#[allow(dead_code)]
577fn _ctx_explicit() -> CastContext {
578 CastContext::Explicit
579}
580
581#[cfg(test)]
582mod tests {
583 use super::*;
584 use crate::ast::Span;
585 use crate::lexer::Position;
586 use std::net::{IpAddr, Ipv4Addr};
587 use std::sync::Arc;
588
589 fn span() -> Span {
590 Span {
591 start: Position::default(),
592 end: Position::default(),
593 }
594 }
595
596 fn lit(value: Value) -> Expr {
597 Expr::Literal {
598 value,
599 span: span(),
600 }
601 }
602
603 fn col(table: &str, column: &str) -> Expr {
604 Expr::Column {
605 field: FieldRef::column(table, column),
606 span: span(),
607 }
608 }
609
610 fn bin(op: BinOp, lhs: Expr, rhs: Expr) -> Expr {
611 Expr::BinaryOp {
612 op,
613 lhs: Box::new(lhs),
614 rhs: Box::new(rhs),
615 span: span(),
616 }
617 }
618
619 fn unary(op: UnaryOp, operand: Expr) -> Expr {
620 Expr::UnaryOp {
621 op,
622 operand: Box::new(operand),
623 span: span(),
624 }
625 }
626
627 fn scope(table: &str, column: &str) -> Option<DataType> {
628 match (table, column) {
629 ("", "age") => Some(DataType::Integer),
630 ("", "active") => Some(DataType::Boolean),
631 ("users", "name") => Some(DataType::Text),
632 ("n", "score") => Some(DataType::Float),
633 _ => None,
634 }
635 }
636
637 fn no_scope(_: &str, _: &str) -> Option<DataType> {
638 None
639 }
640
641 #[test]
642 fn literal_values_map_to_declared_types() {
643 let values = vec![
644 (Value::Null, DataType::Nullable),
645 (Value::Boolean(true), DataType::Boolean),
646 (Value::Integer(1), DataType::Integer),
647 (Value::UnsignedInteger(1), DataType::UnsignedInteger),
648 (Value::Float(1.0), DataType::Float),
649 (Value::BigInt(1), DataType::BigInt),
650 (Value::Decimal(100), DataType::Decimal),
651 (Value::Text(Arc::from("x")), DataType::Text),
652 (Value::Blob(vec![1, 2]), DataType::Blob),
653 (Value::Timestamp(1), DataType::Timestamp),
654 (Value::TimestampMs(1), DataType::TimestampMs),
655 (Value::Duration(1), DataType::Duration),
656 (Value::Date(1), DataType::Date),
657 (Value::Time(1), DataType::Time),
658 (
659 Value::IpAddr(IpAddr::V4(Ipv4Addr::LOCALHOST)),
660 DataType::IpAddr,
661 ),
662 (Value::Ipv4(0x7f00_0001), DataType::Ipv4),
663 (Value::Ipv6([0; 16]), DataType::Ipv6),
664 (Value::Subnet(0, 24), DataType::Subnet),
665 (Value::Cidr(0, 24), DataType::Cidr),
666 (Value::MacAddr([1, 2, 3, 4, 5, 6]), DataType::MacAddr),
667 (Value::Port(5432), DataType::Port),
668 (Value::Latitude(1), DataType::Latitude),
669 (Value::Longitude(1), DataType::Longitude),
670 (Value::GeoPoint(1, 2), DataType::GeoPoint),
671 (Value::Country2(*b"BR"), DataType::Country2),
672 (Value::Country3(*b"BRA"), DataType::Country3),
673 (Value::Lang2(*b"pt"), DataType::Lang2),
674 (Value::Lang5(*b"pt-BR"), DataType::Lang5),
675 (Value::Currency(*b"BRL"), DataType::Currency),
676 (Value::AssetCode("BTC".to_string()), DataType::AssetCode),
677 (
678 Value::Money {
679 asset_code: "BRL".to_string(),
680 minor_units: 123,
681 scale: 2,
682 },
683 DataType::Money,
684 ),
685 (Value::Color([1, 2, 3]), DataType::Color),
686 (Value::ColorAlpha([1, 2, 3, 4]), DataType::ColorAlpha),
687 (Value::Email("a@example.com".to_string()), DataType::Email),
688 (Value::Url("https://example.com".to_string()), DataType::Url),
689 (Value::Phone(5511999999999), DataType::Phone),
690 (Value::Semver(1_002_003), DataType::Semver),
691 (Value::Uuid([1; 16]), DataType::Uuid),
692 (Value::Vector(vec![1.0, 2.0]), DataType::Vector),
693 (Value::Array(vec![Value::Integer(1)]), DataType::Array),
694 (Value::Json(br#"{"x":1}"#.to_vec()), DataType::Json),
695 (Value::EnumValue(1), DataType::Enum),
696 (Value::NodeRef("n1".to_string()), DataType::NodeRef),
697 (Value::EdgeRef("e1".to_string()), DataType::EdgeRef),
698 (Value::VectorRef("vecs".to_string(), 1), DataType::VectorRef),
699 (Value::RowRef("rows".to_string(), 1), DataType::RowRef),
700 (
701 Value::KeyRef("kv".to_string(), "k".to_string()),
702 DataType::KeyRef,
703 ),
704 (Value::DocRef("docs".to_string(), 1), DataType::DocRef),
705 (Value::TableRef("users".to_string()), DataType::TableRef),
706 (Value::PageRef(7), DataType::PageRef),
707 (Value::Secret(vec![1, 2, 3]), DataType::Secret),
708 (Value::Password("argon2".to_string()), DataType::Password),
709 ];
710
711 for (value, expected) in values {
712 let typed = type_expr(&lit(value), &no_scope).unwrap();
713 assert_eq!(typed.ty, expected);
714 assert!(matches!(typed.kind, TypedExprKind::Literal(_)));
715 }
716 }
717
718 #[test]
719 fn column_lookup_preserves_field_ref_and_reports_unknowns() {
720 let typed = type_expr(&col("users", "name"), &scope).unwrap();
721 assert_eq!(typed.ty, DataType::Text);
722 assert!(matches!(
723 typed.kind,
724 TypedExprKind::Column(FieldRef::TableColumn { table, column })
725 if table == "users" && column == "name"
726 ));
727
728 let err = type_expr(&col("", "missing"), &scope).unwrap_err();
729 assert!(matches!(
730 err,
731 TypeError::UnknownColumn { ref table, ref column }
732 if table.is_empty() && column == "missing"
733 ));
734 assert_eq!(err.to_string(), "unknown column `missing`");
735 }
736
737 #[test]
738 fn arithmetic_logical_and_unary_ops_return_expected_types() {
739 let add = bin(BinOp::Add, lit(Value::Integer(1)), lit(Value::Float(2.0)));
740 assert_eq!(type_expr(&add, &scope).unwrap().ty, DataType::Float);
741
742 let and = bin(BinOp::And, col("", "active"), lit(Value::Boolean(false)));
743 assert_eq!(type_expr(&and, &scope).unwrap().ty, DataType::Boolean);
744
745 let neg = unary(UnaryOp::Neg, col("", "age"));
746 assert_eq!(type_expr(&neg, &scope).unwrap().ty, DataType::Integer);
747
748 let not = unary(UnaryOp::Not, col("", "active"));
749 assert_eq!(type_expr(¬, &scope).unwrap().ty, DataType::Boolean);
750 }
751
752 #[test]
753 fn operator_mismatches_are_reported() {
754 let bad_and = bin(
755 BinOp::And,
756 lit(Value::Boolean(true)),
757 lit(Value::Integer(1)),
758 );
759 assert!(matches!(
760 type_expr(&bad_and, &scope).unwrap_err(),
761 TypeError::OperatorMismatch {
762 op: BinOp::And,
763 lhs: DataType::Boolean,
764 rhs: DataType::Integer,
765 }
766 ));
767
768 let bad_neg = unary(UnaryOp::Neg, lit(Value::Text(Arc::from("x"))));
769 assert!(matches!(
770 type_expr(&bad_neg, &scope).unwrap_err(),
771 TypeError::UnaryMismatch {
772 op: UnaryOp::Neg,
773 operand: DataType::Text,
774 }
775 ));
776 }
777
778 #[test]
779 fn casts_functions_and_parameters_have_stable_types() {
780 let cast = Expr::Cast {
781 inner: Box::new(lit(Value::Integer(1))),
782 target: DataType::Text,
783 span: span(),
784 };
785 assert_eq!(type_expr(&cast, &scope).unwrap().ty, DataType::Text);
786
787 let concat = Expr::FunctionCall {
788 name: "concat".to_string(),
789 args: vec![lit(Value::Text(Arc::from("a"))), lit(Value::Integer(1))],
790 span: span(),
791 };
792 assert_eq!(type_expr(&concat, &scope).unwrap().ty, DataType::Text);
793
794 let money_minor = Expr::FunctionCall {
795 name: "money_minor".to_string(),
796 args: vec![lit(Value::Money {
797 asset_code: "BRL".to_string(),
798 minor_units: 10,
799 scale: 2,
800 })],
801 span: span(),
802 };
803 assert_eq!(
804 type_expr(&money_minor, &scope).unwrap().ty,
805 DataType::BigInt
806 );
807
808 let coalesce = Expr::FunctionCall {
809 name: "coalesce".to_string(),
810 args: vec![
811 lit(Value::Null),
812 lit(Value::Integer(1)),
813 lit(Value::Float(2.0)),
814 ],
815 span: span(),
816 };
817 assert_eq!(type_expr(&coalesce, &scope).unwrap().ty, DataType::Integer);
818
819 let unknown = Expr::FunctionCall {
820 name: "not_a_function".to_string(),
821 args: Vec::new(),
822 span: span(),
823 };
824 assert_eq!(type_expr(&unknown, &scope).unwrap().ty, DataType::Nullable);
825
826 let parameter = Expr::Parameter {
827 index: 1,
828 span: span(),
829 };
830 assert_eq!(
831 type_expr(¶meter, &scope).unwrap().ty,
832 DataType::Nullable
833 );
834 }
835
836 #[test]
837 fn invalid_casts_case_branches_and_lists_are_errors() {
838 let bad_cast = Expr::Cast {
839 inner: Box::new(lit(Value::Blob(vec![1]))),
840 target: DataType::Money,
841 span: span(),
842 };
843 assert!(matches!(
844 type_expr(&bad_cast, &scope).unwrap_err(),
845 TypeError::InvalidCast {
846 src: DataType::Blob,
847 target: DataType::Money,
848 }
849 ));
850
851 let case = Expr::Case {
852 branches: vec![(
853 lit(Value::Boolean(true)),
854 lit(Value::Text(Arc::from("text"))),
855 )],
856 else_: Some(Box::new(lit(Value::Integer(1)))),
857 span: span(),
858 };
859 assert!(matches!(
860 type_expr(&case, &scope).unwrap_err(),
861 TypeError::CaseBranchMismatch {
862 first: DataType::Text,
863 other: DataType::Integer,
864 }
865 ));
866
867 let in_list = Expr::InList {
868 target: Box::new(lit(Value::Integer(1))),
869 values: vec![lit(Value::Text(Arc::from("x")))],
870 negated: false,
871 span: span(),
872 };
873 assert!(matches!(
874 type_expr(&in_list, &scope).unwrap_err(),
875 TypeError::InListMismatch {
876 target: DataType::Integer,
877 element: DataType::Text,
878 }
879 ));
880 }
881
882 #[test]
883 fn predicates_return_boolean_when_bounds_and_values_are_compatible() {
884 let is_null = Expr::IsNull {
885 operand: Box::new(col("", "age")),
886 negated: true,
887 span: span(),
888 };
889 assert_eq!(type_expr(&is_null, &scope).unwrap().ty, DataType::Boolean);
890
891 let in_list = Expr::InList {
892 target: Box::new(col("", "age")),
893 values: vec![lit(Value::Integer(1)), lit(Value::Integer(2))],
894 negated: false,
895 span: span(),
896 };
897 assert_eq!(type_expr(&in_list, &scope).unwrap().ty, DataType::Boolean);
898
899 let between = Expr::Between {
900 target: Box::new(col("", "age")),
901 low: Box::new(lit(Value::Integer(1))),
902 high: Box::new(lit(Value::Integer(9))),
903 negated: false,
904 span: span(),
905 };
906 assert_eq!(type_expr(&between, &scope).unwrap().ty, DataType::Boolean);
907
908 let bad_between = Expr::Between {
909 target: Box::new(col("", "age")),
910 low: Box::new(lit(Value::Text(Arc::from("low")))),
911 high: Box::new(lit(Value::Integer(9))),
912 negated: false,
913 span: span(),
914 };
915 assert!(matches!(
916 type_expr(&bad_between, &scope).unwrap_err(),
917 TypeError::OperatorMismatch {
918 op: BinOp::Ge,
919 lhs: DataType::Integer,
920 rhs: DataType::Text,
921 }
922 ));
923 }
924}