1use crate::{
2 expr::{ExprId, ModPath, Pattern, StructurePattern},
3 node::{compiler, Cached},
4 typ::{NoRefs, Type},
5 BindId, Ctx, Event, ExecCtx, UserEvent,
6};
7use anyhow::{anyhow, bail, Result};
8use arcstr::ArcStr;
9use netidx::{publisher::Typ, subscriber::Value};
10use smallvec::SmallVec;
11use std::fmt::Debug;
12
13#[derive(Debug)]
14pub enum StructPatternNode {
15 Ignore,
16 Literal(Value),
17 Bind(BindId),
18 Slice {
19 tuple: bool,
20 all: Option<BindId>,
21 binds: Box<[StructPatternNode]>,
22 },
23 SlicePrefix {
24 all: Option<BindId>,
25 prefix: Box<[StructPatternNode]>,
26 tail: Option<BindId>,
27 },
28 SliceSuffix {
29 all: Option<BindId>,
30 head: Option<BindId>,
31 suffix: Box<[StructPatternNode]>,
32 },
33 Struct {
34 all: Option<BindId>,
35 binds: Box<[(ArcStr, usize, StructPatternNode)]>,
36 },
37 Variant {
38 tag: ArcStr,
39 all: Option<BindId>,
40 binds: Box<[StructPatternNode]>,
41 },
42}
43
44impl StructPatternNode {
45 pub fn compile<C: Ctx, E: UserEvent>(
46 ctx: &mut ExecCtx<C, E>,
47 type_predicate: &Type<NoRefs>,
48 spec: &StructurePattern,
49 scope: &ModPath,
50 ) -> Result<Self> {
51 if !spec.binds_uniq() {
52 bail!("bound variables must have unique names")
53 }
54 Self::compile_int(ctx, type_predicate, spec, scope)
55 }
56
57 fn compile_int<C: Ctx, E: UserEvent>(
58 ctx: &mut ExecCtx<C, E>,
59 type_predicate: &Type<NoRefs>,
60 spec: &StructurePattern,
61 scope: &ModPath,
62 ) -> Result<Self> {
63 macro_rules! with_pref_suf {
64 ($all:expr, $single:expr, $multi:expr) => {
65 match &type_predicate {
66 Type::Array(et) => {
67 let all = $all.as_ref().map(|n| {
68 ctx.env.bind_variable(scope, n, type_predicate.clone()).id
69 });
70 let single = $single.as_ref().map(|n| {
71 ctx.env.bind_variable(scope, n, type_predicate.clone()).id
72 });
73 let multi = $multi
74 .iter()
75 .map(|n| Self::compile_int(ctx, et, n, scope))
76 .collect::<Result<Box<[Self]>>>()?;
77 (all, single, multi)
78 }
79 t => bail!("slice patterns can't match {t}"),
80 }
81 };
82 }
83 let t = match &spec {
84 StructurePattern::Ignore => Self::Ignore,
85 StructurePattern::Literal(v) => {
86 type_predicate.check_contains(&Type::Primitive(Typ::get(v).into()))?;
87 Self::Literal(v.clone())
88 }
89 StructurePattern::Bind(name) => {
90 let id = ctx.env.bind_variable(scope, name, type_predicate.clone()).id;
91 Self::Bind(id)
92 }
93 StructurePattern::SlicePrefix { all, prefix, tail } => {
94 let (all, tail, prefix) = with_pref_suf!(all, tail, prefix);
95 Self::SlicePrefix { all, prefix, tail }
96 }
97 StructurePattern::SliceSuffix { all, head, suffix } => {
98 let (all, head, suffix) = with_pref_suf!(all, head, suffix);
99 Self::SliceSuffix { all, head, suffix }
100 }
101 StructurePattern::Slice { all, binds } => match &type_predicate {
102 Type::Array(et) => {
103 let all = all.as_ref().map(|n| {
104 ctx.env.bind_variable(scope, n, type_predicate.clone()).id
105 });
106 let binds = binds
107 .iter()
108 .map(|b| Self::compile_int(ctx, et, b, scope))
109 .collect::<Result<Box<[Self]>>>()?;
110 Self::Slice { tuple: false, all, binds }
111 }
112 t => bail!("slice patterns can't match {t}"),
113 },
114 StructurePattern::Tuple { all, binds } => match &type_predicate {
115 Type::Tuple(elts) => {
116 if binds.len() != elts.len() {
117 bail!("expected a tuple of length {}", elts.len())
118 }
119 let all = all.as_ref().map(|n| {
120 ctx.env.bind_variable(scope, n, type_predicate.clone()).id
121 });
122 let binds = elts
123 .iter()
124 .zip(binds.iter())
125 .map(|(t, b)| Self::compile_int(ctx, t, b, scope))
126 .collect::<Result<Box<[Self]>>>()?;
127 Self::Slice { tuple: true, all, binds }
128 }
129 t => bail!("tuple patterns can't match {t}"),
130 },
131 StructurePattern::Variant { all, tag, binds } => match &type_predicate {
132 Type::Variant(ttag, elts) => {
133 if ttag != tag {
134 bail!("pattern cannot match type, tag mismatch {ttag} vs {tag}")
135 }
136 if binds.len() != elts.len() {
137 bail!("expected a variant with {} args", elts.len())
138 }
139 let all = all.as_ref().map(|n| {
140 ctx.env.bind_variable(scope, n, type_predicate.clone()).id
141 });
142 let binds = elts
143 .iter()
144 .zip(binds.iter())
145 .map(|(t, b)| Self::compile_int(ctx, t, b, scope))
146 .collect::<Result<Box<[Self]>>>()?;
147 Self::Variant { tag: tag.clone(), all, binds }
148 }
149 t => bail!("variant patterns can't match {t}"),
150 },
151 StructurePattern::Struct { exhaustive, all, binds } => {
152 struct Ifo {
153 name: ArcStr,
154 index: usize,
155 pattern: StructurePattern,
156 typ: Type<NoRefs>,
157 }
158 match &type_predicate {
159 Type::Struct(elts) => {
160 let binds = binds
161 .iter()
162 .map(|(field, pat)| {
163 let r = elts.iter().enumerate().find_map(
164 |(i, (name, typ))| {
165 if field == name {
166 Some(Ifo {
167 name: name.clone(),
168 index: i,
169 pattern: pat.clone(),
170 typ: typ.clone(),
171 })
172 } else {
173 None
174 }
175 },
176 );
177 r.ok_or_else(|| anyhow!("no such struct field {field}"))
178 })
179 .collect::<Result<SmallVec<[Ifo; 8]>>>()?;
180 if *exhaustive && binds.len() < elts.len() {
181 bail!("missing bindings for struct fields")
182 }
183 let all = all.as_ref().map(|n| {
184 ctx.env.bind_variable(scope, n, type_predicate.clone()).id
185 });
186 let binds = binds
187 .into_iter()
188 .map(|ifo| {
189 Ok((
190 ifo.name,
191 ifo.index,
192 Self::compile_int(
193 ctx,
194 &ifo.typ,
195 &ifo.pattern,
196 scope,
197 )?,
198 ))
199 })
200 .collect::<Result<Box<[(ArcStr, usize, Self)]>>>()?;
201 Self::Struct { all, binds }
202 }
203 t => bail!("struct patterns can't match {t}"),
204 }
205 }
206 };
207 Ok(t)
208 }
209
210 pub fn ids<'a>(&'a self, f: &mut (dyn FnMut(BindId) + 'a)) {
211 match &self {
212 Self::Ignore | Self::Literal(_) => (),
213 Self::Bind(id) => f(*id),
214 Self::Slice { tuple: _, all, binds } => {
215 if let Some(id) = all {
216 f(*id);
217 }
218 for n in binds.iter() {
219 n.ids(f)
220 }
221 }
222 Self::Variant { tag: _, all, binds } => {
223 if let Some(id) = all {
224 f(*id)
225 }
226 for n in binds.iter() {
227 n.ids(f)
228 }
229 }
230 Self::SlicePrefix { all, prefix, tail } => {
231 if let Some(id) = all {
232 f(*id)
233 }
234 for n in prefix.iter() {
235 n.ids(f)
236 }
237 if let Some(id) = tail {
238 f(*id)
239 }
240 }
241 Self::SliceSuffix { all, head, suffix } => {
242 if let Some(id) = all {
243 f(*id)
244 }
245 if let Some(id) = head {
246 f(*id)
247 }
248 for n in suffix.iter() {
249 n.ids(f)
250 }
251 }
252 Self::Struct { all, binds } => {
253 if let Some(id) = all {
254 f(*id)
255 }
256 for (_, _, n) in binds.iter() {
257 n.ids(f)
258 }
259 }
260 }
261 }
262
263 pub fn bind<F: FnMut(BindId, Value)>(&self, v: &Value, f: &mut F) {
264 match &self {
265 Self::Ignore | Self::Literal(_) => (),
266 Self::Bind(id) => f(*id, v.clone()),
267 Self::Slice { tuple: _, all, binds } => match v {
268 Value::Array(a) if a.len() == binds.len() => {
269 if let Some(id) = all {
270 f(*id, v.clone());
271 }
272 for (j, n) in binds.iter().enumerate() {
273 n.bind(&a[j], f)
274 }
275 }
276 _ => (),
277 },
278 Self::Variant { tag: _, all, binds } => {
279 if let Some(id) = all {
280 f(*id, v.clone())
281 }
282 match v {
283 Value::Array(a) if a.len() == binds.len() + 1 => {
284 for (j, n) in binds.iter().enumerate() {
285 n.bind(&a[j + 1], f)
286 }
287 }
288 _ => (),
289 }
290 }
291 Self::SlicePrefix { all, prefix, tail } => match v {
292 Value::Array(a) if a.len() >= prefix.len() => {
293 if let Some(id) = all {
294 f(*id, v.clone())
295 }
296 for (j, n) in prefix.iter().enumerate() {
297 n.bind(&a[j], f)
298 }
299 if let Some(id) = tail {
300 let ss = a.subslice(prefix.len()..).unwrap();
301 f(*id, Value::Array(ss))
302 }
303 }
304 _ => (),
305 },
306 Self::SliceSuffix { all, head, suffix } => match v {
307 Value::Array(a) if a.len() >= suffix.len() => {
308 if let Some(id) = all {
309 f(*id, v.clone())
310 }
311 if let Some(id) = head {
312 let ss = a.subslice(..suffix.len()).unwrap();
313 f(*id, Value::Array(ss))
314 }
315 let tail = a.subslice(suffix.len()..).unwrap();
316 for (j, n) in suffix.iter().enumerate() {
317 n.bind(&tail[j], f)
318 }
319 }
320 _ => (),
321 },
322 Self::Struct { all, binds } => match v {
323 Value::Array(a) if a.len() >= binds.len() => {
324 if let Some(id) = all {
325 f(*id, v.clone())
326 }
327 for (_, i, n) in binds.iter() {
328 if let Some(v) = a.get(*i) {
329 match v {
330 Value::Array(a) if a.len() == 2 => n.bind(&a[1], f),
331 _ => (),
332 }
333 }
334 }
335 }
336 _ => (),
337 },
338 }
339 }
340
341 pub fn unbind<F: FnMut(BindId)>(&self, f: &mut F) {
342 match &self {
343 Self::Ignore | Self::Literal(_) => (),
344 Self::Bind(id) => f(*id),
345 Self::Slice { tuple: _, all, binds }
346 | Self::Variant { tag: _, all, binds } => {
347 if let Some(id) = all {
348 f(*id)
349 }
350 for n in binds.iter() {
351 n.unbind(f)
352 }
353 }
354 Self::SlicePrefix { all, prefix, tail } => {
355 if let Some(id) = all {
356 f(*id)
357 }
358 if let Some(id) = tail {
359 f(*id)
360 }
361 for n in prefix.iter() {
362 n.unbind(f)
363 }
364 }
365 Self::SliceSuffix { all, head, suffix } => {
366 if let Some(id) = all {
367 f(*id)
368 }
369 if let Some(id) = head {
370 f(*id)
371 }
372 for n in suffix.iter() {
373 n.unbind(f)
374 }
375 }
376 Self::Struct { all, binds } => {
377 if let Some(id) = all {
378 f(*id)
379 }
380 for (_, _, n) in binds.iter() {
381 n.unbind(f)
382 }
383 }
384 }
385 }
386
387 pub fn is_match(&self, v: &Value) -> bool {
388 match &self {
389 Self::Ignore | Self::Bind(_) => true,
390 Self::Literal(o) => v == o,
391 Self::Slice { tuple: _, all: _, binds } => match v {
392 Value::Array(a) => {
393 a.len() == binds.len()
394 && binds.iter().zip(a.iter()).all(|(b, v)| b.is_match(v))
395 }
396 _ => false,
397 },
398 Self::Variant { tag, all: _, binds } if binds.len() == 0 => match v {
399 Value::String(s) => tag == s,
400 _ => false,
401 },
402 Self::Variant { tag, all: _, binds } => match v {
403 Value::Array(a) => {
404 a.len() == binds.len() + 1
405 && match &a[0] {
406 Value::String(s) => s == tag,
407 _ => false,
408 }
409 && binds.iter().zip(a[1..].iter()).all(|(b, v)| b.is_match(v))
410 }
411 _ => false,
412 },
413 Self::SlicePrefix { all: _, prefix, tail: _ } => match v {
414 Value::Array(a) => {
415 a.len() >= prefix.len()
416 && prefix.iter().zip(a.iter()).all(|(b, v)| b.is_match(v))
417 }
418 _ => false,
419 },
420 Self::SliceSuffix { all: _, head: _, suffix } => match v {
421 Value::Array(a) => {
422 a.len() >= suffix.len()
423 && suffix
424 .iter()
425 .zip(a.iter().skip(a.len() - suffix.len()))
426 .all(|(b, v)| b.is_match(v))
427 }
428 _ => false,
429 },
430 Self::Struct { all: _, binds } => match v {
431 Value::Array(a) => {
432 a.len() >= binds.len()
433 && binds.iter().all(|(_, i, p)| match a.get(*i) {
434 Some(Value::Array(a)) if a.len() == 2 => p.is_match(&a[1]),
435 _ => false,
436 })
437 }
438 _ => false,
439 },
440 }
441 }
442
443 pub fn is_refutable(&self) -> bool {
444 match &self {
445 Self::Bind(_) | Self::Ignore => false,
446 Self::Literal(_) => true,
447 Self::Slice { tuple: true, all: _, binds } => {
448 binds.iter().any(|p| p.is_refutable())
449 }
450 Self::Struct { all: _, binds } => {
451 binds.iter().any(|(_, _, p)| p.is_refutable())
452 }
453 Self::Variant { .. }
454 | Self::Slice { tuple: false, .. }
455 | Self::SlicePrefix { .. }
456 | Self::SliceSuffix { .. } => true,
457 }
458 }
459}
460
461#[derive(Debug)]
462pub struct PatternNode<C: Ctx, E: UserEvent> {
463 pub type_predicate: Type<NoRefs>,
464 pub structure_predicate: StructPatternNode,
465 pub guard: Option<Cached<C, E>>,
466}
467
468impl<C: Ctx, E: UserEvent> PatternNode<C, E> {
469 pub(super) fn compile(
470 ctx: &mut ExecCtx<C, E>,
471 spec: &Pattern,
472 scope: &ModPath,
473 top_id: ExprId,
474 ) -> Result<Self> {
475 let type_predicate = match &spec.type_predicate {
476 Some(t) => t.resolve_typerefs(scope, &ctx.env)?,
477 None => spec.structure_predicate.infer_type_predicate(),
478 };
479 match &type_predicate {
480 Type::Fn(_) => bail!("can't match on Fn type"),
481 Type::Bottom(_)
482 | Type::Primitive(_)
483 | Type::Set(_)
484 | Type::TVar(_)
485 | Type::Array(_)
486 | Type::Tuple(_)
487 | Type::Variant(_, _)
488 | Type::Struct(_) => (),
489 Type::Ref(_) => unreachable!(),
490 }
491 let structure_predicate = StructPatternNode::compile(
492 ctx,
493 &type_predicate,
494 &spec.structure_predicate,
495 scope,
496 )?;
497 let guard = spec
498 .guard
499 .as_ref()
500 .map(|g| compiler::compile(ctx, g.clone(), &scope, top_id))
501 .transpose()?
502 .map(Cached::new);
503 Ok(PatternNode { type_predicate, structure_predicate, guard })
504 }
505
506 pub(super) fn bind_event(&self, event: &mut Event<E>, v: &Value) {
507 self.structure_predicate.bind(v, &mut |id, v| {
508 event.variables.insert(id, v);
509 })
510 }
511
512 pub(super) fn unbind_event(&self, event: &mut Event<E>) {
513 self.structure_predicate.unbind(&mut |id| {
514 event.variables.remove(&id);
515 })
516 }
517
518 pub(super) fn update(
519 &mut self,
520 ctx: &mut ExecCtx<C, E>,
521 event: &mut Event<E>,
522 ) -> bool {
523 match &mut self.guard {
524 None => false,
525 Some(g) => g.update(ctx, event),
526 }
527 }
528
529 pub(super) fn is_match(&self, typ: Typ, v: &Value) -> bool {
530 let tmatch = match (&self.type_predicate, typ) {
531 (Type::Array(_), Typ::Array)
532 | (Type::Tuple(_), Typ::Array)
533 | (Type::Struct(_), Typ::Array)
534 | (Type::Variant(_, _), Typ::Array | Typ::String) => true,
535 _ => self.type_predicate.contains(&Type::Primitive(typ.into())),
536 };
537 tmatch
538 && self.structure_predicate.is_match(v)
539 && match &self.guard {
540 None => true,
541 Some(g) => g
542 .cached
543 .as_ref()
544 .and_then(|v| v.clone().get_as::<bool>())
545 .unwrap_or(false),
546 }
547 }
548}