1use super::ast::{BinOp, Expr, FieldRef, UnaryOp};
32use crate::storage::schema::cast_catalog::{can_implicit_cast, CastContext};
33use crate::storage::schema::types::{DataType, TypeCategory, Value};
34
35#[derive(Debug, Clone)]
39pub enum TypeError {
40 UnknownColumn { table: String, column: String },
42 OperatorMismatch {
45 op: BinOp,
46 lhs: DataType,
47 rhs: DataType,
48 },
49 UnaryMismatch { op: UnaryOp, operand: DataType },
51 InvalidCast { src: DataType, target: DataType },
54 CaseBranchMismatch { first: DataType, other: DataType },
56 InListMismatch { target: DataType, element: DataType },
58}
59
60impl std::fmt::Display for TypeError {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 match self {
63 Self::UnknownColumn { table, column } => {
64 if table.is_empty() {
65 write!(f, "unknown column `{column}`")
66 } else {
67 write!(f, "unknown column `{table}.{column}`")
68 }
69 }
70 Self::OperatorMismatch { op, lhs, rhs } => {
71 write!(
72 f,
73 "operator `{op:?}` cannot apply to `{lhs:?}` and `{rhs:?}`"
74 )
75 }
76 Self::UnaryMismatch { op, operand } => {
77 write!(f, "unary `{op:?}` cannot apply to `{operand:?}`")
78 }
79 Self::InvalidCast { src, target } => {
80 write!(f, "no cast from `{src:?}` to `{target:?}`")
81 }
82 Self::CaseBranchMismatch { first, other } => {
83 write!(
84 f,
85 "CASE branches disagree on type: `{first:?}` vs `{other:?}`"
86 )
87 }
88 Self::InListMismatch { target, element } => {
89 write!(
90 f,
91 "IN list element `{element:?}` is incompatible with target `{target:?}`"
92 )
93 }
94 }
95 }
96}
97
98impl std::error::Error for TypeError {}
99
100#[derive(Debug, Clone)]
108pub struct TypedExpr {
109 pub kind: TypedExprKind,
110 pub ty: DataType,
111}
112
113#[derive(Debug, Clone)]
114pub enum TypedExprKind {
115 Literal(Value),
116 Column(FieldRef),
117 UnaryOp {
118 op: UnaryOp,
119 operand: Box<TypedExpr>,
120 },
121 BinaryOp {
122 op: BinOp,
123 lhs: Box<TypedExpr>,
124 rhs: Box<TypedExpr>,
125 },
126 Cast {
127 inner: Box<TypedExpr>,
128 },
129 FunctionCall {
130 name: String,
131 args: Vec<TypedExpr>,
132 },
133 Case {
134 branches: Vec<(TypedExpr, TypedExpr)>,
135 else_: Option<Box<TypedExpr>>,
136 },
137 IsNull {
138 operand: Box<TypedExpr>,
139 negated: bool,
140 },
141 InList {
142 target: Box<TypedExpr>,
143 values: Vec<TypedExpr>,
144 negated: bool,
145 },
146 Between {
147 target: Box<TypedExpr>,
148 low: Box<TypedExpr>,
149 high: Box<TypedExpr>,
150 negated: bool,
151 },
152}
153
154pub trait Scope {
159 fn lookup(&self, table: &str, column: &str) -> Option<DataType>;
160}
161
162impl<F> Scope for F
163where
164 F: Fn(&str, &str) -> Option<DataType>,
165{
166 fn lookup(&self, table: &str, column: &str) -> Option<DataType> {
167 self(table, column)
168 }
169}
170
171pub fn type_expr(expr: &Expr, scope: &dyn Scope) -> Result<TypedExpr, TypeError> {
173 match expr {
174 Expr::Literal { value, .. } => Ok(TypedExpr {
175 ty: literal_type(value),
176 kind: TypedExprKind::Literal(value.clone()),
177 }),
178 Expr::Column { field, .. } => {
179 let (table, column) = match field {
180 FieldRef::TableColumn { table, column } => (table.as_str(), column.as_str()),
181 FieldRef::NodeProperty { alias, property } => (alias.as_str(), property.as_str()),
182 FieldRef::EdgeProperty { alias, property } => (alias.as_str(), property.as_str()),
183 FieldRef::NodeId { .. } => ("", ""),
184 };
185 let ty = scope
186 .lookup(table, column)
187 .ok_or(TypeError::UnknownColumn {
188 table: table.to_string(),
189 column: column.to_string(),
190 })?;
191 Ok(TypedExpr {
192 ty,
193 kind: TypedExprKind::Column(field.clone()),
194 })
195 }
196 Expr::Parameter { .. } => {
197 Ok(TypedExpr {
201 ty: DataType::Nullable,
202 kind: TypedExprKind::Literal(Value::Null),
203 })
204 }
205 Expr::UnaryOp { op, operand, .. } => {
206 let inner = type_expr(operand, scope)?;
207 let ty = unary_result_type(*op, inner.ty)?;
208 Ok(TypedExpr {
209 ty,
210 kind: TypedExprKind::UnaryOp {
211 op: *op,
212 operand: Box::new(inner),
213 },
214 })
215 }
216 Expr::BinaryOp { op, lhs, rhs, .. } => {
217 let l = type_expr(lhs, scope)?;
218 let r = type_expr(rhs, scope)?;
219 let ty = binop_result_type(*op, l.ty, r.ty)?;
220 Ok(TypedExpr {
221 ty,
222 kind: TypedExprKind::BinaryOp {
223 op: *op,
224 lhs: Box::new(l),
225 rhs: Box::new(r),
226 },
227 })
228 }
229 Expr::Cast { inner, target, .. } => {
230 let inner_typed = type_expr(inner, scope)?;
231 if !crate::storage::schema::cast_catalog::can_explicit_cast(inner_typed.ty, *target) {
234 return Err(TypeError::InvalidCast {
235 src: inner_typed.ty,
236 target: *target,
237 });
238 }
239 Ok(TypedExpr {
240 ty: *target,
241 kind: TypedExprKind::Cast {
242 inner: Box::new(inner_typed),
243 },
244 })
245 }
246 Expr::FunctionCall { name, args, .. } => {
247 let typed_args = args
248 .iter()
249 .map(|a| type_expr(a, scope))
250 .collect::<Result<Vec<_>, _>>()?;
251 let arg_dt: Vec<DataType> = typed_args.iter().map(|t| t.ty).collect();
259 let return_ty = resolve_function_return_type(name, &arg_dt);
260 Ok(TypedExpr {
261 ty: return_ty,
262 kind: TypedExprKind::FunctionCall {
263 name: name.clone(),
264 args: typed_args,
265 },
266 })
267 }
268 Expr::Case {
269 branches, else_, ..
270 } => {
271 let mut typed_branches = Vec::with_capacity(branches.len());
272 let mut result_ty: Option<DataType> = None;
273 for (cond, val) in branches {
274 let cond_typed = type_expr(cond, scope)?;
275 let val_typed = type_expr(val, scope)?;
276 let prev_ty = result_ty;
277 result_ty = merge_compatible_type(result_ty, val_typed.ty).map_err(|_| {
278 TypeError::CaseBranchMismatch {
279 first: prev_ty.unwrap_or(val_typed.ty),
280 other: val_typed.ty,
281 }
282 })?;
283 typed_branches.push((cond_typed, val_typed));
284 }
285 let typed_else = if let Some(else_expr) = else_ {
286 let e = type_expr(else_expr, scope)?;
287 let prev_ty = result_ty;
288 result_ty = merge_compatible_type(result_ty, e.ty).map_err(|_| {
289 TypeError::CaseBranchMismatch {
290 first: prev_ty.unwrap_or(e.ty),
291 other: e.ty,
292 }
293 })?;
294 Some(Box::new(e))
295 } else {
296 None
297 };
298 let ty = result_ty.unwrap_or(DataType::Nullable);
299 Ok(TypedExpr {
300 ty,
301 kind: TypedExprKind::Case {
302 branches: typed_branches,
303 else_: typed_else,
304 },
305 })
306 }
307 Expr::IsNull {
308 operand, negated, ..
309 } => {
310 let inner = type_expr(operand, scope)?;
311 Ok(TypedExpr {
312 ty: DataType::Boolean,
313 kind: TypedExprKind::IsNull {
314 operand: Box::new(inner),
315 negated: *negated,
316 },
317 })
318 }
319 Expr::InList {
320 target,
321 values,
322 negated,
323 ..
324 } => {
325 let target_typed = type_expr(target, scope)?;
326 let mut typed_values = Vec::with_capacity(values.len());
327 for v in values {
328 let vt = type_expr(v, scope)?;
329 if vt.ty != target_typed.ty && !can_implicit_cast(vt.ty, target_typed.ty) {
330 return Err(TypeError::InListMismatch {
331 target: target_typed.ty,
332 element: vt.ty,
333 });
334 }
335 typed_values.push(vt);
336 }
337 Ok(TypedExpr {
338 ty: DataType::Boolean,
339 kind: TypedExprKind::InList {
340 target: Box::new(target_typed),
341 values: typed_values,
342 negated: *negated,
343 },
344 })
345 }
346 Expr::Between {
347 target,
348 low,
349 high,
350 negated,
351 ..
352 } => {
353 let target_typed = type_expr(target, scope)?;
354 let low_typed = type_expr(low, scope)?;
355 let high_typed = type_expr(high, scope)?;
356 for bound in &[&low_typed, &high_typed] {
358 if bound.ty != target_typed.ty && !can_implicit_cast(bound.ty, target_typed.ty) {
359 return Err(TypeError::OperatorMismatch {
360 op: BinOp::Ge,
361 lhs: target_typed.ty,
362 rhs: bound.ty,
363 });
364 }
365 }
366 Ok(TypedExpr {
367 ty: DataType::Boolean,
368 kind: TypedExprKind::Between {
369 target: Box::new(target_typed),
370 low: Box::new(low_typed),
371 high: Box::new(high_typed),
372 negated: *negated,
373 },
374 })
375 }
376 Expr::Subquery { .. } => Ok(TypedExpr {
377 ty: DataType::Nullable,
378 kind: TypedExprKind::Literal(Value::Null),
379 }),
380 }
381}
382
383fn literal_type(v: &Value) -> DataType {
388 match v {
389 Value::Null => DataType::Nullable,
390 Value::Boolean(_) => DataType::Boolean,
391 Value::Integer(_) => DataType::Integer,
392 Value::UnsignedInteger(_) => DataType::UnsignedInteger,
393 Value::Float(_) => DataType::Float,
394 Value::BigInt(_) => DataType::BigInt,
395 Value::Decimal(_) => DataType::Decimal,
396 Value::Text(_) => DataType::Text,
397 Value::Blob(_) => DataType::Blob,
398 Value::Timestamp(_) => DataType::Timestamp,
399 Value::TimestampMs(_) => DataType::TimestampMs,
400 Value::Duration(_) => DataType::Duration,
401 Value::Date(_) => DataType::Date,
402 Value::Time(_) => DataType::Time,
403 Value::IpAddr(_) => DataType::IpAddr,
404 Value::Ipv4(_) => DataType::Ipv4,
405 Value::Ipv6(_) => DataType::Ipv6,
406 Value::Subnet(_, _) => DataType::Subnet,
407 Value::Cidr(_, _) => DataType::Cidr,
408 Value::MacAddr(_) => DataType::MacAddr,
409 Value::Port(_) => DataType::Port,
410 Value::Latitude(_) => DataType::Latitude,
411 Value::Longitude(_) => DataType::Longitude,
412 Value::GeoPoint(_, _) => DataType::GeoPoint,
413 Value::Country2(_) => DataType::Country2,
414 Value::Country3(_) => DataType::Country3,
415 Value::Lang2(_) => DataType::Lang2,
416 Value::Lang5(_) => DataType::Lang5,
417 Value::Currency(_) => DataType::Currency,
418 Value::AssetCode(_) => DataType::AssetCode,
419 Value::Money { .. } => DataType::Money,
420 Value::Color(_) => DataType::Color,
421 Value::ColorAlpha(_) => DataType::ColorAlpha,
422 Value::Email(_) => DataType::Email,
423 Value::Url(_) => DataType::Url,
424 Value::Phone(_) => DataType::Phone,
425 Value::Semver(_) => DataType::Semver,
426 Value::Uuid(_) => DataType::Uuid,
427 Value::Vector(_) => DataType::Vector,
428 Value::Array(_) => DataType::Array,
429 Value::Json(_) => DataType::Json,
430 Value::EnumValue(_) => DataType::Enum,
431 Value::NodeRef(_) => DataType::NodeRef,
432 Value::EdgeRef(_) => DataType::EdgeRef,
433 Value::VectorRef(_, _) => DataType::VectorRef,
434 Value::RowRef(_, _) => DataType::RowRef,
435 Value::KeyRef(_, _) => DataType::KeyRef,
436 Value::DocRef(_, _) => DataType::DocRef,
437 Value::TableRef(_) => DataType::TableRef,
438 Value::PageRef(_) => DataType::PageRef,
439 Value::Secret(_) => DataType::Secret,
440 Value::Password(_) => DataType::Password,
441 }
442}
443
444fn resolve_function_return_type(name: &str, arg_types: &[DataType]) -> DataType {
445 let upper = name.to_ascii_uppercase();
446 match upper.as_str() {
447 "CONCAT" | "CONCAT_WS" | "QUOTE_LITERAL" => DataType::Text,
451 "MONEY" => DataType::Money,
452 "MONEY_ASSET" => DataType::AssetCode,
453 "MONEY_MINOR" => DataType::BigInt,
454 "MONEY_SCALE" => DataType::Integer,
455 "COALESCE" => resolve_coalesce_return_type(arg_types),
458 _ => crate::storage::schema::function_catalog::resolve(name, arg_types)
459 .map(|entry| entry.return_type)
460 .unwrap_or(DataType::Nullable),
461 }
462}
463
464fn resolve_coalesce_return_type(arg_types: &[DataType]) -> DataType {
465 let mut resolved: Option<DataType> = None;
466
467 for &arg_ty in arg_types {
468 match merge_compatible_type(resolved, arg_ty) {
469 Ok(next) => resolved = next,
470 Err(_) => return DataType::Nullable,
471 }
472 }
473
474 resolved.unwrap_or(DataType::Nullable)
475}
476
477fn merge_compatible_type(
478 current: Option<DataType>,
479 next: DataType,
480) -> Result<Option<DataType>, ()> {
481 if next == DataType::Nullable {
482 return Ok(current);
483 }
484
485 match current {
486 None => Ok(Some(next)),
487 Some(prev) if prev == next => Ok(Some(prev)),
488 Some(prev) if can_implicit_cast(next, prev) => Ok(Some(prev)),
489 Some(prev) if can_implicit_cast(prev, next) => Ok(Some(next)),
490 Some(_) => Err(()),
491 }
492}
493
494fn unary_result_type(op: UnaryOp, operand: DataType) -> Result<DataType, TypeError> {
497 match op {
498 UnaryOp::Neg if operand.category() == TypeCategory::Numeric => Ok(operand),
499 UnaryOp::Not if operand == DataType::Boolean => Ok(DataType::Boolean),
500 _ => Err(TypeError::UnaryMismatch { op, operand }),
501 }
502}
503
504fn binop_result_type(op: BinOp, lhs: DataType, rhs: DataType) -> Result<DataType, TypeError> {
516 use BinOp::*;
517 match op {
518 And | Or => {
519 if lhs == DataType::Boolean && rhs == DataType::Boolean {
520 Ok(DataType::Boolean)
521 } else {
522 Err(TypeError::OperatorMismatch { op, lhs, rhs })
523 }
524 }
525 Eq | Ne | Lt | Le | Gt | Ge => {
526 if lhs == rhs {
530 return Ok(DataType::Boolean);
531 }
532 if lhs.category() == rhs.category()
533 && (can_implicit_cast(lhs, rhs) || can_implicit_cast(rhs, lhs))
534 {
535 return Ok(DataType::Boolean);
536 }
537 Err(TypeError::OperatorMismatch { op, lhs, rhs })
538 }
539 Add | Sub | Mul | Div | Mod => {
540 if lhs.category() != TypeCategory::Numeric || rhs.category() != TypeCategory::Numeric {
541 return Err(TypeError::OperatorMismatch { op, lhs, rhs });
542 }
543 if lhs == DataType::Float || rhs == DataType::Float {
547 Ok(DataType::Float)
548 } else if lhs == DataType::Decimal || rhs == DataType::Decimal {
549 Ok(DataType::Decimal)
550 } else if lhs == DataType::BigInt || rhs == DataType::BigInt {
551 Ok(DataType::BigInt)
552 } else {
553 Ok(DataType::Integer)
554 }
555 }
556 Concat => {
557 if lhs == DataType::Text && rhs == DataType::Text {
558 Ok(DataType::Text)
559 } else {
560 Err(TypeError::OperatorMismatch { op, lhs, rhs })
561 }
562 }
563 }
564}
565
566#[allow(dead_code)]
570fn _ctx_explicit() -> CastContext {
571 CastContext::Explicit
572}