use lazy_static::lazy_static;
use snafu::prelude::*;
use std::any::Any;
use std::collections::BTreeMap;
use std::fmt;
use std::sync::Arc;
use super::{
error::*,
inference::mkcref,
inference::{CRef, Constrainable},
schema::*,
sql::get_rowtype,
Compiler,
};
use crate::ast::SourceLocation;
use crate::compile::coerce::{coerce_types, CoerceOp};
use crate::compile::sql::combine_crefs;
use crate::runtime;
use crate::types;
use crate::types::{AtomicType, Type};
pub trait Generic: Send + Sync + fmt::Debug {
fn as_any(&self) -> &dyn Any;
fn name(&self) -> &Ident;
fn to_runtime_type(&self) -> runtime::error::Result<Type>;
fn substitute(&self, variables: &BTreeMap<Ident, CRef<MType>>) -> Result<Arc<dyn Generic>>;
fn unify(&self, other: &MType) -> Result<()>;
fn get_rowtype(&self, _compiler: crate::compile::Compiler) -> Result<Option<CRef<MType>>> {
Ok(None)
}
fn resolve(&self, loc: &SourceLocation) -> Result<CRef<MType>>;
}
pub trait GenericConstructor: Send + Sync {
fn static_name() -> &'static Ident;
fn new(loc: &SourceLocation, args: Vec<CRef<MType>>) -> Result<Arc<dyn Generic>>;
}
pub fn as_generic<T: Generic + 'static>(g: &dyn Generic) -> Option<&T> {
g.as_any().downcast_ref::<T>()
}
fn debug_fmt_generic(
f: &mut std::fmt::Formatter<'_>,
name: &Ident,
args: Vec<CRef<MType>>,
) -> std::fmt::Result {
write!(f, "{}<", name)?;
for (i, arg) in args.into_iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
std::fmt::Debug::fmt(&arg, f)?;
}
write!(f, ">")
}
fn validate_args(
loc: &SourceLocation,
args: &Vec<CRef<MType>>,
num: usize,
name: &Ident,
) -> Result<()> {
if args.len() != num {
return Err(CompileError::internal(
loc.clone(),
format!("{} expects {} argument", name, num).as_str(),
));
}
Ok(())
}
pub trait GenericFactory: Send + Sync {
fn new(&self, loc: &SourceLocation, args: Vec<CRef<MType>>) -> Result<Arc<dyn Generic>>;
fn name(&self) -> &Ident;
}
pub struct BuiltinGeneric<T: GenericConstructor>(std::marker::PhantomData<T>);
impl<T: GenericConstructor + 'static> BuiltinGeneric<T> {
pub fn constructor() -> Box<dyn GenericFactory> {
Box::new(BuiltinGeneric(std::marker::PhantomData::<T>)) as Box<dyn GenericFactory>
}
}
impl<T: GenericConstructor> GenericFactory for BuiltinGeneric<T> {
fn new(&self, loc: &SourceLocation, args: Vec<CRef<MType>>) -> Result<Arc<dyn Generic>> {
T::new(loc, args)
}
fn name(&self) -> &Ident {
T::static_name()
}
}
lazy_static! {
pub static ref SUM_GENERIC_NAME: Ident = "SumAgg".into();
pub static ref EXTERNAL_GENERIC_NAME: Ident = "External".into();
pub static ref CONNECTION_GENERIC_NAME: Ident = "Connection".into();
pub static ref COERCE_GENERIC_NAME: Ident = "Coerce".into();
pub static ref GLOBAL_GENERICS: BTreeMap<Ident, Box<dyn GenericFactory>> = [
BuiltinGeneric::<SumGeneric>::constructor(),
BuiltinGeneric::<ExternalType>::constructor(),
BuiltinGeneric::<ConnectionType>::constructor(),
]
.into_iter()
.map(|builder| (builder.name().clone(), builder))
.collect::<BTreeMap<Ident, Box<dyn GenericFactory>>>();
}
fn resolve_to_runtime_type<G: Generic + Clone + 'static>(
loc: &SourceLocation,
args: Vec<CRef<MType>>,
g: &G,
) -> Result<CRef<MType>> {
let loc = loc.clone();
let resolved_args = args
.clone()
.iter()
.map(|arg| arg.then(|a: Ref<MType>| a.read()?.resolve_generics()))
.collect::<Result<Vec<_>>>()?;
let this = g.clone();
combine_crefs(resolved_args)?.then(move |args: Ref<Vec<Ref<MType>>>| {
for arg in &*args.read()? {
if matches!(arg.read()?.to_runtime_type(), Err(_)) {
return Ok(mkcref(MType::Generic(Located::new(
Arc::new(this.clone()),
SourceLocation::Unknown,
))));
}
}
let rt = this
.to_runtime_type()
.context(RuntimeSnafu { loc: loc.clone() })?;
Ok(mkcref(MType::from_runtime_type(&rt)?))
})
}
#[derive(Clone)]
pub struct SumGeneric(CRef<MType>);
impl GenericConstructor for SumGeneric {
fn new(loc: &SourceLocation, mut args: Vec<CRef<MType>>) -> Result<Arc<dyn Generic>> {
validate_args(loc, &args, 1, Self::static_name())?;
Ok(Arc::new(SumGeneric(args.swap_remove(0))))
}
fn static_name() -> &'static Ident {
&SUM_GENERIC_NAME
}
}
impl std::fmt::Debug for SumGeneric {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
debug_fmt_generic(f, Self::static_name(), vec![self.0.clone()])
}
}
impl Generic for SumGeneric {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &Ident {
Self::static_name()
}
fn to_runtime_type(&self) -> runtime::error::Result<types::Type> {
let arg = self.0.must()?.read()?.to_runtime_type()?;
match &arg {
Type::Atom(at) => Ok(Type::Atom(match &at {
AtomicType::Int8
| AtomicType::Int16
| AtomicType::Int32
| AtomicType::Int64
| AtomicType::UInt8
| AtomicType::UInt16
| AtomicType::UInt32
| AtomicType::UInt64 => AtomicType::Decimal128(38, 0),
AtomicType::Float32 | AtomicType::Float64 => AtomicType::Float64,
AtomicType::Decimal128(..) | AtomicType::Decimal256(..) => at.clone(),
_ => {
return Err(runtime::error::RuntimeError::new(
format!(
"sum(): expected argument to be a numeric ype, got {:?}",
arg
)
.as_str(),
))
}
})),
_ => {
return Err(runtime::error::RuntimeError::new(
format!(
"sum(): expected argument to be an atomic type, got {:?}",
arg
)
.as_str(),
))
}
}
}
fn substitute(&self, variables: &BTreeMap<Ident, CRef<MType>>) -> Result<Arc<dyn Generic>> {
Ok(Arc::new(Self(self.0.substitute(variables)?)))
}
fn unify(&self, other: &MType) -> Result<()> {
let other = other.clone();
let arg = self.0.clone();
arg.constrain(move |arg: Ref<MType>| {
let arg = arg.read()?.to_runtime_type().context(RuntimeSnafu {
loc: ErrorLocation::Unknown,
})?;
let final_type = MType::from_runtime_type(&arg)?;
other.unify(&final_type)
})?;
Ok(())
}
fn resolve(&self, loc: &SourceLocation) -> Result<CRef<MType>> {
resolve_to_runtime_type(loc, vec![self.0.clone()], self)
}
}
#[derive(Clone)]
pub struct CoerceGeneric {
loc: SourceLocation,
op: CoerceOp,
args: Vec<CRef<MType>>,
}
impl CoerceGeneric {
pub fn new(loc: SourceLocation, op: CoerceOp, args: Vec<CRef<MType>>) -> CoerceGeneric {
CoerceGeneric { loc, op, args }
}
fn static_name() -> &'static Ident {
&COERCE_GENERIC_NAME
}
}
fn coerce_list(loc: SourceLocation, op: CoerceOp, args: Vec<CRef<MType>>) -> Result<types::Type> {
if args.len() == 0 {
return Ok(types::Type::Atom(AtomicType::Null));
}
let resolved_args = args
.iter()
.map(|arg| -> Result<MType> {
Ok(arg
.must()
.context(RuntimeSnafu { loc: loc.clone() })?
.read()?
.clone())
})
.collect::<Result<Vec<_>>>()?;
let runtime_args = resolved_args
.iter()
.map(|arg| Ok(arg.to_runtime_type()?))
.collect::<runtime::error::Result<Vec<_>>>()
.context(RuntimeSnafu { loc: loc.clone() })?;
let mut ret = runtime_args[0].clone();
for arg in &runtime_args[1..] {
ret = coerce_types(&ret, &op, &arg)
.ok_or_else(|| CompileError::coercion(loc.clone(), &resolved_args.as_slice()))?;
}
Ok(ret)
}
impl std::fmt::Debug for CoerceGeneric {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
debug_fmt_generic(f, Self::static_name(), self.args.clone())
}
}
impl Generic for CoerceGeneric {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &Ident {
Self::static_name()
}
fn substitute(&self, variables: &BTreeMap<Ident, CRef<MType>>) -> Result<Arc<dyn Generic>> {
Ok(Arc::new(Self {
loc: self.loc.clone(),
op: self.op.clone(),
args: self
.args
.iter()
.map(|arg| arg.substitute(variables))
.collect::<Result<_>>()?,
}))
}
fn to_runtime_type(&self) -> runtime::error::Result<types::Type> {
Ok(coerce_list(
self.loc.clone(),
self.op.clone(),
self.args.clone(),
)?)
}
fn unify(&self, other: &MType) -> Result<()> {
let loc = self.loc.clone();
let op = self.op.clone();
let args = self.args.clone();
let other = other.clone();
combine_crefs(self.args.clone())?.constrain(move |_: Ref<Vec<Ref<MType>>>| {
let coerced_type = coerce_list(loc.clone(), op.clone(), args.clone())?;
MType::unify(&other, &MType::from_runtime_type(&coerced_type)?)?;
Ok(())
})?;
Ok(())
}
fn resolve(&self, loc: &SourceLocation) -> Result<CRef<MType>> {
resolve_to_runtime_type(loc, self.args.clone(), self)
}
}
#[derive(Clone)]
pub struct ExternalType(CRef<MType>);
impl ExternalType {
pub fn inner_type(&self) -> CRef<MType> {
self.0.clone()
}
}
impl GenericConstructor for ExternalType {
fn new(loc: &SourceLocation, mut args: Vec<CRef<MType>>) -> Result<Arc<dyn Generic>> {
validate_args(loc, &args, 1, Self::static_name())?;
Ok(Arc::new(ExternalType(args.swap_remove(0))))
}
fn static_name() -> &'static Ident {
&EXTERNAL_GENERIC_NAME
}
}
impl std::fmt::Debug for ExternalType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
debug_fmt_generic(f, Self::static_name(), vec![self.0.clone()])
}
}
impl Generic for ExternalType {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &Ident {
Self::static_name()
}
fn to_runtime_type(&self) -> crate::runtime::error::Result<crate::types::Type> {
self.0.must()?.read()?.to_runtime_type()
}
fn substitute(&self, variables: &BTreeMap<Ident, CRef<MType>>) -> Result<Arc<dyn Generic>> {
Ok(Arc::new(Self(self.0.substitute(variables)?)))
}
fn unify(&self, other: &MType) -> Result<()> {
let inner_type = &self.0;
match other {
MType::Generic(other_inner) => {
if let Some(other) = as_generic::<Self>(other_inner.get().as_ref()) {
inner_type.unify(&other.0)?;
} else {
inner_type.unify(&mkcref(other.clone()))?;
}
}
other => {
inner_type.unify(&mkcref(other.clone()))?;
}
};
Ok(())
}
fn get_rowtype(&self, compiler: Compiler) -> Result<Option<CRef<MType>>> {
Ok(Some(get_rowtype(compiler, self.0.clone())?))
}
fn resolve(&self, loc: &SourceLocation) -> Result<CRef<MType>> {
resolve_to_runtime_type(loc, vec![self.0.clone()], self)
}
}
#[derive(Clone)]
pub struct ConnectionType();
impl GenericConstructor for ConnectionType {
fn new(loc: &SourceLocation, args: Vec<CRef<MType>>) -> Result<Arc<dyn Generic>> {
validate_args(loc, &args, 0, Self::static_name())?;
Ok(Arc::new(ConnectionType()))
}
fn static_name() -> &'static Ident {
&CONNECTION_GENERIC_NAME
}
}
impl std::fmt::Debug for ConnectionType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}<>", Self::static_name())
}
}
impl Generic for ConnectionType {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &Ident {
Self::static_name()
}
fn to_runtime_type(&self) -> crate::runtime::error::Result<crate::types::Type> {
Ok(crate::types::Type::Atom(crate::types::AtomicType::Utf8))
}
fn substitute(&self, _variables: &BTreeMap<Ident, CRef<MType>>) -> Result<Arc<dyn Generic>> {
Ok(Arc::new(Self()))
}
fn unify(&self, other: &MType) -> Result<()> {
match other {
MType::Generic(other_inner) => {
if let Some(_) = as_generic::<Self>(other_inner.get().as_ref()) {
return Ok(());
}
}
_ => {}
};
return Err(CompileError::invalid_conn(
SourceLocation::Unknown,
"cannot unify a connection type with another kind of type",
));
}
fn get_rowtype(&self, _compiler: Compiler) -> Result<Option<CRef<MType>>> {
Err(CompileError::internal(
SourceLocation::Unknown,
"connection type should have been optimized away",
))
}
fn resolve(&self, loc: &SourceLocation) -> Result<CRef<MType>> {
resolve_to_runtime_type(loc, vec![], self)
}
}