1use super::{compiler::compile, Nop};
2use crate::{
3 env::{Bind, Env},
4 expr::{self, Arg, ErrorContext, Expr, ExprId},
5 node::pattern::StructPatternNode,
6 typ::{FnArgType, FnType, Type},
7 wrap, Apply, BindId, CFlag, Event, ExecCtx, InitFn, LambdaId, Node, Refs, Rt, Scope,
8 TypecheckPhase, Update, UserEvent,
9};
10use anyhow::{anyhow, bail, Context, Result};
11use arcstr::ArcStr;
12use compact_str::format_compact;
13use enumflags2::BitFlags;
14use fxhash::FxHashSet;
15use netidx::{pack::Pack, subscriber::Value, utils::Either};
16use parking_lot::{Mutex, RwLock};
17use poolshark::local::LPooled;
18use std::{fmt, hash::Hash, sync::Arc as SArc};
19use triomphe::Arc;
20
21pub struct LambdaDef<R: Rt, E: UserEvent> {
22 pub id: LambdaId,
23 pub env: Env,
24 pub scope: Scope,
25 pub argspec: Arc<[Arg]>,
26 pub typ: Arc<FnType>,
27 pub init: InitFn<R, E>,
28 pub needs_callsite: bool,
29 pub check: Mutex<Option<Box<dyn Apply<R, E>>>>,
30}
31
32impl<R: Rt, E: UserEvent> fmt::Debug for LambdaDef<R, E> {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 write!(f, "lambda#{}", self.id.inner())
35 }
36}
37
38impl<R: Rt, E: UserEvent> PartialEq for LambdaDef<R, E> {
39 fn eq(&self, other: &Self) -> bool {
40 self.id == other.id
41 }
42}
43
44impl<R: Rt, E: UserEvent> Eq for LambdaDef<R, E> {}
45
46impl<R: Rt, E: UserEvent> PartialOrd for LambdaDef<R, E> {
47 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
48 Some(self.id.cmp(&other.id))
49 }
50}
51
52impl<R: Rt, E: UserEvent> Ord for LambdaDef<R, E> {
53 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
54 self.id.cmp(&other.id)
55 }
56}
57
58impl<R: Rt, E: UserEvent> Hash for LambdaDef<R, E> {
59 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
60 self.id.hash(state)
61 }
62}
63
64impl<R: Rt, E: UserEvent> Pack for LambdaDef<R, E> {
65 fn encoded_len(&self) -> usize {
66 0
67 }
68
69 fn encode(
70 &self,
71 _buf: &mut impl bytes::BufMut,
72 ) -> std::result::Result<(), netidx::pack::PackError> {
73 Err(netidx::pack::PackError::Application(0))
74 }
75
76 fn decode(
77 _buf: &mut impl bytes::Buf,
78 ) -> std::result::Result<Self, netidx::pack::PackError> {
79 Err(netidx::pack::PackError::Application(0))
80 }
81}
82
83#[derive(Debug)]
84struct GXLambda<R: Rt, E: UserEvent> {
85 args: Box<[StructPatternNode]>,
86 body: Node<R, E>,
87 typ: Arc<FnType>,
88}
89
90impl<R: Rt, E: UserEvent> Apply<R, E> for GXLambda<R, E> {
91 fn update(
92 &mut self,
93 ctx: &mut ExecCtx<R, E>,
94 from: &mut [Node<R, E>],
95 event: &mut Event<E>,
96 ) -> Option<Value> {
97 for (arg, pat) in from.iter_mut().zip(&self.args) {
98 if let Some(v) = arg.update(ctx, event) {
99 pat.bind(&v, &mut |id, v| {
100 ctx.cached.insert(id, v.clone());
101 event.variables.insert(id, v);
102 })
103 }
104 }
105 self.body.update(ctx, event)
106 }
107
108 fn typecheck(
109 &mut self,
110 ctx: &mut ExecCtx<R, E>,
111 args: &mut [Node<R, E>],
112 _phase: TypecheckPhase<'_>,
113 ) -> Result<()> {
114 for (arg, FnArgType { typ, .. }) in args.iter_mut().zip(self.typ.args.iter()) {
115 wrap!(arg, arg.typecheck(ctx))?;
116 wrap!(arg, typ.check_contains(&ctx.env, &arg.typ()))?;
117 }
118 wrap!(self.body, self.body.typecheck(ctx))?;
119 wrap!(self.body, self.typ.rtype.check_contains(&ctx.env, &self.body.typ()))?;
120 for (tv, tc) in self.typ.constraints.read().iter() {
121 tc.check_contains(&ctx.env, &Type::TVar(tv.clone()))?
122 }
123 Ok(())
124 }
125
126 fn typ(&self) -> Arc<FnType> {
127 Arc::clone(&self.typ)
128 }
129
130 fn refs(&self, refs: &mut Refs) {
131 for pat in &self.args {
132 pat.ids(&mut |id| {
133 refs.bound.insert(id);
134 })
135 }
136 self.body.refs(refs)
137 }
138
139 fn delete(&mut self, ctx: &mut ExecCtx<R, E>) {
140 self.body.delete(ctx);
141 for n in &self.args {
142 n.delete(ctx)
143 }
144 }
145
146 fn sleep(&mut self, ctx: &mut ExecCtx<R, E>) {
147 self.body.sleep(ctx);
148 }
149}
150
151impl<R: Rt, E: UserEvent> GXLambda<R, E> {
152 pub(super) fn new(
153 ctx: &mut ExecCtx<R, E>,
154 flags: BitFlags<CFlag>,
155 typ: Arc<FnType>,
156 argspec: Arc<[Arg]>,
157 args: &[Node<R, E>],
158 scope: &Scope,
159 tid: ExprId,
160 body: Expr,
161 ) -> Result<Self> {
162 if args.len() != argspec.len() {
163 bail!("arity mismatch, expected {} arguments", argspec.len())
164 }
165 let mut argpats = vec![];
166 for (a, atyp) in argspec.iter().zip(typ.args.iter()) {
167 let pattern = StructPatternNode::compile(ctx, &atyp.typ, &a.pattern, scope)?;
168 if pattern.is_refutable() {
169 bail!(
170 "refutable patterns are not allowed in lambda arguments {}",
171 a.pattern
172 )
173 }
174 argpats.push(pattern);
175 }
176 let body = compile(ctx, flags, body, &scope, tid)?;
177 Ok(Self { args: Box::from(argpats), typ, body })
178 }
179}
180
181#[derive(Debug)]
182struct BuiltInLambda<R: Rt, E: UserEvent> {
183 typ: Arc<FnType>,
184 apply: Box<dyn Apply<R, E> + Send + Sync + 'static>,
185}
186
187impl<R: Rt, E: UserEvent> Apply<R, E> for BuiltInLambda<R, E> {
188 fn update(
189 &mut self,
190 ctx: &mut ExecCtx<R, E>,
191 from: &mut [Node<R, E>],
192 event: &mut Event<E>,
193 ) -> Option<Value> {
194 self.apply.update(ctx, from, event)
195 }
196
197 fn typecheck(
198 &mut self,
199 ctx: &mut ExecCtx<R, E>,
200 args: &mut [Node<R, E>],
201 phase: TypecheckPhase<'_>,
202 ) -> Result<()> {
203 match &phase {
204 TypecheckPhase::CallSite(_) => (),
205 TypecheckPhase::Lambda => {
206 if args.len() < self.typ.args.len()
207 || (args.len() > self.typ.args.len() && self.typ.vargs.is_none())
208 {
209 let vargs = if self.typ.vargs.is_some() { "at least " } else { "" };
210 bail!(
211 "expected {}{} arguments got {}",
212 vargs,
213 self.typ.args.len(),
214 args.len()
215 )
216 }
217 for i in 0..args.len() {
218 wrap!(args[i], args[i].typecheck(ctx))?;
219 let atyp = if i < self.typ.args.len() {
220 &self.typ.args[i].typ
221 } else {
222 self.typ.vargs.as_ref().unwrap()
223 };
224 wrap!(args[i], atyp.check_contains(&ctx.env, &args[i].typ()))?
225 }
226 for (tv, tc) in self.typ.constraints.read().iter() {
227 tc.check_contains(&ctx.env, &Type::TVar(tv.clone()))?
228 }
229 }
230 }
231 self.apply.typecheck(ctx, args, phase)
232 }
233
234 fn typ(&self) -> Arc<FnType> {
235 Arc::clone(&self.typ)
236 }
237
238 fn refs(&self, refs: &mut Refs) {
239 self.apply.refs(refs)
240 }
241
242 fn delete(&mut self, ctx: &mut ExecCtx<R, E>) {
243 self.apply.delete(ctx)
244 }
245
246 fn sleep(&mut self, ctx: &mut ExecCtx<R, E>) {
247 self.apply.sleep(ctx);
248 }
249}
250
251#[derive(Debug)]
252pub(crate) struct Lambda {
253 top_id: ExprId,
254 spec: Expr,
255 def: Value,
256 flags: BitFlags<CFlag>,
257 typ: Type,
258}
259
260impl Lambda {
261 pub(crate) fn compile<R: Rt, E: UserEvent>(
262 ctx: &mut ExecCtx<R, E>,
263 flags: BitFlags<CFlag>,
264 spec: Expr,
265 scope: &Scope,
266 l: &expr::LambdaExpr,
267 top_id: ExprId,
268 ) -> Result<Node<R, E>> {
269 let mut s: LPooled<Vec<&ArcStr>> = LPooled::take();
270 for a in l.args.iter() {
271 a.pattern.with_names(&mut |n| s.push(n));
272 }
273 let len = s.len();
274 s.sort();
275 s.dedup();
276 if len != s.len() {
277 bail!("arguments must have unique names");
278 }
279 let id = LambdaId::new();
280 let vargs = match l.vargs.as_ref() {
281 None => None,
282 Some(None) => Some(None),
283 Some(Some(typ)) => Some(Some(typ.scope_refs(&scope.lexical))),
284 };
285 let rtype = l.rtype.as_ref().map(|t| t.scope_refs(&scope.lexical));
286 let throws = l.throws.as_ref().map(|t| t.scope_refs(&scope.lexical));
287 let mut argspec = l
288 .args
289 .iter()
290 .map(|a| match &a.constraint {
291 None => Arg {
292 labeled: a.labeled.clone(),
293 pattern: a.pattern.clone(),
294 constraint: None,
295 },
296 Some(typ) => Arg {
297 labeled: a.labeled.clone(),
298 pattern: a.pattern.clone(),
299 constraint: Some(typ.scope_refs(&scope.lexical)),
300 },
301 })
302 .collect::<LPooled<Vec<_>>>();
303 let argspec = Arc::from_iter(argspec.drain(..));
304 let constraints = l
305 .constraints
306 .iter()
307 .map(|(tv, tc)| {
308 let tv = tv.scope_refs(&scope.lexical);
309 let tc = tc.scope_refs(&scope.lexical);
310 Ok((tv, tc))
311 })
312 .collect::<Result<LPooled<Vec<_>>>>()?;
313 let constraints = Arc::new(RwLock::new(constraints));
314 let original_scope = scope.clone();
315 let _original_scope = scope.clone();
316 let scope = scope.append(&format_compact!("fn{}", id.0));
317 let _scope = scope.clone();
318 let env = ctx.env.clone();
319 let _env = ctx.env.clone();
320 let mut needs_callsite = false;
321 if let Either::Right(builtin) = &l.body {
322 if let Some((_, nc)) = ctx.builtins.get(builtin.as_str()) {
323 needs_callsite = *nc;
324 } else {
325 bail!("unknown builtin function {builtin}")
326 }
327 if !ctx.builtins_allowed {
328 bail!("defining builtins is not allowed in this context")
329 }
330 for a in argspec.iter() {
331 if a.constraint.is_none() {
332 bail!("builtin function {builtin} requires all arguments to have type annotations")
333 }
334 }
335 if rtype.is_none() {
336 bail!("builtin function {builtin} requires a return type annotation")
337 }
338 }
339 let typ = {
340 let args = Arc::from_iter(argspec.iter().map(|a| FnArgType {
341 label: a.labeled.as_ref().and_then(|dv| {
342 a.pattern.single_bind().map(|n| (n.clone(), dv.is_some()))
343 }),
344 typ: match a.constraint.as_ref() {
345 Some(t) => t.clone(),
346 None => Type::empty_tvar(),
347 },
348 }));
349 let vargs = match vargs {
350 Some(Some(t)) => Some(t.clone()),
351 Some(None) => Some(Type::empty_tvar()),
352 None => None,
353 };
354 let rtype = rtype.clone().unwrap_or_else(|| Type::empty_tvar());
355 let explicit_throws = throws.is_some();
356 let throws = throws.clone().unwrap_or_else(|| Type::empty_tvar());
357 Arc::new(FnType {
358 constraints,
359 args,
360 vargs,
361 rtype,
362 throws,
363 explicit_throws,
364 lambda_ids: Arc::new(RwLock::new(FxHashSet::default())),
365 })
366 };
367 typ.alias_tvars(&mut LPooled::take());
368 if needs_callsite {
369 typ.lambda_ids.write().insert(id);
370 }
371 let _typ = typ.clone();
372 let _argspec = argspec.clone();
373 let body = l.body.clone();
374 let init: InitFn<R, E> = SArc::new(move |scope, ctx, args, resolved, tid| {
375 ctx.with_restored(_env.clone(), |ctx| match body.clone() {
378 Either::Left(body) => {
379 let scope = Scope {
380 dynamic: scope.dynamic.clone(),
381 lexical: _scope.lexical.clone(),
382 };
383 GXLambda::new(
384 ctx,
385 flags,
386 _typ.clone(),
387 _argspec.clone(),
388 args,
389 &scope,
390 tid,
391 body.clone(),
392 )
393 .map(|a| -> Box<dyn Apply<R, E>> { Box::new(a) })
394 }
395 Either::Right(builtin) => match ctx.builtins.get(&*builtin) {
396 None => bail!("unknown builtin function {builtin}"),
397 Some((init, _)) => init(ctx, &_typ, resolved, &_scope, args, tid)
398 .map(|apply| {
399 let f: Box<dyn Apply<R, E>> =
400 Box::new(BuiltInLambda { typ: _typ.clone(), apply });
401 f
402 }),
403 },
404 })
405 });
406 let def = ctx.lambdawrap.wrap(LambdaDef {
407 id,
408 typ: typ.clone(),
409 env,
410 argspec,
411 init,
412 scope: original_scope,
413 needs_callsite,
414 check: Mutex::new(None),
415 });
416 ctx.lambda_defs.insert(id, def.clone());
417 Ok(Box::new(Self { spec, def, typ: Type::Fn(typ), top_id, flags }))
418 }
419}
420
421impl<R: Rt, E: UserEvent> Update<R, E> for Lambda {
422 fn update(
423 &mut self,
424 _ctx: &mut ExecCtx<R, E>,
425 event: &mut Event<E>,
426 ) -> Option<Value> {
427 event.init.then(|| self.def.clone())
428 }
429
430 fn spec(&self) -> &Expr {
431 &self.spec
432 }
433
434 fn refs(&self, _refs: &mut Refs) {}
435
436 fn delete(&mut self, _ctx: &mut ExecCtx<R, E>) {}
437
438 fn sleep(&mut self, _ctx: &mut ExecCtx<R, E>) {}
439
440 fn typ(&self) -> &Type {
441 &self.typ
442 }
443
444 fn typecheck(&mut self, ctx: &mut ExecCtx<R, E>) -> Result<()> {
445 let def = self
446 .def
447 .downcast_ref::<LambdaDef<R, E>>()
448 .ok_or_else(|| anyhow!("failed to unwrap lambda"))?;
449 let needs_callsite = def.needs_callsite;
450 let mut faux_args: LPooled<Vec<Node<R, E>>> = def
451 .argspec
452 .iter()
453 .zip(def.typ.args.iter())
454 .map(|(a, at)| match &a.labeled {
455 Some(Some(e)) => ctx.with_restored(def.env.clone(), |ctx| {
456 compile(ctx, self.flags, e.clone(), &def.scope, self.top_id)
457 }),
458 Some(None) | None => {
459 let n: Node<R, E> = Box::new(Nop { typ: at.typ.clone() });
460 Ok(n)
461 }
462 })
463 .collect::<Result<_>>()?;
464 let faux_id = BindId::new();
465 ctx.env.by_id.insert_cow(
466 faux_id,
467 Bind {
468 doc: None,
469 export: false,
470 id: faux_id,
471 name: "faux".into(),
472 scope: def.scope.lexical.clone(),
473 typ: Type::empty_tvar(),
474 },
475 );
476 let prev_catch = ctx.env.catch.insert_cow(def.scope.dynamic.clone(), faux_id);
477 let res = (def.init)(&def.scope, ctx, &mut faux_args, None, ExprId::new())
478 .with_context(|| ErrorContext(Update::<R, E>::spec(self).clone()));
479 let res = res.and_then(|mut f| {
480 let ftyp = f.typ().clone();
481 let res = f
482 .typecheck(ctx, &mut faux_args, TypecheckPhase::Lambda)
483 .with_context(|| ErrorContext(Update::<R, E>::spec(self).clone()));
484 if !needs_callsite {
485 f.delete(ctx)
486 } else {
487 let def = self
488 .def
489 .downcast_ref::<LambdaDef<R, E>>()
490 .expect("failed to unwrap lambda");
491 *def.check.lock() = Some(f);
492 }
493 res?;
494 let inferred_throws = ctx.env.by_id[&faux_id]
495 .typ
496 .with_deref(|t| t.cloned())
497 .unwrap_or(Type::Bottom)
498 .scope_refs(&def.scope.lexical)
499 .normalize();
500 ftyp.throws
501 .check_contains(&ctx.env, &inferred_throws)
502 .with_context(|| ErrorContext(Update::<R, E>::spec(self).clone()))?;
503 ftyp.constrain_known();
504 Ok(())
505 });
506 ctx.env.by_id.remove_cow(&faux_id);
507 match prev_catch {
508 Some(id) => ctx.env.catch.insert_cow(def.scope.dynamic.clone(), id),
509 None => ctx.env.catch.remove_cow(&def.scope.dynamic),
510 };
511 self.typ.unbind_tvars();
512 res
513 }
514}