use std::{
collections::HashMap,
fmt::{Display, Formatter},
ops::Deref,
};
use spacetimedb_lib::AlgebraicType;
use spacetimedb_primitives::TableId;
use spacetimedb_sats::algebraic_type::fmt::fmt_algebraic_type;
use spacetimedb_sql_parser::ast::BinOp;
use string_interner::{backend::StringBackend, symbol::SymbolU32, StringInterner};
use thiserror::Error;
use super::errors::{ExpectedRelation, InvalidOp};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TyId(u32);
impl TyId {
pub const BOOL: Self = Self(0);
pub const I8: Self = Self(1);
pub const U8: Self = Self(2);
pub const I16: Self = Self(3);
pub const U16: Self = Self(4);
pub const I32: Self = Self(5);
pub const U32: Self = Self(6);
pub const I64: Self = Self(7);
pub const U64: Self = Self(8);
pub const I128: Self = Self(9);
pub const U128: Self = Self(10);
pub const I256: Self = Self(11);
pub const U256: Self = Self(12);
pub const F32: Self = Self(13);
pub const F64: Self = Self(14);
pub const STR: Self = Self(15);
pub const BYTES: Self = Self(16);
pub const IDENT: Self = Self(17);
const N: usize = 18;
}
impl Display for TyId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
pub type Symbol = SymbolU32;
#[derive(Debug)]
pub enum Type {
Var(TableId, Box<[(Symbol, TyId)]>),
Row(Box<[(Symbol, TyId)]>),
Alg(AlgebraicType),
}
impl Type {
pub const BOOL: Self = Self::Alg(AlgebraicType::Bool);
pub const I8: Self = Self::Alg(AlgebraicType::I8);
pub const U8: Self = Self::Alg(AlgebraicType::U8);
pub const I16: Self = Self::Alg(AlgebraicType::I16);
pub const U16: Self = Self::Alg(AlgebraicType::U16);
pub const I32: Self = Self::Alg(AlgebraicType::I32);
pub const U32: Self = Self::Alg(AlgebraicType::U32);
pub const I64: Self = Self::Alg(AlgebraicType::I64);
pub const U64: Self = Self::Alg(AlgebraicType::U64);
pub const I128: Self = Self::Alg(AlgebraicType::I128);
pub const U128: Self = Self::Alg(AlgebraicType::U128);
pub const I256: Self = Self::Alg(AlgebraicType::I256);
pub const U256: Self = Self::Alg(AlgebraicType::U256);
pub const F32: Self = Self::Alg(AlgebraicType::F32);
pub const F64: Self = Self::Alg(AlgebraicType::F64);
pub const STR: Self = Self::Alg(AlgebraicType::String);
pub fn is_compatible_with(&self, op: BinOp) -> bool {
match (op, self) {
(BinOp::And | BinOp::Or, Type::Alg(AlgebraicType::Bool)) => true,
(BinOp::And | BinOp::Or, _) => false,
(BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Gt | BinOp::Lte | BinOp::Gte, Type::Alg(t)) => {
t.is_bool()
|| t.is_integer()
|| t.is_float()
|| t.is_string()
|| t.is_bytes()
|| t.is_identity()
|| t.is_address()
}
(BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Gt | BinOp::Lte | BinOp::Gte, _) => false,
}
}
}
#[derive(Debug)]
pub struct TyCtx {
bytes: Type,
ident: Type,
types: Vec<Type>,
names: StringInterner<StringBackend>,
}
impl Default for TyCtx {
fn default() -> Self {
Self {
bytes: Type::Alg(AlgebraicType::bytes()),
ident: Type::Alg(AlgebraicType::identity()),
types: vec![],
names: StringInterner::new(),
}
}
}
#[derive(Debug, Error)]
#[error("Invalid type id {0}")]
pub struct InvalidTypeId(TyId);
impl TyCtx {
pub fn bool(&self) -> TypeWithCtx {
TypeWithCtx(&Type::BOOL, self)
}
pub fn i8(&self) -> TypeWithCtx {
TypeWithCtx(&Type::I8, self)
}
pub fn u8(&self) -> TypeWithCtx {
TypeWithCtx(&Type::U8, self)
}
pub fn i16(&self) -> TypeWithCtx {
TypeWithCtx(&Type::I16, self)
}
pub fn u16(&self) -> TypeWithCtx {
TypeWithCtx(&Type::U16, self)
}
pub fn i32(&self) -> TypeWithCtx {
TypeWithCtx(&Type::I32, self)
}
pub fn u32(&self) -> TypeWithCtx {
TypeWithCtx(&Type::U32, self)
}
pub fn i64(&self) -> TypeWithCtx {
TypeWithCtx(&Type::I64, self)
}
pub fn u64(&self) -> TypeWithCtx {
TypeWithCtx(&Type::U64, self)
}
pub fn i128(&self) -> TypeWithCtx {
TypeWithCtx(&Type::I128, self)
}
pub fn u128(&self) -> TypeWithCtx {
TypeWithCtx(&Type::U128, self)
}
pub fn i256(&self) -> TypeWithCtx {
TypeWithCtx(&Type::I256, self)
}
pub fn u256(&self) -> TypeWithCtx {
TypeWithCtx(&Type::U256, self)
}
pub fn f32(&self) -> TypeWithCtx {
TypeWithCtx(&Type::F32, self)
}
pub fn f64(&self) -> TypeWithCtx {
TypeWithCtx(&Type::F64, self)
}
pub fn str(&self) -> TypeWithCtx {
TypeWithCtx(&Type::STR, self)
}
pub fn bytes(&self) -> TypeWithCtx {
TypeWithCtx(&self.bytes, self)
}
pub fn ident(&self) -> TypeWithCtx {
TypeWithCtx(&self.ident, self)
}
pub fn try_resolve(&self, id: TyId) -> Result<TypeWithCtx, InvalidTypeId> {
match id {
TyId::BOOL => {
Ok(self.bool())
}
TyId::I8 => {
Ok(self.i8())
}
TyId::U8 => {
Ok(self.u8())
}
TyId::I16 => {
Ok(self.i16())
}
TyId::U16 => {
Ok(self.u16())
}
TyId::I32 => {
Ok(self.i32())
}
TyId::U32 => {
Ok(self.u32())
}
TyId::I64 => {
Ok(self.i64())
}
TyId::U64 => {
Ok(self.u64())
}
TyId::I128 => {
Ok(self.i128())
}
TyId::U128 => {
Ok(self.u128())
}
TyId::I256 => {
Ok(self.i256())
}
TyId::U256 => {
Ok(self.u256())
}
TyId::F32 => {
Ok(self.f32())
}
TyId::F64 => {
Ok(self.f64())
}
TyId::STR => {
Ok(self.str())
}
TyId::BYTES => {
Ok(self.bytes())
}
TyId::IDENT => {
Ok(self.ident())
}
_ => self
.types
.get(id.0 as usize - TyId::N)
.map(|ty| TypeWithCtx(ty, self))
.ok_or(InvalidTypeId(id)),
}
}
pub fn resolve_symbol(&self, id: Symbol) -> Option<&str> {
self.names.resolve(id)
}
pub fn add_algebraic_type(&mut self, ty: &AlgebraicType) -> TyId {
match ty {
AlgebraicType::Bool => {
TyId::BOOL
}
AlgebraicType::I8 => {
TyId::I8
}
AlgebraicType::U8 => {
TyId::U8
}
AlgebraicType::I16 => {
TyId::I16
}
AlgebraicType::U16 => {
TyId::U16
}
AlgebraicType::I32 => {
TyId::I32
}
AlgebraicType::U32 => {
TyId::U32
}
AlgebraicType::I64 => {
TyId::I64
}
AlgebraicType::U64 => {
TyId::U64
}
AlgebraicType::I128 => {
TyId::I128
}
AlgebraicType::U128 => {
TyId::U128
}
AlgebraicType::I256 => {
TyId::I256
}
AlgebraicType::U256 => {
TyId::U256
}
AlgebraicType::F32 => {
TyId::F32
}
AlgebraicType::F64 => {
TyId::F64
}
AlgebraicType::String => {
TyId::STR
}
AlgebraicType::Array(ty) if ty.elem_ty.is_u8() => {
TyId::BYTES
}
AlgebraicType::Product(ty) if ty.is_identity() => {
TyId::IDENT
}
_ => {
let n = self.types.len() + TyId::N;
self.types.push(Type::Alg(ty.clone()));
TyId(n as u32)
}
}
}
pub fn add_var_type(&mut self, table_id: TableId, fields: Vec<(Symbol, TyId)>) -> TyId {
let n = self.types.len() + TyId::N;
self.types.push(Type::Var(table_id, fields.into_boxed_slice()));
TyId(n as u32)
}
pub fn add_row_type(&mut self, fields: Vec<(Symbol, TyId)>) -> TyId {
let n = self.types.len() + TyId::N;
self.types.push(Type::Row(fields.into_boxed_slice()));
TyId(n as u32)
}
pub fn gen_symbol(&mut self, name: impl AsRef<str>) -> Symbol {
self.names.get_or_intern(name)
}
pub fn get_symbol(&self, name: impl AsRef<str>) -> Option<Symbol> {
self.names.get(name)
}
pub fn eq(&self, a: TyId, b: TyId) -> Result<bool, InvalidTypeId> {
if a.0 < TyId::N as u32 || b.0 < TyId::N as u32 {
return Ok(a == b);
}
match (&*self.try_resolve(a)?, &*self.try_resolve(b)?) {
(Type::Alg(a), Type::Alg(b)) => Ok(a == b),
(Type::Var(a, _), Type::Var(b, _)) => Ok(a == b),
(Type::Row(a), Type::Row(b)) => Ok(a.len() == b.len() && {
for (i, (name, id)) in a.iter().enumerate() {
if name != &b[i].0 || !self.eq(*id, b[i].1)? {
return Ok(false);
}
}
true
}),
_ => Ok(false),
}
}
}
#[derive(Debug)]
pub struct TypeWithCtx<'a>(&'a Type, &'a TyCtx);
impl Deref for TypeWithCtx<'_> {
type Target = Type;
fn deref(&self) -> &Self::Target {
self.0
}
}
impl TypeWithCtx<'_> {
pub fn expect_op(&self, op: BinOp) -> Result<(), InvalidOp> {
if self.0.is_compatible_with(op) {
return Ok(());
}
Err(InvalidOp::new(op, self))
}
pub fn expect_relvar(&self) -> Result<RelType, ExpectedRelvar> {
match self.0 {
Type::Var(_, fields) => Ok(RelType { fields }),
Type::Row(_) | Type::Alg(_) => Err(ExpectedRelvar),
}
}
pub fn expect_scalar(&self) -> Result<&AlgebraicType, ExpectedScalar> {
match self.0 {
Type::Alg(t) => Ok(t),
Type::Var(..) | Type::Row(..) => Err(ExpectedScalar),
}
}
pub fn expect_relation(&self) -> Result<RelType, ExpectedRelation> {
match self.0 {
Type::Var(_, fields) | Type::Row(fields) => Ok(RelType { fields }),
Type::Alg(_) => Err(ExpectedRelation::new(self)),
}
}
}
pub struct ExpectedRelvar;
pub struct ExpectedScalar;
impl<'a> Display for TypeWithCtx<'a> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self.0 {
Type::Alg(ty) => write!(f, "{}", fmt_algebraic_type(ty)),
Type::Var(_, fields) | Type::Row(fields) => {
const UNKNOWN: &str = "UNKNOWN";
write!(f, "(")?;
let (symbol, id) = &fields[0];
let name = self.1.resolve_symbol(*symbol).unwrap_or(UNKNOWN);
match self.1.try_resolve(*id) {
Ok(ty) => {
write!(f, "{}: {}", name, ty)?;
}
Err(_) => {
write!(f, "{}: {}", name, UNKNOWN)?;
}
};
for (symbol, id) in &fields[1..] {
let name = self.1.resolve_symbol(*symbol).unwrap_or(UNKNOWN);
match self.1.try_resolve(*id) {
Ok(ty) => {
write!(f, "{}: {}", name, ty)?;
}
Err(_) => {
write!(f, "{}: {}", name, UNKNOWN)?;
}
};
}
write!(f, ")")
}
}
}
}
#[derive(Debug)]
pub struct RelType<'a> {
fields: &'a [(Symbol, TyId)],
}
impl<'a> RelType<'a> {
pub fn iter(&'a self) -> impl Iterator<Item = (usize, Symbol, TyId)> + '_ {
self.fields.iter().enumerate().map(|(i, (name, ty))| (i, *name, *ty))
}
pub fn find(&'a self, name: Symbol) -> Option<(usize, TyId)> {
self.iter()
.find(|(_, field, _)| *field == name)
.map(|(i, _, ty)| (i, ty))
}
}
#[derive(Debug, Clone, Default)]
pub struct TyEnv(HashMap<Symbol, TyId>);
impl TyEnv {
pub fn add(&mut self, name: Symbol, ty: TyId) -> Option<TyId> {
self.0.insert(name, ty)
}
pub fn find(&self, name: Symbol) -> Option<TyId> {
self.0.get(&name).copied()
}
}