Skip to main content

graphix_compiler/node/
lambda.rs

1use super::{compiler::compile, Nop};
2use crate::{
3    env::{Bind, Env},
4    expr::{self, Arg, ErrorContext, Expr, ExprId},
5    node::pattern::StructPatternNode,
6    typ::{FnArgType, FnType, Type},
7    wrap, Apply, BindId, CFlag, Event, ExecCtx, InitFn, LambdaId, Node, Refs, Rt, Scope,
8    TypecheckPhase, Update, UserEvent,
9};
10use anyhow::{anyhow, bail, Context, Result};
11use arcstr::ArcStr;
12use compact_str::format_compact;
13use enumflags2::BitFlags;
14use fxhash::FxHashSet;
15use netidx::{pack::Pack, subscriber::Value, utils::Either};
16use parking_lot::{Mutex, RwLock};
17use poolshark::local::LPooled;
18use std::{fmt, hash::Hash, sync::Arc as SArc};
19use triomphe::Arc;
20
21pub struct LambdaDef<R: Rt, E: UserEvent> {
22    pub id: LambdaId,
23    pub env: Env,
24    pub scope: Scope,
25    pub argspec: Arc<[Arg]>,
26    pub typ: Arc<FnType>,
27    pub init: InitFn<R, E>,
28    pub needs_callsite: bool,
29    pub check: Mutex<Option<Box<dyn Apply<R, E>>>>,
30}
31
32impl<R: Rt, E: UserEvent> fmt::Debug for LambdaDef<R, E> {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        write!(f, "lambda#{}", self.id.inner())
35    }
36}
37
38impl<R: Rt, E: UserEvent> PartialEq for LambdaDef<R, E> {
39    fn eq(&self, other: &Self) -> bool {
40        self.id == other.id
41    }
42}
43
44impl<R: Rt, E: UserEvent> Eq for LambdaDef<R, E> {}
45
46impl<R: Rt, E: UserEvent> PartialOrd for LambdaDef<R, E> {
47    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
48        Some(self.id.cmp(&other.id))
49    }
50}
51
52impl<R: Rt, E: UserEvent> Ord for LambdaDef<R, E> {
53    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
54        self.id.cmp(&other.id)
55    }
56}
57
58impl<R: Rt, E: UserEvent> Hash for LambdaDef<R, E> {
59    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
60        self.id.hash(state)
61    }
62}
63
64impl<R: Rt, E: UserEvent> Pack for LambdaDef<R, E> {
65    fn encoded_len(&self) -> usize {
66        0
67    }
68
69    fn encode(
70        &self,
71        _buf: &mut impl bytes::BufMut,
72    ) -> std::result::Result<(), netidx::pack::PackError> {
73        Err(netidx::pack::PackError::Application(0))
74    }
75
76    fn decode(
77        _buf: &mut impl bytes::Buf,
78    ) -> std::result::Result<Self, netidx::pack::PackError> {
79        Err(netidx::pack::PackError::Application(0))
80    }
81}
82
83#[derive(Debug)]
84struct GXLambda<R: Rt, E: UserEvent> {
85    args: Box<[StructPatternNode]>,
86    body: Node<R, E>,
87    typ: Arc<FnType>,
88}
89
90impl<R: Rt, E: UserEvent> Apply<R, E> for GXLambda<R, E> {
91    fn update(
92        &mut self,
93        ctx: &mut ExecCtx<R, E>,
94        from: &mut [Node<R, E>],
95        event: &mut Event<E>,
96    ) -> Option<Value> {
97        for (arg, pat) in from.iter_mut().zip(&self.args) {
98            if let Some(v) = arg.update(ctx, event) {
99                pat.bind(&v, &mut |id, v| {
100                    ctx.cached.insert(id, v.clone());
101                    event.variables.insert(id, v);
102                })
103            }
104        }
105        self.body.update(ctx, event)
106    }
107
108    fn typecheck(
109        &mut self,
110        ctx: &mut ExecCtx<R, E>,
111        args: &mut [Node<R, E>],
112        _phase: TypecheckPhase<'_>,
113    ) -> Result<()> {
114        for (arg, FnArgType { typ, .. }) in args.iter_mut().zip(self.typ.args.iter()) {
115            wrap!(arg, arg.typecheck(ctx))?;
116            wrap!(arg, typ.check_contains(&ctx.env, &arg.typ()))?;
117        }
118        wrap!(self.body, self.body.typecheck(ctx))?;
119        wrap!(self.body, self.typ.rtype.check_contains(&ctx.env, &self.body.typ()))?;
120        for (tv, tc) in self.typ.constraints.read().iter() {
121            tc.check_contains(&ctx.env, &Type::TVar(tv.clone()))?
122        }
123        Ok(())
124    }
125
126    fn typ(&self) -> Arc<FnType> {
127        Arc::clone(&self.typ)
128    }
129
130    fn refs(&self, refs: &mut Refs) {
131        for pat in &self.args {
132            pat.ids(&mut |id| {
133                refs.bound.insert(id);
134            })
135        }
136        self.body.refs(refs)
137    }
138
139    fn delete(&mut self, ctx: &mut ExecCtx<R, E>) {
140        self.body.delete(ctx);
141        for n in &self.args {
142            n.delete(ctx)
143        }
144    }
145
146    fn sleep(&mut self, ctx: &mut ExecCtx<R, E>) {
147        self.body.sleep(ctx);
148    }
149}
150
151impl<R: Rt, E: UserEvent> GXLambda<R, E> {
152    pub(super) fn new(
153        ctx: &mut ExecCtx<R, E>,
154        flags: BitFlags<CFlag>,
155        typ: Arc<FnType>,
156        argspec: Arc<[Arg]>,
157        args: &[Node<R, E>],
158        scope: &Scope,
159        tid: ExprId,
160        body: Expr,
161    ) -> Result<Self> {
162        if args.len() != argspec.len() {
163            bail!("arity mismatch, expected {} arguments", argspec.len())
164        }
165        let mut argpats = vec![];
166        for (a, atyp) in argspec.iter().zip(typ.args.iter()) {
167            let pattern = StructPatternNode::compile(ctx, &atyp.typ, &a.pattern, scope)?;
168            if pattern.is_refutable() {
169                bail!(
170                    "refutable patterns are not allowed in lambda arguments {}",
171                    a.pattern
172                )
173            }
174            argpats.push(pattern);
175        }
176        let body = compile(ctx, flags, body, &scope, tid)?;
177        Ok(Self { args: Box::from(argpats), typ, body })
178    }
179}
180
181#[derive(Debug)]
182struct BuiltInLambda<R: Rt, E: UserEvent> {
183    typ: Arc<FnType>,
184    apply: Box<dyn Apply<R, E> + Send + Sync + 'static>,
185}
186
187impl<R: Rt, E: UserEvent> Apply<R, E> for BuiltInLambda<R, E> {
188    fn update(
189        &mut self,
190        ctx: &mut ExecCtx<R, E>,
191        from: &mut [Node<R, E>],
192        event: &mut Event<E>,
193    ) -> Option<Value> {
194        self.apply.update(ctx, from, event)
195    }
196
197    fn typecheck(
198        &mut self,
199        ctx: &mut ExecCtx<R, E>,
200        args: &mut [Node<R, E>],
201        phase: TypecheckPhase<'_>,
202    ) -> Result<()> {
203        match &phase {
204            TypecheckPhase::CallSite(_) => (),
205            TypecheckPhase::Lambda => {
206                if args.len() < self.typ.args.len()
207                    || (args.len() > self.typ.args.len() && self.typ.vargs.is_none())
208                {
209                    let vargs = if self.typ.vargs.is_some() { "at least " } else { "" };
210                    bail!(
211                        "expected {}{} arguments got {}",
212                        vargs,
213                        self.typ.args.len(),
214                        args.len()
215                    )
216                }
217                for i in 0..args.len() {
218                    wrap!(args[i], args[i].typecheck(ctx))?;
219                    let atyp = if i < self.typ.args.len() {
220                        &self.typ.args[i].typ
221                    } else {
222                        self.typ.vargs.as_ref().unwrap()
223                    };
224                    wrap!(args[i], atyp.check_contains(&ctx.env, &args[i].typ()))?
225                }
226                for (tv, tc) in self.typ.constraints.read().iter() {
227                    tc.check_contains(&ctx.env, &Type::TVar(tv.clone()))?
228                }
229            }
230        }
231        self.apply.typecheck(ctx, args, phase)
232    }
233
234    fn typ(&self) -> Arc<FnType> {
235        Arc::clone(&self.typ)
236    }
237
238    fn refs(&self, refs: &mut Refs) {
239        self.apply.refs(refs)
240    }
241
242    fn delete(&mut self, ctx: &mut ExecCtx<R, E>) {
243        self.apply.delete(ctx)
244    }
245
246    fn sleep(&mut self, ctx: &mut ExecCtx<R, E>) {
247        self.apply.sleep(ctx);
248    }
249}
250
251#[derive(Debug)]
252pub(crate) struct Lambda {
253    top_id: ExprId,
254    spec: Expr,
255    def: Value,
256    flags: BitFlags<CFlag>,
257    typ: Type,
258}
259
260impl Lambda {
261    pub(crate) fn compile<R: Rt, E: UserEvent>(
262        ctx: &mut ExecCtx<R, E>,
263        flags: BitFlags<CFlag>,
264        spec: Expr,
265        scope: &Scope,
266        l: &expr::LambdaExpr,
267        top_id: ExprId,
268    ) -> Result<Node<R, E>> {
269        let mut s: LPooled<Vec<&ArcStr>> = LPooled::take();
270        for a in l.args.iter() {
271            a.pattern.with_names(&mut |n| s.push(n));
272        }
273        let len = s.len();
274        s.sort();
275        s.dedup();
276        if len != s.len() {
277            bail!("arguments must have unique names");
278        }
279        let id = LambdaId::new();
280        let vargs = match l.vargs.as_ref() {
281            None => None,
282            Some(None) => Some(None),
283            Some(Some(typ)) => Some(Some(typ.scope_refs(&scope.lexical))),
284        };
285        let rtype = l.rtype.as_ref().map(|t| t.scope_refs(&scope.lexical));
286        let throws = l.throws.as_ref().map(|t| t.scope_refs(&scope.lexical));
287        let mut argspec = l
288            .args
289            .iter()
290            .map(|a| match &a.constraint {
291                None => Arg {
292                    labeled: a.labeled.clone(),
293                    pattern: a.pattern.clone(),
294                    constraint: None,
295                },
296                Some(typ) => Arg {
297                    labeled: a.labeled.clone(),
298                    pattern: a.pattern.clone(),
299                    constraint: Some(typ.scope_refs(&scope.lexical)),
300                },
301            })
302            .collect::<LPooled<Vec<_>>>();
303        let argspec = Arc::from_iter(argspec.drain(..));
304        let constraints = l
305            .constraints
306            .iter()
307            .map(|(tv, tc)| {
308                let tv = tv.scope_refs(&scope.lexical);
309                let tc = tc.scope_refs(&scope.lexical);
310                Ok((tv, tc))
311            })
312            .collect::<Result<LPooled<Vec<_>>>>()?;
313        let constraints = Arc::new(RwLock::new(constraints));
314        let original_scope = scope.clone();
315        let _original_scope = scope.clone();
316        let scope = scope.append(&format_compact!("fn{}", id.0));
317        let _scope = scope.clone();
318        let env = ctx.env.clone();
319        let _env = ctx.env.clone();
320        let mut needs_callsite = false;
321        if let Either::Right(builtin) = &l.body {
322            if let Some((_, nc)) = ctx.builtins.get(builtin.as_str()) {
323                needs_callsite = *nc;
324            } else {
325                bail!("unknown builtin function {builtin}")
326            }
327            if !ctx.builtins_allowed {
328                bail!("defining builtins is not allowed in this context")
329            }
330            for a in argspec.iter() {
331                if a.constraint.is_none() {
332                    bail!("builtin function {builtin} requires all arguments to have type annotations")
333                }
334            }
335            if rtype.is_none() {
336                bail!("builtin function {builtin} requires a return type annotation")
337            }
338        }
339        let typ = {
340            let args = Arc::from_iter(argspec.iter().map(|a| FnArgType {
341                label: a.labeled.as_ref().and_then(|dv| {
342                    a.pattern.single_bind().map(|n| (n.clone(), dv.is_some()))
343                }),
344                typ: match a.constraint.as_ref() {
345                    Some(t) => t.clone(),
346                    None => Type::empty_tvar(),
347                },
348            }));
349            let vargs = match vargs {
350                Some(Some(t)) => Some(t.clone()),
351                Some(None) => Some(Type::empty_tvar()),
352                None => None,
353            };
354            let rtype = rtype.clone().unwrap_or_else(|| Type::empty_tvar());
355            let explicit_throws = throws.is_some();
356            let throws = throws.clone().unwrap_or_else(|| Type::empty_tvar());
357            Arc::new(FnType {
358                constraints,
359                args,
360                vargs,
361                rtype,
362                throws,
363                explicit_throws,
364                lambda_ids: Arc::new(RwLock::new(FxHashSet::default())),
365            })
366        };
367        typ.alias_tvars(&mut LPooled::take());
368        if needs_callsite {
369            typ.lambda_ids.write().insert(id);
370        }
371        let _typ = typ.clone();
372        let _argspec = argspec.clone();
373        let body = l.body.clone();
374        let init: InitFn<R, E> = SArc::new(move |scope, ctx, args, resolved, tid| {
375            // restore the lexical environment to the state it was in
376            // when the closure was created
377            ctx.with_restored(_env.clone(), |ctx| match body.clone() {
378                Either::Left(body) => {
379                    let scope = Scope {
380                        dynamic: scope.dynamic.clone(),
381                        lexical: _scope.lexical.clone(),
382                    };
383                    GXLambda::new(
384                        ctx,
385                        flags,
386                        _typ.clone(),
387                        _argspec.clone(),
388                        args,
389                        &scope,
390                        tid,
391                        body.clone(),
392                    )
393                    .map(|a| -> Box<dyn Apply<R, E>> { Box::new(a) })
394                }
395                Either::Right(builtin) => match ctx.builtins.get(&*builtin) {
396                    None => bail!("unknown builtin function {builtin}"),
397                    Some((init, _)) => init(ctx, &_typ, resolved, &_scope, args, tid)
398                        .map(|apply| {
399                            let f: Box<dyn Apply<R, E>> =
400                                Box::new(BuiltInLambda { typ: _typ.clone(), apply });
401                            f
402                        }),
403                },
404            })
405        });
406        let def = ctx.lambdawrap.wrap(LambdaDef {
407            id,
408            typ: typ.clone(),
409            env,
410            argspec,
411            init,
412            scope: original_scope,
413            needs_callsite,
414            check: Mutex::new(None),
415        });
416        ctx.lambda_defs.insert(id, def.clone());
417        Ok(Box::new(Self { spec, def, typ: Type::Fn(typ), top_id, flags }))
418    }
419}
420
421impl<R: Rt, E: UserEvent> Update<R, E> for Lambda {
422    fn update(
423        &mut self,
424        _ctx: &mut ExecCtx<R, E>,
425        event: &mut Event<E>,
426    ) -> Option<Value> {
427        event.init.then(|| self.def.clone())
428    }
429
430    fn spec(&self) -> &Expr {
431        &self.spec
432    }
433
434    fn refs(&self, _refs: &mut Refs) {}
435
436    fn delete(&mut self, _ctx: &mut ExecCtx<R, E>) {}
437
438    fn sleep(&mut self, _ctx: &mut ExecCtx<R, E>) {}
439
440    fn typ(&self) -> &Type {
441        &self.typ
442    }
443
444    fn typecheck(&mut self, ctx: &mut ExecCtx<R, E>) -> Result<()> {
445        let def = self
446            .def
447            .downcast_ref::<LambdaDef<R, E>>()
448            .ok_or_else(|| anyhow!("failed to unwrap lambda"))?;
449        let needs_callsite = def.needs_callsite;
450        let mut faux_args: LPooled<Vec<Node<R, E>>> = def
451            .argspec
452            .iter()
453            .zip(def.typ.args.iter())
454            .map(|(a, at)| match &a.labeled {
455                Some(Some(e)) => ctx.with_restored(def.env.clone(), |ctx| {
456                    compile(ctx, self.flags, e.clone(), &def.scope, self.top_id)
457                }),
458                Some(None) | None => {
459                    let n: Node<R, E> = Box::new(Nop { typ: at.typ.clone() });
460                    Ok(n)
461                }
462            })
463            .collect::<Result<_>>()?;
464        let faux_id = BindId::new();
465        ctx.env.by_id.insert_cow(
466            faux_id,
467            Bind {
468                doc: None,
469                export: false,
470                id: faux_id,
471                name: "faux".into(),
472                scope: def.scope.lexical.clone(),
473                typ: Type::empty_tvar(),
474            },
475        );
476        let prev_catch = ctx.env.catch.insert_cow(def.scope.dynamic.clone(), faux_id);
477        let res = (def.init)(&def.scope, ctx, &mut faux_args, None, ExprId::new())
478            .with_context(|| ErrorContext(Update::<R, E>::spec(self).clone()));
479        let res = res.and_then(|mut f| {
480            let ftyp = f.typ().clone();
481            let res = f
482                .typecheck(ctx, &mut faux_args, TypecheckPhase::Lambda)
483                .with_context(|| ErrorContext(Update::<R, E>::spec(self).clone()));
484            if !needs_callsite {
485                f.delete(ctx)
486            } else {
487                let def = self
488                    .def
489                    .downcast_ref::<LambdaDef<R, E>>()
490                    .expect("failed to unwrap lambda");
491                *def.check.lock() = Some(f);
492            }
493            res?;
494            let inferred_throws = ctx.env.by_id[&faux_id]
495                .typ
496                .with_deref(|t| t.cloned())
497                .unwrap_or(Type::Bottom)
498                .scope_refs(&def.scope.lexical)
499                .normalize();
500            ftyp.throws
501                .check_contains(&ctx.env, &inferred_throws)
502                .with_context(|| ErrorContext(Update::<R, E>::spec(self).clone()))?;
503            ftyp.constrain_known();
504            Ok(())
505        });
506        ctx.env.by_id.remove_cow(&faux_id);
507        match prev_catch {
508            Some(id) => ctx.env.catch.insert_cow(def.scope.dynamic.clone(), id),
509            None => ctx.env.catch.remove_cow(&def.scope.dynamic),
510        };
511        self.typ.unbind_tvars();
512        res
513    }
514}