Skip to main content

rexlang_engine/
evaluator.rs

1use std::collections::hash_map::DefaultHasher;
2use std::hash::{Hash, Hasher};
3use std::ops::Deref;
4use std::path::Path;
5use std::sync::Arc;
6
7use rexlang_ast::expr::{Expr, Program, Symbol, sym};
8use rexlang_typesystem::{Subst, Type, TypeError, TypedExpr, Types, unify};
9use rexlang_util::{GasMeter, sha256_hex};
10
11use crate::engine::{
12    CompiledProgram, NativeImpl, OverloadedFn, RuntimeSnapshot, check_runtime_cancelled,
13    eval_typed_expr, impl_matches_type, is_function_type, type_head_is_var,
14};
15use crate::libraries::{LibraryId, ReplState, ResolvedLibrary, ResolvedLibraryContent};
16use crate::value::Value;
17use crate::{
18    CompileError, Compiler, EngineError, Env, EvalError, ExecutionError, Pointer, RuntimeEnv,
19};
20
21pub struct Evaluator<State = ()>
22where
23    State: Clone + Send + Sync + 'static,
24{
25    pub(crate) runtime: RuntimeEnv<State>,
26    pub(crate) compiler: Option<Compiler<State>>,
27}
28
29#[derive(Clone, Copy)]
30pub struct EvaluatorRef<'a, State = ()>
31where
32    State: Clone + Send + Sync + 'static,
33{
34    runtime: &'a RuntimeSnapshot<State>,
35}
36
37impl<State> Evaluator<State>
38where
39    State: Clone + Send + Sync + 'static,
40{
41    pub fn new(runtime: RuntimeEnv<State>) -> Self {
42        Self {
43            runtime,
44            compiler: None,
45        }
46    }
47
48    pub fn new_with_compiler(runtime: RuntimeEnv<State>, compiler: Compiler<State>) -> Self {
49        Self {
50            runtime,
51            compiler: Some(compiler),
52        }
53    }
54
55    pub(crate) fn sync_runtime_from_compiler(&mut self) {
56        if let Some(compiler) = &self.compiler {
57            self.runtime.sync_from_engine(&compiler.engine);
58        }
59    }
60
61    pub async fn run(
62        &mut self,
63        program: &CompiledProgram,
64        gas: &mut GasMeter,
65    ) -> Result<Pointer, EvalError> {
66        self.run_internal(program, gas)
67            .await
68            .map_err(EvalError::from)
69    }
70
71    pub(crate) async fn run_internal(
72        &mut self,
73        program: &CompiledProgram,
74        gas: &mut GasMeter,
75    ) -> Result<Pointer, EngineError> {
76        check_runtime_cancelled(&self.runtime.runtime)?;
77        self.runtime.validate_internal(program)?;
78        eval_typed_expr(
79            &self.runtime.runtime,
80            &program.env,
81            program.expr.as_ref(),
82            gas,
83        )
84        .await
85    }
86
87    pub async fn eval(
88        &mut self,
89        expr: &Expr,
90        gas: &mut GasMeter,
91    ) -> Result<(Pointer, Type), ExecutionError> {
92        self.prepare_and_run(gas, |compiler, _gas| compiler.compile_expr(expr))
93            .await
94    }
95
96    pub async fn run_prepared(
97        &mut self,
98        program: CompiledProgram,
99        gas: &mut GasMeter,
100    ) -> Result<(Pointer, Type), ExecutionError> {
101        self.sync_runtime_from_compiler();
102        let typ = program.result_type().clone();
103        let value = self.run(&program, gas).await?;
104        Ok((value, typ))
105    }
106
107    pub(crate) async fn prepare_and_run<F>(
108        &mut self,
109        gas: &mut GasMeter,
110        compile: F,
111    ) -> Result<(Pointer, Type), ExecutionError>
112    where
113        F: FnOnce(&mut Compiler<State>, &mut GasMeter) -> Result<CompiledProgram, CompileError>,
114    {
115        let compiler = self.compiler.as_mut().ok_or_else(|| {
116            CompileError::from(EngineError::Internal("evaluator has no compiler".into()))
117        })?;
118        let program = compile(compiler, gas)?;
119        self.run_prepared(program, gas).await
120    }
121
122    pub async fn eval_library_file(
123        &mut self,
124        path: impl AsRef<Path>,
125        gas: &mut GasMeter,
126    ) -> Result<(Pointer, Type), ExecutionError> {
127        let (id, bytes) = self
128            .runtime
129            .loader
130            .read_local_library_bytes(path.as_ref())
131            .map_err(CompileError::from)?;
132        let source_fingerprint = sha256_hex(&bytes);
133        if let Some(inst) = self
134            .runtime
135            .loader
136            .modules
137            .cached(&id)
138            .map_err(EvalError::from)?
139        {
140            if inst.source_fingerprint.as_deref() == Some(source_fingerprint.as_str()) {
141                return Ok((inst.init_value, inst.init_type));
142            }
143            self.runtime
144                .loader
145                .invalidate_library_caches(&id)
146                .map_err(EvalError::from)?;
147        }
148        let source = self
149            .runtime
150            .loader
151            .decode_local_library_source(&id, bytes)
152            .map_err(CompileError::from)?;
153        let inst = self
154            .runtime
155            .loader
156            .load_library_from_resolved(
157                ResolvedLibrary {
158                    id,
159                    content: ResolvedLibraryContent::Source(source),
160                },
161                gas,
162            )
163            .await
164            .map_err(CompileError::from)?;
165        Ok((inst.init_value, inst.init_type))
166    }
167
168    pub async fn eval_library_source(
169        &mut self,
170        source: &str,
171        gas: &mut GasMeter,
172    ) -> Result<(Pointer, Type), ExecutionError> {
173        let mut hasher = DefaultHasher::new();
174        source.hash(&mut hasher);
175        let id = LibraryId::Virtual(format!("<inline:{:016x}>", hasher.finish()));
176        if let Some(inst) = self
177            .runtime
178            .loader
179            .modules
180            .cached(&id)
181            .map_err(EvalError::from)?
182        {
183            return Ok((inst.init_value, inst.init_type));
184        }
185        let inst = self
186            .runtime
187            .loader
188            .load_library_from_resolved(
189                ResolvedLibrary {
190                    id,
191                    content: ResolvedLibraryContent::Source(source.to_string()),
192                },
193                gas,
194            )
195            .await
196            .map_err(CompileError::from)?;
197        Ok((inst.init_value, inst.init_type))
198    }
199
200    pub async fn eval_snippet(
201        &mut self,
202        source: &str,
203        gas: &mut GasMeter,
204    ) -> Result<(Pointer, Type), ExecutionError> {
205        self.prepare_and_run(gas, |compiler, gas| compiler.compile_snippet(source, gas))
206            .await
207    }
208
209    pub async fn eval_repl_program(
210        &mut self,
211        program: &Program,
212        state: &mut ReplState,
213        gas: &mut GasMeter,
214    ) -> Result<(Pointer, Type), ExecutionError> {
215        let compiler = self.compiler.as_mut().ok_or_else(|| {
216            CompileError::from(EngineError::Internal("evaluator has no compiler".into()))
217        })?;
218        let compiled = compiler.compile_repl_program(program, state, gas).await?;
219        self.run_prepared(compiled, gas).await
220    }
221
222    pub async fn eval_snippet_at(
223        &mut self,
224        source: &str,
225        importer_path: impl AsRef<Path>,
226        gas: &mut GasMeter,
227    ) -> Result<(Pointer, Type), ExecutionError> {
228        let path = importer_path.as_ref().to_path_buf();
229        self.prepare_and_run(gas, |compiler, gas| {
230            compiler.compile_snippet_at(source, &path, gas)
231        })
232        .await
233    }
234}
235
236impl<'a, State> EvaluatorRef<'a, State>
237where
238    State: Clone + Send + Sync + 'static,
239{
240    pub(crate) fn new(runtime: &'a RuntimeSnapshot<State>) -> Self {
241        Self { runtime }
242    }
243
244    fn resolve_typeclass_method_impl(
245        &self,
246        name: &Symbol,
247        call_type: &Type,
248    ) -> Result<(Env, Arc<TypedExpr>, Subst), EngineError> {
249        let info = self
250            .runtime
251            .type_system
252            .class_methods
253            .get(name)
254            .ok_or_else(|| EngineError::UnknownVar(name.clone()))?;
255
256        let s_method = unify(&info.scheme.typ, call_type).map_err(EngineError::Type)?;
257        let class_pred = info
258            .scheme
259            .preds
260            .iter()
261            .find(|p| p.class == info.class)
262            .ok_or(EngineError::Type(TypeError::UnsupportedExpr(
263                "method scheme missing class predicate",
264            )))?;
265        let param_type = class_pred.typ.apply(&s_method);
266        if type_head_is_var(&param_type) {
267            return Err(EngineError::AmbiguousOverload { name: name.clone() });
268        }
269
270        self.runtime
271            .typeclasses
272            .resolve(&info.class, name, &param_type)
273    }
274
275    fn cached_class_method(&self, name: &Symbol, typ: &Type) -> Option<Pointer> {
276        if !typ.ftv().is_empty() {
277            return None;
278        }
279        let cache = self.runtime.typeclass_cache.lock().ok()?;
280        cache.get(&(name.clone(), typ.clone())).cloned()
281    }
282
283    fn insert_cached_class_method(&self, name: &Symbol, typ: &Type, pointer: &Pointer) {
284        if typ.ftv().is_empty()
285            && let Ok(mut cache) = self.runtime.typeclass_cache.lock()
286        {
287            cache.insert((name.clone(), typ.clone()), *pointer);
288        }
289    }
290
291    fn resolve_class_method_plan(
292        &self,
293        name: &Symbol,
294        typ: &Type,
295    ) -> Result<Result<(Env, TypedExpr), Pointer>, EngineError> {
296        let (def_env, typed, s) = match self.resolve_typeclass_method_impl(name, typ) {
297            Ok(res) => res,
298            Err(EngineError::AmbiguousOverload { .. }) if is_function_type(typ) => {
299                let (name, typ, applied, applied_types) =
300                    OverloadedFn::new(name.clone(), typ.clone()).into_parts();
301                let pointer =
302                    self.runtime
303                        .heap
304                        .alloc_overloaded(name, typ, applied, applied_types)?;
305                return Ok(Err(pointer));
306            }
307            Err(err) => return Err(err),
308        };
309        let specialized = typed.as_ref().apply(&s);
310        Ok(Ok((def_env, specialized)))
311    }
312
313    pub(crate) async fn resolve_class_method(
314        &self,
315        name: &Symbol,
316        typ: &Type,
317        gas: &mut GasMeter,
318    ) -> Result<Pointer, EngineError> {
319        if let Some(pointer) = self.cached_class_method(name, typ) {
320            return Ok(pointer);
321        }
322
323        let pointer = match self.resolve_class_method_plan(name, typ)? {
324            Ok((def_env, specialized)) => {
325                eval_typed_expr(self.runtime, &def_env, &specialized, gas).await?
326            }
327            Err(pointer) => pointer,
328        };
329
330        if typ.ftv().is_empty() {
331            self.insert_cached_class_method(name, typ, &pointer);
332        }
333        Ok(pointer)
334    }
335
336    pub(crate) fn resolve_native_impl(
337        &self,
338        name: &str,
339        typ: &Type,
340    ) -> Result<NativeImpl<State>, EngineError> {
341        let sym_name = sym(name);
342        let impls = self
343            .runtime
344            .natives
345            .get(&sym_name)
346            .ok_or_else(|| EngineError::UnknownVar(sym_name.clone()))?;
347        let matches: Vec<NativeImpl<State>> = impls
348            .iter()
349            .filter(|imp| impl_matches_type(imp, typ))
350            .cloned()
351            .collect();
352        match matches.len() {
353            0 => Err(EngineError::MissingImpl {
354                name: sym_name.clone(),
355                typ: typ.to_string(),
356            }),
357            1 => Ok(matches[0].clone()),
358            _ => Err(EngineError::AmbiguousImpl {
359                name: sym_name,
360                typ: typ.to_string(),
361            }),
362        }
363    }
364
365    pub(crate) fn resolve_native(
366        &self,
367        name: &str,
368        typ: &Type,
369        _gas: &mut GasMeter,
370    ) -> Result<Pointer, EngineError> {
371        let sym_name = sym(name);
372        let impls = self
373            .runtime
374            .natives
375            .get(&sym_name)
376            .ok_or_else(|| EngineError::UnknownVar(sym_name.clone()))?;
377        let matches: Vec<NativeImpl<State>> = impls
378            .iter()
379            .filter(|imp| impl_matches_type(imp, typ))
380            .cloned()
381            .collect();
382        match matches.len() {
383            0 => Err(EngineError::MissingImpl {
384                name: sym_name.clone(),
385                typ: typ.to_string(),
386            }),
387            1 => {
388                let imp = matches[0].clone();
389                let (native_id, name, arity, typ, gas_cost, applied, applied_types) =
390                    imp.to_native_fn(typ.clone()).into_parts();
391                self.runtime.heap.alloc_native(
392                    native_id,
393                    name,
394                    arity,
395                    typ,
396                    gas_cost,
397                    applied,
398                    applied_types,
399                )
400            }
401            _ => {
402                if typ.ftv().is_empty() {
403                    Err(EngineError::AmbiguousImpl {
404                        name: sym_name.clone(),
405                        typ: typ.to_string(),
406                    })
407                } else if is_function_type(typ) {
408                    let (name, typ, applied, applied_types) =
409                        OverloadedFn::new(sym_name.clone(), typ.clone()).into_parts();
410                    self.runtime
411                        .heap
412                        .alloc_overloaded(name, typ, applied, applied_types)
413                } else {
414                    Err(EngineError::AmbiguousOverload { name: sym_name })
415                }
416            }
417        }
418    }
419
420    pub(crate) async fn resolve_global(
421        &self,
422        name: &Symbol,
423        typ: &Type,
424    ) -> Result<Pointer, EngineError> {
425        if let Some(ptr) = self.runtime.env.get(name) {
426            let value = self.runtime.heap.get(&ptr)?;
427            match value.as_ref() {
428                Value::Native(native) if native.is_zero_unapplied() => {
429                    let mut gas = GasMeter::default();
430                    native.call_zero(self.runtime, &mut gas).await
431                }
432                _ => Ok(ptr),
433            }
434        } else if self.runtime.type_system.class_methods.contains_key(name) {
435            let mut gas = GasMeter::default();
436            self.resolve_class_method(name, typ, &mut gas).await
437        } else {
438            let mut gas = GasMeter::default();
439            let pointer = self.resolve_native(name.as_ref(), typ, &mut gas)?;
440            let value = self.runtime.heap.get(&pointer)?;
441            match value.as_ref() {
442                Value::Native(native) if native.is_zero_unapplied() => {
443                    let mut gas = GasMeter::default();
444                    native.call_zero(self.runtime, &mut gas).await
445                }
446                _ => Ok(pointer),
447            }
448        }
449    }
450
451    pub(crate) async fn call_native_impl(
452        &self,
453        name: &str,
454        typ: &Type,
455        args: &[Pointer],
456    ) -> Result<Pointer, EngineError> {
457        let imp = self.resolve_native_impl(name, typ)?;
458        imp.func.call(self.runtime, typ.clone(), args).await
459    }
460}
461
462impl<'a, State> Deref for EvaluatorRef<'a, State>
463where
464    State: Clone + Send + Sync + 'static,
465{
466    type Target = RuntimeSnapshot<State>;
467
468    fn deref(&self) -> &Self::Target {
469        self.runtime
470    }
471}