use polar_core::terms::Term;
use std::any::TypeId;
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use crate::errors::{InvalidCallError, OsoError};
use super::class_method::{AttributeGetter, ClassMethod, Constructor, InstanceMethod};
use super::from_polar::FromPolarList;
use super::method::{Function, Method};
use super::to_polar::ToPolarResults;
use super::Host;
type Attributes = HashMap<&'static str, AttributeGetter>;
type ClassMethods = HashMap<&'static str, ClassMethod>;
type InstanceMethods = HashMap<&'static str, InstanceMethod>;
fn equality_not_supported(
) -> Box<dyn Fn(&Host, &Instance, &Instance) -> crate::Result<bool> + Send + Sync> {
let eq = move |host: &Host, lhs: &Instance, _: &Instance| -> crate::Result<bool> {
Err(OsoError::UnsupportedOperation {
operation: String::from("equals"),
type_name: lhs.name(host).to_owned(),
})
};
Box::new(eq)
}
#[derive(Clone)]
pub struct Class {
pub name: String,
pub type_id: TypeId,
constructor: Option<Constructor>,
attributes: Attributes,
instance_methods: InstanceMethods,
class_methods: ClassMethods,
class_check: Arc<dyn Fn(TypeId) -> bool + Send + Sync>,
equality_check: Arc<dyn Fn(&Host, &Instance, &Instance) -> crate::Result<bool> + Send + Sync>,
}
impl Class {
pub fn builder<T: 'static>() -> ClassBuilder<T> {
ClassBuilder::new()
}
pub fn init(&self, fields: Vec<Term>, host: &mut Host) -> crate::Result<Instance> {
if let Some(constructor) = &self.constructor {
constructor.invoke(fields, host)
} else {
Err(crate::OsoError::Custom {
message: format!("MissingConstructorError: {} has no constructor", self.name),
})
}
}
pub fn call(
&self,
attr: &str,
args: Vec<Term>,
host: &mut Host,
) -> crate::Result<super::to_polar::PolarResultIter> {
let attr =
self.class_methods
.get(attr)
.ok_or_else(|| InvalidCallError::ClassMethodNotFound {
method_name: attr.to_owned(),
type_name: self.name.clone(),
})?;
attr.clone().invoke(args, host)
}
fn get_method(&self, name: &str) -> Option<InstanceMethod> {
tracing::trace!({class=%self.name, name}, "get_method");
if self.type_id == TypeId::of::<Class>() {
Some(InstanceMethod::from_class_method(name.to_string()))
} else {
self.instance_methods.get(name).cloned()
}
}
fn equals(&self, host: &Host, lhs: &Instance, rhs: &Instance) -> crate::Result<bool> {
(self.equality_check)(host, lhs, rhs)
}
}
#[derive(Clone)]
pub struct ClassBuilder<T> {
class: Class,
ty: std::marker::PhantomData<T>,
}
impl<T> ClassBuilder<T>
where
T: 'static,
{
fn new() -> Self {
let fq_name = std::any::type_name::<T>().to_string();
let short_name = fq_name.split("::").last().expect("type has invalid name");
Self {
class: Class {
name: short_name.to_string(),
constructor: None,
attributes: HashMap::new(),
instance_methods: InstanceMethods::new(),
class_methods: ClassMethods::new(),
class_check: Arc::new(|type_id| TypeId::of::<T>() == type_id),
equality_check: Arc::from(equality_not_supported()),
type_id: TypeId::of::<T>(),
},
ty: std::marker::PhantomData,
}
}
pub fn with_default() -> Self
where
T: std::default::Default,
T: Send + Sync,
{
Self::with_constructor::<_, _>(T::default)
}
pub fn with_constructor<F, Args>(f: F) -> Self
where
F: Function<Args, Result = T>,
T: Send + Sync,
Args: FromPolarList,
{
let mut class: ClassBuilder<T> = ClassBuilder::new();
class = class.set_constructor(f);
class
}
pub fn set_constructor<F, Args>(mut self, f: F) -> Self
where
F: Function<Args, Result = T>,
T: Send + Sync,
Args: FromPolarList,
{
self.class.constructor = Some(Constructor::new(f));
self
}
pub fn set_equality_check<F>(mut self, f: F) -> Self
where
F: Fn(&T, &T) -> bool + Send + Sync + 'static,
{
self.class.equality_check = Arc::new(move |host, a, b| {
tracing::trace!("equality check");
let a = a.downcast(Some(host)).map_err(|e| e.user())?;
let b = b.downcast(Some(host)).map_err(|e| e.user())?;
Ok((f)(a, b))
});
self
}
pub fn with_equality_check(self) -> Self
where
T: PartialEq<T>,
{
self.set_equality_check(|a, b| PartialEq::eq(a, b))
}
pub fn add_attribute_getter<F, R>(mut self, name: &'static str, f: F) -> Self
where
F: Fn(&T) -> R + Send + Sync + 'static,
R: crate::ToPolar,
T: 'static,
{
self.class.attributes.insert(name, AttributeGetter::new(f));
self
}
pub fn name(mut self, name: &str) -> Self {
self.class.name = name.to_string();
self
}
pub fn add_method<F, Args, R>(mut self, name: &'static str, f: F) -> Self
where
Args: FromPolarList,
F: Method<T, Args, Result = R>,
R: ToPolarResults + 'static,
{
self.class
.instance_methods
.insert(name, InstanceMethod::new(f));
self
}
pub fn add_iterator_method<F, Args, I>(mut self, name: &'static str, f: F) -> Self
where
Args: FromPolarList,
F: Method<T, Args>,
F::Result: IntoIterator<Item = I>,
<<F as Method<T, Args>>::Result as IntoIterator>::IntoIter: Sized + 'static,
I: ToPolarResults + 'static,
T: 'static,
{
self.class
.instance_methods
.insert(name, InstanceMethod::new_iterator(f));
self
}
pub fn add_class_method<F, Args, R>(mut self, name: &'static str, f: F) -> Self
where
F: Function<Args, Result = R>,
Args: FromPolarList,
R: ToPolarResults + 'static,
{
self.class.class_methods.insert(name, ClassMethod::new(f));
self
}
pub fn build(self) -> Class {
self.class
}
}
#[derive(Clone)]
pub struct Instance {
inner: Arc<dyn std::any::Any + Send + Sync>,
debug_type_name: &'static str,
}
impl fmt::Debug for Instance {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Instance<{}>", self.debug_type_name)
}
}
impl Instance {
pub fn new<T: Send + Sync + 'static>(instance: T) -> Self {
Self {
inner: Arc::new(instance),
debug_type_name: std::any::type_name::<T>(),
}
}
pub fn instance_of(&self, class: &Class) -> bool {
self.inner.as_ref().type_id() == class.type_id
}
pub fn class<'a>(&self, host: &'a Host) -> crate::Result<&'a Class> {
host.get_class_by_type_id(self.inner.as_ref().type_id())
.map_err(|_| OsoError::MissingClassError {
name: self.name(&host).to_owned(),
})
}
pub fn name<'a>(&self, host: &'a Host) -> &'a str {
self.class(host)
.map(|class| class.name.as_ref())
.unwrap_or_else(|_| self.debug_type_name)
}
pub fn get_attr(&self, name: &str, host: &mut Host) -> crate::Result<Term> {
tracing::trace!({ method = %name }, "get_attr");
let attr = self
.class(host)
.and_then(|c| {
c.attributes.get(name).ok_or_else(|| {
InvalidCallError::AttributeNotFound {
attribute_name: name.to_owned(),
type_name: self.name(&host).to_owned(),
}
.into()
})
})?
.clone();
attr.invoke(self, host)
}
pub fn call(
&self,
name: &str,
args: Vec<Term>,
host: &mut Host,
) -> crate::Result<super::to_polar::PolarResultIter> {
tracing::trace!({method = %name, ?args}, "call");
let method = self.class(host).and_then(|c| {
c.get_method(name).ok_or_else(|| {
InvalidCallError::MethodNotFound {
method_name: name.to_owned(),
type_name: self.name(&host).to_owned(),
}
.into()
})
})?;
method.invoke(self, args, host)
}
pub fn equals(&self, other: &Self, host: &Host) -> crate::Result<bool> {
tracing::trace!("equals");
self.class(host)
.and_then(|class| class.equals(host, &self, other))
}
pub fn downcast<T: 'static>(
&self,
host: Option<&Host>,
) -> Result<&T, crate::errors::TypeError> {
let name = host
.map(|h| self.name(h).to_owned())
.unwrap_or_else(|| self.debug_type_name.to_owned());
let expected_name = host
.and_then(|h| {
h.get_class_by_type_id(std::any::TypeId::of::<T>())
.map(|class| class.name.clone())
.ok()
})
.unwrap_or_else(|| std::any::type_name::<T>().to_owned());
self.inner
.as_ref()
.downcast_ref()
.ok_or_else(|| crate::errors::TypeError::expected(expected_name).got(name))
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_instance_of() {
struct Foo {}
struct Bar {}
let foo_class = Class::builder::<Foo>().build();
let bar_class = Class::builder::<Bar>().build();
let foo_instance = Instance::new(Foo {});
assert!(foo_instance.instance_of(&foo_class));
assert!(!foo_instance.instance_of(&bar_class));
}
}