1use crate::ast::{BinaryOperator, Expr, Literal, TemporalKeyword, UnaryOperator, Visitor};
7use std::fmt;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum InferredType {
12 Integer,
14
15 Float,
17
18 String,
20
21 Boolean,
23
24 Null,
26
27 Array(Box<InferredType>),
29
30 Object,
32
33 Date,
35
36 DateTime,
38
39 Duration,
41
42 Unknown,
44
45 Numeric,
47
48 Error(String),
50}
51
52impl fmt::Display for InferredType {
53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 match self {
55 Self::Integer => write!(f, "integer"),
56 Self::Float => write!(f, "float"),
57 Self::String => write!(f, "string"),
58 Self::Boolean => write!(f, "boolean"),
59 Self::Null => write!(f, "null"),
60 Self::Array(elem_type) => write!(f, "[{}]", elem_type),
61 Self::Object => write!(f, "object"),
62 Self::Date => write!(f, "date"),
63 Self::DateTime => write!(f, "datetime"),
64 Self::Duration => write!(f, "duration"),
65 Self::Unknown => write!(f, "unknown"),
66 Self::Numeric => write!(f, "number"),
67 Self::Error(msg) => write!(f, "error({})", msg),
68 }
69 }
70}
71
72impl InferredType {
73 pub fn is_numeric(&self) -> bool {
75 matches!(self, Self::Integer | Self::Float | Self::Numeric)
76 }
77
78 pub fn is_scalar(&self) -> bool {
80 matches!(
81 self,
82 Self::Integer | Self::Float | Self::String | Self::Boolean
83 )
84 }
85
86 pub fn is_error(&self) -> bool {
88 matches!(self, Self::Error(_))
89 }
90
91 pub fn common_type(a: &InferredType, b: &InferredType) -> InferredType {
93 match (a, b) {
94 (t1, t2) if t1 == t2 => t1.clone(),
96
97 (InferredType::Integer, InferredType::Float)
99 | (InferredType::Float, InferredType::Integer) => InferredType::Float,
100 (InferredType::Numeric, other) | (other, InferredType::Numeric) => {
101 if other.is_numeric() {
102 other.clone()
103 } else {
104 InferredType::Error(format!(
105 "Type mismatch: cannot unify numeric and {}",
106 other
107 ))
108 }
109 }
110
111 (InferredType::Unknown, t) | (t, InferredType::Unknown) => t.clone(),
113
114 (InferredType::Array(a), InferredType::Array(b)) => {
116 let elem_type = Self::common_type(a, b);
117 InferredType::Array(Box::new(elem_type))
118 }
119
120 (a, b) => InferredType::Error(format!("Type mismatch: cannot unify {} and {}", a, b)),
122 }
123 }
124}
125
126#[derive(Debug)]
131pub struct TypeInferenceVisitor;
132
133impl TypeInferenceVisitor {
134 pub fn new() -> Self {
136 TypeInferenceVisitor
137 }
138
139 pub fn infer(&self, expr: &Expr) -> InferredType {
141 Self::infer_expr(expr)
142 }
143
144 fn infer_expr(expr: &Expr) -> InferredType {
146 match expr {
147 Expr::Literal(lit) => match lit {
148 Literal::Integer(_) => InferredType::Integer,
149 Literal::Float(_) => InferredType::Float,
150 Literal::Boolean(_) => InferredType::Boolean,
151 },
152 Expr::Null => InferredType::Null,
153 Expr::Identifier(_) => InferredType::Unknown,
154 Expr::String(_) => InferredType::String,
155 Expr::FieldAccess { .. } => InferredType::Unknown,
156 Expr::BinaryOp { op, left, right } => Self::infer_binary_op(*op, left, right),
157 Expr::UnaryOp { op, operand } => Self::infer_unary_op(*op, operand),
158 Expr::FunctionCall { name, args } => Self::infer_function_call(name, args),
159 Expr::Lambda { .. } => InferredType::Unknown,
160 Expr::Let { body, .. } => Self::infer_expr(body),
161 Expr::If {
162 then_branch,
163 else_branch,
164 ..
165 } => {
166 let then_type = Self::infer_expr(then_branch);
167 let else_type = Self::infer_expr(else_branch);
168 InferredType::common_type(&then_type, &else_type)
169 }
170 Expr::Array(elements) => {
171 if elements.is_empty() {
172 InferredType::Array(Box::new(InferredType::Unknown))
173 } else {
174 let first_type = Self::infer_expr(&elements[0]);
175 let mut common = first_type;
176 for elem in &elements[1..] {
177 let elem_type = Self::infer_expr(elem);
178 common = InferredType::common_type(&common, &elem_type);
179 if common.is_error() {
180 break;
181 }
182 }
183 InferredType::Array(Box::new(common))
184 }
185 }
186 Expr::Object(_) => InferredType::Object,
187 Expr::Pipe { functions, .. } => {
188 if functions.is_empty() {
189 InferredType::Unknown
190 } else {
191 Self::infer_expr(functions.last().unwrap())
192 }
193 }
194 Expr::Alternative {
195 primary,
196 alternative,
197 } => {
198 let primary_type = Self::infer_expr(primary);
199 let alt_type = Self::infer_expr(alternative);
200 InferredType::common_type(&primary_type, &alt_type)
201 }
202 Expr::Guard { body, .. } => Self::infer_expr(body),
203 Expr::Date(_) => InferredType::Date,
204 Expr::DateTime(_) => InferredType::DateTime,
205 Expr::Duration(_) => InferredType::Duration,
206 Expr::TemporalKeyword(keyword) => match keyword {
207 TemporalKeyword::Now => InferredType::DateTime,
208 TemporalKeyword::Today | TemporalKeyword::Tomorrow | TemporalKeyword::Yesterday => {
209 InferredType::Date
210 }
211 _ => InferredType::Date, },
213 }
214 }
215
216 fn infer_binary_op(op: BinaryOperator, left: &Expr, right: &Expr) -> InferredType {
217 let left_type = Self::infer_expr(left);
218 let right_type = Self::infer_expr(right);
219
220 match op {
221 BinaryOperator::Add => match (&left_type, &right_type) {
222 (InferredType::Integer, InferredType::Integer) => InferredType::Integer,
223 (InferredType::Float, InferredType::Float) => InferredType::Float,
224 (InferredType::Integer, InferredType::Float)
225 | (InferredType::Float, InferredType::Integer) => InferredType::Float,
226 (InferredType::String, InferredType::String) => InferredType::String,
227 (InferredType::Date, InferredType::Duration)
229 | (InferredType::Duration, InferredType::Date) => InferredType::Date,
230 (InferredType::DateTime, InferredType::Duration)
231 | (InferredType::Duration, InferredType::DateTime) => InferredType::DateTime,
232 (InferredType::Duration, InferredType::Duration) => InferredType::Duration,
233 (InferredType::Unknown, t) | (t, InferredType::Unknown) => t.clone(),
235 _ => InferredType::Error(format!("Cannot add {} and {}", left_type, right_type)),
236 },
237 BinaryOperator::Sub => match (&left_type, &right_type) {
238 (InferredType::Integer, InferredType::Integer) => InferredType::Integer,
239 (InferredType::Float, InferredType::Float) => InferredType::Float,
240 (InferredType::Integer, InferredType::Float)
241 | (InferredType::Float, InferredType::Integer) => InferredType::Float,
242 (InferredType::Date, InferredType::Duration) => InferredType::Date,
244 (InferredType::DateTime, InferredType::Duration) => InferredType::DateTime,
245 (InferredType::Date, InferredType::Date) => InferredType::Duration,
246 (InferredType::DateTime, InferredType::DateTime) => InferredType::Duration,
247 (InferredType::Duration, InferredType::Duration) => InferredType::Duration,
248 (InferredType::Unknown, t) | (t, InferredType::Unknown) => {
250 if t.is_numeric() {
251 t.clone()
252 } else {
253 InferredType::Error(format!(
254 "Cannot apply arithmetic to {} and {}",
255 left_type, right_type
256 ))
257 }
258 }
259 _ => InferredType::Error(format!(
260 "Cannot apply arithmetic to {} and {}",
261 left_type, right_type
262 )),
263 },
264 BinaryOperator::Mul | BinaryOperator::Div => {
265 match (&left_type, &right_type) {
266 (InferredType::Integer, InferredType::Integer) => InferredType::Integer,
267 (InferredType::Float, InferredType::Float) => InferredType::Float,
268 (InferredType::Integer, InferredType::Float)
269 | (InferredType::Float, InferredType::Integer) => InferredType::Float,
270 (InferredType::Unknown, t) | (t, InferredType::Unknown) => {
272 if t.is_numeric() {
273 t.clone()
274 } else {
275 InferredType::Error(format!(
276 "Cannot apply arithmetic to {} and {}",
277 left_type, right_type
278 ))
279 }
280 }
281 _ => InferredType::Error(format!(
282 "Cannot apply arithmetic to {} and {}",
283 left_type, right_type
284 )),
285 }
286 }
287 BinaryOperator::Mod | BinaryOperator::Pow => {
288 if left_type.is_numeric() && right_type.is_numeric() {
289 InferredType::Integer
290 } else if left_type == InferredType::Unknown || right_type == InferredType::Unknown
291 {
292 InferredType::Integer
294 } else {
295 InferredType::Error(format!(
296 "Cannot apply operator to {} and {}",
297 left_type, right_type
298 ))
299 }
300 }
301 BinaryOperator::Eq
302 | BinaryOperator::Neq
303 | BinaryOperator::Lt
304 | BinaryOperator::Lte
305 | BinaryOperator::Gt
306 | BinaryOperator::Gte => InferredType::Boolean,
307 BinaryOperator::And | BinaryOperator::Or => InferredType::Boolean,
308 }
309 }
310
311 fn infer_unary_op(op: UnaryOperator, operand: &Expr) -> InferredType {
312 let operand_type = Self::infer_expr(operand);
313 match op {
314 UnaryOperator::Not => InferredType::Boolean,
315 UnaryOperator::Neg | UnaryOperator::Plus => operand_type,
316 }
317 }
318
319 fn infer_function_call(name: &str, args: &[Expr]) -> InferredType {
320 match name {
321 "length" | "uppercase" | "lowercase" | "trim" | "contains" | "starts_with"
322 | "ends_with" => InferredType::String,
323 "map" | "filter" | "sort" => InferredType::Array(Box::new(InferredType::Unknown)),
324 "abs" | "min" | "max" | "round" | "floor" | "ceil" => {
325 if args.is_empty() {
326 InferredType::Unknown
327 } else {
328 let arg_type = Self::infer_expr(&args[0]);
329 if arg_type.is_numeric() {
330 arg_type
331 } else {
332 InferredType::Error(format!("Expected numeric argument, got {}", arg_type))
333 }
334 }
335 }
336 "all" | "any" => InferredType::Boolean,
337 _ => InferredType::Unknown,
338 }
339 }
340}
341
342impl Default for TypeInferenceVisitor {
343 fn default() -> Self {
344 Self::new()
345 }
346}
347
348impl Visitor<InferredType> for TypeInferenceVisitor {
349 fn visit_expr(&mut self, expr: &Expr) -> InferredType {
350 Self::infer_expr(expr)
351 }
352
353 fn visit_literal(&mut self, lit: &Literal) -> InferredType {
354 match lit {
355 Literal::Integer(_) => InferredType::Integer,
356 Literal::Float(_) => InferredType::Float,
357 Literal::Boolean(_) => InferredType::Boolean,
358 }
359 }
360
361 fn visit_null(&mut self) -> InferredType {
362 InferredType::Null
363 }
364
365 fn visit_identifier(&mut self, _name: &str) -> InferredType {
366 InferredType::Unknown
367 }
368
369 fn visit_field_access(&mut self, _receiver: &Expr, _field: &str) -> InferredType {
370 InferredType::Unknown
371 }
372
373 fn visit_binary_op(&mut self, op: BinaryOperator, left: &Expr, right: &Expr) -> InferredType {
374 Self::infer_binary_op(op, left, right)
375 }
376
377 fn visit_unary_op(&mut self, op: UnaryOperator, operand: &Expr) -> InferredType {
378 Self::infer_unary_op(op, operand)
379 }
380
381 fn visit_function_call(&mut self, name: &str, args: &[Expr]) -> InferredType {
382 Self::infer_function_call(name, args)
383 }
384
385 fn visit_lambda(&mut self, _param: &str, _body: &Expr) -> InferredType {
386 InferredType::Unknown
387 }
388
389 fn visit_let(&mut self, _name: &str, _value: &Expr, body: &Expr) -> InferredType {
390 Self::infer_expr(body)
391 }
392
393 fn visit_if(
394 &mut self,
395 _condition: &Expr,
396 then_branch: &Expr,
397 else_branch: &Expr,
398 ) -> InferredType {
399 let then_type = Self::infer_expr(then_branch);
400 let else_type = Self::infer_expr(else_branch);
401 InferredType::common_type(&then_type, &else_type)
402 }
403
404 fn visit_array(&mut self, elements: &[Expr]) -> InferredType {
405 if elements.is_empty() {
406 InferredType::Array(Box::new(InferredType::Unknown))
407 } else {
408 let first_type = Self::infer_expr(&elements[0]);
409 let mut common = first_type;
410 for elem in &elements[1..] {
411 let elem_type = Self::infer_expr(elem);
412 common = InferredType::common_type(&common, &elem_type);
413 if common.is_error() {
414 break;
415 }
416 }
417 InferredType::Array(Box::new(common))
418 }
419 }
420
421 fn visit_object(&mut self, _fields: &[(String, Expr)]) -> InferredType {
422 InferredType::Object
423 }
424
425 fn visit_pipe(&mut self, value: &Expr, functions: &[Expr]) -> InferredType {
426 if functions.is_empty() {
427 Self::infer_expr(value)
428 } else {
429 Self::infer_expr(functions.last().unwrap())
430 }
431 }
432
433 fn visit_alternative(&mut self, primary: &Expr, alternative: &Expr) -> InferredType {
434 let primary_type = Self::infer_expr(primary);
435 let alt_type = Self::infer_expr(alternative);
436 InferredType::common_type(&primary_type, &alt_type)
437 }
438
439 fn visit_guard(&mut self, _condition: &Expr, body: &Expr) -> InferredType {
440 Self::infer_expr(body)
441 }
442
443 fn visit_date(&mut self, _date: &str) -> InferredType {
444 InferredType::String
445 }
446
447 fn visit_datetime(&mut self, _datetime: &str) -> InferredType {
448 InferredType::String
449 }
450
451 fn visit_duration(&mut self, _duration: &str) -> InferredType {
452 InferredType::String
453 }
454
455 fn visit_temporal_keyword(&mut self, _keyword: TemporalKeyword) -> InferredType {
456 InferredType::String
457 }
458
459 fn visit_string(&mut self, _value: &str) -> InferredType {
460 InferredType::String
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467 use crate::parser::Parser;
468
469 #[test]
470 fn test_infer_integer_literal() {
471 let expr = Parser::parse("42").unwrap();
472 let ty = TypeInferenceVisitor::infer_expr(&expr);
473 assert_eq!(ty, InferredType::Integer);
474 }
475
476 #[test]
477 fn test_infer_float_literal() {
478 let expr = Parser::parse("3.14").unwrap();
479 let ty = TypeInferenceVisitor::infer_expr(&expr);
480 assert_eq!(ty, InferredType::Float);
481 }
482
483 #[test]
484 fn test_infer_string_literal() {
485 let expr = Parser::parse("'hello'").unwrap();
486 let ty = TypeInferenceVisitor::infer_expr(&expr);
487 assert_eq!(ty, InferredType::String);
488 }
489
490 #[test]
491 fn test_infer_boolean_literal() {
492 let expr = Parser::parse("true").unwrap();
493 let ty = TypeInferenceVisitor::infer_expr(&expr);
494 assert_eq!(ty, InferredType::Boolean);
495 }
496
497 #[test]
498 fn test_infer_null_literal() {
499 let expr = Parser::parse("null").unwrap();
500 let ty = TypeInferenceVisitor::infer_expr(&expr);
501 assert_eq!(ty, InferredType::Null);
502 }
503
504 #[test]
505 fn test_infer_integer_addition() {
506 let expr = Parser::parse("1 + 2").unwrap();
507 let ty = TypeInferenceVisitor::infer_expr(&expr);
508 assert_eq!(ty, InferredType::Integer);
509 }
510
511 #[test]
512 fn test_infer_float_arithmetic() {
513 let expr = Parser::parse("3.0 + 2.0").unwrap();
514 let ty = TypeInferenceVisitor::infer_expr(&expr);
515 assert_eq!(ty, InferredType::Float);
516 }
517
518 #[test]
519 fn test_infer_mixed_numeric() {
520 let expr = Parser::parse("1 + 2.0").unwrap();
521 let ty = TypeInferenceVisitor::infer_expr(&expr);
522 assert_eq!(ty, InferredType::Float);
523 }
524
525 #[test]
526 fn test_infer_comparison() {
527 let expr = Parser::parse("5 > 3").unwrap();
528 let ty = TypeInferenceVisitor::infer_expr(&expr);
529 assert_eq!(ty, InferredType::Boolean);
530 }
531
532 #[test]
533 fn test_infer_logical_and() {
534 let expr = Parser::parse("true && false").unwrap();
535 let ty = TypeInferenceVisitor::infer_expr(&expr);
536 assert_eq!(ty, InferredType::Boolean);
537 }
538
539 #[test]
540 fn test_infer_array_integers() {
541 let expr = Parser::parse("[1, 2, 3]").unwrap();
542 let ty = TypeInferenceVisitor::infer_expr(&expr);
543 assert_eq!(ty, InferredType::Array(Box::new(InferredType::Integer)));
544 }
545
546 #[test]
547 fn test_infer_array_mixed_numeric() {
548 let expr = Parser::parse("[1, 2.0, 3]").unwrap();
549 let ty = TypeInferenceVisitor::infer_expr(&expr);
550 assert_eq!(ty, InferredType::Array(Box::new(InferredType::Float)));
551 }
552
553 #[test]
554 fn test_infer_empty_array() {
555 let expr = Parser::parse("[]").unwrap();
556 let ty = TypeInferenceVisitor::infer_expr(&expr);
557 assert_eq!(ty, InferredType::Array(Box::new(InferredType::Unknown)));
558 }
559
560 #[test]
561 fn test_infer_if_same_types() {
562 let expr = Parser::parse("if true then 1 else 2").unwrap();
563 let ty = TypeInferenceVisitor::infer_expr(&expr);
564 assert_eq!(ty, InferredType::Integer);
565 }
566
567 #[test]
568 fn test_infer_if_different_numeric_types() {
569 let expr = Parser::parse("if true then 1 else 2.0").unwrap();
570 let ty = TypeInferenceVisitor::infer_expr(&expr);
571 assert_eq!(ty, InferredType::Float);
572 }
573
574 #[test]
575 fn test_infer_let_expression() {
576 let expr = Parser::parse("let x = 5 in x + 3").unwrap();
577 let ty = TypeInferenceVisitor::infer_expr(&expr);
578 assert_eq!(ty, InferredType::Integer);
579 }
580
581 #[test]
582 fn test_infer_unary_not() {
583 let expr = Parser::parse("!true").unwrap();
584 let ty = TypeInferenceVisitor::infer_expr(&expr);
585 assert_eq!(ty, InferredType::Boolean);
586 }
587
588 #[test]
589 fn test_infer_string_concat() {
590 let expr = Parser::parse("'hello' + ' world'").unwrap();
591 let ty = TypeInferenceVisitor::infer_expr(&expr);
592 assert_eq!(ty, InferredType::String);
593 }
594
595 #[test]
596 fn test_type_common_type_same() {
597 let t1 = InferredType::Integer;
598 let t2 = InferredType::Integer;
599 let common = InferredType::common_type(&t1, &t2);
600 assert_eq!(common, InferredType::Integer);
601 }
602
603 #[test]
604 fn test_type_common_type_numeric() {
605 let t1 = InferredType::Integer;
606 let t2 = InferredType::Float;
607 let common = InferredType::common_type(&t1, &t2);
608 assert_eq!(common, InferredType::Float);
609 }
610
611 #[test]
612 fn test_type_common_type_unknown() {
613 let t1 = InferredType::Unknown;
614 let t2 = InferredType::Integer;
615 let common = InferredType::common_type(&t1, &t2);
616 assert_eq!(common, InferredType::Integer);
617 }
618
619 #[test]
620 fn test_type_is_numeric() {
621 assert!(InferredType::Integer.is_numeric());
622 assert!(InferredType::Float.is_numeric());
623 assert!(InferredType::Numeric.is_numeric());
624 assert!(!InferredType::String.is_numeric());
625 }
626
627 #[test]
628 fn test_type_is_scalar() {
629 assert!(InferredType::Integer.is_scalar());
630 assert!(InferredType::String.is_scalar());
631 assert!(!InferredType::Array(Box::new(InferredType::Integer)).is_scalar());
632 }
633}