1use hashbrown::HashMap;
4
5use core::fmt;
6
7use crate::{
8 alloc::{Rc, String, Vec},
9 arith::OrdArithmetic,
10 error::{Backtrace, CodeInModule},
11 executable::ExecutableFn,
12 Error, ErrorKind, EvalResult, ModuleId, SpannedValue, Value,
13};
14use arithmetic_parser::{LvalueLen, MaybeSpanned, StripCode};
15
16#[derive(Debug)]
18pub struct CallContext<'r, 'a, T> {
19 call_span: CodeInModule<'a>,
20 backtrace: Option<&'r mut Backtrace<'a>>,
21 arithmetic: &'r dyn OrdArithmetic<T>,
22}
23
24impl<'r, 'a, T> CallContext<'r, 'a, T> {
25 pub fn mock(
27 module_id: &dyn ModuleId,
28 call_span: MaybeSpanned<'a>,
29 arithmetic: &'r dyn OrdArithmetic<T>,
30 ) -> Self {
31 Self {
32 call_span: CodeInModule::new(module_id, call_span),
33 backtrace: None,
34 arithmetic,
35 }
36 }
37
38 pub(crate) fn new(
39 call_span: CodeInModule<'a>,
40 backtrace: Option<&'r mut Backtrace<'a>>,
41 arithmetic: &'r dyn OrdArithmetic<T>,
42 ) -> Self {
43 Self {
44 call_span,
45 backtrace,
46 arithmetic,
47 }
48 }
49
50 pub(crate) fn backtrace(&mut self) -> Option<&mut Backtrace<'a>> {
51 self.backtrace.as_deref_mut()
52 }
53
54 pub(crate) fn arithmetic(&self) -> &'r dyn OrdArithmetic<T> {
55 self.arithmetic
56 }
57
58 pub fn call_span(&self) -> &CodeInModule<'a> {
60 &self.call_span
61 }
62
63 pub fn apply_call_span<U>(&self, value: U) -> MaybeSpanned<'a, U> {
65 self.call_span.code().copy_with_extra(value)
66 }
67
68 pub fn call_site_error(&self, error: ErrorKind) -> Error<'a> {
70 Error::from_parts(self.call_span.clone(), error)
71 }
72
73 pub fn check_args_count(
75 &self,
76 args: &[SpannedValue<'a, T>],
77 expected_count: impl Into<LvalueLen>,
78 ) -> Result<(), Error<'a>> {
79 let expected_count = expected_count.into();
80 if expected_count.matches(args.len()) {
81 Ok(())
82 } else {
83 Err(self.call_site_error(ErrorKind::ArgsLenMismatch {
84 def: expected_count,
85 call: args.len(),
86 }))
87 }
88 }
89}
90
91pub trait NativeFn<T> {
96 fn evaluate<'a>(
98 &self,
99 args: Vec<SpannedValue<'a, T>>,
100 context: &mut CallContext<'_, 'a, T>,
101 ) -> EvalResult<'a, T>;
102}
103
104impl<T, F: 'static> NativeFn<T> for F
105where
106 F: for<'a> Fn(Vec<SpannedValue<'a, T>>, &mut CallContext<'_, 'a, T>) -> EvalResult<'a, T>,
107{
108 fn evaluate<'a>(
109 &self,
110 args: Vec<SpannedValue<'a, T>>,
111 context: &mut CallContext<'_, 'a, T>,
112 ) -> EvalResult<'a, T> {
113 self(args, context)
114 }
115}
116
117impl<T> fmt::Debug for dyn NativeFn<T> {
118 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
119 formatter.debug_tuple("NativeFn").finish()
120 }
121}
122
123impl<T> dyn NativeFn<T> {
124 pub(crate) fn data_ptr(&self) -> *const () {
126 (self as *const dyn NativeFn<T>).cast()
131 }
132}
133
134#[derive(Debug)]
136pub struct InterpretedFn<'a, T> {
137 definition: Rc<ExecutableFn<'a, T>>,
138 captures: Vec<Value<'a, T>>,
139 capture_names: Vec<String>,
140}
141
142impl<T: Clone> Clone for InterpretedFn<'_, T> {
143 fn clone(&self) -> Self {
144 Self {
145 definition: Rc::clone(&self.definition),
146 captures: self.captures.clone(),
147 capture_names: self.capture_names.clone(),
148 }
149 }
150}
151
152impl<T: 'static + Clone> StripCode for InterpretedFn<'_, T> {
153 type Stripped = InterpretedFn<'static, T>;
154
155 fn strip_code(self) -> Self::Stripped {
156 InterpretedFn {
157 definition: Rc::new(self.definition.to_stripped_code()),
158 captures: self
159 .captures
160 .into_iter()
161 .map(StripCode::strip_code)
162 .collect(),
163 capture_names: self.capture_names,
164 }
165 }
166}
167
168impl<'a, T> InterpretedFn<'a, T> {
169 pub(crate) fn new(
170 definition: Rc<ExecutableFn<'a, T>>,
171 captures: Vec<Value<'a, T>>,
172 capture_names: Vec<String>,
173 ) -> Self {
174 Self {
175 definition,
176 captures,
177 capture_names,
178 }
179 }
180
181 pub fn module_id(&self) -> &dyn ModuleId {
183 self.definition.inner.id()
184 }
185
186 pub fn arg_count(&self) -> LvalueLen {
188 self.definition.arg_count
189 }
190
191 pub fn captures(&self) -> HashMap<&str, &Value<'a, T>> {
193 self.capture_names
194 .iter()
195 .zip(&self.captures)
196 .map(|(name, val)| (name.as_str(), val))
197 .collect()
198 }
199}
200
201impl<T: 'static + Clone> InterpretedFn<'_, T> {
202 fn to_stripped_code(&self) -> InterpretedFn<'static, T> {
203 self.clone().strip_code()
204 }
205}
206
207impl<'a, T: Clone> InterpretedFn<'a, T> {
208 pub fn evaluate(
210 &self,
211 args: Vec<SpannedValue<'a, T>>,
212 ctx: &mut CallContext<'_, 'a, T>,
213 ) -> EvalResult<'a, T> {
214 if !self.arg_count().matches(args.len()) {
215 let err = ErrorKind::ArgsLenMismatch {
216 def: self.arg_count(),
217 call: args.len(),
218 };
219 return Err(ctx.call_site_error(err));
220 }
221
222 let args = args.into_iter().map(|arg| arg.extra).collect();
223 self.definition
224 .inner
225 .call_function(self.captures.clone(), args, ctx)
226 }
227}
228
229#[derive(Debug)]
232pub enum Function<'a, T> {
233 Native(Rc<dyn NativeFn<T>>),
235 Interpreted(Rc<InterpretedFn<'a, T>>),
237}
238
239impl<T> Clone for Function<'_, T> {
240 fn clone(&self) -> Self {
241 match self {
242 Self::Native(function) => Self::Native(Rc::clone(&function)),
243 Self::Interpreted(function) => Self::Interpreted(Rc::clone(&function)),
244 }
245 }
246}
247
248impl<T: 'static + Clone> StripCode for Function<'_, T> {
249 type Stripped = Function<'static, T>;
250
251 fn strip_code(self) -> Self::Stripped {
252 match self {
253 Self::Native(function) => Function::Native(function),
254 Self::Interpreted(function) => {
255 Function::Interpreted(Rc::new(function.to_stripped_code()))
256 }
257 }
258 }
259}
260
261impl<'a, T> Function<'a, T> {
262 pub fn native(function: impl NativeFn<T> + 'static) -> Self {
264 Self::Native(Rc::new(function))
265 }
266
267 pub fn is_same_function(&self, other: &Self) -> bool {
269 match (self, other) {
270 (Self::Native(this), Self::Native(other)) => this.data_ptr() == other.data_ptr(),
271 (Self::Interpreted(this), Self::Interpreted(other)) => Rc::ptr_eq(this, other),
272 _ => false,
273 }
274 }
275
276 pub(crate) fn def_span(&self) -> Option<CodeInModule<'a>> {
277 match self {
278 Self::Native(_) => None,
279 Self::Interpreted(function) => Some(CodeInModule::new(
280 function.module_id(),
281 function.definition.def_span,
282 )),
283 }
284 }
285}
286
287impl<'a, T: Clone> Function<'a, T> {
288 pub fn evaluate(
290 &self,
291 args: Vec<SpannedValue<'a, T>>,
292 ctx: &mut CallContext<'_, 'a, T>,
293 ) -> EvalResult<'a, T> {
294 match self {
295 Self::Native(function) => function.evaluate(args, ctx),
296 Self::Interpreted(function) => function.evaluate(args, ctx),
297 }
298 }
299}