use std::sync::Arc;
use tatara_lisp::Span;
use crate::error::{EvalError, Result};
use crate::value::Value;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Arity {
Exact(usize),
AtLeast(usize),
Range(usize, usize),
Any,
}
impl Arity {
pub fn check(&self, got: usize) -> std::result::Result<(), String> {
match *self {
Self::Exact(n) if got == n => Ok(()),
Self::Exact(n) => Err(format!("expected exactly {n}, got {got}")),
Self::AtLeast(n) if got >= n => Ok(()),
Self::AtLeast(n) => Err(format!("expected at least {n}, got {got}")),
Self::Range(lo, hi) if got >= lo && got <= hi => Ok(()),
Self::Range(lo, hi) => Err(format!("expected {lo}..={hi}, got {got}")),
Self::Any => Ok(()),
}
}
}
pub trait NativeCallable<H>: Send + Sync + 'static {
fn call(&self, args: &[Value], host: &mut H, call_span: Span) -> Result<Value>;
}
impl<H, F> NativeCallable<H> for F
where
F: Fn(&[Value], &mut H, Span) -> Result<Value> + Send + Sync + 'static,
{
fn call(&self, args: &[Value], host: &mut H, call_span: Span) -> Result<Value> {
(self)(args, host, call_span)
}
}
pub trait HigherOrderCallable<H>: Send + Sync + 'static {
fn call(
&self,
args: &[Value],
host: &mut H,
caller: &Caller<H>,
call_span: Span,
) -> Result<Value>;
}
impl<H, F> HigherOrderCallable<H> for F
where
F: Fn(&[Value], &mut H, &Caller<H>, Span) -> Result<Value> + Send + Sync + 'static,
{
fn call(
&self,
args: &[Value],
host: &mut H,
caller: &Caller<H>,
call_span: Span,
) -> Result<Value> {
(self)(args, host, caller, call_span)
}
}
pub struct Caller<'a, H> {
pub(crate) registry: &'a FnRegistry<H>,
pub(crate) expander: &'a tatara_lisp::SpannedExpander,
}
impl<'a, H: 'static> Caller<'a, H> {
pub fn apply_value(
&self,
callee: &Value,
args: Vec<Value>,
host: &mut H,
call_span: Span,
) -> Result<Value> {
crate::eval::apply_external(callee, args, call_span, self.registry, self.expander, host)
}
pub fn expander(&self) -> &tatara_lisp::SpannedExpander {
self.expander
}
pub fn call1(&self, f: &Value, x: Value, host: &mut H, span: Span) -> Result<Value> {
self.apply_value(f, vec![x], host, span)
}
pub fn call2(&self, f: &Value, a: Value, b: Value, host: &mut H, span: Span) -> Result<Value> {
self.apply_value(f, vec![a, b], host, span)
}
}
pub(crate) enum FnImpl<H> {
Native(Arc<dyn NativeCallable<H>>),
Higher(Arc<dyn HigherOrderCallable<H>>),
}
impl<H> Clone for FnImpl<H> {
fn clone(&self) -> Self {
match self {
Self::Native(f) => Self::Native(Arc::clone(f)),
Self::Higher(f) => Self::Higher(Arc::clone(f)),
}
}
}
pub(crate) struct FnRegistry<H> {
entries: Vec<FnEntry<H>>,
}
pub(crate) struct FnEntry<H> {
pub name: Arc<str>,
#[allow(dead_code)]
pub arity: Arity,
pub callable: FnImpl<H>,
}
impl<H> Default for FnRegistry<H> {
fn default() -> Self {
Self {
entries: Vec::new(),
}
}
}
impl<H> FnRegistry<H> {
pub(crate) fn new() -> Self {
Self::default()
}
pub(crate) fn insert(&mut self, entry: FnEntry<H>) {
if let Some(slot) = self.entries.iter_mut().find(|e| e.name == entry.name) {
*slot = entry;
} else {
self.entries.push(entry);
}
}
pub(crate) fn lookup(&self, name: &str) -> Option<&FnEntry<H>> {
self.entries.iter().find(|e| &*e.name == name)
}
}
pub trait FromValue: Sized {
fn from_value(v: &Value, at: Span) -> Result<Self>;
}
impl FromValue for Value {
fn from_value(v: &Value, _at: Span) -> Result<Self> {
Ok(v.clone())
}
}
impl FromValue for i64 {
fn from_value(v: &Value, at: Span) -> Result<Self> {
match v {
Value::Int(n) => Ok(*n),
other => Err(EvalError::type_mismatch("integer", other.type_name(), at)),
}
}
}
impl FromValue for f64 {
fn from_value(v: &Value, at: Span) -> Result<Self> {
match v {
Value::Int(n) => Ok(*n as f64),
Value::Float(n) => Ok(*n),
other => Err(EvalError::type_mismatch("number", other.type_name(), at)),
}
}
}
impl FromValue for bool {
fn from_value(v: &Value, at: Span) -> Result<Self> {
match v {
Value::Bool(b) => Ok(*b),
other => Err(EvalError::type_mismatch("bool", other.type_name(), at)),
}
}
}
impl FromValue for String {
fn from_value(v: &Value, at: Span) -> Result<Self> {
match v {
Value::Str(s) => Ok(s.to_string()),
other => Err(EvalError::type_mismatch("string", other.type_name(), at)),
}
}
}
impl FromValue for Arc<str> {
fn from_value(v: &Value, at: Span) -> Result<Self> {
match v {
Value::Str(s) => Ok(s.clone()),
Value::Symbol(s) => Ok(s.clone()),
Value::Keyword(s) => Ok(s.clone()),
other => Err(EvalError::type_mismatch(
"string/symbol",
other.type_name(),
at,
)),
}
}
}
impl FromValue for Vec<Value> {
fn from_value(v: &Value, at: Span) -> Result<Self> {
match v {
Value::Nil => Ok(Vec::new()),
Value::List(xs) => Ok(xs.as_ref().clone()),
other => Err(EvalError::type_mismatch("list", other.type_name(), at)),
}
}
}
impl<T: FromValue> FromValue for Option<T> {
fn from_value(v: &Value, at: Span) -> Result<Self> {
match v {
Value::Nil => Ok(None),
other => T::from_value(other, at).map(Some),
}
}
}
pub trait IntoValue {
fn into_value(self) -> Value;
}
impl IntoValue for Value {
fn into_value(self) -> Value {
self
}
}
impl IntoValue for () {
fn into_value(self) -> Value {
Value::Nil
}
}
impl IntoValue for bool {
fn into_value(self) -> Value {
Value::Bool(self)
}
}
impl IntoValue for i64 {
fn into_value(self) -> Value {
Value::Int(self)
}
}
impl IntoValue for f64 {
fn into_value(self) -> Value {
Value::Float(self)
}
}
impl IntoValue for String {
fn into_value(self) -> Value {
Value::Str(Arc::from(self))
}
}
impl IntoValue for &str {
fn into_value(self) -> Value {
Value::Str(Arc::from(self))
}
}
impl IntoValue for Arc<str> {
fn into_value(self) -> Value {
Value::Str(self)
}
}
impl<T: IntoValue> IntoValue for Option<T> {
fn into_value(self) -> Value {
match self {
None => Value::Nil,
Some(x) => x.into_value(),
}
}
}
impl<T: IntoValue> IntoValue for Vec<T> {
fn into_value(self) -> Value {
Value::list(self.into_iter().map(IntoValue::into_value))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn arity_check() {
assert!(Arity::Exact(2).check(2).is_ok());
assert!(Arity::Exact(2).check(3).is_err());
assert!(Arity::AtLeast(1).check(5).is_ok());
assert!(Arity::AtLeast(1).check(0).is_err());
assert!(Arity::Range(1, 3).check(2).is_ok());
assert!(Arity::Range(1, 3).check(4).is_err());
assert!(Arity::Any.check(0).is_ok());
assert!(Arity::Any.check(1000).is_ok());
}
#[test]
fn from_value_round_trips_primitives() {
let sp = Span::synthetic();
assert_eq!(i64::from_value(&Value::Int(42), sp).unwrap(), 42);
assert_eq!(f64::from_value(&Value::Float(1.5), sp).unwrap(), 1.5);
assert!(bool::from_value(&Value::Bool(true), sp).unwrap());
assert_eq!(
String::from_value(&Value::Str(Arc::from("hi")), sp).unwrap(),
"hi"
);
}
#[test]
fn from_value_int_to_float_coerces() {
let sp = Span::synthetic();
assert_eq!(f64::from_value(&Value::Int(3), sp).unwrap(), 3.0);
}
#[test]
fn from_value_option_nil_is_none() {
let sp = Span::synthetic();
assert_eq!(
<Option<i64> as FromValue>::from_value(&Value::Nil, sp).unwrap(),
None
);
assert_eq!(
<Option<i64> as FromValue>::from_value(&Value::Int(7), sp).unwrap(),
Some(7)
);
}
#[test]
fn from_value_type_mismatch_reports_expected_kind() {
let sp = Span::synthetic();
let err = i64::from_value(&Value::Str(Arc::from("x")), sp).unwrap_err();
assert!(matches!(
err,
EvalError::TypeMismatch {
expected: "integer",
..
}
));
}
#[test]
fn into_value_round_trips() {
assert!(matches!(42i64.into_value(), Value::Int(42)));
assert!(matches!(true.into_value(), Value::Bool(true)));
assert!(matches!(().into_value(), Value::Nil));
match String::from("hello").into_value() {
Value::Str(s) => assert_eq!(&*s, "hello"),
other => panic!("{other:?}"),
}
}
#[test]
fn into_value_vec_produces_list() {
let v: Vec<i64> = vec![1, 2, 3];
match v.into_value() {
Value::List(xs) => {
assert_eq!(xs.len(), 3);
assert!(matches!(&xs[0], Value::Int(1)));
}
other => panic!("{other:?}"),
}
}
}