use hashbrown::HashMap;
use core::{
any::{type_name, Any},
cmp::Ordering,
fmt,
};
use crate::{
alloc::{vec, Rc, String, ToOwned, Vec},
arith::OrdArithmetic,
error::{AuxErrorInfo, Backtrace, CodeInModule, TupleLenMismatchContext},
executable::ExecutableFn,
fns, Error, ErrorKind, EvalResult, ModuleId,
};
use arithmetic_parser::{BinaryOp, LvalueLen, MaybeSpanned, Op, StripCode, UnaryOp};
mod env;
mod variable_map;
pub use self::{
env::Environment,
variable_map::{Assertions, Comparisons, Prelude, VariableMap},
};
#[derive(Debug)]
pub struct CallContext<'r, 'a, T> {
call_span: CodeInModule<'a>,
backtrace: Option<&'r mut Backtrace<'a>>,
arithmetic: &'r dyn OrdArithmetic<T>,
}
impl<'r, 'a, T> CallContext<'r, 'a, T> {
pub fn mock(
module_id: &dyn ModuleId,
call_span: MaybeSpanned<'a>,
arithmetic: &'r dyn OrdArithmetic<T>,
) -> Self {
Self {
call_span: CodeInModule::new(module_id, call_span),
backtrace: None,
arithmetic,
}
}
pub(crate) fn new(
call_span: CodeInModule<'a>,
backtrace: Option<&'r mut Backtrace<'a>>,
arithmetic: &'r dyn OrdArithmetic<T>,
) -> Self {
Self {
call_span,
backtrace,
arithmetic,
}
}
pub(crate) fn backtrace(&mut self) -> Option<&mut Backtrace<'a>> {
self.backtrace.as_deref_mut()
}
pub(crate) fn arithmetic(&self) -> &'r dyn OrdArithmetic<T> {
self.arithmetic
}
pub fn call_span(&self) -> &CodeInModule<'a> {
&self.call_span
}
pub fn apply_call_span<U>(&self, value: U) -> MaybeSpanned<'a, U> {
self.call_span.code().copy_with_extra(value)
}
pub fn call_site_error(&self, error: ErrorKind) -> Error<'a> {
Error::from_parts(self.call_span.clone(), error)
}
pub fn check_args_count(
&self,
args: &[SpannedValue<'a, T>],
expected_count: impl Into<LvalueLen>,
) -> Result<(), Error<'a>> {
let expected_count = expected_count.into();
if expected_count.matches(args.len()) {
Ok(())
} else {
Err(self.call_site_error(ErrorKind::ArgsLenMismatch {
def: expected_count,
call: args.len(),
}))
}
}
}
pub trait NativeFn<T> {
fn evaluate<'a>(
&self,
args: Vec<SpannedValue<'a, T>>,
context: &mut CallContext<'_, 'a, T>,
) -> EvalResult<'a, T>;
}
impl<T, F: 'static> NativeFn<T> for F
where
F: for<'a> Fn(Vec<SpannedValue<'a, T>>, &mut CallContext<'_, 'a, T>) -> EvalResult<'a, T>,
{
fn evaluate<'a>(
&self,
args: Vec<SpannedValue<'a, T>>,
context: &mut CallContext<'_, 'a, T>,
) -> EvalResult<'a, T> {
self(args, context)
}
}
impl<T> fmt::Debug for dyn NativeFn<T> {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.debug_tuple("NativeFn").finish()
}
}
impl<T> dyn NativeFn<T> {
pub(crate) fn data_ptr(&self) -> *const () {
self as *const dyn NativeFn<T> as *const ()
}
}
#[derive(Debug)]
pub struct InterpretedFn<'a, T> {
definition: Rc<ExecutableFn<'a, T>>,
captures: Vec<Value<'a, T>>,
capture_names: Vec<String>,
}
impl<T: Clone> Clone for InterpretedFn<'_, T> {
fn clone(&self) -> Self {
Self {
definition: Rc::clone(&self.definition),
captures: self.captures.clone(),
capture_names: self.capture_names.clone(),
}
}
}
impl<T: 'static + Clone> StripCode for InterpretedFn<'_, T> {
type Stripped = InterpretedFn<'static, T>;
fn strip_code(self) -> Self::Stripped {
InterpretedFn {
definition: Rc::new(self.definition.to_stripped_code()),
captures: self
.captures
.into_iter()
.map(StripCode::strip_code)
.collect(),
capture_names: self.capture_names,
}
}
}
impl<'a, T> InterpretedFn<'a, T> {
pub(crate) fn new(
definition: Rc<ExecutableFn<'a, T>>,
captures: Vec<Value<'a, T>>,
capture_names: Vec<String>,
) -> Self {
Self {
definition,
captures,
capture_names,
}
}
pub fn module_id(&self) -> &dyn ModuleId {
self.definition.inner.id()
}
pub fn arg_count(&self) -> LvalueLen {
self.definition.arg_count
}
pub fn captures(&self) -> HashMap<&str, &Value<'a, T>> {
self.capture_names
.iter()
.zip(&self.captures)
.map(|(name, val)| (name.as_str(), val))
.collect()
}
}
impl<T: 'static + Clone> InterpretedFn<'_, T> {
fn to_stripped_code(&self) -> InterpretedFn<'static, T> {
self.clone().strip_code()
}
}
impl<'a, T: Clone> InterpretedFn<'a, T> {
pub fn evaluate(
&self,
args: Vec<SpannedValue<'a, T>>,
ctx: &mut CallContext<'_, 'a, T>,
) -> EvalResult<'a, T> {
if !self.arg_count().matches(args.len()) {
let err = ErrorKind::ArgsLenMismatch {
def: self.arg_count(),
call: args.len(),
};
return Err(ctx.call_site_error(err));
}
let args = args.into_iter().map(|arg| arg.extra).collect();
self.definition
.inner
.call_function(self.captures.clone(), args, ctx)
}
}
#[derive(Debug)]
pub enum Function<'a, T> {
Native(Rc<dyn NativeFn<T>>),
Interpreted(Rc<InterpretedFn<'a, T>>),
}
impl<T> Clone for Function<'_, T> {
fn clone(&self) -> Self {
match self {
Self::Native(function) => Self::Native(Rc::clone(&function)),
Self::Interpreted(function) => Self::Interpreted(Rc::clone(&function)),
}
}
}
impl<T: 'static + Clone> StripCode for Function<'_, T> {
type Stripped = Function<'static, T>;
fn strip_code(self) -> Self::Stripped {
match self {
Self::Native(function) => Function::Native(function),
Self::Interpreted(function) => {
Function::Interpreted(Rc::new(function.to_stripped_code()))
}
}
}
}
impl<'a, T> Function<'a, T> {
pub fn native(function: impl NativeFn<T> + 'static) -> Self {
Self::Native(Rc::new(function))
}
pub fn is_same_function(&self, other: &Self) -> bool {
match (self, other) {
(Self::Native(this), Self::Native(other)) => this.data_ptr() == other.data_ptr(),
(Self::Interpreted(this), Self::Interpreted(other)) => Rc::ptr_eq(this, other),
_ => false,
}
}
pub(crate) fn def_span(&self) -> Option<CodeInModule<'a>> {
match self {
Self::Native(_) => None,
Self::Interpreted(function) => Some(CodeInModule::new(
function.module_id(),
function.definition.def_span,
)),
}
}
}
impl<'a, T: Clone> Function<'a, T> {
pub fn evaluate(
&self,
args: Vec<SpannedValue<'a, T>>,
ctx: &mut CallContext<'_, 'a, T>,
) -> EvalResult<'a, T> {
match self {
Self::Native(function) => function.evaluate(args, ctx),
Self::Interpreted(function) => function.evaluate(args, ctx),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[non_exhaustive]
pub enum ValueType {
Number,
Bool,
Function,
Tuple(usize),
Array,
Ref,
}
impl fmt::Display for ValueType {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Number => formatter.write_str("number"),
Self::Bool => formatter.write_str("boolean value"),
Self::Function => formatter.write_str("function"),
Self::Tuple(1) => write!(formatter, "tuple with 1 element"),
Self::Tuple(size) => write!(formatter, "tuple with {} elements", size),
Self::Array => formatter.write_str("array"),
Self::Ref => formatter.write_str("reference"),
}
}
}
pub struct OpaqueRef {
value: Rc<dyn Any>,
type_name: &'static str,
dyn_eq: fn(&dyn Any, &dyn Any) -> bool,
dyn_fmt: fn(&dyn Any, &mut fmt::Formatter<'_>) -> fmt::Result,
}
impl OpaqueRef {
pub fn new<T>(value: T) -> Self
where
T: Any + fmt::Debug + PartialEq,
{
Self {
value: Rc::new(value),
type_name: type_name::<T>(),
dyn_eq: |this, other| {
let this_cast = this.downcast_ref::<T>().unwrap();
other
.downcast_ref::<T>()
.map_or(false, |other_cast| other_cast == this_cast)
},
dyn_fmt: |this, formatter| {
let this_cast = this.downcast_ref::<T>().unwrap();
fmt::Debug::fmt(this_cast, formatter)
},
}
}
pub fn with_identity_eq<T>(value: T) -> Self
where
T: Any + fmt::Debug,
{
Self {
value: Rc::new(value),
type_name: type_name::<T>(),
dyn_eq: |this, other| {
let this_data = this as *const dyn Any as *const ();
let other_data = other as *const dyn Any as *const ();
this_data == other_data
},
dyn_fmt: |this, formatter| {
let this_cast = this.downcast_ref::<T>().unwrap();
fmt::Debug::fmt(this_cast, formatter)
},
}
}
pub fn downcast_ref<T: Any>(&self) -> Option<&T> {
self.value.downcast_ref()
}
}
impl Clone for OpaqueRef {
fn clone(&self) -> Self {
Self {
value: Rc::clone(&self.value),
type_name: self.type_name,
dyn_eq: self.dyn_eq,
dyn_fmt: self.dyn_fmt,
}
}
}
impl PartialEq for OpaqueRef {
fn eq(&self, other: &Self) -> bool {
(self.dyn_eq)(self.value.as_ref(), other.value.as_ref())
}
}
impl fmt::Debug for OpaqueRef {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_tuple("OpaqueRef")
.field(&self.value.as_ref())
.finish()
}
}
impl fmt::Display for OpaqueRef {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "{}::", self.type_name)?;
(self.dyn_fmt)(self.value.as_ref(), formatter)
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum Value<'a, T> {
Number(T),
Bool(bool),
Function(Function<'a, T>),
Tuple(Vec<Value<'a, T>>),
Ref(OpaqueRef),
}
pub type SpannedValue<'a, T> = MaybeSpanned<'a, Value<'a, T>>;
impl<'a, T> Value<'a, T> {
pub fn native_fn(function: impl NativeFn<T> + 'static) -> Self {
Self::Function(Function::Native(Rc::new(function)))
}
pub fn wrapped_fn<Args, F>(fn_to_wrap: F) -> Self
where
fns::FnWrapper<Args, F>: NativeFn<T> + 'static,
{
let wrapped = fns::wrap::<Args, _>(fn_to_wrap);
Self::native_fn(wrapped)
}
pub(crate) fn interpreted_fn(function: InterpretedFn<'a, T>) -> Self {
Self::Function(Function::Interpreted(Rc::new(function)))
}
pub fn void() -> Self {
Self::Tuple(vec![])
}
pub fn opaque_ref(value: impl Any + fmt::Debug + PartialEq) -> Self {
Self::Ref(OpaqueRef::new(value))
}
pub fn value_type(&self) -> ValueType {
match self {
Self::Number(_) => ValueType::Number,
Self::Bool(_) => ValueType::Bool,
Self::Function(_) => ValueType::Function,
Self::Tuple(elements) => ValueType::Tuple(elements.len()),
Self::Ref(_) => ValueType::Ref,
}
}
pub fn is_void(&self) -> bool {
matches!(self, Self::Tuple(tuple) if tuple.is_empty())
}
pub fn is_function(&self) -> bool {
matches!(self, Self::Function(_))
}
}
impl<T: Clone> Clone for Value<'_, T> {
fn clone(&self) -> Self {
match self {
Self::Number(lit) => Self::Number(lit.clone()),
Self::Bool(bool) => Self::Bool(*bool),
Self::Function(function) => Self::Function(function.clone()),
Self::Tuple(tuple) => Self::Tuple(tuple.clone()),
Self::Ref(reference) => Self::Ref(reference.clone()),
}
}
}
impl<T: 'static + Clone> StripCode for Value<'_, T> {
type Stripped = Value<'static, T>;
fn strip_code(self) -> Self::Stripped {
match self {
Self::Number(lit) => Value::Number(lit),
Self::Bool(bool) => Value::Bool(bool),
Self::Function(function) => Value::Function(function.strip_code()),
Self::Tuple(tuple) => {
Value::Tuple(tuple.into_iter().map(StripCode::strip_code).collect())
}
Self::Ref(reference) => Value::Ref(reference),
}
}
}
impl<'a, T: Clone> From<&Value<'a, T>> for Value<'a, T> {
fn from(reference: &Value<'a, T>) -> Self {
reference.to_owned()
}
}
impl<T: PartialEq> PartialEq for Value<'_, T> {
fn eq(&self, rhs: &Self) -> bool {
match (self, rhs) {
(Self::Number(this), Self::Number(other)) => this == other,
(Self::Bool(this), Self::Bool(other)) => this == other,
(Self::Tuple(this), Self::Tuple(other)) => this == other,
(Self::Function(this), Self::Function(other)) => this.is_same_function(other),
(Self::Ref(this), Self::Ref(other)) => this == other,
_ => false,
}
}
}
impl<T: fmt::Display> fmt::Display for Value<'_, T> {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Number(number) => fmt::Display::fmt(number, formatter),
Self::Bool(true) => formatter.write_str("true"),
Self::Bool(false) => formatter.write_str("false"),
Self::Ref(opaque_ref) => fmt::Display::fmt(opaque_ref, formatter),
Self::Function(_) => formatter.write_str("[function]"),
Self::Tuple(elements) => {
formatter.write_str("(")?;
for (i, element) in elements.iter().enumerate() {
fmt::Display::fmt(element, formatter)?;
if i + 1 < elements.len() {
formatter.write_str(", ")?;
}
}
formatter.write_str(")")
}
}
}
}
#[derive(Debug, Clone, Copy)]
enum OpSide {
Lhs,
Rhs,
}
#[derive(Debug)]
struct BinaryOpError {
inner: ErrorKind,
side: Option<OpSide>,
}
impl BinaryOpError {
fn new(op: BinaryOp) -> Self {
Self {
inner: ErrorKind::UnexpectedOperand { op: Op::Binary(op) },
side: None,
}
}
fn tuple(op: BinaryOp, lhs: usize, rhs: usize) -> Self {
Self {
inner: ErrorKind::TupleLenMismatch {
lhs: lhs.into(),
rhs,
context: TupleLenMismatchContext::BinaryOp(op),
},
side: Some(OpSide::Lhs),
}
}
fn with_side(mut self, side: OpSide) -> Self {
self.side = Some(side);
self
}
fn with_error_kind(mut self, error_kind: ErrorKind) -> Self {
self.inner = error_kind;
self
}
fn span<'a>(
self,
module_id: &dyn ModuleId,
total_span: MaybeSpanned<'a>,
lhs_span: MaybeSpanned<'a>,
rhs_span: MaybeSpanned<'a>,
) -> Error<'a> {
let main_span = match self.side {
Some(OpSide::Lhs) => lhs_span,
Some(OpSide::Rhs) => rhs_span,
None => total_span,
};
let aux_info = if let ErrorKind::TupleLenMismatch { rhs, .. } = self.inner {
Some(AuxErrorInfo::UnbalancedRhs(rhs))
} else {
None
};
let mut err = Error::new(module_id, &main_span, self.inner);
if let Some(aux_info) = aux_info {
err = err.with_span(&rhs_span, aux_info);
}
err
}
}
impl<'a, T: Clone> Value<'a, T> {
fn try_binary_op_inner(
self,
rhs: Self,
op: BinaryOp,
arithmetic: &dyn OrdArithmetic<T>,
) -> Result<Self, BinaryOpError> {
match (self, rhs) {
(Self::Number(this), Self::Number(other)) => {
let op_result = match op {
BinaryOp::Add => arithmetic.add(this, other),
BinaryOp::Sub => arithmetic.sub(this, other),
BinaryOp::Mul => arithmetic.mul(this, other),
BinaryOp::Div => arithmetic.div(this, other),
BinaryOp::Power => arithmetic.pow(this, other),
_ => unreachable!(),
};
op_result
.map(Self::Number)
.map_err(|e| BinaryOpError::new(op).with_error_kind(ErrorKind::Arithmetic(e)))
}
(this @ Self::Number(_), Self::Tuple(other)) => {
let output: Result<Vec<_>, _> = other
.into_iter()
.map(|y| this.clone().try_binary_op_inner(y, op, arithmetic))
.collect();
output.map(Self::Tuple)
}
(Self::Tuple(this), other @ Self::Number(_)) => {
let output: Result<Vec<_>, _> = this
.into_iter()
.map(|x| x.try_binary_op_inner(other.clone(), op, arithmetic))
.collect();
output.map(Self::Tuple)
}
(Self::Tuple(this), Self::Tuple(other)) => {
if this.len() == other.len() {
let output: Result<Vec<_>, _> = this
.into_iter()
.zip(other)
.map(|(x, y)| x.try_binary_op_inner(y, op, arithmetic))
.collect();
output.map(Self::Tuple)
} else {
Err(BinaryOpError::tuple(op, this.len(), other.len()))
}
}
(Self::Number(_), _) | (Self::Tuple(_), _) => {
Err(BinaryOpError::new(op).with_side(OpSide::Rhs))
}
_ => Err(BinaryOpError::new(op).with_side(OpSide::Lhs)),
}
}
#[inline]
pub(crate) fn try_binary_op(
module_id: &dyn ModuleId,
total_span: MaybeSpanned<'a>,
lhs: MaybeSpanned<'a, Self>,
rhs: MaybeSpanned<'a, Self>,
op: BinaryOp,
arithmetic: &dyn OrdArithmetic<T>,
) -> Result<Self, Error<'a>> {
let lhs_span = lhs.with_no_extra();
let rhs_span = rhs.with_no_extra();
lhs.extra
.try_binary_op_inner(rhs.extra, op, arithmetic)
.map_err(|e| e.span(module_id, total_span, lhs_span, rhs_span))
}
}
impl<'a, T> Value<'a, T> {
pub(crate) fn try_neg(self, arithmetic: &dyn OrdArithmetic<T>) -> Result<Self, ErrorKind> {
match self {
Self::Number(val) => arithmetic
.neg(val)
.map(Self::Number)
.map_err(ErrorKind::Arithmetic),
Self::Tuple(tuple) => {
let res: Result<Vec<_>, _> = tuple
.into_iter()
.map(|elem| Value::try_neg(elem, arithmetic))
.collect();
res.map(Self::Tuple)
}
_ => Err(ErrorKind::UnexpectedOperand {
op: UnaryOp::Neg.into(),
}),
}
}
pub(crate) fn try_not(self) -> Result<Self, ErrorKind> {
match self {
Self::Bool(val) => Ok(Self::Bool(!val)),
Self::Tuple(tuple) => {
let res: Result<Vec<_>, _> = tuple.into_iter().map(Value::try_not).collect();
res.map(Self::Tuple)
}
_ => Err(ErrorKind::UnexpectedOperand {
op: UnaryOp::Not.into(),
}),
}
}
pub(crate) fn eq_by_arithmetic(&self, rhs: &Self, arithmetic: &dyn OrdArithmetic<T>) -> bool {
match (self, rhs) {
(Self::Number(this), Self::Number(other)) => arithmetic.eq(this, other),
(Self::Bool(this), Self::Bool(other)) => this == other,
(Self::Tuple(this), Self::Tuple(other)) => {
if this.len() == other.len() {
this.iter()
.zip(other.iter())
.all(|(x, y)| x.eq_by_arithmetic(y, arithmetic))
} else {
false
}
}
(Self::Function(this), Self::Function(other)) => this.is_same_function(other),
(Self::Ref(this), Self::Ref(other)) => this == other,
_ => false,
}
}
pub(crate) fn compare(
module_id: &dyn ModuleId,
lhs: &MaybeSpanned<'a, Self>,
rhs: &MaybeSpanned<'a, Self>,
op: BinaryOp,
arithmetic: &dyn OrdArithmetic<T>,
) -> Result<Self, Error<'a>> {
let lhs_number = match &lhs.extra {
Value::Number(number) => number,
_ => return Err(Error::new(module_id, &lhs, ErrorKind::CannotCompare)),
};
let rhs_number = match &rhs.extra {
Value::Number(number) => number,
_ => return Err(Error::new(module_id, &rhs, ErrorKind::CannotCompare)),
};
let maybe_ordering = arithmetic.partial_cmp(lhs_number, rhs_number);
let cmp_result = maybe_ordering.map_or(false, |ordering| match op {
BinaryOp::Gt => ordering == Ordering::Greater,
BinaryOp::Lt => ordering == Ordering::Less,
BinaryOp::Ge => ordering != Ordering::Less,
BinaryOp::Le => ordering != Ordering::Greater,
_ => unreachable!(),
});
Ok(Value::Bool(cmp_result))
}
pub(crate) fn try_and(
module_id: &dyn ModuleId,
lhs: &MaybeSpanned<'a, Self>,
rhs: &MaybeSpanned<'a, Self>,
) -> Result<Self, Error<'a>> {
match (&lhs.extra, &rhs.extra) {
(Value::Bool(this), Value::Bool(other)) => Ok(Value::Bool(*this && *other)),
(Value::Bool(_), _) => {
let err = ErrorKind::UnexpectedOperand {
op: BinaryOp::And.into(),
};
Err(Error::new(module_id, &rhs, err))
}
_ => {
let err = ErrorKind::UnexpectedOperand {
op: BinaryOp::And.into(),
};
Err(Error::new(module_id, &lhs, err))
}
}
}
pub(crate) fn try_or(
module_id: &dyn ModuleId,
lhs: &MaybeSpanned<'a, Self>,
rhs: &MaybeSpanned<'a, Self>,
) -> Result<Self, Error<'a>> {
match (&lhs.extra, &rhs.extra) {
(Value::Bool(this), Value::Bool(other)) => Ok(Value::Bool(*this || *other)),
(Value::Bool(_), _) => {
let err = ErrorKind::UnexpectedOperand {
op: BinaryOp::Or.into(),
};
Err(Error::new(module_id, &rhs, err))
}
_ => {
let err = ErrorKind::UnexpectedOperand {
op: BinaryOp::Or.into(),
};
Err(Error::new(module_id, &lhs, err))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use core::cmp::Ordering;
#[test]
fn opaque_ref_equality() {
let value = Value::<f32>::opaque_ref(Ordering::Less);
let same_value = Value::<f32>::opaque_ref(Ordering::Less);
assert_eq!(value, same_value);
assert_eq!(value, value.clone());
let other_value = Value::<f32>::opaque_ref(Ordering::Greater);
assert_ne!(value, other_value);
}
#[test]
fn opaque_ref_formatting() {
let value = OpaqueRef::new(Ordering::Less);
assert_eq!(value.to_string(), "core::cmp::Ordering::Less");
}
}