use super::ast::{BinOp, Expr, FieldRef, UnaryOp};
use crate::storage::schema::cast_catalog::{can_implicit_cast, CastContext};
use crate::storage::schema::types::{DataType, TypeCategory, Value};
#[derive(Debug, Clone)]
pub enum TypeError {
UnknownColumn { table: String, column: String },
OperatorMismatch {
op: BinOp,
lhs: DataType,
rhs: DataType,
},
UnaryMismatch { op: UnaryOp, operand: DataType },
InvalidCast { src: DataType, target: DataType },
CaseBranchMismatch { first: DataType, other: DataType },
InListMismatch { target: DataType, element: DataType },
}
impl std::fmt::Display for TypeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::UnknownColumn { table, column } => {
if table.is_empty() {
write!(f, "unknown column `{column}`")
} else {
write!(f, "unknown column `{table}.{column}`")
}
}
Self::OperatorMismatch { op, lhs, rhs } => {
write!(
f,
"operator `{op:?}` cannot apply to `{lhs:?}` and `{rhs:?}`"
)
}
Self::UnaryMismatch { op, operand } => {
write!(f, "unary `{op:?}` cannot apply to `{operand:?}`")
}
Self::InvalidCast { src, target } => {
write!(f, "no cast from `{src:?}` to `{target:?}`")
}
Self::CaseBranchMismatch { first, other } => {
write!(
f,
"CASE branches disagree on type: `{first:?}` vs `{other:?}`"
)
}
Self::InListMismatch { target, element } => {
write!(
f,
"IN list element `{element:?}` is incompatible with target `{target:?}`"
)
}
}
}
}
impl std::error::Error for TypeError {}
#[derive(Debug, Clone)]
pub struct TypedExpr {
pub kind: TypedExprKind,
pub ty: DataType,
}
#[derive(Debug, Clone)]
pub enum TypedExprKind {
Literal(Value),
Column(FieldRef),
UnaryOp {
op: UnaryOp,
operand: Box<TypedExpr>,
},
BinaryOp {
op: BinOp,
lhs: Box<TypedExpr>,
rhs: Box<TypedExpr>,
},
Cast {
inner: Box<TypedExpr>,
},
FunctionCall {
name: String,
args: Vec<TypedExpr>,
},
Case {
branches: Vec<(TypedExpr, TypedExpr)>,
else_: Option<Box<TypedExpr>>,
},
IsNull {
operand: Box<TypedExpr>,
negated: bool,
},
InList {
target: Box<TypedExpr>,
values: Vec<TypedExpr>,
negated: bool,
},
Between {
target: Box<TypedExpr>,
low: Box<TypedExpr>,
high: Box<TypedExpr>,
negated: bool,
},
}
pub trait Scope {
fn lookup(&self, table: &str, column: &str) -> Option<DataType>;
}
impl<F> Scope for F
where
F: Fn(&str, &str) -> Option<DataType>,
{
fn lookup(&self, table: &str, column: &str) -> Option<DataType> {
self(table, column)
}
}
pub fn type_expr(expr: &Expr, scope: &dyn Scope) -> Result<TypedExpr, TypeError> {
match expr {
Expr::Literal { value, .. } => Ok(TypedExpr {
ty: literal_type(value),
kind: TypedExprKind::Literal(value.clone()),
}),
Expr::Column { field, .. } => {
let (table, column) = match field {
FieldRef::TableColumn { table, column } => (table.as_str(), column.as_str()),
FieldRef::NodeProperty { alias, property } => (alias.as_str(), property.as_str()),
FieldRef::EdgeProperty { alias, property } => (alias.as_str(), property.as_str()),
FieldRef::NodeId { .. } => ("", ""),
};
let ty = scope
.lookup(table, column)
.ok_or(TypeError::UnknownColumn {
table: table.to_string(),
column: column.to_string(),
})?;
Ok(TypedExpr {
ty,
kind: TypedExprKind::Column(field.clone()),
})
}
Expr::Parameter { .. } => {
Ok(TypedExpr {
ty: DataType::Nullable,
kind: TypedExprKind::Literal(Value::Null),
})
}
Expr::UnaryOp { op, operand, .. } => {
let inner = type_expr(operand, scope)?;
let ty = unary_result_type(*op, inner.ty)?;
Ok(TypedExpr {
ty,
kind: TypedExprKind::UnaryOp {
op: *op,
operand: Box::new(inner),
},
})
}
Expr::BinaryOp { op, lhs, rhs, .. } => {
let l = type_expr(lhs, scope)?;
let r = type_expr(rhs, scope)?;
let ty = binop_result_type(*op, l.ty, r.ty)?;
Ok(TypedExpr {
ty,
kind: TypedExprKind::BinaryOp {
op: *op,
lhs: Box::new(l),
rhs: Box::new(r),
},
})
}
Expr::Cast { inner, target, .. } => {
let inner_typed = type_expr(inner, scope)?;
if !crate::storage::schema::cast_catalog::can_explicit_cast(inner_typed.ty, *target) {
return Err(TypeError::InvalidCast {
src: inner_typed.ty,
target: *target,
});
}
Ok(TypedExpr {
ty: *target,
kind: TypedExprKind::Cast {
inner: Box::new(inner_typed),
},
})
}
Expr::FunctionCall { name, args, .. } => {
let typed_args = args
.iter()
.map(|a| type_expr(a, scope))
.collect::<Result<Vec<_>, _>>()?;
let arg_dt: Vec<DataType> = typed_args.iter().map(|t| t.ty).collect();
let return_ty = resolve_function_return_type(name, &arg_dt);
Ok(TypedExpr {
ty: return_ty,
kind: TypedExprKind::FunctionCall {
name: name.clone(),
args: typed_args,
},
})
}
Expr::Case {
branches, else_, ..
} => {
let mut typed_branches = Vec::with_capacity(branches.len());
let mut result_ty: Option<DataType> = None;
for (cond, val) in branches {
let cond_typed = type_expr(cond, scope)?;
let val_typed = type_expr(val, scope)?;
let prev_ty = result_ty;
result_ty = merge_compatible_type(result_ty, val_typed.ty).map_err(|_| {
TypeError::CaseBranchMismatch {
first: prev_ty.unwrap_or(val_typed.ty),
other: val_typed.ty,
}
})?;
typed_branches.push((cond_typed, val_typed));
}
let typed_else = if let Some(else_expr) = else_ {
let e = type_expr(else_expr, scope)?;
let prev_ty = result_ty;
result_ty = merge_compatible_type(result_ty, e.ty).map_err(|_| {
TypeError::CaseBranchMismatch {
first: prev_ty.unwrap_or(e.ty),
other: e.ty,
}
})?;
Some(Box::new(e))
} else {
None
};
let ty = result_ty.unwrap_or(DataType::Nullable);
Ok(TypedExpr {
ty,
kind: TypedExprKind::Case {
branches: typed_branches,
else_: typed_else,
},
})
}
Expr::IsNull {
operand, negated, ..
} => {
let inner = type_expr(operand, scope)?;
Ok(TypedExpr {
ty: DataType::Boolean,
kind: TypedExprKind::IsNull {
operand: Box::new(inner),
negated: *negated,
},
})
}
Expr::InList {
target,
values,
negated,
..
} => {
let target_typed = type_expr(target, scope)?;
let mut typed_values = Vec::with_capacity(values.len());
for v in values {
let vt = type_expr(v, scope)?;
if vt.ty != target_typed.ty && !can_implicit_cast(vt.ty, target_typed.ty) {
return Err(TypeError::InListMismatch {
target: target_typed.ty,
element: vt.ty,
});
}
typed_values.push(vt);
}
Ok(TypedExpr {
ty: DataType::Boolean,
kind: TypedExprKind::InList {
target: Box::new(target_typed),
values: typed_values,
negated: *negated,
},
})
}
Expr::Between {
target,
low,
high,
negated,
..
} => {
let target_typed = type_expr(target, scope)?;
let low_typed = type_expr(low, scope)?;
let high_typed = type_expr(high, scope)?;
for bound in &[&low_typed, &high_typed] {
if bound.ty != target_typed.ty && !can_implicit_cast(bound.ty, target_typed.ty) {
return Err(TypeError::OperatorMismatch {
op: BinOp::Ge,
lhs: target_typed.ty,
rhs: bound.ty,
});
}
}
Ok(TypedExpr {
ty: DataType::Boolean,
kind: TypedExprKind::Between {
target: Box::new(target_typed),
low: Box::new(low_typed),
high: Box::new(high_typed),
negated: *negated,
},
})
}
}
}
fn literal_type(v: &Value) -> DataType {
match v {
Value::Null => DataType::Nullable,
Value::Boolean(_) => DataType::Boolean,
Value::Integer(_) => DataType::Integer,
Value::UnsignedInteger(_) => DataType::UnsignedInteger,
Value::Float(_) => DataType::Float,
Value::BigInt(_) => DataType::BigInt,
Value::Decimal(_) => DataType::Decimal,
Value::Text(_) => DataType::Text,
Value::Blob(_) => DataType::Blob,
Value::Timestamp(_) => DataType::Timestamp,
Value::TimestampMs(_) => DataType::TimestampMs,
Value::Duration(_) => DataType::Duration,
Value::Date(_) => DataType::Date,
Value::Time(_) => DataType::Time,
Value::IpAddr(_) => DataType::IpAddr,
Value::Ipv4(_) => DataType::Ipv4,
Value::Ipv6(_) => DataType::Ipv6,
Value::Subnet(_, _) => DataType::Subnet,
Value::Cidr(_, _) => DataType::Cidr,
Value::MacAddr(_) => DataType::MacAddr,
Value::Port(_) => DataType::Port,
Value::Latitude(_) => DataType::Latitude,
Value::Longitude(_) => DataType::Longitude,
Value::GeoPoint(_, _) => DataType::GeoPoint,
Value::Country2(_) => DataType::Country2,
Value::Country3(_) => DataType::Country3,
Value::Lang2(_) => DataType::Lang2,
Value::Lang5(_) => DataType::Lang5,
Value::Currency(_) => DataType::Currency,
Value::AssetCode(_) => DataType::AssetCode,
Value::Money { .. } => DataType::Money,
Value::Color(_) => DataType::Color,
Value::ColorAlpha(_) => DataType::ColorAlpha,
Value::Email(_) => DataType::Email,
Value::Url(_) => DataType::Url,
Value::Phone(_) => DataType::Phone,
Value::Semver(_) => DataType::Semver,
Value::Uuid(_) => DataType::Uuid,
Value::Vector(_) => DataType::Vector,
Value::Array(_) => DataType::Array,
Value::Json(_) => DataType::Json,
Value::EnumValue(_) => DataType::Enum,
Value::NodeRef(_) => DataType::NodeRef,
Value::EdgeRef(_) => DataType::EdgeRef,
Value::VectorRef(_, _) => DataType::VectorRef,
Value::RowRef(_, _) => DataType::RowRef,
Value::KeyRef(_, _) => DataType::KeyRef,
Value::DocRef(_, _) => DataType::DocRef,
Value::TableRef(_) => DataType::TableRef,
Value::PageRef(_) => DataType::PageRef,
Value::Secret(_) => DataType::Secret,
Value::Password(_) => DataType::Password,
}
}
fn resolve_function_return_type(name: &str, arg_types: &[DataType]) -> DataType {
let upper = name.to_ascii_uppercase();
match upper.as_str() {
"CONCAT" | "CONCAT_WS" | "QUOTE_LITERAL" => DataType::Text,
"MONEY" => DataType::Money,
"MONEY_ASSET" => DataType::AssetCode,
"MONEY_MINOR" => DataType::BigInt,
"MONEY_SCALE" => DataType::Integer,
"COALESCE" => resolve_coalesce_return_type(arg_types),
_ => crate::storage::schema::function_catalog::resolve(name, arg_types)
.map(|entry| entry.return_type)
.unwrap_or(DataType::Nullable),
}
}
fn resolve_coalesce_return_type(arg_types: &[DataType]) -> DataType {
let mut resolved: Option<DataType> = None;
for &arg_ty in arg_types {
match merge_compatible_type(resolved, arg_ty) {
Ok(next) => resolved = next,
Err(_) => return DataType::Nullable,
}
}
resolved.unwrap_or(DataType::Nullable)
}
fn merge_compatible_type(
current: Option<DataType>,
next: DataType,
) -> Result<Option<DataType>, ()> {
if next == DataType::Nullable {
return Ok(current);
}
match current {
None => Ok(Some(next)),
Some(prev) if prev == next => Ok(Some(prev)),
Some(prev) if can_implicit_cast(next, prev) => Ok(Some(prev)),
Some(prev) if can_implicit_cast(prev, next) => Ok(Some(next)),
Some(_) => Err(()),
}
}
fn unary_result_type(op: UnaryOp, operand: DataType) -> Result<DataType, TypeError> {
match op {
UnaryOp::Neg if operand.category() == TypeCategory::Numeric => Ok(operand),
UnaryOp::Not if operand == DataType::Boolean => Ok(DataType::Boolean),
_ => Err(TypeError::UnaryMismatch { op, operand }),
}
}
fn binop_result_type(op: BinOp, lhs: DataType, rhs: DataType) -> Result<DataType, TypeError> {
use BinOp::*;
match op {
And | Or => {
if lhs == DataType::Boolean && rhs == DataType::Boolean {
Ok(DataType::Boolean)
} else {
Err(TypeError::OperatorMismatch { op, lhs, rhs })
}
}
Eq | Ne | Lt | Le | Gt | Ge => {
if lhs == rhs {
return Ok(DataType::Boolean);
}
if lhs.category() == rhs.category()
&& (can_implicit_cast(lhs, rhs) || can_implicit_cast(rhs, lhs))
{
return Ok(DataType::Boolean);
}
Err(TypeError::OperatorMismatch { op, lhs, rhs })
}
Add | Sub | Mul | Div | Mod => {
if lhs.category() != TypeCategory::Numeric || rhs.category() != TypeCategory::Numeric {
return Err(TypeError::OperatorMismatch { op, lhs, rhs });
}
if lhs == DataType::Float || rhs == DataType::Float {
Ok(DataType::Float)
} else if lhs == DataType::Decimal || rhs == DataType::Decimal {
Ok(DataType::Decimal)
} else if lhs == DataType::BigInt || rhs == DataType::BigInt {
Ok(DataType::BigInt)
} else {
Ok(DataType::Integer)
}
}
Concat => {
if lhs == DataType::Text && rhs == DataType::Text {
Ok(DataType::Text)
} else {
Err(TypeError::OperatorMismatch { op, lhs, rhs })
}
}
}
}
#[allow(dead_code)]
fn _ctx_explicit() -> CastContext {
CastContext::Explicit
}