Skip to main content

graphix_compiler/node/
lambda.rs

1use super::{compiler::compile, Nop};
2use crate::{
3    env::{Bind, Env},
4    expr::{self, Arg, ErrorContext, Expr, ExprId},
5    node::pattern::StructPatternNode,
6    typ::{FnArgType, FnType, Type},
7    wrap, Apply, BindId, CFlag, Event, ExecCtx, InitFn, LambdaId, Node, Refs, Rt, Scope,
8    Update, UserEvent,
9};
10use anyhow::{anyhow, bail, Context, Result};
11use arcstr::ArcStr;
12use compact_str::format_compact;
13use enumflags2::BitFlags;
14use netidx::{pack::Pack, subscriber::Value, utils::Either};
15use parking_lot::RwLock;
16use poolshark::local::LPooled;
17use std::{fmt, hash::Hash, sync::Arc as SArc};
18use triomphe::Arc;
19
20pub struct LambdaDef<R: Rt, E: UserEvent> {
21    pub id: LambdaId,
22    pub env: Env,
23    pub scope: Scope,
24    pub argspec: Arc<[Arg]>,
25    pub typ: Arc<FnType>,
26    pub init: InitFn<R, E>,
27}
28
29impl<R: Rt, E: UserEvent> fmt::Debug for LambdaDef<R, E> {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        write!(f, "lambda#{}", self.id.inner())
32    }
33}
34
35impl<R: Rt, E: UserEvent> PartialEq for LambdaDef<R, E> {
36    fn eq(&self, other: &Self) -> bool {
37        self.id == other.id
38    }
39}
40
41impl<R: Rt, E: UserEvent> Eq for LambdaDef<R, E> {}
42
43impl<R: Rt, E: UserEvent> PartialOrd for LambdaDef<R, E> {
44    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
45        Some(self.id.cmp(&other.id))
46    }
47}
48
49impl<R: Rt, E: UserEvent> Ord for LambdaDef<R, E> {
50    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
51        self.id.cmp(&other.id)
52    }
53}
54
55impl<R: Rt, E: UserEvent> Hash for LambdaDef<R, E> {
56    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
57        self.id.hash(state)
58    }
59}
60
61impl<R: Rt, E: UserEvent> Pack for LambdaDef<R, E> {
62    fn encoded_len(&self) -> usize {
63        0
64    }
65
66    fn encode(
67        &self,
68        _buf: &mut impl bytes::BufMut,
69    ) -> std::result::Result<(), netidx::pack::PackError> {
70        Err(netidx::pack::PackError::Application(0))
71    }
72
73    fn decode(
74        _buf: &mut impl bytes::Buf,
75    ) -> std::result::Result<Self, netidx::pack::PackError> {
76        Err(netidx::pack::PackError::Application(0))
77    }
78}
79
80#[derive(Debug)]
81struct GXLambda<R: Rt, E: UserEvent> {
82    args: Box<[StructPatternNode]>,
83    body: Node<R, E>,
84    typ: Arc<FnType>,
85}
86
87impl<R: Rt, E: UserEvent> Apply<R, E> for GXLambda<R, E> {
88    fn update(
89        &mut self,
90        ctx: &mut ExecCtx<R, E>,
91        from: &mut [Node<R, E>],
92        event: &mut Event<E>,
93    ) -> Option<Value> {
94        for (arg, pat) in from.iter_mut().zip(&self.args) {
95            if let Some(v) = arg.update(ctx, event) {
96                pat.bind(&v, &mut |id, v| {
97                    ctx.cached.insert(id, v.clone());
98                    event.variables.insert(id, v);
99                })
100            }
101        }
102        self.body.update(ctx, event)
103    }
104
105    fn typecheck(
106        &mut self,
107        ctx: &mut ExecCtx<R, E>,
108        args: &mut [Node<R, E>],
109    ) -> Result<()> {
110        for (arg, FnArgType { typ, .. }) in args.iter_mut().zip(self.typ.args.iter()) {
111            wrap!(arg, arg.typecheck(ctx))?;
112            wrap!(arg, typ.check_contains(&ctx.env, &arg.typ()))?;
113        }
114        wrap!(self.body, self.body.typecheck(ctx))?;
115        wrap!(self.body, self.typ.rtype.check_contains(&ctx.env, &self.body.typ()))?;
116        for (tv, tc) in self.typ.constraints.read().iter() {
117            tc.check_contains(&ctx.env, &Type::TVar(tv.clone()))?
118        }
119        Ok(())
120    }
121
122    fn typ(&self) -> Arc<FnType> {
123        Arc::clone(&self.typ)
124    }
125
126    fn refs(&self, refs: &mut Refs) {
127        for pat in &self.args {
128            pat.ids(&mut |id| {
129                refs.bound.insert(id);
130            })
131        }
132        self.body.refs(refs)
133    }
134
135    fn delete(&mut self, ctx: &mut ExecCtx<R, E>) {
136        self.body.delete(ctx);
137        for n in &self.args {
138            n.delete(ctx)
139        }
140    }
141
142    fn sleep(&mut self, ctx: &mut ExecCtx<R, E>) {
143        self.body.sleep(ctx);
144    }
145}
146
147impl<R: Rt, E: UserEvent> GXLambda<R, E> {
148    pub(super) fn new(
149        ctx: &mut ExecCtx<R, E>,
150        flags: BitFlags<CFlag>,
151        typ: Arc<FnType>,
152        argspec: Arc<[Arg]>,
153        args: &[Node<R, E>],
154        scope: &Scope,
155        tid: ExprId,
156        body: Expr,
157    ) -> Result<Self> {
158        if args.len() != argspec.len() {
159            bail!("arity mismatch, expected {} arguments", argspec.len())
160        }
161        let mut argpats = vec![];
162        for (a, atyp) in argspec.iter().zip(typ.args.iter()) {
163            let pattern = StructPatternNode::compile(ctx, &atyp.typ, &a.pattern, scope)?;
164            if pattern.is_refutable() {
165                bail!(
166                    "refutable patterns are not allowed in lambda arguments {}",
167                    a.pattern
168                )
169            }
170            argpats.push(pattern);
171        }
172        let body = compile(ctx, flags, body, &scope, tid)?;
173        Ok(Self { args: Box::from(argpats), typ, body })
174    }
175}
176
177#[derive(Debug)]
178struct BuiltInLambda<R: Rt, E: UserEvent> {
179    typ: Arc<FnType>,
180    apply: Box<dyn Apply<R, E> + Send + Sync + 'static>,
181}
182
183impl<R: Rt, E: UserEvent> Apply<R, E> for BuiltInLambda<R, E> {
184    fn update(
185        &mut self,
186        ctx: &mut ExecCtx<R, E>,
187        from: &mut [Node<R, E>],
188        event: &mut Event<E>,
189    ) -> Option<Value> {
190        self.apply.update(ctx, from, event)
191    }
192
193    fn typecheck(
194        &mut self,
195        ctx: &mut ExecCtx<R, E>,
196        args: &mut [Node<R, E>],
197    ) -> Result<()> {
198        if args.len() < self.typ.args.len()
199            || (args.len() > self.typ.args.len() && self.typ.vargs.is_none())
200        {
201            let vargs = if self.typ.vargs.is_some() { "at least " } else { "" };
202            bail!(
203                "expected {}{} arguments got {}",
204                vargs,
205                self.typ.args.len(),
206                args.len()
207            )
208        }
209        for i in 0..args.len() {
210            wrap!(args[i], args[i].typecheck(ctx))?;
211            let atyp = if i < self.typ.args.len() {
212                &self.typ.args[i].typ
213            } else {
214                self.typ.vargs.as_ref().unwrap()
215            };
216            wrap!(args[i], atyp.check_contains(&ctx.env, &args[i].typ()))?
217        }
218        for (tv, tc) in self.typ.constraints.read().iter() {
219            tc.check_contains(&ctx.env, &Type::TVar(tv.clone()))?
220        }
221        self.apply.typecheck(ctx, args)?;
222        Ok(())
223    }
224
225    fn typ(&self) -> Arc<FnType> {
226        Arc::clone(&self.typ)
227    }
228
229    fn refs(&self, refs: &mut Refs) {
230        self.apply.refs(refs)
231    }
232
233    fn delete(&mut self, ctx: &mut ExecCtx<R, E>) {
234        self.apply.delete(ctx)
235    }
236
237    fn sleep(&mut self, ctx: &mut ExecCtx<R, E>) {
238        self.apply.sleep(ctx);
239    }
240}
241
242#[derive(Debug)]
243pub(crate) struct Lambda {
244    top_id: ExprId,
245    spec: Expr,
246    def: Value,
247    flags: BitFlags<CFlag>,
248    typ: Type,
249}
250
251impl Lambda {
252    pub(crate) fn compile<R: Rt, E: UserEvent>(
253        ctx: &mut ExecCtx<R, E>,
254        flags: BitFlags<CFlag>,
255        spec: Expr,
256        scope: &Scope,
257        l: &expr::LambdaExpr,
258        top_id: ExprId,
259    ) -> Result<Node<R, E>> {
260        let mut s: LPooled<Vec<&ArcStr>> = LPooled::take();
261        for a in l.args.iter() {
262            a.pattern.with_names(&mut |n| s.push(n));
263        }
264        let len = s.len();
265        s.sort();
266        s.dedup();
267        if len != s.len() {
268            bail!("arguments must have unique names");
269        }
270        let id = LambdaId::new();
271        let vargs = match l.vargs.as_ref() {
272            None => None,
273            Some(None) => Some(None),
274            Some(Some(typ)) => Some(Some(typ.scope_refs(&scope.lexical))),
275        };
276        let rtype = l.rtype.as_ref().map(|t| t.scope_refs(&scope.lexical));
277        let throws = l.throws.as_ref().map(|t| t.scope_refs(&scope.lexical));
278        let mut argspec = l
279            .args
280            .iter()
281            .map(|a| match &a.constraint {
282                None => Arg {
283                    labeled: a.labeled.clone(),
284                    pattern: a.pattern.clone(),
285                    constraint: None,
286                },
287                Some(typ) => Arg {
288                    labeled: a.labeled.clone(),
289                    pattern: a.pattern.clone(),
290                    constraint: Some(typ.scope_refs(&scope.lexical)),
291                },
292            })
293            .collect::<LPooled<Vec<_>>>();
294        let argspec = Arc::from_iter(argspec.drain(..));
295        let constraints = l
296            .constraints
297            .iter()
298            .map(|(tv, tc)| {
299                let tv = tv.scope_refs(&scope.lexical);
300                let tc = tc.scope_refs(&scope.lexical);
301                Ok((tv, tc))
302            })
303            .collect::<Result<LPooled<Vec<_>>>>()?;
304        let constraints = Arc::new(RwLock::new(constraints));
305        let original_scope = scope.clone();
306        let _original_scope = scope.clone();
307        let scope = scope.append(&format_compact!("fn{}", id.0));
308        let _scope = scope.clone();
309        let env = ctx.env.clone();
310        let _env = ctx.env.clone();
311        let typ = match &l.body {
312            Either::Left(_) => {
313                let args = Arc::from_iter(argspec.iter().map(|a| FnArgType {
314                    label: a.labeled.as_ref().and_then(|dv| {
315                        a.pattern.single_bind().map(|n| (n.clone(), dv.is_some()))
316                    }),
317                    typ: match a.constraint.as_ref() {
318                        Some(t) => t.clone(),
319                        None => Type::empty_tvar(),
320                    },
321                }));
322                let vargs = match vargs {
323                    Some(Some(t)) => Some(t.clone()),
324                    Some(None) => Some(Type::empty_tvar()),
325                    None => None,
326                };
327                let rtype = rtype.clone().unwrap_or_else(|| Type::empty_tvar());
328                let explicit_throws = throws.is_some();
329                let throws = throws.clone().unwrap_or_else(|| Type::empty_tvar());
330                Arc::new(FnType {
331                    constraints,
332                    args,
333                    vargs,
334                    rtype,
335                    throws,
336                    explicit_throws,
337                })
338            }
339            Either::Right(builtin) => match ctx.builtins.get(builtin.as_str()) {
340                None => bail!("unknown builtin function {builtin}"),
341                Some((styp, _)) => {
342                    if !ctx.builtins_allowed {
343                        bail!("defining builtins is not allowed in this context")
344                    }
345                    Arc::new(styp.clone().scope_refs(&_original_scope.lexical))
346                }
347            },
348        };
349        typ.alias_tvars(&mut LPooled::take());
350        let _typ = typ.clone();
351        let _argspec = argspec.clone();
352        let body = l.body.clone();
353        let init: InitFn<R, E> = SArc::new(move |scope, ctx, args, tid, tcerr| {
354            // restore the lexical environment to the state it was in
355            // when the closure was created
356            ctx.with_restored(_env.clone(), |ctx| match body.clone() {
357                Either::Left(body) => {
358                    let scope = Scope {
359                        dynamic: scope.dynamic.clone(),
360                        lexical: _scope.lexical.clone(),
361                    };
362                    let apply = GXLambda::new(
363                        ctx,
364                        flags,
365                        _typ.clone(),
366                        _argspec.clone(),
367                        args,
368                        &scope,
369                        tid,
370                        body.clone(),
371                    );
372                    apply.and_then(|a| {
373                        let mut f: Box<dyn Apply<R, E>> = Box::new(a);
374                        if tcerr {
375                            f.typecheck(ctx, args).map(|()| f)
376                        } else {
377                            let _ = f.typecheck(ctx, args);
378                            Ok(f)
379                        }
380                    })
381                }
382                Either::Right(builtin) => match ctx.builtins.get(&*builtin) {
383                    None => bail!("unknown builtin function {builtin}"),
384                    Some((_, init)) => {
385                        let init = SArc::clone(init);
386                        init(ctx, &_typ, &_scope, args, tid).and_then(|apply| {
387                            let mut f: Box<dyn Apply<R, E>> =
388                                Box::new(BuiltInLambda { typ: _typ.clone(), apply });
389                            if tcerr {
390                                f.typecheck(ctx, args).map(|()| f)
391                            } else {
392                                let _ = f.typecheck(ctx, args);
393                                Ok(f)
394                            }
395                        })
396                    }
397                },
398            })
399        });
400        let def = ctx.lambdawrap.wrap(LambdaDef {
401            id,
402            typ: typ.clone(),
403            env,
404            argspec,
405            init,
406            scope: original_scope,
407        });
408        Ok(Box::new(Self { spec, def, typ: Type::Fn(typ), top_id, flags }))
409    }
410}
411
412impl<R: Rt, E: UserEvent> Update<R, E> for Lambda {
413    fn update(
414        &mut self,
415        _ctx: &mut ExecCtx<R, E>,
416        event: &mut Event<E>,
417    ) -> Option<Value> {
418        event.init.then(|| self.def.clone())
419    }
420
421    fn spec(&self) -> &Expr {
422        &self.spec
423    }
424
425    fn refs(&self, _refs: &mut Refs) {}
426
427    fn delete(&mut self, _ctx: &mut ExecCtx<R, E>) {}
428
429    fn sleep(&mut self, _ctx: &mut ExecCtx<R, E>) {}
430
431    fn typ(&self) -> &Type {
432        &self.typ
433    }
434
435    fn typecheck(&mut self, ctx: &mut ExecCtx<R, E>) -> Result<()> {
436        let def = self
437            .def
438            .downcast_ref::<LambdaDef<R, E>>()
439            .ok_or_else(|| anyhow!("failed to unwrap lambda"))?;
440        let mut faux_args: LPooled<Vec<Node<R, E>>> = def
441            .argspec
442            .iter()
443            .zip(def.typ.args.iter())
444            .map(|(a, at)| match &a.labeled {
445                Some(Some(e)) => ctx.with_restored(def.env.clone(), |ctx| {
446                    compile(ctx, self.flags, e.clone(), &def.scope, self.top_id)
447                }),
448                Some(None) | None => {
449                    let n: Node<R, E> = Box::new(Nop { typ: at.typ.clone() });
450                    Ok(n)
451                }
452            })
453            .collect::<Result<_>>()?;
454        let faux_id = BindId::new();
455        ctx.env.by_id.insert_cow(
456            faux_id,
457            Bind {
458                doc: None,
459                export: false,
460                id: faux_id,
461                name: "faux".into(),
462                scope: def.scope.lexical.clone(),
463                typ: Type::empty_tvar(),
464            },
465        );
466        let prev_catch = ctx.env.catch.insert_cow(def.scope.dynamic.clone(), faux_id);
467        let res = (def.init)(&def.scope, ctx, &mut faux_args, ExprId::new(), true)
468            .with_context(|| ErrorContext(Update::<R, E>::spec(self).clone()));
469        let res = res.and_then(|mut f| {
470            let ftyp = f.typ().clone();
471            f.delete(ctx);
472            let inferred_throws = ctx.env.by_id[&faux_id]
473                .typ
474                .with_deref(|t| t.cloned())
475                .unwrap_or(Type::Bottom)
476                .scope_refs(&def.scope.lexical)
477                .normalize();
478            ftyp.throws
479                .check_contains(&ctx.env, &inferred_throws)
480                .with_context(|| ErrorContext(Update::<R, E>::spec(self).clone()))?;
481            ftyp.constrain_known();
482            Ok(())
483        });
484        ctx.env.by_id.remove_cow(&faux_id);
485        match prev_catch {
486            Some(id) => ctx.env.catch.insert_cow(def.scope.dynamic.clone(), id),
487            None => ctx.env.catch.remove_cow(&def.scope.dynamic),
488        };
489        self.typ.unbind_tvars();
490        res
491    }
492}