1use super::{compiler::compile, Nop};
2use crate::{
3 env::{Bind, Env},
4 expr::{self, Arg, ErrorContext, Expr, ExprId, Origin},
5 node::pattern::StructPatternNode,
6 typ::{FnArgKind, FnArgType, FnType, Type},
7 wrap, Apply, BindId, CFlag, Event, ExecCtx, InitFn, LambdaId, Node, Refs, Rt, Scope,
8 TypecheckPhase, Update, UserEvent,
9};
10use anyhow::{anyhow, bail, Context, Result};
11use arcstr::ArcStr;
12use combine::stream::position::SourcePosition;
13use compact_str::format_compact;
14use enumflags2::BitFlags;
15use netidx::{pack::Pack, subscriber::Value, utils::Either};
16use nohash::IntSet;
17use parking_lot::{Mutex, RwLock};
18use poolshark::local::LPooled;
19use std::{fmt, hash::Hash, sync::Arc as SArc};
20use triomphe::Arc;
21
22pub struct LambdaDef<R: Rt, E: UserEvent> {
23 pub id: LambdaId,
24 pub env: Env,
25 pub scope: Scope,
26 pub argspec: Arc<[Arg]>,
27 pub typ: Arc<FnType>,
28 pub init: InitFn<R, E>,
29 pub needs_callsite: bool,
30 pub check: Mutex<Option<Box<dyn Apply<R, E>>>>,
31}
32
33impl<R: Rt, E: UserEvent> fmt::Debug for LambdaDef<R, E> {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 write!(f, "lambda#{}", self.id.inner())
36 }
37}
38
39impl<R: Rt, E: UserEvent> PartialEq for LambdaDef<R, E> {
40 fn eq(&self, other: &Self) -> bool {
41 self.id == other.id
42 }
43}
44
45impl<R: Rt, E: UserEvent> Eq for LambdaDef<R, E> {}
46
47impl<R: Rt, E: UserEvent> PartialOrd for LambdaDef<R, E> {
48 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
49 Some(self.id.cmp(&other.id))
50 }
51}
52
53impl<R: Rt, E: UserEvent> Ord for LambdaDef<R, E> {
54 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
55 self.id.cmp(&other.id)
56 }
57}
58
59impl<R: Rt, E: UserEvent> Hash for LambdaDef<R, E> {
60 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
61 self.id.hash(state)
62 }
63}
64
65impl<R: Rt, E: UserEvent> Pack for LambdaDef<R, E> {
66 fn encoded_len(&self) -> usize {
67 0
68 }
69
70 fn encode(
71 &self,
72 _buf: &mut impl bytes::BufMut,
73 ) -> std::result::Result<(), netidx::pack::PackError> {
74 Err(netidx::pack::PackError::Application(0))
75 }
76
77 fn decode(
78 _buf: &mut impl bytes::Buf,
79 ) -> std::result::Result<Self, netidx::pack::PackError> {
80 Err(netidx::pack::PackError::Application(0))
81 }
82}
83
84#[derive(Debug)]
85struct GXLambda<R: Rt, E: UserEvent> {
86 args: Box<[StructPatternNode]>,
87 body: Node<R, E>,
88 typ: Arc<FnType>,
89}
90
91impl<R: Rt, E: UserEvent> Apply<R, E> for GXLambda<R, E> {
92 fn update(
93 &mut self,
94 ctx: &mut ExecCtx<R, E>,
95 from: &mut [Node<R, E>],
96 event: &mut Event<E>,
97 ) -> Option<Value> {
98 for (arg, pat) in from.iter_mut().zip(&self.args) {
99 if let Some(v) = arg.update(ctx, event) {
100 pat.bind(&v, &mut |id, v| {
101 ctx.cached.insert(id, v.clone());
102 event.variables.insert(id, v);
103 })
104 }
105 }
106 self.body.update(ctx, event)
107 }
108
109 fn typecheck(
110 &mut self,
111 ctx: &mut ExecCtx<R, E>,
112 args: &mut [Node<R, E>],
113 _phase: TypecheckPhase<'_>,
114 ) -> Result<()> {
115 for (arg, FnArgType { typ, .. }) in args.iter_mut().zip(self.typ.args.iter()) {
116 wrap!(arg, arg.typecheck(ctx))?;
117 wrap!(arg, typ.check_contains(&ctx.env, &arg.typ()))?;
118 }
119 wrap!(self.body, self.body.typecheck(ctx))?;
120 wrap!(self.body, self.typ.rtype.check_contains(&ctx.env, &self.body.typ()))?;
121 for (tv, tc) in self.typ.constraints.read().iter() {
122 tc.check_contains(&ctx.env, &Type::TVar(tv.clone()))?
123 }
124 Ok(())
125 }
126
127 fn typ(&self) -> Arc<FnType> {
128 Arc::clone(&self.typ)
129 }
130
131 fn refs(&self, refs: &mut Refs) {
132 for pat in &self.args {
133 pat.ids(&mut |id| {
134 refs.bound.insert(id);
135 })
136 }
137 self.body.refs(refs)
138 }
139
140 fn delete(&mut self, ctx: &mut ExecCtx<R, E>) {
141 self.body.delete(ctx);
142 for n in &self.args {
143 n.delete(ctx)
144 }
145 }
146
147 fn sleep(&mut self, ctx: &mut ExecCtx<R, E>) {
148 self.body.sleep(ctx);
149 }
150}
151
152impl<R: Rt, E: UserEvent> GXLambda<R, E> {
153 pub(super) fn new(
154 ctx: &mut ExecCtx<R, E>,
155 flags: BitFlags<CFlag>,
156 typ: Arc<FnType>,
157 argspec: Arc<[Arg]>,
158 args: &[Node<R, E>],
159 scope: &Scope,
160 tid: ExprId,
161 body: Expr,
162 ) -> Result<Self> {
163 if args.len() != argspec.len() {
164 bail!("arity mismatch, expected {} arguments", argspec.len())
165 }
166 let mut argpats = vec![];
167 for (a, atyp) in argspec.iter().zip(typ.args.iter()) {
168 let pattern = StructPatternNode::compile(
169 ctx,
170 &atyp.typ,
171 &a.pattern,
172 scope,
173 a.pos,
174 body.ori.clone(),
175 )?;
176 if pattern.is_refutable() {
177 bail!(
178 "refutable patterns are not allowed in lambda arguments {}",
179 a.pattern
180 )
181 }
182 argpats.push(pattern);
183 }
184 let body = compile(ctx, flags, body, &scope, tid)?;
185 Ok(Self { args: Box::from(argpats), typ, body })
186 }
187}
188
189#[derive(Debug)]
190struct BuiltInLambda<R: Rt, E: UserEvent> {
191 typ: Arc<FnType>,
192 apply: Box<dyn Apply<R, E> + Send + Sync + 'static>,
193}
194
195impl<R: Rt, E: UserEvent> Apply<R, E> for BuiltInLambda<R, E> {
196 fn update(
197 &mut self,
198 ctx: &mut ExecCtx<R, E>,
199 from: &mut [Node<R, E>],
200 event: &mut Event<E>,
201 ) -> Option<Value> {
202 self.apply.update(ctx, from, event)
203 }
204
205 fn typecheck(
206 &mut self,
207 ctx: &mut ExecCtx<R, E>,
208 args: &mut [Node<R, E>],
209 phase: TypecheckPhase<'_>,
210 ) -> Result<()> {
211 match &phase {
212 TypecheckPhase::CallSite(_) => (),
213 TypecheckPhase::Lambda => {
214 if args.len() < self.typ.args.len()
215 || (args.len() > self.typ.args.len() && self.typ.vargs.is_none())
216 {
217 let vargs = if self.typ.vargs.is_some() { "at least " } else { "" };
218 bail!(
219 "expected {}{} arguments got {}",
220 vargs,
221 self.typ.args.len(),
222 args.len()
223 )
224 }
225 for i in 0..args.len() {
226 wrap!(args[i], args[i].typecheck(ctx))?;
227 let atyp = if i < self.typ.args.len() {
228 &self.typ.args[i].typ
229 } else {
230 self.typ.vargs.as_ref().unwrap()
231 };
232 wrap!(args[i], atyp.check_contains(&ctx.env, &args[i].typ()))?
233 }
234 for (tv, tc) in self.typ.constraints.read().iter() {
235 tc.check_contains(&ctx.env, &Type::TVar(tv.clone()))?
236 }
237 }
238 }
239 self.apply.typecheck(ctx, args, phase)
240 }
241
242 fn typ(&self) -> Arc<FnType> {
243 Arc::clone(&self.typ)
244 }
245
246 fn refs(&self, refs: &mut Refs) {
247 self.apply.refs(refs)
248 }
249
250 fn delete(&mut self, ctx: &mut ExecCtx<R, E>) {
251 self.apply.delete(ctx)
252 }
253
254 fn sleep(&mut self, ctx: &mut ExecCtx<R, E>) {
255 self.apply.sleep(ctx);
256 }
257}
258
259#[derive(Debug)]
260pub(crate) struct Lambda {
261 top_id: ExprId,
262 spec: Expr,
263 def: Value,
264 flags: BitFlags<CFlag>,
265 typ: Type,
266}
267
268impl Lambda {
269 pub(crate) fn compile<R: Rt, E: UserEvent>(
270 ctx: &mut ExecCtx<R, E>,
271 flags: BitFlags<CFlag>,
272 spec: Expr,
273 scope: &Scope,
274 l: &expr::LambdaExpr,
275 top_id: ExprId,
276 ) -> Result<Node<R, E>> {
277 let mut s: LPooled<Vec<&ArcStr>> = LPooled::take();
278 for a in l.args.iter() {
279 a.pattern.with_names(&mut |n| s.push(n));
280 }
281 let len = s.len();
282 s.sort();
283 s.dedup();
284 if len != s.len() {
285 bail!("arguments must have unique names");
286 }
287 let id = LambdaId::new();
288 let vargs = match l.vargs.as_ref() {
289 None => None,
290 Some(None) => Some(None),
291 Some(Some(typ)) => Some(Some(typ.scope_refs(&scope.lexical))),
292 };
293 let rtype = l.rtype.as_ref().map(|t| t.scope_refs(&scope.lexical));
294 let throws = l.throws.as_ref().map(|t| t.scope_refs(&scope.lexical));
295 let mut argspec = l
296 .args
297 .iter()
298 .map(|a| match &a.constraint {
299 None => Arg {
300 labeled: a.labeled.clone(),
301 pattern: a.pattern.clone(),
302 constraint: None,
303 pos: a.pos,
304 },
305 Some(typ) => Arg {
306 labeled: a.labeled.clone(),
307 pattern: a.pattern.clone(),
308 constraint: Some(typ.scope_refs(&scope.lexical)),
309 pos: a.pos,
310 },
311 })
312 .collect::<LPooled<Vec<_>>>();
313 let argspec = Arc::from_iter(argspec.drain(..));
314 let constraints = l
315 .constraints
316 .iter()
317 .map(|(tv, tc)| {
318 let tv = tv.scope_refs(&scope.lexical);
319 let tc = tc.scope_refs(&scope.lexical);
320 Ok((tv, tc))
321 })
322 .collect::<Result<LPooled<Vec<_>>>>()?;
323 let constraints = Arc::new(RwLock::new(constraints));
324 let original_scope = scope.clone();
325 let _original_scope = scope.clone();
326 let scope = scope.append(&format_compact!("fn{}", id.0));
327 let _scope = scope.clone();
328 let env = ctx.env.clone();
329 let _env = ctx.env.clone();
330 let mut needs_callsite = false;
331 if let Either::Right(builtin) = &l.body {
332 if let Some((_, nc)) = ctx.builtins.get(builtin.as_str()) {
333 needs_callsite = *nc;
334 } else {
335 bail!("unknown builtin function {builtin}")
336 }
337 if !ctx.builtins_allowed {
338 bail!("defining builtins is not allowed in this context")
339 }
340 for a in argspec.iter() {
341 if a.constraint.is_none() {
342 bail!("builtin function {builtin} requires all arguments to have type annotations")
343 }
344 }
345 if rtype.is_none() {
346 bail!("builtin function {builtin} requires a return type annotation")
347 }
348 }
349 let typ = {
350 let args = Arc::from_iter(argspec.iter().map(|a| {
351 let kind = match (a.labeled.as_ref(), a.pattern.single_bind()) {
352 (Some(default), Some(name)) => FnArgKind::Labeled {
353 name: name.clone(),
354 has_default: default.is_some(),
355 },
356 (Some(_), None) => FnArgKind::Positional { name: None },
357 (None, name) => FnArgKind::Positional { name: name.cloned() },
358 };
359 let typ = match a.constraint.as_ref() {
360 Some(t) => t.clone(),
361 None => Type::empty_tvar(),
362 };
363 FnArgType { kind, typ }
364 }));
365 let vargs = match vargs {
366 Some(Some(t)) => Some(t.clone()),
367 Some(None) => Some(Type::empty_tvar()),
368 None => None,
369 };
370 let rtype = rtype.clone().unwrap_or_else(|| Type::empty_tvar());
371 let explicit_throws = throws.is_some();
372 let throws = throws.clone().unwrap_or_else(|| Type::empty_tvar());
373 Arc::new(FnType {
374 constraints,
375 args,
376 vargs,
377 rtype,
378 throws,
379 explicit_throws,
380 lambda_ids: Arc::new(RwLock::new(IntSet::default())),
381 })
382 };
383 typ.alias_tvars(&mut LPooled::take());
384 if needs_callsite || ctx.env.lsp_mode {
385 typ.lambda_ids.write().insert(id);
386 }
387 let _typ = typ.clone();
388 let _argspec = argspec.clone();
389 let body = l.body.clone();
390 let init: InitFn<R, E> = SArc::new(move |scope, ctx, args, resolved, tid| {
391 ctx.with_restored(_env.clone(), |ctx| match body.clone() {
394 Either::Left(body) => {
395 let scope = Scope {
396 dynamic: scope.dynamic.clone(),
397 lexical: _scope.lexical.clone(),
398 };
399 GXLambda::new(
400 ctx,
401 flags,
402 _typ.clone(),
403 _argspec.clone(),
404 args,
405 &scope,
406 tid,
407 body.clone(),
408 )
409 .map(|a| -> Box<dyn Apply<R, E>> { Box::new(a) })
410 }
411 Either::Right(builtin) => match ctx.builtins.get(&*builtin) {
412 None => bail!("unknown builtin function {builtin}"),
413 Some((init, _)) => init(ctx, &_typ, resolved, &_scope, args, tid)
414 .map(|apply| {
415 let f: Box<dyn Apply<R, E>> =
416 Box::new(BuiltInLambda { typ: _typ.clone(), apply });
417 f
418 }),
419 },
420 })
421 });
422 let def = ctx.lambdawrap.wrap(LambdaDef {
423 id,
424 typ: typ.clone(),
425 env,
426 argspec,
427 init,
428 scope: original_scope,
429 needs_callsite,
430 check: Mutex::new(None),
431 });
432 ctx.lambda_defs.insert(id, def.clone());
433 Ok(Box::new(Self { spec, def, typ: Type::Fn(typ), top_id, flags }))
434 }
435}
436
437impl<R: Rt, E: UserEvent> Update<R, E> for Lambda {
438 fn update(
439 &mut self,
440 _ctx: &mut ExecCtx<R, E>,
441 event: &mut Event<E>,
442 ) -> Option<Value> {
443 event.init.then(|| self.def.clone())
444 }
445
446 fn spec(&self) -> &Expr {
447 &self.spec
448 }
449
450 fn refs(&self, _refs: &mut Refs) {}
451
452 fn delete(&mut self, _ctx: &mut ExecCtx<R, E>) {}
453
454 fn sleep(&mut self, _ctx: &mut ExecCtx<R, E>) {}
455
456 fn typ(&self) -> &Type {
457 &self.typ
458 }
459
460 fn typecheck(&mut self, ctx: &mut ExecCtx<R, E>) -> Result<()> {
461 let def = self
462 .def
463 .downcast_ref::<LambdaDef<R, E>>()
464 .ok_or_else(|| anyhow!("failed to unwrap lambda"))?;
465 let needs_callsite = def.needs_callsite;
466 let mut faux_args: LPooled<Vec<Node<R, E>>> = def
467 .argspec
468 .iter()
469 .zip(def.typ.args.iter())
470 .map(|(a, at)| match &a.labeled {
471 Some(Some(e)) => ctx.with_restored(def.env.clone(), |ctx| {
472 compile(ctx, self.flags, e.clone(), &def.scope, self.top_id)
473 }),
474 Some(None) | None => {
475 let n: Node<R, E> = Box::new(Nop { typ: at.typ.clone() });
476 Ok(n)
477 }
478 })
479 .collect::<Result<_>>()?;
480 let faux_id = BindId::new();
481 ctx.env.by_id.insert_cow(
482 faux_id,
483 Bind {
484 doc: None,
485 export: false,
486 id: faux_id,
487 name: "faux".into(),
488 scope: def.scope.lexical.clone(),
489 typ: Type::empty_tvar(),
490 pos: SourcePosition::default(),
491 ori: Arc::new(Origin::default()),
492 },
493 );
494 let prev_catch = ctx.env.catch.insert_cow(def.scope.dynamic.clone(), faux_id);
495 let res = (def.init)(&def.scope, ctx, &mut faux_args, None, ExprId::new())
496 .with_context(|| ErrorContext(Update::<R, E>::spec(self).clone()));
497 let res = res.and_then(|mut f| {
498 let ftyp = f.typ().clone();
499 let res = f
500 .typecheck(ctx, &mut faux_args, TypecheckPhase::Lambda)
501 .with_context(|| ErrorContext(Update::<R, E>::spec(self).clone()));
502 if !needs_callsite {
503 f.delete(ctx)
504 } else {
505 let def = self
506 .def
507 .downcast_ref::<LambdaDef<R, E>>()
508 .expect("failed to unwrap lambda");
509 *def.check.lock() = Some(f);
510 }
511 res?;
512 let inferred_throws = ctx.env.by_id[&faux_id]
513 .typ
514 .with_deref(|t| t.cloned())
515 .unwrap_or(Type::Bottom)
516 .scope_refs(&def.scope.lexical)
517 .normalize();
518 ftyp.throws
519 .check_contains(&ctx.env, &inferred_throws)
520 .with_context(|| ErrorContext(Update::<R, E>::spec(self).clone()))?;
521 ftyp.constrain_known();
522 Ok(())
523 });
524 ctx.env.by_id.remove_cow(&faux_id);
525 match prev_catch {
526 Some(id) => ctx.env.catch.insert_cow(def.scope.dynamic.clone(), id),
527 None => ctx.env.catch.remove_cow(&def.scope.dynamic),
528 };
529 self.typ.unbind_tvars();
530 res
531 }
532}