Skip to main content

graphix_compiler/node/
lambda.rs

1use super::{compiler::compile, Nop};
2use crate::{
3    env::{Bind, Env},
4    expr::{self, Arg, ErrorContext, Expr, ExprId},
5    node::pattern::StructPatternNode,
6    typ::{FnArgType, FnType, Type},
7    wrap, Apply, BindId, CFlag, Event, ExecCtx, InitFn, LambdaId, Node, Refs, Rt, Scope,
8    Update, UserEvent,
9};
10use anyhow::{anyhow, bail, Context, Result};
11use arcstr::ArcStr;
12use compact_str::format_compact;
13use enumflags2::BitFlags;
14use netidx::{pack::Pack, subscriber::Value, utils::Either};
15use parking_lot::RwLock;
16use poolshark::local::LPooled;
17use std::{fmt, hash::Hash, sync::Arc as SArc};
18use triomphe::Arc;
19
20pub struct LambdaDef<R: Rt, E: UserEvent> {
21    pub id: LambdaId,
22    pub env: Env,
23    pub scope: Scope,
24    pub argspec: Arc<[Arg]>,
25    pub typ: Arc<FnType>,
26    pub init: InitFn<R, E>,
27}
28
29impl<R: Rt, E: UserEvent> fmt::Debug for LambdaDef<R, E> {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        write!(f, "lambda#{}", self.id.inner())
32    }
33}
34
35impl<R: Rt, E: UserEvent> PartialEq for LambdaDef<R, E> {
36    fn eq(&self, other: &Self) -> bool {
37        self.id == other.id
38    }
39}
40
41impl<R: Rt, E: UserEvent> Eq for LambdaDef<R, E> {}
42
43impl<R: Rt, E: UserEvent> PartialOrd for LambdaDef<R, E> {
44    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
45        Some(self.id.cmp(&other.id))
46    }
47}
48
49impl<R: Rt, E: UserEvent> Ord for LambdaDef<R, E> {
50    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
51        self.id.cmp(&other.id)
52    }
53}
54
55impl<R: Rt, E: UserEvent> Hash for LambdaDef<R, E> {
56    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
57        self.id.hash(state)
58    }
59}
60
61impl<R: Rt, E: UserEvent> Pack for LambdaDef<R, E> {
62    fn encoded_len(&self) -> usize {
63        0
64    }
65
66    fn encode(
67        &self,
68        _buf: &mut impl bytes::BufMut,
69    ) -> std::result::Result<(), netidx::pack::PackError> {
70        Err(netidx::pack::PackError::Application(0))
71    }
72
73    fn decode(
74        _buf: &mut impl bytes::Buf,
75    ) -> std::result::Result<Self, netidx::pack::PackError> {
76        Err(netidx::pack::PackError::Application(0))
77    }
78}
79
80#[derive(Debug)]
81struct GXLambda<R: Rt, E: UserEvent> {
82    args: Box<[StructPatternNode]>,
83    body: Node<R, E>,
84    typ: Arc<FnType>,
85}
86
87impl<R: Rt, E: UserEvent> Apply<R, E> for GXLambda<R, E> {
88    fn update(
89        &mut self,
90        ctx: &mut ExecCtx<R, E>,
91        from: &mut [Node<R, E>],
92        event: &mut Event<E>,
93    ) -> Option<Value> {
94        for (arg, pat) in from.iter_mut().zip(&self.args) {
95            if let Some(v) = arg.update(ctx, event) {
96                pat.bind(&v, &mut |id, v| {
97                    ctx.cached.insert(id, v.clone());
98                    event.variables.insert(id, v);
99                })
100            }
101        }
102        self.body.update(ctx, event)
103    }
104
105    fn typecheck(
106        &mut self,
107        ctx: &mut ExecCtx<R, E>,
108        args: &mut [Node<R, E>],
109    ) -> Result<()> {
110        for (arg, FnArgType { typ, .. }) in args.iter_mut().zip(self.typ.args.iter()) {
111            wrap!(arg, arg.typecheck(ctx))?;
112            wrap!(arg, typ.check_contains(&ctx.env, &arg.typ()))?;
113        }
114        wrap!(self.body, self.body.typecheck(ctx))?;
115        wrap!(self.body, self.typ.rtype.check_contains(&ctx.env, &self.body.typ()))?;
116        for (tv, tc) in self.typ.constraints.read().iter() {
117            tc.check_contains(&ctx.env, &Type::TVar(tv.clone()))?
118        }
119        Ok(())
120    }
121
122    fn typ(&self) -> Arc<FnType> {
123        Arc::clone(&self.typ)
124    }
125
126    fn refs(&self, refs: &mut Refs) {
127        for pat in &self.args {
128            pat.ids(&mut |id| {
129                refs.bound.insert(id);
130            })
131        }
132        self.body.refs(refs)
133    }
134
135    fn delete(&mut self, ctx: &mut ExecCtx<R, E>) {
136        self.body.delete(ctx);
137        for n in &self.args {
138            n.delete(ctx)
139        }
140    }
141
142    fn sleep(&mut self, ctx: &mut ExecCtx<R, E>) {
143        self.body.sleep(ctx);
144    }
145}
146
147impl<R: Rt, E: UserEvent> GXLambda<R, E> {
148    pub(super) fn new(
149        ctx: &mut ExecCtx<R, E>,
150        flags: BitFlags<CFlag>,
151        typ: Arc<FnType>,
152        argspec: Arc<[Arg]>,
153        args: &[Node<R, E>],
154        scope: &Scope,
155        tid: ExprId,
156        body: Expr,
157    ) -> Result<Self> {
158        if args.len() != argspec.len() {
159            bail!("arity mismatch, expected {} arguments", argspec.len())
160        }
161        let mut argpats = vec![];
162        for (a, atyp) in argspec.iter().zip(typ.args.iter()) {
163            let pattern = StructPatternNode::compile(ctx, &atyp.typ, &a.pattern, scope)?;
164            if pattern.is_refutable() {
165                bail!(
166                    "refutable patterns are not allowed in lambda arguments {}",
167                    a.pattern
168                )
169            }
170            argpats.push(pattern);
171        }
172        let body = compile(ctx, flags, body, &scope, tid)?;
173        Ok(Self { args: Box::from(argpats), typ, body })
174    }
175}
176
177#[derive(Debug)]
178struct BuiltInLambda<R: Rt, E: UserEvent> {
179    typ: Arc<FnType>,
180    apply: Box<dyn Apply<R, E> + Send + Sync + 'static>,
181}
182
183impl<R: Rt, E: UserEvent> Apply<R, E> for BuiltInLambda<R, E> {
184    fn update(
185        &mut self,
186        ctx: &mut ExecCtx<R, E>,
187        from: &mut [Node<R, E>],
188        event: &mut Event<E>,
189    ) -> Option<Value> {
190        self.apply.update(ctx, from, event)
191    }
192
193    fn typecheck(
194        &mut self,
195        ctx: &mut ExecCtx<R, E>,
196        args: &mut [Node<R, E>],
197    ) -> Result<()> {
198        if args.len() < self.typ.args.len()
199            || (args.len() > self.typ.args.len() && self.typ.vargs.is_none())
200        {
201            let vargs = if self.typ.vargs.is_some() { "at least " } else { "" };
202            bail!(
203                "expected {}{} arguments got {}",
204                vargs,
205                self.typ.args.len(),
206                args.len()
207            )
208        }
209        for i in 0..args.len() {
210            wrap!(args[i], args[i].typecheck(ctx))?;
211            let atyp = if i < self.typ.args.len() {
212                &self.typ.args[i].typ
213            } else {
214                self.typ.vargs.as_ref().unwrap()
215            };
216            wrap!(args[i], atyp.check_contains(&ctx.env, &args[i].typ()))?
217        }
218        for (tv, tc) in self.typ.constraints.read().iter() {
219            tc.check_contains(&ctx.env, &Type::TVar(tv.clone()))?
220        }
221        self.apply.typecheck(ctx, args)?;
222        Ok(())
223    }
224
225    fn typ(&self) -> Arc<FnType> {
226        Arc::clone(&self.typ)
227    }
228
229    fn refs(&self, refs: &mut Refs) {
230        self.apply.refs(refs)
231    }
232
233    fn delete(&mut self, ctx: &mut ExecCtx<R, E>) {
234        self.apply.delete(ctx)
235    }
236
237    fn sleep(&mut self, ctx: &mut ExecCtx<R, E>) {
238        self.apply.sleep(ctx);
239    }
240}
241
242#[derive(Debug)]
243pub(crate) struct Lambda {
244    top_id: ExprId,
245    spec: Expr,
246    def: Value,
247    flags: BitFlags<CFlag>,
248    typ: Type,
249}
250
251impl Lambda {
252    pub(crate) fn compile<R: Rt, E: UserEvent>(
253        ctx: &mut ExecCtx<R, E>,
254        flags: BitFlags<CFlag>,
255        spec: Expr,
256        scope: &Scope,
257        l: &expr::LambdaExpr,
258        top_id: ExprId,
259    ) -> Result<Node<R, E>> {
260        let mut s: LPooled<Vec<&ArcStr>> = LPooled::take();
261        for a in l.args.iter() {
262            a.pattern.with_names(&mut |n| s.push(n));
263        }
264        let len = s.len();
265        s.sort();
266        s.dedup();
267        if len != s.len() {
268            bail!("arguments must have unique names");
269        }
270        let id = LambdaId::new();
271        let vargs = match l.vargs.as_ref() {
272            None => None,
273            Some(None) => Some(None),
274            Some(Some(typ)) => Some(Some(typ.scope_refs(&scope.lexical))),
275        };
276        let rtype = l.rtype.as_ref().map(|t| t.scope_refs(&scope.lexical));
277        let throws = l.throws.as_ref().map(|t| t.scope_refs(&scope.lexical));
278        let mut argspec = l
279            .args
280            .iter()
281            .map(|a| match &a.constraint {
282                None => Arg {
283                    labeled: a.labeled.clone(),
284                    pattern: a.pattern.clone(),
285                    constraint: None,
286                },
287                Some(typ) => Arg {
288                    labeled: a.labeled.clone(),
289                    pattern: a.pattern.clone(),
290                    constraint: Some(typ.scope_refs(&scope.lexical)),
291                },
292            })
293            .collect::<LPooled<Vec<_>>>();
294        let argspec = Arc::from_iter(argspec.drain(..));
295        let constraints = l
296            .constraints
297            .iter()
298            .map(|(tv, tc)| {
299                let tv = tv.scope_refs(&scope.lexical);
300                let tc = tc.scope_refs(&scope.lexical);
301                Ok((tv, tc))
302            })
303            .collect::<Result<LPooled<Vec<_>>>>()?;
304        let constraints = Arc::new(RwLock::new(constraints));
305        let original_scope = scope.clone();
306        let _original_scope = scope.clone();
307        let scope = scope.append(&format_compact!("fn{}", id.0));
308        let _scope = scope.clone();
309        let env = ctx.env.clone();
310        let _env = ctx.env.clone();
311        let typ = match &l.body {
312            Either::Left(_) => {
313                let args = Arc::from_iter(argspec.iter().map(|a| FnArgType {
314                    label: a.labeled.as_ref().and_then(|dv| {
315                        a.pattern.single_bind().map(|n| (n.clone(), dv.is_some()))
316                    }),
317                    typ: match a.constraint.as_ref() {
318                        Some(t) => t.clone(),
319                        None => Type::empty_tvar(),
320                    },
321                }));
322                let vargs = match vargs {
323                    Some(Some(t)) => Some(t.clone()),
324                    Some(None) => Some(Type::empty_tvar()),
325                    None => None,
326                };
327                let rtype = rtype.clone().unwrap_or_else(|| Type::empty_tvar());
328                let explicit_throws = throws.is_some();
329                let throws = throws.clone().unwrap_or_else(|| Type::empty_tvar());
330                Arc::new(FnType {
331                    constraints,
332                    args,
333                    vargs,
334                    rtype,
335                    throws,
336                    explicit_throws,
337                })
338            }
339            Either::Right(builtin) => match ctx.builtins.get(builtin.as_str()) {
340                None => bail!("unknown builtin function {builtin}"),
341                Some((styp, _)) => {
342                    if !ctx.builtins_allowed {
343                        bail!("defining builtins is not allowed in this context")
344                    }
345                    Arc::new(styp.scope_refs(&_original_scope.lexical))
346                }
347            },
348        };
349        typ.alias_tvars(&mut LPooled::take());
350        let _typ = typ.clone();
351        let _argspec = argspec.clone();
352        let body = l.body.clone();
353        let init: InitFn<R, E> = SArc::new(move |scope, ctx, args, tid, tcerr| {
354            // restore the lexical environment to the state it was in
355            // when the closure was created
356            ctx.with_restored(_env.clone(), |ctx| match body.clone() {
357                Either::Left(body) => {
358                    let scope = Scope {
359                        dynamic: scope.dynamic.clone(),
360                        lexical: _scope.lexical.clone(),
361                    };
362                    let apply = GXLambda::new(
363                        ctx,
364                        flags,
365                        _typ.clone(),
366                        _argspec.clone(),
367                        args,
368                        &scope,
369                        tid,
370                        body.clone(),
371                    );
372                    apply.and_then(|a| {
373                        let mut f: Box<dyn Apply<R, E>> = Box::new(a);
374                        if tcerr {
375                            f.typecheck(ctx, args).map(|()| f)
376                        } else {
377                            let _ = f.typecheck(ctx, args);
378                            Ok(f)
379                        }
380                    })
381                }
382                Either::Right(builtin) => match ctx.builtins.get(&*builtin) {
383                    None => bail!("unknown builtin function {builtin}"),
384                    Some((_, init)) => {
385                        init(ctx, &_typ, &_scope, args, tid).and_then(|apply| {
386                            let mut f: Box<dyn Apply<R, E>> =
387                                Box::new(BuiltInLambda { typ: _typ.clone(), apply });
388                            if tcerr {
389                                f.typecheck(ctx, args).map(|()| f)
390                            } else {
391                                let _ = f.typecheck(ctx, args);
392                                Ok(f)
393                            }
394                        })
395                    }
396                },
397            })
398        });
399        let def = ctx.lambdawrap.wrap(LambdaDef {
400            id,
401            typ: typ.clone(),
402            env,
403            argspec,
404            init,
405            scope: original_scope,
406        });
407        Ok(Box::new(Self { spec, def, typ: Type::Fn(typ), top_id, flags }))
408    }
409}
410
411impl<R: Rt, E: UserEvent> Update<R, E> for Lambda {
412    fn update(
413        &mut self,
414        _ctx: &mut ExecCtx<R, E>,
415        event: &mut Event<E>,
416    ) -> Option<Value> {
417        event.init.then(|| self.def.clone())
418    }
419
420    fn spec(&self) -> &Expr {
421        &self.spec
422    }
423
424    fn refs(&self, _refs: &mut Refs) {}
425
426    fn delete(&mut self, _ctx: &mut ExecCtx<R, E>) {}
427
428    fn sleep(&mut self, _ctx: &mut ExecCtx<R, E>) {}
429
430    fn typ(&self) -> &Type {
431        &self.typ
432    }
433
434    fn typecheck(&mut self, ctx: &mut ExecCtx<R, E>) -> Result<()> {
435        let def = self
436            .def
437            .downcast_ref::<LambdaDef<R, E>>()
438            .ok_or_else(|| anyhow!("failed to unwrap lambda"))?;
439        let mut faux_args: LPooled<Vec<Node<R, E>>> = def
440            .argspec
441            .iter()
442            .zip(def.typ.args.iter())
443            .map(|(a, at)| match &a.labeled {
444                Some(Some(e)) => ctx.with_restored(def.env.clone(), |ctx| {
445                    compile(ctx, self.flags, e.clone(), &def.scope, self.top_id)
446                }),
447                Some(None) | None => {
448                    let n: Node<R, E> = Box::new(Nop { typ: at.typ.clone() });
449                    Ok(n)
450                }
451            })
452            .collect::<Result<_>>()?;
453        let faux_id = BindId::new();
454        ctx.env.by_id.insert_cow(
455            faux_id,
456            Bind {
457                doc: None,
458                export: false,
459                id: faux_id,
460                name: "faux".into(),
461                scope: def.scope.lexical.clone(),
462                typ: Type::empty_tvar(),
463            },
464        );
465        let prev_catch = ctx.env.catch.insert_cow(def.scope.dynamic.clone(), faux_id);
466        let res = (def.init)(&def.scope, ctx, &mut faux_args, ExprId::new(), true)
467            .with_context(|| ErrorContext(Update::<R, E>::spec(self).clone()));
468        let res = res.and_then(|mut f| {
469            let ftyp = f.typ().clone();
470            f.delete(ctx);
471            let inferred_throws = ctx.env.by_id[&faux_id]
472                .typ
473                .with_deref(|t| t.cloned())
474                .unwrap_or(Type::Bottom)
475                .scope_refs(&def.scope.lexical)
476                .normalize();
477            ftyp.throws
478                .check_contains(&ctx.env, &inferred_throws)
479                .with_context(|| ErrorContext(Update::<R, E>::spec(self).clone()))?;
480            ftyp.constrain_known();
481            Ok(())
482        });
483        ctx.env.by_id.remove_cow(&faux_id);
484        match prev_catch {
485            Some(id) => ctx.env.catch.insert_cow(def.scope.dynamic.clone(), id),
486            None => ctx.env.catch.remove_cow(&def.scope.dynamic),
487        };
488        self.typ.unbind_tvars();
489        res
490    }
491}