netidx_bscript/node/
pattern.rs

1use crate::{
2    expr::{ExprId, ModPath, Pattern, StructurePattern},
3    node::{compiler, Cached},
4    typ::{NoRefs, Type},
5    BindId, Ctx, Event, ExecCtx, UserEvent,
6};
7use anyhow::{anyhow, bail, Result};
8use arcstr::ArcStr;
9use netidx::{publisher::Typ, subscriber::Value};
10use smallvec::SmallVec;
11use std::fmt::Debug;
12
13#[derive(Debug)]
14pub enum StructPatternNode {
15    Ignore,
16    Literal(Value),
17    Bind(BindId),
18    Slice {
19        tuple: bool,
20        all: Option<BindId>,
21        binds: Box<[StructPatternNode]>,
22    },
23    SlicePrefix {
24        all: Option<BindId>,
25        prefix: Box<[StructPatternNode]>,
26        tail: Option<BindId>,
27    },
28    SliceSuffix {
29        all: Option<BindId>,
30        head: Option<BindId>,
31        suffix: Box<[StructPatternNode]>,
32    },
33    Struct {
34        all: Option<BindId>,
35        binds: Box<[(ArcStr, usize, StructPatternNode)]>,
36    },
37    Variant {
38        tag: ArcStr,
39        all: Option<BindId>,
40        binds: Box<[StructPatternNode]>,
41    },
42}
43
44impl StructPatternNode {
45    pub fn compile<C: Ctx, E: UserEvent>(
46        ctx: &mut ExecCtx<C, E>,
47        type_predicate: &Type<NoRefs>,
48        spec: &StructurePattern,
49        scope: &ModPath,
50    ) -> Result<Self> {
51        if !spec.binds_uniq() {
52            bail!("bound variables must have unique names")
53        }
54        Self::compile_int(ctx, type_predicate, spec, scope)
55    }
56
57    fn compile_int<C: Ctx, E: UserEvent>(
58        ctx: &mut ExecCtx<C, E>,
59        type_predicate: &Type<NoRefs>,
60        spec: &StructurePattern,
61        scope: &ModPath,
62    ) -> Result<Self> {
63        macro_rules! with_pref_suf {
64            ($all:expr, $single:expr, $multi:expr) => {
65                match &type_predicate {
66                    Type::Array(et) => {
67                        let all = $all.as_ref().map(|n| {
68                            ctx.env.bind_variable(scope, n, type_predicate.clone()).id
69                        });
70                        let single = $single.as_ref().map(|n| {
71                            ctx.env.bind_variable(scope, n, type_predicate.clone()).id
72                        });
73                        let multi = $multi
74                            .iter()
75                            .map(|n| Self::compile_int(ctx, et, n, scope))
76                            .collect::<Result<Box<[Self]>>>()?;
77                        (all, single, multi)
78                    }
79                    t => bail!("slice patterns can't match {t}"),
80                }
81            };
82        }
83        let t = match &spec {
84            StructurePattern::Ignore => Self::Ignore,
85            StructurePattern::Literal(v) => {
86                type_predicate.check_contains(&Type::Primitive(Typ::get(v).into()))?;
87                Self::Literal(v.clone())
88            }
89            StructurePattern::Bind(name) => {
90                let id = ctx.env.bind_variable(scope, name, type_predicate.clone()).id;
91                Self::Bind(id)
92            }
93            StructurePattern::SlicePrefix { all, prefix, tail } => {
94                let (all, tail, prefix) = with_pref_suf!(all, tail, prefix);
95                Self::SlicePrefix { all, prefix, tail }
96            }
97            StructurePattern::SliceSuffix { all, head, suffix } => {
98                let (all, head, suffix) = with_pref_suf!(all, head, suffix);
99                Self::SliceSuffix { all, head, suffix }
100            }
101            StructurePattern::Slice { all, binds } => match &type_predicate {
102                Type::Array(et) => {
103                    let all = all.as_ref().map(|n| {
104                        ctx.env.bind_variable(scope, n, type_predicate.clone()).id
105                    });
106                    let binds = binds
107                        .iter()
108                        .map(|b| Self::compile_int(ctx, et, b, scope))
109                        .collect::<Result<Box<[Self]>>>()?;
110                    Self::Slice { tuple: false, all, binds }
111                }
112                t => bail!("slice patterns can't match {t}"),
113            },
114            StructurePattern::Tuple { all, binds } => match &type_predicate {
115                Type::Tuple(elts) => {
116                    if binds.len() != elts.len() {
117                        bail!("expected a tuple of length {}", elts.len())
118                    }
119                    let all = all.as_ref().map(|n| {
120                        ctx.env.bind_variable(scope, n, type_predicate.clone()).id
121                    });
122                    let binds = elts
123                        .iter()
124                        .zip(binds.iter())
125                        .map(|(t, b)| Self::compile_int(ctx, t, b, scope))
126                        .collect::<Result<Box<[Self]>>>()?;
127                    Self::Slice { tuple: true, all, binds }
128                }
129                t => bail!("tuple patterns can't match {t}"),
130            },
131            StructurePattern::Variant { all, tag, binds } => match &type_predicate {
132                Type::Variant(ttag, elts) => {
133                    if ttag != tag {
134                        bail!("pattern cannot match type, tag mismatch {ttag} vs {tag}")
135                    }
136                    if binds.len() != elts.len() {
137                        bail!("expected a variant with {} args", elts.len())
138                    }
139                    let all = all.as_ref().map(|n| {
140                        ctx.env.bind_variable(scope, n, type_predicate.clone()).id
141                    });
142                    let binds = elts
143                        .iter()
144                        .zip(binds.iter())
145                        .map(|(t, b)| Self::compile_int(ctx, t, b, scope))
146                        .collect::<Result<Box<[Self]>>>()?;
147                    Self::Variant { tag: tag.clone(), all, binds }
148                }
149                t => bail!("variant patterns can't match {t}"),
150            },
151            StructurePattern::Struct { exhaustive, all, binds } => {
152                struct Ifo {
153                    name: ArcStr,
154                    index: usize,
155                    pattern: StructurePattern,
156                    typ: Type<NoRefs>,
157                }
158                match &type_predicate {
159                    Type::Struct(elts) => {
160                        let binds = binds
161                            .iter()
162                            .map(|(field, pat)| {
163                                let r = elts.iter().enumerate().find_map(
164                                    |(i, (name, typ))| {
165                                        if field == name {
166                                            Some(Ifo {
167                                                name: name.clone(),
168                                                index: i,
169                                                pattern: pat.clone(),
170                                                typ: typ.clone(),
171                                            })
172                                        } else {
173                                            None
174                                        }
175                                    },
176                                );
177                                r.ok_or_else(|| anyhow!("no such struct field {field}"))
178                            })
179                            .collect::<Result<SmallVec<[Ifo; 8]>>>()?;
180                        if *exhaustive && binds.len() < elts.len() {
181                            bail!("missing bindings for struct fields")
182                        }
183                        let all = all.as_ref().map(|n| {
184                            ctx.env.bind_variable(scope, n, type_predicate.clone()).id
185                        });
186                        let binds = binds
187                            .into_iter()
188                            .map(|ifo| {
189                                Ok((
190                                    ifo.name,
191                                    ifo.index,
192                                    Self::compile_int(
193                                        ctx,
194                                        &ifo.typ,
195                                        &ifo.pattern,
196                                        scope,
197                                    )?,
198                                ))
199                            })
200                            .collect::<Result<Box<[(ArcStr, usize, Self)]>>>()?;
201                        Self::Struct { all, binds }
202                    }
203                    t => bail!("struct patterns can't match {t}"),
204                }
205            }
206        };
207        Ok(t)
208    }
209
210    pub fn ids<'a>(&'a self, f: &mut (dyn FnMut(BindId) + 'a)) {
211        match &self {
212            Self::Ignore | Self::Literal(_) => (),
213            Self::Bind(id) => f(*id),
214            Self::Slice { tuple: _, all, binds } => {
215                if let Some(id) = all {
216                    f(*id);
217                }
218                for n in binds.iter() {
219                    n.ids(f)
220                }
221            }
222            Self::Variant { tag: _, all, binds } => {
223                if let Some(id) = all {
224                    f(*id)
225                }
226                for n in binds.iter() {
227                    n.ids(f)
228                }
229            }
230            Self::SlicePrefix { all, prefix, tail } => {
231                if let Some(id) = all {
232                    f(*id)
233                }
234                for n in prefix.iter() {
235                    n.ids(f)
236                }
237                if let Some(id) = tail {
238                    f(*id)
239                }
240            }
241            Self::SliceSuffix { all, head, suffix } => {
242                if let Some(id) = all {
243                    f(*id)
244                }
245                if let Some(id) = head {
246                    f(*id)
247                }
248                for n in suffix.iter() {
249                    n.ids(f)
250                }
251            }
252            Self::Struct { all, binds } => {
253                if let Some(id) = all {
254                    f(*id)
255                }
256                for (_, _, n) in binds.iter() {
257                    n.ids(f)
258                }
259            }
260        }
261    }
262
263    pub fn bind<F: FnMut(BindId, Value)>(&self, v: &Value, f: &mut F) {
264        match &self {
265            Self::Ignore | Self::Literal(_) => (),
266            Self::Bind(id) => f(*id, v.clone()),
267            Self::Slice { tuple: _, all, binds } => match v {
268                Value::Array(a) if a.len() == binds.len() => {
269                    if let Some(id) = all {
270                        f(*id, v.clone());
271                    }
272                    for (j, n) in binds.iter().enumerate() {
273                        n.bind(&a[j], f)
274                    }
275                }
276                _ => (),
277            },
278            Self::Variant { tag: _, all, binds } => {
279                if let Some(id) = all {
280                    f(*id, v.clone())
281                }
282                match v {
283                    Value::Array(a) if a.len() == binds.len() + 1 => {
284                        for (j, n) in binds.iter().enumerate() {
285                            n.bind(&a[j + 1], f)
286                        }
287                    }
288                    _ => (),
289                }
290            }
291            Self::SlicePrefix { all, prefix, tail } => match v {
292                Value::Array(a) if a.len() >= prefix.len() => {
293                    if let Some(id) = all {
294                        f(*id, v.clone())
295                    }
296                    for (j, n) in prefix.iter().enumerate() {
297                        n.bind(&a[j], f)
298                    }
299                    if let Some(id) = tail {
300                        let ss = a.subslice(prefix.len()..).unwrap();
301                        f(*id, Value::Array(ss))
302                    }
303                }
304                _ => (),
305            },
306            Self::SliceSuffix { all, head, suffix } => match v {
307                Value::Array(a) if a.len() >= suffix.len() => {
308                    if let Some(id) = all {
309                        f(*id, v.clone())
310                    }
311                    if let Some(id) = head {
312                        let ss = a.subslice(..suffix.len()).unwrap();
313                        f(*id, Value::Array(ss))
314                    }
315                    let tail = a.subslice(suffix.len()..).unwrap();
316                    for (j, n) in suffix.iter().enumerate() {
317                        n.bind(&tail[j], f)
318                    }
319                }
320                _ => (),
321            },
322            Self::Struct { all, binds } => match v {
323                Value::Array(a) if a.len() >= binds.len() => {
324                    if let Some(id) = all {
325                        f(*id, v.clone())
326                    }
327                    for (_, i, n) in binds.iter() {
328                        if let Some(v) = a.get(*i) {
329                            match v {
330                                Value::Array(a) if a.len() == 2 => n.bind(&a[1], f),
331                                _ => (),
332                            }
333                        }
334                    }
335                }
336                _ => (),
337            },
338        }
339    }
340
341    pub fn unbind<F: FnMut(BindId)>(&self, f: &mut F) {
342        match &self {
343            Self::Ignore | Self::Literal(_) => (),
344            Self::Bind(id) => f(*id),
345            Self::Slice { tuple: _, all, binds }
346            | Self::Variant { tag: _, all, binds } => {
347                if let Some(id) = all {
348                    f(*id)
349                }
350                for n in binds.iter() {
351                    n.unbind(f)
352                }
353            }
354            Self::SlicePrefix { all, prefix, tail } => {
355                if let Some(id) = all {
356                    f(*id)
357                }
358                if let Some(id) = tail {
359                    f(*id)
360                }
361                for n in prefix.iter() {
362                    n.unbind(f)
363                }
364            }
365            Self::SliceSuffix { all, head, suffix } => {
366                if let Some(id) = all {
367                    f(*id)
368                }
369                if let Some(id) = head {
370                    f(*id)
371                }
372                for n in suffix.iter() {
373                    n.unbind(f)
374                }
375            }
376            Self::Struct { all, binds } => {
377                if let Some(id) = all {
378                    f(*id)
379                }
380                for (_, _, n) in binds.iter() {
381                    n.unbind(f)
382                }
383            }
384        }
385    }
386
387    pub fn is_match(&self, v: &Value) -> bool {
388        match &self {
389            Self::Ignore | Self::Bind(_) => true,
390            Self::Literal(o) => v == o,
391            Self::Slice { tuple: _, all: _, binds } => match v {
392                Value::Array(a) => {
393                    a.len() == binds.len()
394                        && binds.iter().zip(a.iter()).all(|(b, v)| b.is_match(v))
395                }
396                _ => false,
397            },
398            Self::Variant { tag, all: _, binds } if binds.len() == 0 => match v {
399                Value::String(s) => tag == s,
400                _ => false,
401            },
402            Self::Variant { tag, all: _, binds } => match v {
403                Value::Array(a) => {
404                    a.len() == binds.len() + 1
405                        && match &a[0] {
406                            Value::String(s) => s == tag,
407                            _ => false,
408                        }
409                        && binds.iter().zip(a[1..].iter()).all(|(b, v)| b.is_match(v))
410                }
411                _ => false,
412            },
413            Self::SlicePrefix { all: _, prefix, tail: _ } => match v {
414                Value::Array(a) => {
415                    a.len() >= prefix.len()
416                        && prefix.iter().zip(a.iter()).all(|(b, v)| b.is_match(v))
417                }
418                _ => false,
419            },
420            Self::SliceSuffix { all: _, head: _, suffix } => match v {
421                Value::Array(a) => {
422                    a.len() >= suffix.len()
423                        && suffix
424                            .iter()
425                            .zip(a.iter().skip(a.len() - suffix.len()))
426                            .all(|(b, v)| b.is_match(v))
427                }
428                _ => false,
429            },
430            Self::Struct { all: _, binds } => match v {
431                Value::Array(a) => {
432                    a.len() >= binds.len()
433                        && binds.iter().all(|(_, i, p)| match a.get(*i) {
434                            Some(Value::Array(a)) if a.len() == 2 => p.is_match(&a[1]),
435                            _ => false,
436                        })
437                }
438                _ => false,
439            },
440        }
441    }
442
443    pub fn is_refutable(&self) -> bool {
444        match &self {
445            Self::Bind(_) | Self::Ignore => false,
446            Self::Literal(_) => true,
447            Self::Slice { tuple: true, all: _, binds } => {
448                binds.iter().any(|p| p.is_refutable())
449            }
450            Self::Struct { all: _, binds } => {
451                binds.iter().any(|(_, _, p)| p.is_refutable())
452            }
453            Self::Variant { .. }
454            | Self::Slice { tuple: false, .. }
455            | Self::SlicePrefix { .. }
456            | Self::SliceSuffix { .. } => true,
457        }
458    }
459}
460
461#[derive(Debug)]
462pub struct PatternNode<C: Ctx, E: UserEvent> {
463    pub type_predicate: Type<NoRefs>,
464    pub structure_predicate: StructPatternNode,
465    pub guard: Option<Cached<C, E>>,
466}
467
468impl<C: Ctx, E: UserEvent> PatternNode<C, E> {
469    pub(super) fn compile(
470        ctx: &mut ExecCtx<C, E>,
471        spec: &Pattern,
472        scope: &ModPath,
473        top_id: ExprId,
474    ) -> Result<Self> {
475        let type_predicate = match &spec.type_predicate {
476            Some(t) => t.resolve_typerefs(scope, &ctx.env)?,
477            None => spec.structure_predicate.infer_type_predicate(),
478        };
479        match &type_predicate {
480            Type::Fn(_) => bail!("can't match on Fn type"),
481            Type::Bottom(_)
482            | Type::Primitive(_)
483            | Type::Set(_)
484            | Type::TVar(_)
485            | Type::Array(_)
486            | Type::Tuple(_)
487            | Type::Variant(_, _)
488            | Type::Struct(_) => (),
489            Type::Ref(_) => unreachable!(),
490        }
491        let structure_predicate = StructPatternNode::compile(
492            ctx,
493            &type_predicate,
494            &spec.structure_predicate,
495            scope,
496        )?;
497        let guard = spec
498            .guard
499            .as_ref()
500            .map(|g| compiler::compile(ctx, g.clone(), &scope, top_id))
501            .transpose()?
502            .map(Cached::new);
503        Ok(PatternNode { type_predicate, structure_predicate, guard })
504    }
505
506    pub(super) fn bind_event(&self, event: &mut Event<E>, v: &Value) {
507        self.structure_predicate.bind(v, &mut |id, v| {
508            event.variables.insert(id, v);
509        })
510    }
511
512    pub(super) fn unbind_event(&self, event: &mut Event<E>) {
513        self.structure_predicate.unbind(&mut |id| {
514            event.variables.remove(&id);
515        })
516    }
517
518    pub(super) fn update(
519        &mut self,
520        ctx: &mut ExecCtx<C, E>,
521        event: &mut Event<E>,
522    ) -> bool {
523        match &mut self.guard {
524            None => false,
525            Some(g) => g.update(ctx, event),
526        }
527    }
528
529    pub(super) fn is_match(&self, typ: Typ, v: &Value) -> bool {
530        let tmatch = match (&self.type_predicate, typ) {
531            (Type::Array(_), Typ::Array)
532            | (Type::Tuple(_), Typ::Array)
533            | (Type::Struct(_), Typ::Array)
534            | (Type::Variant(_, _), Typ::Array | Typ::String) => true,
535            _ => self.type_predicate.contains(&Type::Primitive(typ.into())),
536        };
537        tmatch
538            && self.structure_predicate.is_match(v)
539            && match &self.guard {
540                None => true,
541                Some(g) => g
542                    .cached
543                    .as_ref()
544                    .and_then(|v| v.clone().get_as::<bool>())
545                    .unwrap_or(false),
546            }
547    }
548}