1use crate::ast::Span;
8use crate::ast::ddl::VectorMetric;
9use crate::ast::expr::{BinaryOp, Expr, ExprKind, Literal, UnaryOp};
10use crate::catalog::{Catalog, TableMetadata};
11use crate::planner::aggregate_expr::{AggregateExpr, AggregateFunction};
12use crate::planner::error::PlannerError;
13use crate::planner::typed_expr::{TypedExpr, TypedExprKind};
14use crate::planner::types::ResolvedType;
15
16pub struct TypeChecker<'a, C: Catalog + ?Sized> {
31 catalog: &'a C,
32}
33
34impl<'a, C: Catalog + ?Sized> TypeChecker<'a, C> {
35 pub fn new(catalog: &'a C) -> Self {
37 Self { catalog }
38 }
39
40 pub fn catalog(&self) -> &'a C {
42 self.catalog
43 }
44
45 pub fn infer_type(
57 &self,
58 expr: &Expr,
59 table: &TableMetadata,
60 ) -> Result<TypedExpr, PlannerError> {
61 let span = expr.span;
62 match &expr.kind {
63 ExprKind::Literal(lit) => self.infer_literal_type(lit, span),
64
65 ExprKind::ColumnRef {
66 table: table_qualifier,
67 column,
68 } => {
69 if let Some(qualifier) = table_qualifier
71 && qualifier != &table.name
72 {
73 return Err(PlannerError::TableNotFound {
74 name: qualifier.clone(),
75 line: span.start.line,
76 column: span.start.column,
77 });
78 }
79 self.infer_column_ref_type(table, column, span)
80 }
81
82 ExprKind::BinaryOp { left, op, right } => {
83 self.infer_binary_op_type(left, *op, right, table, span)
84 }
85
86 ExprKind::UnaryOp { op, operand } => {
87 self.infer_unary_op_type(*op, operand, table, span)
88 }
89
90 ExprKind::FunctionCall {
91 name,
92 args,
93 distinct,
94 star,
95 } => self.infer_function_call_type(name, args, *distinct, *star, table, span),
96
97 ExprKind::Between {
98 expr,
99 low,
100 high,
101 negated,
102 } => self.infer_between_type(expr, low, high, *negated, table, span),
103
104 ExprKind::Like {
105 expr,
106 pattern,
107 escape,
108 negated,
109 } => self.infer_like_type(expr, pattern, escape.as_deref(), *negated, table, span),
110
111 ExprKind::InList {
112 expr,
113 list,
114 negated,
115 } => self.infer_in_list_type(expr, list, *negated, table, span),
116
117 ExprKind::IsNull { expr, negated } => {
118 self.infer_is_null_type(expr, *negated, table, span)
119 }
120
121 ExprKind::VectorLiteral(values) => self.infer_vector_literal_type(values, span),
122 }
123 }
124
125 fn infer_literal_type(&self, lit: &Literal, span: Span) -> Result<TypedExpr, PlannerError> {
127 let (kind, resolved_type) = match lit {
128 Literal::Number(s) => {
129 let resolved_type = if s.contains('.') || s.contains('e') || s.contains('E') {
131 ResolvedType::Double
132 } else {
133 if s.parse::<i32>().is_ok() {
135 ResolvedType::Integer
136 } else {
137 ResolvedType::BigInt
138 }
139 };
140 (TypedExprKind::Literal(lit.clone()), resolved_type)
141 }
142 Literal::String(_) => (TypedExprKind::Literal(lit.clone()), ResolvedType::Text),
143 Literal::Boolean(_) => (TypedExprKind::Literal(lit.clone()), ResolvedType::Boolean),
144 Literal::Null => (TypedExprKind::Literal(lit.clone()), ResolvedType::Null),
145 };
146
147 Ok(TypedExpr {
148 kind,
149 resolved_type,
150 span,
151 })
152 }
153
154 fn infer_column_ref_type(
156 &self,
157 table: &TableMetadata,
158 column_name: &str,
159 span: Span,
160 ) -> Result<TypedExpr, PlannerError> {
161 let (column_index, column) = table
163 .columns
164 .iter()
165 .enumerate()
166 .find(|(_, c)| c.name == column_name)
167 .ok_or_else(|| PlannerError::ColumnNotFound {
168 column: column_name.to_string(),
169 table: table.name.clone(),
170 line: span.start.line,
171 col: span.start.column,
172 })?;
173
174 Ok(TypedExpr {
175 kind: TypedExprKind::ColumnRef {
176 table: table.name.clone(),
177 column: column_name.to_string(),
178 column_index,
179 },
180 resolved_type: column.data_type.clone(),
181 span,
182 })
183 }
184
185 fn infer_binary_op_type(
187 &self,
188 left: &Expr,
189 op: BinaryOp,
190 right: &Expr,
191 table: &TableMetadata,
192 span: Span,
193 ) -> Result<TypedExpr, PlannerError> {
194 let left_typed = self.infer_type(left, table)?;
195 let right_typed = self.infer_type(right, table)?;
196
197 let result_type = self.check_binary_op(
198 op,
199 &left_typed.resolved_type,
200 &right_typed.resolved_type,
201 span,
202 )?;
203
204 Ok(TypedExpr {
205 kind: TypedExprKind::BinaryOp {
206 left: Box::new(left_typed),
207 op,
208 right: Box::new(right_typed),
209 },
210 resolved_type: result_type,
211 span,
212 })
213 }
214
215 pub fn check_binary_op(
227 &self,
228 op: BinaryOp,
229 left: &ResolvedType,
230 right: &ResolvedType,
231 span: Span,
232 ) -> Result<ResolvedType, PlannerError> {
233 use BinaryOp::*;
234 use ResolvedType::*;
235
236 match op {
237 Add | Sub | Mul | Div | Mod => {
239 let result = self.check_arithmetic_op(left, right, span)?;
240 Ok(result)
241 }
242
243 Eq | Neq | Lt | Gt | LtEq | GtEq => {
245 self.check_comparison_op(left, right, span)?;
246 Ok(Boolean)
247 }
248
249 And | Or => {
251 self.check_logical_op(left, right, span)?;
252 Ok(Boolean)
253 }
254
255 StringConcat => {
257 self.check_string_concat_op(left, right, span)?;
258 Ok(Text)
259 }
260 }
261 }
262
263 fn check_arithmetic_op(
265 &self,
266 left: &ResolvedType,
267 right: &ResolvedType,
268 span: Span,
269 ) -> Result<ResolvedType, PlannerError> {
270 use ResolvedType::*;
271
272 if matches!(left, Null) || matches!(right, Null) {
274 return Ok(Null);
275 }
276
277 match (left, right) {
279 (Integer, Integer) => Ok(Integer),
281 (Integer, BigInt) | (BigInt, Integer) | (BigInt, BigInt) => Ok(BigInt),
282 (Integer, Float) | (Float, Integer) | (Float, Float) => Ok(Float),
283 (Integer, Double)
284 | (Double, Integer)
285 | (BigInt, Float)
286 | (Float, BigInt)
287 | (BigInt, Double)
288 | (Double, BigInt)
289 | (Float, Double)
290 | (Double, Float)
291 | (Double, Double) => Ok(Double),
292
293 _ => Err(PlannerError::InvalidOperator {
294 op: "arithmetic".to_string(),
295 type_name: format!("{} and {}", left.type_name(), right.type_name()),
296 line: span.start.line,
297 column: span.start.column,
298 }),
299 }
300 }
301
302 fn check_comparison_op(
304 &self,
305 left: &ResolvedType,
306 right: &ResolvedType,
307 span: Span,
308 ) -> Result<(), PlannerError> {
309 use ResolvedType::*;
310
311 if matches!(left, Null) || matches!(right, Null) {
313 return Ok(());
314 }
315
316 let compatible = match (left, right) {
318 (a, b) if a == b => true,
320
321 (Integer | BigInt | Float | Double, Integer | BigInt | Float | Double) => true,
323
324 (Text, Text) => true,
326
327 (Boolean, Boolean) => true,
329
330 (Timestamp, Timestamp) => true,
332
333 (Vector { dimension: d1, .. }, Vector { dimension: d2, .. }) => d1 == d2,
335
336 _ => false,
337 };
338
339 if compatible {
340 Ok(())
341 } else {
342 Err(PlannerError::TypeMismatch {
343 expected: left.type_name().to_string(),
344 found: right.type_name().to_string(),
345 line: span.start.line,
346 column: span.start.column,
347 })
348 }
349 }
350
351 fn check_logical_op(
353 &self,
354 left: &ResolvedType,
355 right: &ResolvedType,
356 span: Span,
357 ) -> Result<(), PlannerError> {
358 use ResolvedType::*;
359
360 let left_ok = matches!(left, Boolean | Null);
362 let right_ok = matches!(right, Boolean | Null);
363
364 if !left_ok {
365 return Err(PlannerError::TypeMismatch {
366 expected: "Boolean".to_string(),
367 found: left.type_name().to_string(),
368 line: span.start.line,
369 column: span.start.column,
370 });
371 }
372
373 if !right_ok {
374 return Err(PlannerError::TypeMismatch {
375 expected: "Boolean".to_string(),
376 found: right.type_name().to_string(),
377 line: span.start.line,
378 column: span.start.column,
379 });
380 }
381
382 Ok(())
383 }
384
385 fn check_string_concat_op(
387 &self,
388 left: &ResolvedType,
389 right: &ResolvedType,
390 span: Span,
391 ) -> Result<(), PlannerError> {
392 use ResolvedType::*;
393
394 let left_ok = matches!(left, Text | Null);
396 let right_ok = matches!(right, Text | Null);
397
398 if !left_ok {
399 return Err(PlannerError::TypeMismatch {
400 expected: "Text".to_string(),
401 found: left.type_name().to_string(),
402 line: span.start.line,
403 column: span.start.column,
404 });
405 }
406
407 if !right_ok {
408 return Err(PlannerError::TypeMismatch {
409 expected: "Text".to_string(),
410 found: right.type_name().to_string(),
411 line: span.start.line,
412 column: span.start.column,
413 });
414 }
415
416 Ok(())
417 }
418
419 fn infer_unary_op_type(
421 &self,
422 op: UnaryOp,
423 operand: &Expr,
424 table: &TableMetadata,
425 span: Span,
426 ) -> Result<TypedExpr, PlannerError> {
427 let operand_typed = self.infer_type(operand, table)?;
428
429 let result_type = match op {
430 UnaryOp::Not => {
431 if !matches!(
433 operand_typed.resolved_type,
434 ResolvedType::Boolean | ResolvedType::Null
435 ) {
436 return Err(PlannerError::TypeMismatch {
437 expected: "Boolean".to_string(),
438 found: operand_typed.resolved_type.type_name().to_string(),
439 line: span.start.line,
440 column: span.start.column,
441 });
442 }
443 ResolvedType::Boolean
444 }
445 UnaryOp::Minus => {
446 match &operand_typed.resolved_type {
448 ResolvedType::Integer => ResolvedType::Integer,
449 ResolvedType::BigInt => ResolvedType::BigInt,
450 ResolvedType::Float => ResolvedType::Float,
451 ResolvedType::Double => ResolvedType::Double,
452 ResolvedType::Null => ResolvedType::Null,
453 other => {
454 return Err(PlannerError::InvalidOperator {
455 op: "unary minus".to_string(),
456 type_name: other.type_name().to_string(),
457 line: span.start.line,
458 column: span.start.column,
459 });
460 }
461 }
462 }
463 };
464
465 Ok(TypedExpr {
466 kind: TypedExprKind::UnaryOp {
467 op,
468 operand: Box::new(operand_typed),
469 },
470 resolved_type: result_type,
471 span,
472 })
473 }
474
475 fn infer_function_call_type(
477 &self,
478 name: &str,
479 args: &[Expr],
480 distinct: bool,
481 star: bool,
482 table: &TableMetadata,
483 span: Span,
484 ) -> Result<TypedExpr, PlannerError> {
485 let typed_args: Vec<TypedExpr> = args
487 .iter()
488 .map(|arg| self.infer_type(arg, table))
489 .collect::<Result<Vec<_>, _>>()?;
490
491 let result_type = self.check_function_call(name, &typed_args, distinct, star, span)?;
493
494 Ok(TypedExpr {
495 kind: TypedExprKind::FunctionCall {
496 name: name.to_string(),
497 args: typed_args,
498 distinct,
499 star,
500 },
501 resolved_type: result_type,
502 span,
503 })
504 }
505
506 fn infer_between_type(
508 &self,
509 expr: &Expr,
510 low: &Expr,
511 high: &Expr,
512 negated: bool,
513 table: &TableMetadata,
514 span: Span,
515 ) -> Result<TypedExpr, PlannerError> {
516 let expr_typed = self.infer_type(expr, table)?;
517 let low_typed = self.infer_type(low, table)?;
518 let high_typed = self.infer_type(high, table)?;
519
520 self.check_comparison_op(&expr_typed.resolved_type, &low_typed.resolved_type, span)?;
522 self.check_comparison_op(&expr_typed.resolved_type, &high_typed.resolved_type, span)?;
523
524 Ok(TypedExpr {
525 kind: TypedExprKind::Between {
526 expr: Box::new(expr_typed),
527 low: Box::new(low_typed),
528 high: Box::new(high_typed),
529 negated,
530 },
531 resolved_type: ResolvedType::Boolean,
532 span,
533 })
534 }
535
536 fn infer_like_type(
538 &self,
539 expr: &Expr,
540 pattern: &Expr,
541 escape: Option<&Expr>,
542 negated: bool,
543 table: &TableMetadata,
544 span: Span,
545 ) -> Result<TypedExpr, PlannerError> {
546 let expr_typed = self.infer_type(expr, table)?;
547 let pattern_typed = self.infer_type(pattern, table)?;
548
549 if !matches!(
551 expr_typed.resolved_type,
552 ResolvedType::Text | ResolvedType::Null
553 ) {
554 return Err(PlannerError::TypeMismatch {
555 expected: "Text".to_string(),
556 found: expr_typed.resolved_type.type_name().to_string(),
557 line: expr.span.start.line,
558 column: expr.span.start.column,
559 });
560 }
561
562 if !matches!(
564 pattern_typed.resolved_type,
565 ResolvedType::Text | ResolvedType::Null
566 ) {
567 return Err(PlannerError::TypeMismatch {
568 expected: "Text".to_string(),
569 found: pattern_typed.resolved_type.type_name().to_string(),
570 line: pattern.span.start.line,
571 column: pattern.span.start.column,
572 });
573 }
574
575 let escape_typed = if let Some(esc) = escape {
576 let typed = self.infer_type(esc, table)?;
577 if !matches!(typed.resolved_type, ResolvedType::Text | ResolvedType::Null) {
578 return Err(PlannerError::TypeMismatch {
579 expected: "Text".to_string(),
580 found: typed.resolved_type.type_name().to_string(),
581 line: esc.span.start.line,
582 column: esc.span.start.column,
583 });
584 }
585 Some(Box::new(typed))
586 } else {
587 None
588 };
589
590 Ok(TypedExpr {
591 kind: TypedExprKind::Like {
592 expr: Box::new(expr_typed),
593 pattern: Box::new(pattern_typed),
594 escape: escape_typed,
595 negated,
596 },
597 resolved_type: ResolvedType::Boolean,
598 span,
599 })
600 }
601
602 fn infer_in_list_type(
604 &self,
605 expr: &Expr,
606 list: &[Expr],
607 negated: bool,
608 table: &TableMetadata,
609 span: Span,
610 ) -> Result<TypedExpr, PlannerError> {
611 let expr_typed = self.infer_type(expr, table)?;
612
613 let typed_list: Vec<TypedExpr> = list
614 .iter()
615 .map(|item| {
616 let typed = self.infer_type(item, table)?;
617 self.check_comparison_op(
619 &expr_typed.resolved_type,
620 &typed.resolved_type,
621 item.span,
622 )?;
623 Ok(typed)
624 })
625 .collect::<Result<Vec<_>, PlannerError>>()?;
626
627 Ok(TypedExpr {
628 kind: TypedExprKind::InList {
629 expr: Box::new(expr_typed),
630 list: typed_list,
631 negated,
632 },
633 resolved_type: ResolvedType::Boolean,
634 span,
635 })
636 }
637
638 fn infer_is_null_type(
640 &self,
641 expr: &Expr,
642 negated: bool,
643 table: &TableMetadata,
644 span: Span,
645 ) -> Result<TypedExpr, PlannerError> {
646 let expr_typed = self.infer_type(expr, table)?;
647
648 Ok(TypedExpr {
649 kind: TypedExprKind::IsNull {
650 expr: Box::new(expr_typed),
651 negated,
652 },
653 resolved_type: ResolvedType::Boolean,
654 span,
655 })
656 }
657
658 fn infer_vector_literal_type(
660 &self,
661 values: &[f64],
662 span: Span,
663 ) -> Result<TypedExpr, PlannerError> {
664 Ok(TypedExpr {
665 kind: TypedExprKind::VectorLiteral(values.to_vec()),
666 resolved_type: ResolvedType::Vector {
667 dimension: values.len() as u32,
668 metric: VectorMetric::Cosine, },
670 span,
671 })
672 }
673
674 pub fn normalize_metric(&self, metric: &str, span: Span) -> Result<VectorMetric, PlannerError> {
686 match metric.to_lowercase().as_str() {
687 "cosine" => Ok(VectorMetric::Cosine),
688 "l2" => Ok(VectorMetric::L2),
689 "inner" => Ok(VectorMetric::Inner),
690 _ => Err(PlannerError::InvalidMetric {
691 value: metric.to_string(),
692 line: span.start.line,
693 column: span.start.column,
694 }),
695 }
696 }
697
698 pub fn check_function_call(
703 &self,
704 name: &str,
705 args: &[TypedExpr],
706 distinct: bool,
707 star: bool,
708 span: Span,
709 ) -> Result<ResolvedType, PlannerError> {
710 let lower_name = name.to_lowercase();
711
712 match lower_name.as_str() {
713 "count" => self.check_count(args, distinct, star, span),
714 "sum" => self.check_sum(args, distinct, star, span),
715 "total" => self.check_total(args, distinct, star, span),
716 "avg" => self.check_avg(args, distinct, star, span),
717 "min" => self.check_min_max(args, distinct, star, span),
718 "max" => self.check_min_max(args, distinct, star, span),
719 "group_concat" => self.check_group_concat(args, distinct, star, span),
720 "string_agg" => self.check_string_agg(args, distinct, star, span),
721 "vector_distance" => self.check_vector_distance(args, span),
722 "vector_similarity" => self.check_vector_similarity(args, span),
723 "vector_dims" => self.check_vector_dims(args, span),
724 "vector_norm" => self.check_vector_norm(args, span),
725 _ => {
727 Err(PlannerError::UnsupportedFeature {
729 feature: format!("function '{}'", name),
730 version: "future".to_string(),
731 line: span.start.line,
732 column: span.start.column,
733 })
734 }
735 }
736 }
737
738 pub fn validate_having_expr(
739 &self,
740 expr: &TypedExpr,
741 group_keys: &[TypedExpr],
742 aggregates: &[AggregateExpr],
743 ) -> Result<(), PlannerError> {
744 use std::collections::HashSet;
745
746 let group_key_indices: HashSet<usize> = group_keys
747 .iter()
748 .filter_map(|expr| match &expr.kind {
749 TypedExprKind::ColumnRef { column_index, .. } => Some(*column_index),
750 _ => None,
751 })
752 .collect();
753
754 let aggregate_signatures: HashSet<AggregateSignature> = aggregates
755 .iter()
756 .map(aggregate_signature_from_expr)
757 .collect();
758
759 fn walk(
760 expr: &TypedExpr,
761 group_key_indices: &HashSet<usize>,
762 aggregate_signatures: &HashSet<AggregateSignature>,
763 ) -> Result<(), PlannerError> {
764 match &expr.kind {
765 TypedExprKind::ColumnRef { column_index, .. } => {
766 if group_key_indices.contains(column_index) {
767 Ok(())
768 } else {
769 Err(PlannerError::invalid_expression(
770 "column in HAVING must be in GROUP BY or be aggregated".to_string(),
771 ))
772 }
773 }
774 TypedExprKind::FunctionCall {
775 name,
776 args,
777 distinct,
778 star,
779 } if is_aggregate_name(name) => {
780 let signature = aggregate_signature_from_call(name, args, *distinct, *star)?;
781 if aggregate_signatures.contains(&signature) {
782 Ok(())
783 } else {
784 Err(PlannerError::invalid_expression(
785 "aggregate in HAVING must appear in plan".to_string(),
786 ))
787 }
788 }
789 TypedExprKind::BinaryOp { left, right, .. } => {
790 walk(left, group_key_indices, aggregate_signatures)?;
791 walk(right, group_key_indices, aggregate_signatures)
792 }
793 TypedExprKind::UnaryOp { operand, .. } => {
794 walk(operand, group_key_indices, aggregate_signatures)
795 }
796 TypedExprKind::FunctionCall { args, .. } => {
797 for arg in args {
798 walk(arg, group_key_indices, aggregate_signatures)?;
799 }
800 Ok(())
801 }
802 TypedExprKind::Between {
803 expr, low, high, ..
804 } => {
805 walk(expr, group_key_indices, aggregate_signatures)?;
806 walk(low, group_key_indices, aggregate_signatures)?;
807 walk(high, group_key_indices, aggregate_signatures)
808 }
809 TypedExprKind::Like {
810 expr,
811 pattern,
812 escape,
813 ..
814 } => {
815 walk(expr, group_key_indices, aggregate_signatures)?;
816 walk(pattern, group_key_indices, aggregate_signatures)?;
817 if let Some(esc) = escape {
818 walk(esc, group_key_indices, aggregate_signatures)?;
819 }
820 Ok(())
821 }
822 TypedExprKind::InList { expr, list, .. } => {
823 walk(expr, group_key_indices, aggregate_signatures)?;
824 for item in list {
825 walk(item, group_key_indices, aggregate_signatures)?;
826 }
827 Ok(())
828 }
829 TypedExprKind::IsNull { expr, .. } => {
830 walk(expr, group_key_indices, aggregate_signatures)
831 }
832 _ => Ok(()),
833 }
834 }
835
836 walk(expr, &group_key_indices, &aggregate_signatures)
837 }
838
839 fn check_count(
840 &self,
841 args: &[TypedExpr],
842 distinct: bool,
843 star: bool,
844 span: Span,
845 ) -> Result<ResolvedType, PlannerError> {
846 if star {
847 if distinct {
848 return Err(PlannerError::unsupported_feature(
849 "COUNT(DISTINCT *)",
850 "future",
851 span,
852 ));
853 }
854 if !args.is_empty() {
855 return Err(PlannerError::type_mismatch(
856 "no arguments with COUNT(*)",
857 format!("{} arguments", args.len()),
858 span,
859 ));
860 }
861 return Ok(ResolvedType::BigInt);
862 }
863
864 if args.len() != 1 {
865 return Err(PlannerError::type_mismatch(
866 "1 argument",
867 format!("{} arguments", args.len()),
868 span,
869 ));
870 }
871
872 if distinct {
873 return Ok(ResolvedType::BigInt);
874 }
875
876 Ok(ResolvedType::BigInt)
877 }
878
879 fn check_sum(
880 &self,
881 args: &[TypedExpr],
882 distinct: bool,
883 star: bool,
884 span: Span,
885 ) -> Result<ResolvedType, PlannerError> {
886 if star {
887 return Err(PlannerError::type_mismatch(
888 "numeric argument",
889 "COUNT(*) style",
890 span,
891 ));
892 }
893 if distinct {
894 return Err(PlannerError::unsupported_feature(
895 "SUM(DISTINCT ...)",
896 "future",
897 span,
898 ));
899 }
900 let arg = self.require_single_arg(args, span)?;
901 if !is_numeric_type(&arg.resolved_type) && arg.resolved_type != ResolvedType::Null {
902 return Err(PlannerError::type_mismatch(
903 "numeric",
904 arg.resolved_type.type_name().to_string(),
905 arg.span,
906 ));
907 }
908 Ok(ResolvedType::Double)
909 }
910
911 fn check_total(
912 &self,
913 args: &[TypedExpr],
914 distinct: bool,
915 star: bool,
916 span: Span,
917 ) -> Result<ResolvedType, PlannerError> {
918 if star {
919 return Err(PlannerError::type_mismatch(
920 "numeric argument",
921 "COUNT(*) style",
922 span,
923 ));
924 }
925 if distinct {
926 return Err(PlannerError::unsupported_feature(
927 "TOTAL(DISTINCT ...)",
928 "future",
929 span,
930 ));
931 }
932 let arg = self.require_single_arg(args, span)?;
933 if !is_numeric_type(&arg.resolved_type) && arg.resolved_type != ResolvedType::Null {
934 return Err(PlannerError::type_mismatch(
935 "numeric",
936 arg.resolved_type.type_name().to_string(),
937 arg.span,
938 ));
939 }
940 Ok(ResolvedType::Double)
941 }
942
943 fn check_avg(
944 &self,
945 args: &[TypedExpr],
946 distinct: bool,
947 star: bool,
948 span: Span,
949 ) -> Result<ResolvedType, PlannerError> {
950 if star {
951 return Err(PlannerError::type_mismatch(
952 "numeric argument",
953 "COUNT(*) style",
954 span,
955 ));
956 }
957 if distinct {
958 return Err(PlannerError::unsupported_feature(
959 "AVG(DISTINCT ...)",
960 "future",
961 span,
962 ));
963 }
964 let arg = self.require_single_arg(args, span)?;
965 if !is_numeric_type(&arg.resolved_type) && arg.resolved_type != ResolvedType::Null {
966 return Err(PlannerError::type_mismatch(
967 "numeric",
968 arg.resolved_type.type_name().to_string(),
969 arg.span,
970 ));
971 }
972 Ok(ResolvedType::Double)
973 }
974
975 fn check_min_max(
976 &self,
977 args: &[TypedExpr],
978 distinct: bool,
979 star: bool,
980 span: Span,
981 ) -> Result<ResolvedType, PlannerError> {
982 if star {
983 return Err(PlannerError::type_mismatch(
984 "argument",
985 "COUNT(*) style",
986 span,
987 ));
988 }
989 if distinct {
990 return Err(PlannerError::unsupported_feature(
991 "MIN/MAX(DISTINCT ...)",
992 "future",
993 span,
994 ));
995 }
996 let arg = self.require_single_arg(args, span)?;
997 if matches!(arg.resolved_type, ResolvedType::Vector { .. }) {
998 return Err(PlannerError::type_mismatch(
999 "comparable",
1000 arg.resolved_type.type_name().to_string(),
1001 arg.span,
1002 ));
1003 }
1004 Ok(arg.resolved_type.clone())
1005 }
1006
1007 fn check_group_concat(
1008 &self,
1009 args: &[TypedExpr],
1010 distinct: bool,
1011 star: bool,
1012 span: Span,
1013 ) -> Result<ResolvedType, PlannerError> {
1014 if star {
1015 return Err(PlannerError::type_mismatch(
1016 "text argument",
1017 "COUNT(*) style",
1018 span,
1019 ));
1020 }
1021 if distinct {
1022 return Err(PlannerError::unsupported_feature(
1023 "GROUP_CONCAT(DISTINCT ...)",
1024 "future",
1025 span,
1026 ));
1027 }
1028 if args.is_empty() || args.len() > 2 {
1029 return Err(PlannerError::type_mismatch(
1030 "1 or 2 arguments",
1031 format!("{} arguments", args.len()),
1032 span,
1033 ));
1034 }
1035 if !matches!(
1036 args[0].resolved_type,
1037 ResolvedType::Text | ResolvedType::Null
1038 ) {
1039 return Err(PlannerError::type_mismatch(
1040 "Text",
1041 args[0].resolved_type.type_name().to_string(),
1042 args[0].span,
1043 ));
1044 }
1045 if args.len() == 2
1046 && !matches!(
1047 args[1].resolved_type,
1048 ResolvedType::Text | ResolvedType::Null
1049 )
1050 {
1051 return Err(PlannerError::type_mismatch(
1052 "Text",
1053 args[1].resolved_type.type_name().to_string(),
1054 args[1].span,
1055 ));
1056 }
1057 Ok(ResolvedType::Text)
1058 }
1059
1060 fn check_string_agg(
1061 &self,
1062 args: &[TypedExpr],
1063 distinct: bool,
1064 star: bool,
1065 span: Span,
1066 ) -> Result<ResolvedType, PlannerError> {
1067 if star {
1068 return Err(PlannerError::type_mismatch(
1069 "text argument",
1070 "COUNT(*) style",
1071 span,
1072 ));
1073 }
1074 if distinct {
1075 return Err(PlannerError::unsupported_feature(
1076 "STRING_AGG(DISTINCT ...)",
1077 "future",
1078 span,
1079 ));
1080 }
1081 if args.len() != 2 {
1082 return Err(PlannerError::type_mismatch(
1083 "2 arguments",
1084 format!("{} arguments", args.len()),
1085 span,
1086 ));
1087 }
1088 if !matches!(
1089 args[0].resolved_type,
1090 ResolvedType::Text | ResolvedType::Null
1091 ) {
1092 return Err(PlannerError::type_mismatch(
1093 "Text",
1094 args[0].resolved_type.type_name().to_string(),
1095 args[0].span,
1096 ));
1097 }
1098 if !matches!(
1099 args[1].resolved_type,
1100 ResolvedType::Text | ResolvedType::Null
1101 ) {
1102 return Err(PlannerError::type_mismatch(
1103 "Text",
1104 args[1].resolved_type.type_name().to_string(),
1105 args[1].span,
1106 ));
1107 }
1108 Ok(ResolvedType::Text)
1109 }
1110
1111 fn check_vector_dims(
1112 &self,
1113 args: &[TypedExpr],
1114 span: Span,
1115 ) -> Result<ResolvedType, PlannerError> {
1116 let arg = self.require_single_arg(args, span)?;
1117 if !matches!(
1118 arg.resolved_type,
1119 ResolvedType::Vector { .. } | ResolvedType::Null
1120 ) {
1121 return Err(PlannerError::type_mismatch(
1122 "Vector",
1123 arg.resolved_type.type_name().to_string(),
1124 arg.span,
1125 ));
1126 }
1127 Ok(ResolvedType::Integer)
1128 }
1129
1130 fn check_vector_norm(
1131 &self,
1132 args: &[TypedExpr],
1133 span: Span,
1134 ) -> Result<ResolvedType, PlannerError> {
1135 let arg = self.require_single_arg(args, span)?;
1136 if !matches!(
1137 arg.resolved_type,
1138 ResolvedType::Vector { .. } | ResolvedType::Null
1139 ) {
1140 return Err(PlannerError::type_mismatch(
1141 "Vector",
1142 arg.resolved_type.type_name().to_string(),
1143 arg.span,
1144 ));
1145 }
1146 Ok(ResolvedType::Double)
1147 }
1148
1149 fn require_single_arg<'b>(
1150 &self,
1151 args: &'b [TypedExpr],
1152 span: Span,
1153 ) -> Result<&'b TypedExpr, PlannerError> {
1154 if args.len() != 1 {
1155 return Err(PlannerError::type_mismatch(
1156 "1 argument",
1157 format!("{} arguments", args.len()),
1158 span,
1159 ));
1160 }
1161 Ok(&args[0])
1162 }
1163
1164 pub fn check_vector_distance(
1175 &self,
1176 args: &[TypedExpr],
1177 span: Span,
1178 ) -> Result<ResolvedType, PlannerError> {
1179 if args.len() != 3 {
1180 return Err(PlannerError::TypeMismatch {
1181 expected: "3 arguments".to_string(),
1182 found: format!("{} arguments", args.len()),
1183 line: span.start.line,
1184 column: span.start.column,
1185 });
1186 }
1187
1188 let col_dim = match &args[0].resolved_type {
1190 ResolvedType::Vector { dimension, .. } => *dimension,
1191 other => {
1192 return Err(PlannerError::TypeMismatch {
1193 expected: "Vector".to_string(),
1194 found: other.type_name().to_string(),
1195 line: args[0].span.start.line,
1196 column: args[0].span.start.column,
1197 });
1198 }
1199 };
1200
1201 let vec_dim = match &args[1].resolved_type {
1203 ResolvedType::Vector { dimension, .. } => *dimension,
1204 other => {
1205 return Err(PlannerError::TypeMismatch {
1206 expected: "Vector".to_string(),
1207 found: other.type_name().to_string(),
1208 line: args[1].span.start.line,
1209 column: args[1].span.start.column,
1210 });
1211 }
1212 };
1213
1214 self.check_vector_dimension(col_dim, vec_dim, args[1].span)?;
1216
1217 match &args[2].resolved_type {
1219 ResolvedType::Text => {
1220 if let TypedExprKind::Literal(Literal::String(s)) = &args[2].kind {
1222 self.normalize_metric(s, args[2].span)?;
1223 }
1224 }
1225 ResolvedType::Null => {
1226 return Err(PlannerError::TypeMismatch {
1228 expected: "Text (metric)".to_string(),
1229 found: "Null".to_string(),
1230 line: args[2].span.start.line,
1231 column: args[2].span.start.column,
1232 });
1233 }
1234 other => {
1235 return Err(PlannerError::TypeMismatch {
1236 expected: "Text (metric)".to_string(),
1237 found: other.type_name().to_string(),
1238 line: args[2].span.start.line,
1239 column: args[2].span.start.column,
1240 });
1241 }
1242 }
1243
1244 Ok(ResolvedType::Double)
1245 }
1246
1247 pub fn check_vector_similarity(
1253 &self,
1254 args: &[TypedExpr],
1255 span: Span,
1256 ) -> Result<ResolvedType, PlannerError> {
1257 self.check_vector_distance(args, span)
1259 }
1260
1261 pub fn check_vector_dimension(
1267 &self,
1268 expected: u32,
1269 found: u32,
1270 span: Span,
1271 ) -> Result<(), PlannerError> {
1272 if expected != found {
1273 Err(PlannerError::VectorDimensionMismatch {
1274 expected,
1275 found,
1276 line: span.start.line,
1277 column: span.start.column,
1278 })
1279 } else {
1280 Ok(())
1281 }
1282 }
1283
1284 pub fn check_insert_values(
1307 &self,
1308 table: &TableMetadata,
1309 columns: &[String],
1310 values: &[Vec<Expr>],
1311 span: Span,
1312 ) -> Result<Vec<Vec<TypedExpr>>, PlannerError> {
1313 let target_columns: Vec<&str> = if columns.is_empty() {
1315 table.column_names()
1316 } else {
1317 columns.iter().map(|s| s.as_str()).collect()
1318 };
1319
1320 let mut typed_rows = Vec::with_capacity(values.len());
1321
1322 for row in values {
1323 if row.len() != target_columns.len() {
1325 return Err(PlannerError::ColumnValueCountMismatch {
1326 columns: target_columns.len(),
1327 values: row.len(),
1328 line: span.start.line,
1329 column: span.start.column,
1330 });
1331 }
1332
1333 let mut typed_values = Vec::with_capacity(row.len());
1334
1335 for (value, col_name) in row.iter().zip(target_columns.iter()) {
1336 let col_meta =
1338 table
1339 .get_column(col_name)
1340 .ok_or_else(|| PlannerError::ColumnNotFound {
1341 column: col_name.to_string(),
1342 table: table.name.clone(),
1343 line: span.start.line,
1344 col: span.start.column,
1345 })?;
1346
1347 let typed_value = self.infer_type(value, table)?;
1349
1350 self.check_null_constraint(col_meta, &typed_value, value.span)?;
1352
1353 self.check_type_compatibility(
1355 &col_meta.data_type,
1356 &typed_value.resolved_type,
1357 value.span,
1358 )?;
1359
1360 if let (
1362 ResolvedType::Vector {
1363 dimension: expected_dim,
1364 ..
1365 },
1366 ResolvedType::Vector {
1367 dimension: actual_dim,
1368 ..
1369 },
1370 ) = (&col_meta.data_type, &typed_value.resolved_type)
1371 {
1372 self.check_vector_dimension(*expected_dim, *actual_dim, value.span)?;
1373 }
1374
1375 typed_values.push(typed_value);
1376 }
1377
1378 typed_rows.push(typed_values);
1379 }
1380
1381 Ok(typed_rows)
1382 }
1383
1384 pub fn check_assignment(
1395 &self,
1396 table: &TableMetadata,
1397 column: &str,
1398 value: &Expr,
1399 span: Span,
1400 ) -> Result<TypedExpr, PlannerError> {
1401 let col_meta = table
1403 .get_column(column)
1404 .ok_or_else(|| PlannerError::ColumnNotFound {
1405 column: column.to_string(),
1406 table: table.name.clone(),
1407 line: span.start.line,
1408 col: span.start.column,
1409 })?;
1410
1411 let typed_value = self.infer_type(value, table)?;
1413
1414 self.check_null_constraint(col_meta, &typed_value, value.span)?;
1416
1417 self.check_type_compatibility(&col_meta.data_type, &typed_value.resolved_type, value.span)?;
1419
1420 if let (
1422 ResolvedType::Vector {
1423 dimension: expected_dim,
1424 ..
1425 },
1426 ResolvedType::Vector {
1427 dimension: actual_dim,
1428 ..
1429 },
1430 ) = (&col_meta.data_type, &typed_value.resolved_type)
1431 {
1432 self.check_vector_dimension(*expected_dim, *actual_dim, value.span)?;
1433 }
1434
1435 Ok(typed_value)
1436 }
1437
1438 pub fn check_null_constraint(
1445 &self,
1446 column: &crate::catalog::ColumnMetadata,
1447 value: &TypedExpr,
1448 span: Span,
1449 ) -> Result<(), PlannerError> {
1450 if column.not_null && matches!(value.resolved_type, ResolvedType::Null) {
1451 Err(PlannerError::NullConstraintViolation {
1452 column: column.name.clone(),
1453 line: span.start.line,
1454 col: span.start.column,
1455 })
1456 } else {
1457 Ok(())
1458 }
1459 }
1460
1461 fn check_type_compatibility(
1469 &self,
1470 expected: &ResolvedType,
1471 actual: &ResolvedType,
1472 span: Span,
1473 ) -> Result<(), PlannerError> {
1474 if expected == actual {
1476 return Ok(());
1477 }
1478
1479 if actual.can_cast_to(expected) {
1481 return Ok(());
1482 }
1483
1484 if let (
1487 ResolvedType::Vector {
1488 dimension: d1,
1489 metric: _,
1490 },
1491 ResolvedType::Vector {
1492 dimension: d2,
1493 metric: _,
1494 },
1495 ) = (expected, actual)
1496 {
1497 if *d1 == *d2 {
1499 return Ok(());
1500 }
1501 }
1503
1504 Err(PlannerError::TypeMismatch {
1505 expected: expected.type_name().to_string(),
1506 found: actual.type_name().to_string(),
1507 line: span.start.line,
1508 column: span.start.column,
1509 })
1510 }
1511}
1512
1513fn is_numeric_type(ty: &ResolvedType) -> bool {
1514 matches!(
1515 ty,
1516 ResolvedType::Integer | ResolvedType::BigInt | ResolvedType::Float | ResolvedType::Double
1517 )
1518}
1519
1520#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1521struct AggregateSignature {
1522 name: String,
1523 distinct: bool,
1524 star: bool,
1525 arg_key: Option<String>,
1526 separator: Option<String>,
1527}
1528
1529fn is_aggregate_name(name: &str) -> bool {
1530 matches!(
1531 name.to_ascii_lowercase().as_str(),
1532 "count" | "sum" | "total" | "avg" | "min" | "max" | "group_concat" | "string_agg"
1533 )
1534}
1535
1536fn aggregate_signature_from_expr(expr: &AggregateExpr) -> AggregateSignature {
1537 let (name, separator, star, arg) = match &expr.function {
1538 AggregateFunction::Count => (
1539 "count".to_string(),
1540 None,
1541 expr.arg.is_none(),
1542 expr.arg.as_ref(),
1543 ),
1544 AggregateFunction::Sum => ("sum".to_string(), None, false, expr.arg.as_ref()),
1545 AggregateFunction::Total => ("total".to_string(), None, false, expr.arg.as_ref()),
1546 AggregateFunction::Avg => ("avg".to_string(), None, false, expr.arg.as_ref()),
1547 AggregateFunction::Min => ("min".to_string(), None, false, expr.arg.as_ref()),
1548 AggregateFunction::Max => ("max".to_string(), None, false, expr.arg.as_ref()),
1549 AggregateFunction::GroupConcat { separator } => (
1550 "group_concat".to_string(),
1551 separator.clone(),
1552 false,
1553 expr.arg.as_ref(),
1554 ),
1555 AggregateFunction::StringAgg { separator } => (
1556 "string_agg".to_string(),
1557 separator.clone(),
1558 false,
1559 expr.arg.as_ref(),
1560 ),
1561 };
1562 AggregateSignature {
1563 name,
1564 distinct: expr.distinct,
1565 star,
1566 arg_key: arg.map(typed_expr_signature),
1567 separator,
1568 }
1569}
1570
1571fn aggregate_signature_from_call(
1572 name: &str,
1573 args: &[TypedExpr],
1574 distinct: bool,
1575 star: bool,
1576) -> Result<AggregateSignature, PlannerError> {
1577 let separator = if name.eq_ignore_ascii_case("group_concat") && args.len() == 2 {
1578 if let TypedExprKind::Literal(Literal::String(value)) = &args[1].kind {
1579 Some(value.clone())
1580 } else {
1581 return Err(PlannerError::invalid_expression(
1582 "GROUP_CONCAT separator must be a string literal".to_string(),
1583 ));
1584 }
1585 } else if name.eq_ignore_ascii_case("string_agg") && args.len() == 2 {
1586 if let TypedExprKind::Literal(Literal::String(value)) = &args[1].kind {
1587 Some(value.clone())
1588 } else {
1589 return Err(PlannerError::invalid_expression(
1590 "STRING_AGG separator must be a string literal".to_string(),
1591 ));
1592 }
1593 } else {
1594 None
1595 };
1596 Ok(AggregateSignature {
1597 name: name.to_ascii_lowercase(),
1598 distinct,
1599 star,
1600 arg_key: args.first().map(typed_expr_signature),
1601 separator,
1602 })
1603}
1604
1605fn typed_expr_signature(expr: &TypedExpr) -> String {
1606 format!("{:?}", expr.kind)
1607}
1608
1609#[cfg(test)]
1611#[path = "type_checker/tests.rs"]
1612mod tests;