arithmetic_eval/values/
function.rs

1//! `Function` and closely related types.
2
3use hashbrown::HashMap;
4
5use core::fmt;
6
7use crate::{
8    alloc::{Rc, String, Vec},
9    arith::OrdArithmetic,
10    error::{Backtrace, CodeInModule},
11    executable::ExecutableFn,
12    Error, ErrorKind, EvalResult, ModuleId, SpannedValue, Value,
13};
14use arithmetic_parser::{LvalueLen, MaybeSpanned, StripCode};
15
16/// Context for native function calls.
17#[derive(Debug)]
18pub struct CallContext<'r, 'a, T> {
19    call_span: CodeInModule<'a>,
20    backtrace: Option<&'r mut Backtrace<'a>>,
21    arithmetic: &'r dyn OrdArithmetic<T>,
22}
23
24impl<'r, 'a, T> CallContext<'r, 'a, T> {
25    /// Creates a mock call context with the specified module ID and call span.
26    pub fn mock(
27        module_id: &dyn ModuleId,
28        call_span: MaybeSpanned<'a>,
29        arithmetic: &'r dyn OrdArithmetic<T>,
30    ) -> Self {
31        Self {
32            call_span: CodeInModule::new(module_id, call_span),
33            backtrace: None,
34            arithmetic,
35        }
36    }
37
38    pub(crate) fn new(
39        call_span: CodeInModule<'a>,
40        backtrace: Option<&'r mut Backtrace<'a>>,
41        arithmetic: &'r dyn OrdArithmetic<T>,
42    ) -> Self {
43        Self {
44            call_span,
45            backtrace,
46            arithmetic,
47        }
48    }
49
50    pub(crate) fn backtrace(&mut self) -> Option<&mut Backtrace<'a>> {
51        self.backtrace.as_deref_mut()
52    }
53
54    pub(crate) fn arithmetic(&self) -> &'r dyn OrdArithmetic<T> {
55        self.arithmetic
56    }
57
58    /// Returns the call span of the currently executing function.
59    pub fn call_span(&self) -> &CodeInModule<'a> {
60        &self.call_span
61    }
62
63    /// Applies the call span to the specified `value`.
64    pub fn apply_call_span<U>(&self, value: U) -> MaybeSpanned<'a, U> {
65        self.call_span.code().copy_with_extra(value)
66    }
67
68    /// Creates an error spanning the call site.
69    pub fn call_site_error(&self, error: ErrorKind) -> Error<'a> {
70        Error::from_parts(self.call_span.clone(), error)
71    }
72
73    /// Checks argument count and returns an error if it doesn't match.
74    pub fn check_args_count(
75        &self,
76        args: &[SpannedValue<'a, T>],
77        expected_count: impl Into<LvalueLen>,
78    ) -> Result<(), Error<'a>> {
79        let expected_count = expected_count.into();
80        if expected_count.matches(args.len()) {
81            Ok(())
82        } else {
83            Err(self.call_site_error(ErrorKind::ArgsLenMismatch {
84                def: expected_count,
85                call: args.len(),
86            }))
87        }
88    }
89}
90
91/// Function on zero or more [`Value`]s.
92///
93/// Native functions are defined in the Rust code and then can be used from the interpreted
94/// code. See [`fns`](crate::fns) module docs for different ways to define native functions.
95pub trait NativeFn<T> {
96    /// Executes the function on the specified arguments.
97    fn evaluate<'a>(
98        &self,
99        args: Vec<SpannedValue<'a, T>>,
100        context: &mut CallContext<'_, 'a, T>,
101    ) -> EvalResult<'a, T>;
102}
103
104impl<T, F: 'static> NativeFn<T> for F
105where
106    F: for<'a> Fn(Vec<SpannedValue<'a, T>>, &mut CallContext<'_, 'a, T>) -> EvalResult<'a, T>,
107{
108    fn evaluate<'a>(
109        &self,
110        args: Vec<SpannedValue<'a, T>>,
111        context: &mut CallContext<'_, 'a, T>,
112    ) -> EvalResult<'a, T> {
113        self(args, context)
114    }
115}
116
117impl<T> fmt::Debug for dyn NativeFn<T> {
118    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
119        formatter.debug_tuple("NativeFn").finish()
120    }
121}
122
123impl<T> dyn NativeFn<T> {
124    /// Extracts a data pointer from this trait object reference.
125    pub(crate) fn data_ptr(&self) -> *const () {
126        // `*const dyn Trait as *const ()` extracts the data pointer,
127        // see https://github.com/rust-lang/rust/issues/27751. This is seemingly
128        // the simplest way to extract the data pointer; `TraitObject` in `std::raw` is
129        // a more future-proof alternative, but it is unstable.
130        (self as *const dyn NativeFn<T>).cast()
131    }
132}
133
134/// Function defined within the interpreter.
135#[derive(Debug)]
136pub struct InterpretedFn<'a, T> {
137    definition: Rc<ExecutableFn<'a, T>>,
138    captures: Vec<Value<'a, T>>,
139    capture_names: Vec<String>,
140}
141
142impl<T: Clone> Clone for InterpretedFn<'_, T> {
143    fn clone(&self) -> Self {
144        Self {
145            definition: Rc::clone(&self.definition),
146            captures: self.captures.clone(),
147            capture_names: self.capture_names.clone(),
148        }
149    }
150}
151
152impl<T: 'static + Clone> StripCode for InterpretedFn<'_, T> {
153    type Stripped = InterpretedFn<'static, T>;
154
155    fn strip_code(self) -> Self::Stripped {
156        InterpretedFn {
157            definition: Rc::new(self.definition.to_stripped_code()),
158            captures: self
159                .captures
160                .into_iter()
161                .map(StripCode::strip_code)
162                .collect(),
163            capture_names: self.capture_names,
164        }
165    }
166}
167
168impl<'a, T> InterpretedFn<'a, T> {
169    pub(crate) fn new(
170        definition: Rc<ExecutableFn<'a, T>>,
171        captures: Vec<Value<'a, T>>,
172        capture_names: Vec<String>,
173    ) -> Self {
174        Self {
175            definition,
176            captures,
177            capture_names,
178        }
179    }
180
181    /// Returns ID of the module defining this function.
182    pub fn module_id(&self) -> &dyn ModuleId {
183        self.definition.inner.id()
184    }
185
186    /// Returns the number of arguments for this function.
187    pub fn arg_count(&self) -> LvalueLen {
188        self.definition.arg_count
189    }
190
191    /// Returns values captured by this function.
192    pub fn captures(&self) -> HashMap<&str, &Value<'a, T>> {
193        self.capture_names
194            .iter()
195            .zip(&self.captures)
196            .map(|(name, val)| (name.as_str(), val))
197            .collect()
198    }
199}
200
201impl<T: 'static + Clone> InterpretedFn<'_, T> {
202    fn to_stripped_code(&self) -> InterpretedFn<'static, T> {
203        self.clone().strip_code()
204    }
205}
206
207impl<'a, T: Clone> InterpretedFn<'a, T> {
208    /// Evaluates this function with the provided arguments and the execution context.
209    pub fn evaluate(
210        &self,
211        args: Vec<SpannedValue<'a, T>>,
212        ctx: &mut CallContext<'_, 'a, T>,
213    ) -> EvalResult<'a, T> {
214        if !self.arg_count().matches(args.len()) {
215            let err = ErrorKind::ArgsLenMismatch {
216                def: self.arg_count(),
217                call: args.len(),
218            };
219            return Err(ctx.call_site_error(err));
220        }
221
222        let args = args.into_iter().map(|arg| arg.extra).collect();
223        self.definition
224            .inner
225            .call_function(self.captures.clone(), args, ctx)
226    }
227}
228
229/// Function definition. Functions can be either native (defined in the Rust code) or defined
230/// in the interpreter.
231#[derive(Debug)]
232pub enum Function<'a, T> {
233    /// Native function.
234    Native(Rc<dyn NativeFn<T>>),
235    /// Interpreted function.
236    Interpreted(Rc<InterpretedFn<'a, T>>),
237}
238
239impl<T> Clone for Function<'_, T> {
240    fn clone(&self) -> Self {
241        match self {
242            Self::Native(function) => Self::Native(Rc::clone(&function)),
243            Self::Interpreted(function) => Self::Interpreted(Rc::clone(&function)),
244        }
245    }
246}
247
248impl<T: 'static + Clone> StripCode for Function<'_, T> {
249    type Stripped = Function<'static, T>;
250
251    fn strip_code(self) -> Self::Stripped {
252        match self {
253            Self::Native(function) => Function::Native(function),
254            Self::Interpreted(function) => {
255                Function::Interpreted(Rc::new(function.to_stripped_code()))
256            }
257        }
258    }
259}
260
261impl<'a, T> Function<'a, T> {
262    /// Creates a native function.
263    pub fn native(function: impl NativeFn<T> + 'static) -> Self {
264        Self::Native(Rc::new(function))
265    }
266
267    /// Checks if the provided function is the same as this one.
268    pub fn is_same_function(&self, other: &Self) -> bool {
269        match (self, other) {
270            (Self::Native(this), Self::Native(other)) => this.data_ptr() == other.data_ptr(),
271            (Self::Interpreted(this), Self::Interpreted(other)) => Rc::ptr_eq(this, other),
272            _ => false,
273        }
274    }
275
276    pub(crate) fn def_span(&self) -> Option<CodeInModule<'a>> {
277        match self {
278            Self::Native(_) => None,
279            Self::Interpreted(function) => Some(CodeInModule::new(
280                function.module_id(),
281                function.definition.def_span,
282            )),
283        }
284    }
285}
286
287impl<'a, T: Clone> Function<'a, T> {
288    /// Evaluates the function on the specified arguments.
289    pub fn evaluate(
290        &self,
291        args: Vec<SpannedValue<'a, T>>,
292        ctx: &mut CallContext<'_, 'a, T>,
293    ) -> EvalResult<'a, T> {
294        match self {
295            Self::Native(function) => function.evaluate(args, ctx),
296            Self::Interpreted(function) => function.evaluate(args, ctx),
297        }
298    }
299}