use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum MathType {
Number,
Variable,
Function {
arity: Arity,
domain: Vec<MathType>,
codomain: Box<MathType>,
},
BinaryOp,
UnaryOp,
NaryOp,
Relation,
Set,
Vector {
element: Box<MathType>,
dimension: Option<usize>,
},
Matrix {
element: Box<MathType>,
dimensions: Option<(usize, usize)>,
},
Boolean,
Unit,
TypeVar(u32),
Unknown,
Error(String),
}
impl MathType {
pub fn is_numeric(&self) -> bool {
matches!(
self,
MathType::Number | MathType::Variable | MathType::TypeVar(_)
)
}
pub fn is_function(&self) -> bool {
matches!(self, MathType::Function { .. })
}
pub fn is_operator(&self) -> bool {
matches!(
self,
MathType::BinaryOp | MathType::UnaryOp | MathType::NaryOp
)
}
pub fn compatible_with(&self, other: &MathType) -> bool {
match (self, other) {
(a, b) if a == b => true,
(MathType::TypeVar(_), _) | (_, MathType::TypeVar(_)) => true,
(MathType::Unknown, _) | (_, MathType::Unknown) => true,
(MathType::Variable, MathType::Number) | (MathType::Number, MathType::Variable) => true,
(MathType::Function { arity: a1, .. }, MathType::Function { arity: a2, .. }) => {
a1 == a2
}
(MathType::Vector { .. }, MathType::Vector { .. }) => true,
(MathType::Matrix { .. }, MathType::Matrix { .. }) => true,
_ => false,
}
}
}
impl fmt::Display for MathType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
MathType::Number => write!(f, "Number"),
MathType::Variable => write!(f, "Var"),
MathType::Function {
arity,
domain,
codomain,
} => {
let args: Vec<_> = domain.iter().map(|t| t.to_string()).collect();
write!(
f,
"({}) -> {} [arity: {:?}]",
args.join(", "),
codomain,
arity
)
}
MathType::BinaryOp => write!(f, "BinOp"),
MathType::UnaryOp => write!(f, "UnaryOp"),
MathType::NaryOp => write!(f, "NaryOp"),
MathType::Relation => write!(f, "Relation"),
MathType::Set => write!(f, "Set"),
MathType::Vector { element, dimension } => {
if let Some(d) = dimension {
write!(f, "Vec<{}>^{}", element, d)
} else {
write!(f, "Vec<{}>", element)
}
}
MathType::Matrix {
element,
dimensions,
} => {
if let Some((r, c)) = dimensions {
write!(f, "Mat<{}>^({}x{})", element, r, c)
} else {
write!(f, "Mat<{}>", element)
}
}
MathType::Boolean => write!(f, "Bool"),
MathType::Unit => write!(f, "()"),
MathType::TypeVar(id) => write!(f, "T{}", id),
MathType::Unknown => write!(f, "?"),
MathType::Error(msg) => write!(f, "Error({})", msg),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Arity {
Nullary,
Unary,
Binary,
Ternary,
Variadic,
Fixed(usize),
}
impl Arity {
pub fn accepts(&self, n: usize) -> bool {
match self {
Arity::Nullary => n == 0,
Arity::Unary => n == 1,
Arity::Binary => n == 2,
Arity::Ternary => n == 3,
Arity::Variadic => true,
Arity::Fixed(k) => n == *k,
}
}
pub fn min_args(&self) -> usize {
match self {
Arity::Nullary => 0,
Arity::Unary => 1,
Arity::Binary => 2,
Arity::Ternary => 3,
Arity::Variadic => 0,
Arity::Fixed(k) => *k,
}
}
}
#[derive(Debug, Clone)]
pub struct TypeSignature {
pub name: String,
pub math_type: MathType,
pub aliases: Vec<String>,
pub category: SemanticCategory,
}
impl TypeSignature {
pub fn new(name: impl Into<String>, math_type: MathType, category: SemanticCategory) -> Self {
Self {
name: name.into(),
math_type,
aliases: Vec::new(),
category,
}
}
pub fn with_alias(mut self, alias: impl Into<String>) -> Self {
self.aliases.push(alias.into());
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SemanticCategory {
Arithmetic,
Algebra,
Calculus,
SetTheory,
Logic,
LinearAlgebra,
Trigonometry,
Constant,
Variable,
Delimiter,
Presentation,
}
#[derive(Debug, Clone, Default)]
pub struct TypeEnvironment {
bindings: HashMap<String, MathType>,
parent: Option<Box<TypeEnvironment>>,
}
impl TypeEnvironment {
pub fn new() -> Self {
Self::default()
}
pub fn child(&self) -> Self {
Self {
bindings: HashMap::new(),
parent: Some(Box::new(self.clone())),
}
}
pub fn bind(&mut self, name: impl Into<String>, ty: MathType) {
self.bindings.insert(name.into(), ty);
}
pub fn lookup(&self, name: &str) -> Option<&MathType> {
self.bindings
.get(name)
.or_else(|| self.parent.as_ref().and_then(|p| p.lookup(name)))
}
pub fn contains(&self, name: &str) -> bool {
self.lookup(name).is_some()
}
}
#[derive(Debug, Clone)]
pub struct TypeResult {
pub inferred_type: MathType,
pub errors: Vec<TypeError>,
pub warnings: Vec<TypeWarning>,
}
impl TypeResult {
pub fn ok(ty: MathType) -> Self {
Self {
inferred_type: ty,
errors: Vec::new(),
warnings: Vec::new(),
}
}
pub fn error(ty: MathType, error: TypeError) -> Self {
Self {
inferred_type: ty,
errors: vec![error],
warnings: Vec::new(),
}
}
pub fn is_ok(&self) -> bool {
self.errors.is_empty()
}
pub fn with_error(mut self, error: TypeError) -> Self {
self.errors.push(error);
self
}
pub fn with_warning(mut self, warning: TypeWarning) -> Self {
self.warnings.push(warning);
self
}
}
#[derive(Debug, Clone)]
pub struct TypeError {
pub kind: TypeErrorKind,
pub position: Option<usize>,
pub message: String,
}
impl TypeError {
pub fn new(kind: TypeErrorKind, message: impl Into<String>) -> Self {
Self {
kind,
position: None,
message: message.into(),
}
}
pub fn at(mut self, pos: usize) -> Self {
self.position = Some(pos);
self
}
}
impl fmt::Display for TypeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(pos) = self.position {
write!(f, "[{}] {:?}: {}", pos, self.kind, self.message)
} else {
write!(f, "{:?}: {}", self.kind, self.message)
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TypeErrorKind {
TypeMismatch,
ArityMismatch,
UndefinedVariable,
InvalidOperator,
DivisionByZero,
InvalidStructure,
AmbiguousType,
}
#[derive(Debug, Clone)]
pub struct TypeWarning {
pub kind: TypeWarningKind,
pub position: Option<usize>,
pub message: String,
}
impl TypeWarning {
pub fn new(kind: TypeWarningKind, message: impl Into<String>) -> Self {
Self {
kind,
position: None,
message: message.into(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TypeWarningKind {
ImplicitCoercion,
UnusedVariable,
Ambiguity,
Deprecated,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_math_type_display() {
assert_eq!(format!("{}", MathType::Number), "Number");
assert_eq!(format!("{}", MathType::Variable), "Var");
assert_eq!(format!("{}", MathType::BinaryOp), "BinOp");
}
#[test]
fn test_math_type_is_numeric() {
assert!(MathType::Number.is_numeric());
assert!(MathType::Variable.is_numeric());
assert!(!MathType::Set.is_numeric());
}
#[test]
fn test_math_type_compatible() {
assert!(MathType::Number.compatible_with(&MathType::Number));
assert!(MathType::Number.compatible_with(&MathType::Variable));
assert!(MathType::TypeVar(0).compatible_with(&MathType::Set));
assert!(!MathType::Set.compatible_with(&MathType::Number));
}
#[test]
fn test_arity_accepts() {
assert!(Arity::Nullary.accepts(0));
assert!(!Arity::Nullary.accepts(1));
assert!(Arity::Unary.accepts(1));
assert!(Arity::Binary.accepts(2));
assert!(Arity::Variadic.accepts(5));
assert!(Arity::Fixed(3).accepts(3));
assert!(!Arity::Fixed(3).accepts(2));
}
#[test]
fn test_type_environment() {
let mut env = TypeEnvironment::new();
env.bind("x", MathType::Number);
env.bind(
"f",
MathType::Function {
arity: Arity::Unary,
domain: vec![MathType::Number],
codomain: Box::new(MathType::Number),
},
);
assert_eq!(env.lookup("x"), Some(&MathType::Number));
assert!(env.lookup("f").is_some());
assert!(env.lookup("y").is_none());
}
#[test]
fn test_type_environment_scoping() {
let mut parent = TypeEnvironment::new();
parent.bind("x", MathType::Number);
let mut child = parent.child();
child.bind("y", MathType::Variable);
assert!(child.lookup("x").is_some());
assert!(child.lookup("y").is_some());
assert!(parent.lookup("y").is_none());
}
#[test]
fn test_type_result() {
let ok = TypeResult::ok(MathType::Number);
assert!(ok.is_ok());
let err = TypeResult::error(
MathType::Error("test".to_string()),
TypeError::new(TypeErrorKind::TypeMismatch, "mismatch"),
);
assert!(!err.is_ok());
}
#[test]
fn test_type_signature() {
let sig = TypeSignature::new(
"sin",
MathType::Function {
arity: Arity::Unary,
domain: vec![MathType::Number],
codomain: Box::new(MathType::Number),
},
SemanticCategory::Trigonometry,
)
.with_alias("sine");
assert_eq!(sig.name, "sin");
assert_eq!(sig.aliases, vec!["sine"]);
assert_eq!(sig.category, SemanticCategory::Trigonometry);
}
}