Skip to main content

graphix_compiler/node/
lambda.rs

1use super::{compiler::compile, Nop};
2use crate::{
3    env::{Bind, Env},
4    expr::{self, Arg, ErrorContext, Expr, ExprId, Origin},
5    node::pattern::StructPatternNode,
6    typ::{FnArgKind, FnArgType, FnType, Type},
7    wrap, Apply, BindId, CFlag, Event, ExecCtx, InitFn, LambdaId, Node, Refs, Rt, Scope,
8    TypecheckPhase, Update, UserEvent,
9};
10use anyhow::{anyhow, bail, Context, Result};
11use arcstr::ArcStr;
12use combine::stream::position::SourcePosition;
13use compact_str::format_compact;
14use enumflags2::BitFlags;
15use netidx::{pack::Pack, subscriber::Value, utils::Either};
16use nohash::IntSet;
17use parking_lot::{Mutex, RwLock};
18use poolshark::local::LPooled;
19use std::{fmt, hash::Hash, sync::Arc as SArc};
20use triomphe::Arc;
21
22pub struct LambdaDef<R: Rt, E: UserEvent> {
23    pub id: LambdaId,
24    pub env: Env,
25    pub scope: Scope,
26    pub argspec: Arc<[Arg]>,
27    pub typ: Arc<FnType>,
28    pub init: InitFn<R, E>,
29    pub needs_callsite: bool,
30    pub check: Mutex<Option<Box<dyn Apply<R, E>>>>,
31}
32
33impl<R: Rt, E: UserEvent> fmt::Debug for LambdaDef<R, E> {
34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35        write!(f, "lambda#{}", self.id.inner())
36    }
37}
38
39impl<R: Rt, E: UserEvent> PartialEq for LambdaDef<R, E> {
40    fn eq(&self, other: &Self) -> bool {
41        self.id == other.id
42    }
43}
44
45impl<R: Rt, E: UserEvent> Eq for LambdaDef<R, E> {}
46
47impl<R: Rt, E: UserEvent> PartialOrd for LambdaDef<R, E> {
48    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
49        Some(self.id.cmp(&other.id))
50    }
51}
52
53impl<R: Rt, E: UserEvent> Ord for LambdaDef<R, E> {
54    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
55        self.id.cmp(&other.id)
56    }
57}
58
59impl<R: Rt, E: UserEvent> Hash for LambdaDef<R, E> {
60    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
61        self.id.hash(state)
62    }
63}
64
65impl<R: Rt, E: UserEvent> Pack for LambdaDef<R, E> {
66    fn encoded_len(&self) -> usize {
67        0
68    }
69
70    fn encode(
71        &self,
72        _buf: &mut impl bytes::BufMut,
73    ) -> std::result::Result<(), netidx::pack::PackError> {
74        Err(netidx::pack::PackError::Application(0))
75    }
76
77    fn decode(
78        _buf: &mut impl bytes::Buf,
79    ) -> std::result::Result<Self, netidx::pack::PackError> {
80        Err(netidx::pack::PackError::Application(0))
81    }
82}
83
84#[derive(Debug)]
85struct GXLambda<R: Rt, E: UserEvent> {
86    args: Box<[StructPatternNode]>,
87    body: Node<R, E>,
88    typ: Arc<FnType>,
89}
90
91impl<R: Rt, E: UserEvent> Apply<R, E> for GXLambda<R, E> {
92    fn update(
93        &mut self,
94        ctx: &mut ExecCtx<R, E>,
95        from: &mut [Node<R, E>],
96        event: &mut Event<E>,
97    ) -> Option<Value> {
98        for (arg, pat) in from.iter_mut().zip(&self.args) {
99            if let Some(v) = arg.update(ctx, event) {
100                pat.bind(&v, &mut |id, v| {
101                    ctx.cached.insert(id, v.clone());
102                    event.variables.insert(id, v);
103                })
104            }
105        }
106        self.body.update(ctx, event)
107    }
108
109    fn typecheck(
110        &mut self,
111        ctx: &mut ExecCtx<R, E>,
112        args: &mut [Node<R, E>],
113        _phase: TypecheckPhase<'_>,
114    ) -> Result<()> {
115        for (arg, FnArgType { typ, .. }) in args.iter_mut().zip(self.typ.args.iter()) {
116            wrap!(arg, arg.typecheck(ctx))?;
117            wrap!(arg, typ.check_contains(&ctx.env, &arg.typ()))?;
118        }
119        wrap!(self.body, self.body.typecheck(ctx))?;
120        wrap!(self.body, self.typ.rtype.check_contains(&ctx.env, &self.body.typ()))?;
121        for (tv, tc) in self.typ.constraints.read().iter() {
122            tc.check_contains(&ctx.env, &Type::TVar(tv.clone()))?
123        }
124        Ok(())
125    }
126
127    fn typ(&self) -> Arc<FnType> {
128        Arc::clone(&self.typ)
129    }
130
131    fn refs(&self, refs: &mut Refs) {
132        for pat in &self.args {
133            pat.ids(&mut |id| {
134                refs.bound.insert(id);
135            })
136        }
137        self.body.refs(refs)
138    }
139
140    fn delete(&mut self, ctx: &mut ExecCtx<R, E>) {
141        self.body.delete(ctx);
142        for n in &self.args {
143            n.delete(ctx)
144        }
145    }
146
147    fn sleep(&mut self, ctx: &mut ExecCtx<R, E>) {
148        self.body.sleep(ctx);
149    }
150}
151
152impl<R: Rt, E: UserEvent> GXLambda<R, E> {
153    pub(super) fn new(
154        ctx: &mut ExecCtx<R, E>,
155        flags: BitFlags<CFlag>,
156        typ: Arc<FnType>,
157        argspec: Arc<[Arg]>,
158        args: &[Node<R, E>],
159        scope: &Scope,
160        tid: ExprId,
161        body: Expr,
162    ) -> Result<Self> {
163        if args.len() != argspec.len() {
164            bail!("arity mismatch, expected {} arguments", argspec.len())
165        }
166        let mut argpats = vec![];
167        for (a, atyp) in argspec.iter().zip(typ.args.iter()) {
168            let pattern = StructPatternNode::compile(
169                ctx,
170                &atyp.typ,
171                &a.pattern,
172                scope,
173                a.pos,
174                body.ori.clone(),
175            )?;
176            if pattern.is_refutable() {
177                bail!(
178                    "refutable patterns are not allowed in lambda arguments {}",
179                    a.pattern
180                )
181            }
182            argpats.push(pattern);
183        }
184        let body = compile(ctx, flags, body, &scope, tid)?;
185        Ok(Self { args: Box::from(argpats), typ, body })
186    }
187}
188
189#[derive(Debug)]
190struct BuiltInLambda<R: Rt, E: UserEvent> {
191    typ: Arc<FnType>,
192    apply: Box<dyn Apply<R, E> + Send + Sync + 'static>,
193}
194
195impl<R: Rt, E: UserEvent> Apply<R, E> for BuiltInLambda<R, E> {
196    fn update(
197        &mut self,
198        ctx: &mut ExecCtx<R, E>,
199        from: &mut [Node<R, E>],
200        event: &mut Event<E>,
201    ) -> Option<Value> {
202        self.apply.update(ctx, from, event)
203    }
204
205    fn typecheck(
206        &mut self,
207        ctx: &mut ExecCtx<R, E>,
208        args: &mut [Node<R, E>],
209        phase: TypecheckPhase<'_>,
210    ) -> Result<()> {
211        match &phase {
212            TypecheckPhase::CallSite(_) => (),
213            TypecheckPhase::Lambda => {
214                if args.len() < self.typ.args.len()
215                    || (args.len() > self.typ.args.len() && self.typ.vargs.is_none())
216                {
217                    let vargs = if self.typ.vargs.is_some() { "at least " } else { "" };
218                    bail!(
219                        "expected {}{} arguments got {}",
220                        vargs,
221                        self.typ.args.len(),
222                        args.len()
223                    )
224                }
225                for i in 0..args.len() {
226                    wrap!(args[i], args[i].typecheck(ctx))?;
227                    let atyp = if i < self.typ.args.len() {
228                        &self.typ.args[i].typ
229                    } else {
230                        self.typ.vargs.as_ref().unwrap()
231                    };
232                    wrap!(args[i], atyp.check_contains(&ctx.env, &args[i].typ()))?
233                }
234                for (tv, tc) in self.typ.constraints.read().iter() {
235                    tc.check_contains(&ctx.env, &Type::TVar(tv.clone()))?
236                }
237            }
238        }
239        self.apply.typecheck(ctx, args, phase)
240    }
241
242    fn typ(&self) -> Arc<FnType> {
243        Arc::clone(&self.typ)
244    }
245
246    fn refs(&self, refs: &mut Refs) {
247        self.apply.refs(refs)
248    }
249
250    fn delete(&mut self, ctx: &mut ExecCtx<R, E>) {
251        self.apply.delete(ctx)
252    }
253
254    fn sleep(&mut self, ctx: &mut ExecCtx<R, E>) {
255        self.apply.sleep(ctx);
256    }
257}
258
259#[derive(Debug)]
260pub(crate) struct Lambda {
261    top_id: ExprId,
262    spec: Expr,
263    def: Value,
264    flags: BitFlags<CFlag>,
265    typ: Type,
266}
267
268impl Lambda {
269    pub(crate) fn compile<R: Rt, E: UserEvent>(
270        ctx: &mut ExecCtx<R, E>,
271        flags: BitFlags<CFlag>,
272        spec: Expr,
273        scope: &Scope,
274        l: &expr::LambdaExpr,
275        top_id: ExprId,
276    ) -> Result<Node<R, E>> {
277        let mut s: LPooled<Vec<&ArcStr>> = LPooled::take();
278        for a in l.args.iter() {
279            a.pattern.with_names(&mut |n| s.push(n));
280        }
281        let len = s.len();
282        s.sort();
283        s.dedup();
284        if len != s.len() {
285            bail!("arguments must have unique names");
286        }
287        let id = LambdaId::new();
288        let vargs = match l.vargs.as_ref() {
289            None => None,
290            Some(None) => Some(None),
291            Some(Some(typ)) => Some(Some(typ.scope_refs(&scope.lexical))),
292        };
293        let rtype = l.rtype.as_ref().map(|t| t.scope_refs(&scope.lexical));
294        let throws = l.throws.as_ref().map(|t| t.scope_refs(&scope.lexical));
295        let mut argspec = l
296            .args
297            .iter()
298            .map(|a| match &a.constraint {
299                None => Arg {
300                    labeled: a.labeled.clone(),
301                    pattern: a.pattern.clone(),
302                    constraint: None,
303                    pos: a.pos,
304                },
305                Some(typ) => Arg {
306                    labeled: a.labeled.clone(),
307                    pattern: a.pattern.clone(),
308                    constraint: Some(typ.scope_refs(&scope.lexical)),
309                    pos: a.pos,
310                },
311            })
312            .collect::<LPooled<Vec<_>>>();
313        let argspec = Arc::from_iter(argspec.drain(..));
314        let constraints = l
315            .constraints
316            .iter()
317            .map(|(tv, tc)| {
318                let tv = tv.scope_refs(&scope.lexical);
319                let tc = tc.scope_refs(&scope.lexical);
320                Ok((tv, tc))
321            })
322            .collect::<Result<LPooled<Vec<_>>>>()?;
323        let constraints = Arc::new(RwLock::new(constraints));
324        let original_scope = scope.clone();
325        let _original_scope = scope.clone();
326        let scope = scope.append(&format_compact!("fn{}", id.0));
327        let _scope = scope.clone();
328        let env = ctx.env.clone();
329        let _env = ctx.env.clone();
330        let mut needs_callsite = false;
331        if let Either::Right(builtin) = &l.body {
332            if let Some((_, nc)) = ctx.builtins.get(builtin.as_str()) {
333                needs_callsite = *nc;
334            } else {
335                bail!("unknown builtin function {builtin}")
336            }
337            if !ctx.builtins_allowed {
338                bail!("defining builtins is not allowed in this context")
339            }
340            for a in argspec.iter() {
341                if a.constraint.is_none() {
342                    bail!("builtin function {builtin} requires all arguments to have type annotations")
343                }
344            }
345            if rtype.is_none() {
346                bail!("builtin function {builtin} requires a return type annotation")
347            }
348        }
349        let typ = {
350            let args = Arc::from_iter(argspec.iter().map(|a| {
351                let kind = match (a.labeled.as_ref(), a.pattern.single_bind()) {
352                    (Some(default), Some(name)) => FnArgKind::Labeled {
353                        name: name.clone(),
354                        has_default: default.is_some(),
355                    },
356                    (Some(_), None) => FnArgKind::Positional { name: None },
357                    (None, name) => FnArgKind::Positional { name: name.cloned() },
358                };
359                let typ = match a.constraint.as_ref() {
360                    Some(t) => t.clone(),
361                    None => Type::empty_tvar(),
362                };
363                FnArgType { kind, typ }
364            }));
365            let vargs = match vargs {
366                Some(Some(t)) => Some(t.clone()),
367                Some(None) => Some(Type::empty_tvar()),
368                None => None,
369            };
370            let rtype = rtype.clone().unwrap_or_else(|| Type::empty_tvar());
371            let explicit_throws = throws.is_some();
372            let throws = throws.clone().unwrap_or_else(|| Type::empty_tvar());
373            Arc::new(FnType {
374                constraints,
375                args,
376                vargs,
377                rtype,
378                throws,
379                explicit_throws,
380                lambda_ids: Arc::new(RwLock::new(IntSet::default())),
381            })
382        };
383        typ.alias_tvars(&mut LPooled::take());
384        if needs_callsite || ctx.env.lsp_mode {
385            typ.lambda_ids.write().insert(id);
386        }
387        let _typ = typ.clone();
388        let _argspec = argspec.clone();
389        let body = l.body.clone();
390        let init: InitFn<R, E> = SArc::new(move |scope, ctx, args, resolved, tid| {
391            // restore the lexical environment to the state it was in
392            // when the closure was created
393            ctx.with_restored(_env.clone(), |ctx| match body.clone() {
394                Either::Left(body) => {
395                    let scope = Scope {
396                        dynamic: scope.dynamic.clone(),
397                        lexical: _scope.lexical.clone(),
398                    };
399                    GXLambda::new(
400                        ctx,
401                        flags,
402                        _typ.clone(),
403                        _argspec.clone(),
404                        args,
405                        &scope,
406                        tid,
407                        body.clone(),
408                    )
409                    .map(|a| -> Box<dyn Apply<R, E>> { Box::new(a) })
410                }
411                Either::Right(builtin) => match ctx.builtins.get(&*builtin) {
412                    None => bail!("unknown builtin function {builtin}"),
413                    Some((init, _)) => init(ctx, &_typ, resolved, &_scope, args, tid)
414                        .map(|apply| {
415                            let f: Box<dyn Apply<R, E>> =
416                                Box::new(BuiltInLambda { typ: _typ.clone(), apply });
417                            f
418                        }),
419                },
420            })
421        });
422        let def = ctx.lambdawrap.wrap(LambdaDef {
423            id,
424            typ: typ.clone(),
425            env,
426            argspec,
427            init,
428            scope: original_scope,
429            needs_callsite,
430            check: Mutex::new(None),
431        });
432        ctx.lambda_defs.insert(id, def.clone());
433        Ok(Box::new(Self { spec, def, typ: Type::Fn(typ), top_id, flags }))
434    }
435}
436
437impl<R: Rt, E: UserEvent> Update<R, E> for Lambda {
438    fn update(
439        &mut self,
440        _ctx: &mut ExecCtx<R, E>,
441        event: &mut Event<E>,
442    ) -> Option<Value> {
443        event.init.then(|| self.def.clone())
444    }
445
446    fn spec(&self) -> &Expr {
447        &self.spec
448    }
449
450    fn refs(&self, _refs: &mut Refs) {}
451
452    fn delete(&mut self, _ctx: &mut ExecCtx<R, E>) {}
453
454    fn sleep(&mut self, _ctx: &mut ExecCtx<R, E>) {}
455
456    fn typ(&self) -> &Type {
457        &self.typ
458    }
459
460    fn typecheck(&mut self, ctx: &mut ExecCtx<R, E>) -> Result<()> {
461        let def = self
462            .def
463            .downcast_ref::<LambdaDef<R, E>>()
464            .ok_or_else(|| anyhow!("failed to unwrap lambda"))?;
465        let needs_callsite = def.needs_callsite;
466        let mut faux_args: LPooled<Vec<Node<R, E>>> = def
467            .argspec
468            .iter()
469            .zip(def.typ.args.iter())
470            .map(|(a, at)| match &a.labeled {
471                Some(Some(e)) => ctx.with_restored(def.env.clone(), |ctx| {
472                    compile(ctx, self.flags, e.clone(), &def.scope, self.top_id)
473                }),
474                Some(None) | None => {
475                    let n: Node<R, E> = Box::new(Nop { typ: at.typ.clone() });
476                    Ok(n)
477                }
478            })
479            .collect::<Result<_>>()?;
480        let faux_id = BindId::new();
481        ctx.env.by_id.insert_cow(
482            faux_id,
483            Bind {
484                doc: None,
485                export: false,
486                id: faux_id,
487                name: "faux".into(),
488                scope: def.scope.lexical.clone(),
489                typ: Type::empty_tvar(),
490                pos: SourcePosition::default(),
491                ori: Arc::new(Origin::default()),
492            },
493        );
494        let prev_catch = ctx.env.catch.insert_cow(def.scope.dynamic.clone(), faux_id);
495        let res = (def.init)(&def.scope, ctx, &mut faux_args, None, ExprId::new())
496            .with_context(|| ErrorContext(Update::<R, E>::spec(self).clone()));
497        let res = res.and_then(|mut f| {
498            let ftyp = f.typ().clone();
499            let res = f
500                .typecheck(ctx, &mut faux_args, TypecheckPhase::Lambda)
501                .with_context(|| ErrorContext(Update::<R, E>::spec(self).clone()));
502            if !needs_callsite {
503                f.delete(ctx)
504            } else {
505                let def = self
506                    .def
507                    .downcast_ref::<LambdaDef<R, E>>()
508                    .expect("failed to unwrap lambda");
509                *def.check.lock() = Some(f);
510            }
511            res?;
512            let inferred_throws = ctx.env.by_id[&faux_id]
513                .typ
514                .with_deref(|t| t.cloned())
515                .unwrap_or(Type::Bottom)
516                .scope_refs(&def.scope.lexical)
517                .normalize();
518            ftyp.throws
519                .check_contains(&ctx.env, &inferred_throws)
520                .with_context(|| ErrorContext(Update::<R, E>::spec(self).clone()))?;
521            ftyp.constrain_known();
522            Ok(())
523        });
524        ctx.env.by_id.remove_cow(&faux_id);
525        match prev_catch {
526            Some(id) => ctx.env.catch.insert_cow(def.scope.dynamic.clone(), id),
527            None => ctx.env.catch.remove_cow(&def.scope.dynamic),
528        };
529        self.typ.unbind_tvars();
530        res
531    }
532}