use hashbrown::HashMap;
use core::fmt;
use crate::{
alloc::{Rc, String, Vec},
arith::OrdArithmetic,
error::{Backtrace, CodeInModule},
executable::ExecutableFn,
Error, ErrorKind, EvalResult, ModuleId, SpannedValue, Value,
};
use arithmetic_parser::{LvalueLen, MaybeSpanned, StripCode};
#[derive(Debug)]
pub struct CallContext<'r, 'a, T> {
call_span: CodeInModule<'a>,
backtrace: Option<&'r mut Backtrace<'a>>,
arithmetic: &'r dyn OrdArithmetic<T>,
}
impl<'r, 'a, T> CallContext<'r, 'a, T> {
pub fn mock(
module_id: &dyn ModuleId,
call_span: MaybeSpanned<'a>,
arithmetic: &'r dyn OrdArithmetic<T>,
) -> Self {
Self {
call_span: CodeInModule::new(module_id, call_span),
backtrace: None,
arithmetic,
}
}
pub(crate) fn new(
call_span: CodeInModule<'a>,
backtrace: Option<&'r mut Backtrace<'a>>,
arithmetic: &'r dyn OrdArithmetic<T>,
) -> Self {
Self {
call_span,
backtrace,
arithmetic,
}
}
pub(crate) fn backtrace(&mut self) -> Option<&mut Backtrace<'a>> {
self.backtrace.as_deref_mut()
}
pub(crate) fn arithmetic(&self) -> &'r dyn OrdArithmetic<T> {
self.arithmetic
}
pub fn call_span(&self) -> &CodeInModule<'a> {
&self.call_span
}
pub fn apply_call_span<U>(&self, value: U) -> MaybeSpanned<'a, U> {
self.call_span.code().copy_with_extra(value)
}
pub fn call_site_error(&self, error: ErrorKind) -> Error<'a> {
Error::from_parts(self.call_span.clone(), error)
}
pub fn check_args_count(
&self,
args: &[SpannedValue<'a, T>],
expected_count: impl Into<LvalueLen>,
) -> Result<(), Error<'a>> {
let expected_count = expected_count.into();
if expected_count.matches(args.len()) {
Ok(())
} else {
Err(self.call_site_error(ErrorKind::ArgsLenMismatch {
def: expected_count,
call: args.len(),
}))
}
}
}
pub trait NativeFn<T> {
fn evaluate<'a>(
&self,
args: Vec<SpannedValue<'a, T>>,
context: &mut CallContext<'_, 'a, T>,
) -> EvalResult<'a, T>;
}
impl<T, F: 'static> NativeFn<T> for F
where
F: for<'a> Fn(Vec<SpannedValue<'a, T>>, &mut CallContext<'_, 'a, T>) -> EvalResult<'a, T>,
{
fn evaluate<'a>(
&self,
args: Vec<SpannedValue<'a, T>>,
context: &mut CallContext<'_, 'a, T>,
) -> EvalResult<'a, T> {
self(args, context)
}
}
impl<T> fmt::Debug for dyn NativeFn<T> {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.debug_tuple("NativeFn").finish()
}
}
impl<T> dyn NativeFn<T> {
pub(crate) fn data_ptr(&self) -> *const () {
(self as *const dyn NativeFn<T>).cast()
}
}
#[derive(Debug)]
pub struct InterpretedFn<'a, T> {
definition: Rc<ExecutableFn<'a, T>>,
captures: Vec<Value<'a, T>>,
capture_names: Vec<String>,
}
impl<T: Clone> Clone for InterpretedFn<'_, T> {
fn clone(&self) -> Self {
Self {
definition: Rc::clone(&self.definition),
captures: self.captures.clone(),
capture_names: self.capture_names.clone(),
}
}
}
impl<T: 'static + Clone> StripCode for InterpretedFn<'_, T> {
type Stripped = InterpretedFn<'static, T>;
fn strip_code(self) -> Self::Stripped {
InterpretedFn {
definition: Rc::new(self.definition.to_stripped_code()),
captures: self
.captures
.into_iter()
.map(StripCode::strip_code)
.collect(),
capture_names: self.capture_names,
}
}
}
impl<'a, T> InterpretedFn<'a, T> {
pub(crate) fn new(
definition: Rc<ExecutableFn<'a, T>>,
captures: Vec<Value<'a, T>>,
capture_names: Vec<String>,
) -> Self {
Self {
definition,
captures,
capture_names,
}
}
pub fn module_id(&self) -> &dyn ModuleId {
self.definition.inner.id()
}
pub fn arg_count(&self) -> LvalueLen {
self.definition.arg_count
}
pub fn captures(&self) -> HashMap<&str, &Value<'a, T>> {
self.capture_names
.iter()
.zip(&self.captures)
.map(|(name, val)| (name.as_str(), val))
.collect()
}
}
impl<T: 'static + Clone> InterpretedFn<'_, T> {
fn to_stripped_code(&self) -> InterpretedFn<'static, T> {
self.clone().strip_code()
}
}
impl<'a, T: Clone> InterpretedFn<'a, T> {
pub fn evaluate(
&self,
args: Vec<SpannedValue<'a, T>>,
ctx: &mut CallContext<'_, 'a, T>,
) -> EvalResult<'a, T> {
if !self.arg_count().matches(args.len()) {
let err = ErrorKind::ArgsLenMismatch {
def: self.arg_count(),
call: args.len(),
};
return Err(ctx.call_site_error(err));
}
let args = args.into_iter().map(|arg| arg.extra).collect();
self.definition
.inner
.call_function(self.captures.clone(), args, ctx)
}
}
#[derive(Debug)]
pub enum Function<'a, T> {
Native(Rc<dyn NativeFn<T>>),
Interpreted(Rc<InterpretedFn<'a, T>>),
}
impl<T> Clone for Function<'_, T> {
fn clone(&self) -> Self {
match self {
Self::Native(function) => Self::Native(Rc::clone(&function)),
Self::Interpreted(function) => Self::Interpreted(Rc::clone(&function)),
}
}
}
impl<T: 'static + Clone> StripCode for Function<'_, T> {
type Stripped = Function<'static, T>;
fn strip_code(self) -> Self::Stripped {
match self {
Self::Native(function) => Function::Native(function),
Self::Interpreted(function) => {
Function::Interpreted(Rc::new(function.to_stripped_code()))
}
}
}
}
impl<'a, T> Function<'a, T> {
pub fn native(function: impl NativeFn<T> + 'static) -> Self {
Self::Native(Rc::new(function))
}
pub fn is_same_function(&self, other: &Self) -> bool {
match (self, other) {
(Self::Native(this), Self::Native(other)) => this.data_ptr() == other.data_ptr(),
(Self::Interpreted(this), Self::Interpreted(other)) => Rc::ptr_eq(this, other),
_ => false,
}
}
pub(crate) fn def_span(&self) -> Option<CodeInModule<'a>> {
match self {
Self::Native(_) => None,
Self::Interpreted(function) => Some(CodeInModule::new(
function.module_id(),
function.definition.def_span,
)),
}
}
}
impl<'a, T: Clone> Function<'a, T> {
pub fn evaluate(
&self,
args: Vec<SpannedValue<'a, T>>,
ctx: &mut CallContext<'_, 'a, T>,
) -> EvalResult<'a, T> {
match self {
Self::Native(function) => function.evaluate(args, ctx),
Self::Interpreted(function) => function.evaluate(args, ctx),
}
}
}