bend/fun/
net_to_term.rs

1use crate::{
2  diagnostics::{DiagnosticOrigin, Diagnostics, Severity},
3  fun::{term_to_net::Labels, Book, FanKind, Name, Num, Op, Pattern, Tag, Term},
4  maybe_grow,
5  net::{BendLab, CtrKind, INet, NodeId, NodeKind, Port, SlotId, ROOT},
6};
7use hvm::hvm::Numb;
8use std::collections::{BTreeSet, HashMap, HashSet};
9
10/// Converts an Interaction-INet to a Lambda Calculus term
11pub fn net_to_term(
12  net: &INet,
13  book: &Book,
14  labels: &Labels,
15  linear: bool,
16  diagnostics: &mut Diagnostics,
17) -> Term {
18  let mut reader = Reader {
19    net,
20    labels,
21    book,
22    recursive_defs: &book.recursive_defs(),
23    dup_paths: if linear { None } else { Some(Default::default()) },
24    scope: Default::default(),
25    seen_fans: Default::default(),
26    namegen: Default::default(),
27    seen: Default::default(),
28    errors: Default::default(),
29  };
30
31  let mut term = reader.read_term(net.enter_port(ROOT));
32
33  while let Some(node) = reader.scope.pop_first() {
34    let val = reader.read_term(reader.net.enter_port(Port(node, 0)));
35    let fst = reader.namegen.decl_name(net, Port(node, 1));
36    let snd = reader.namegen.decl_name(net, Port(node, 2));
37
38    let (fan, tag) = match reader.net.node(node).kind {
39      NodeKind::Ctr(CtrKind::Tup(lab)) => (FanKind::Tup, reader.labels.tup.to_tag(lab)),
40      NodeKind::Ctr(CtrKind::Dup(lab)) => (FanKind::Dup, reader.labels.dup.to_tag(Some(lab))),
41      _ => unreachable!(),
42    };
43
44    let split = &mut Split { fan, tag, fst, snd, val };
45
46    let uses = term.insert_split(split, usize::MAX).unwrap();
47    let result = term.insert_split(split, uses);
48    debug_assert_eq!(result, None);
49  }
50
51  reader.report_errors(diagnostics);
52
53  let mut unscoped = HashSet::new();
54  let mut scope = Vec::new();
55  term.collect_unscoped(&mut unscoped, &mut scope);
56  term.apply_unscoped(&unscoped);
57
58  term
59}
60
61// BTreeSet for consistent readback of dups
62type Scope = BTreeSet<NodeId>;
63
64pub struct Reader<'a> {
65  pub book: &'a Book,
66  pub namegen: NameGen,
67  net: &'a INet,
68  labels: &'a Labels,
69  dup_paths: Option<HashMap<u16, Vec<SlotId>>>,
70  /// Store for floating/unscoped terms, like dups and let tups.
71  scope: Scope,
72  // To avoid reinserting things in the scope.
73  seen_fans: Scope,
74  seen: HashSet<Port>,
75  errors: Vec<ReadbackError>,
76  recursive_defs: &'a BTreeSet<Name>,
77}
78
79impl Reader<'_> {
80  fn read_term(&mut self, next: Port) -> Term {
81    use CtrKind::*;
82
83    maybe_grow(|| {
84      if !self.seen.insert(next) && self.dup_paths.is_none() {
85        self.error(ReadbackError::Cyclic);
86        return Term::Var { nam: Name::new("...") };
87      }
88
89      let node = next.node_id();
90      match &self.net.node(node).kind {
91        NodeKind::Era => Term::Era,
92        NodeKind::Ctr(CtrKind::Con(lab)) => self.read_con(next, *lab),
93        NodeKind::Swi => self.read_swi(next),
94        NodeKind::Ref { def_name } => Term::Ref { nam: def_name.clone() },
95        NodeKind::Ctr(kind @ (Dup(_) | Tup(_))) => self.read_fan(next, *kind),
96        NodeKind::Num { val } => num_from_bits_with_type(*val, *val),
97        NodeKind::Opr => self.read_opr(next),
98        NodeKind::Rot => {
99          self.error(ReadbackError::ReachedRoot);
100          Term::Err
101        }
102      }
103    })
104  }
105
106  /// Reads a term from a CON node.
107  /// Could be a lambda, an application, a CON tuple or a CON tuple elimination.
108  fn read_con(&mut self, next: Port, label: Option<BendLab>) -> Term {
109    let node = next.node_id();
110    match next.slot() {
111      // If we're visiting a port 0, then it is a tuple or a lambda.
112      0 => {
113        if self.is_tup(node) {
114          // A tuple
115          let lft = self.read_term(self.net.enter_port(Port(node, 1)));
116          let rgt = self.read_term(self.net.enter_port(Port(node, 2)));
117          Term::Fan { fan: FanKind::Tup, tag: self.labels.con.to_tag(label), els: vec![lft, rgt] }
118        } else {
119          // A lambda
120          let nam = self.namegen.decl_name(self.net, Port(node, 1));
121          let bod = self.read_term(self.net.enter_port(Port(node, 2)));
122          Term::Lam {
123            tag: self.labels.con.to_tag(label),
124            pat: Box::new(Pattern::Var(nam)),
125            bod: Box::new(bod),
126          }
127        }
128      }
129      // If we're visiting a port 1, then it is a variable.
130      1 => Term::Var { nam: self.namegen.var_name(next) },
131      // If we're visiting a port 2, then it is an application.
132      2 => {
133        let fun = self.read_term(self.net.enter_port(Port(node, 0)));
134        let arg = self.read_term(self.net.enter_port(Port(node, 1)));
135        Term::App { tag: self.labels.con.to_tag(label), fun: Box::new(fun), arg: Box::new(arg) }
136      }
137      _ => unreachable!(),
138    }
139  }
140
141  /// Reads a fan term from a DUP node.
142  /// Could be a superposition, a duplication, a DUP tuple or a DUP tuple elimination.
143  fn read_fan(&mut self, next: Port, kind: CtrKind) -> Term {
144    let node = next.node_id();
145    let (fan, lab) = match kind {
146      CtrKind::Tup(lab) => (FanKind::Tup, lab),
147      CtrKind::Dup(lab) => (FanKind::Dup, Some(lab)),
148      _ => unreachable!(),
149    };
150    match next.slot() {
151      // If we're visiting a port 0, then it is a pair.
152      0 => {
153        // If this superposition is in a readback path with a paired Dup,
154        // we resolve it by splitting the two sup values into the two Dup variables.
155        // If we find that it's not paired with a Dup, we just keep the Sup as a term.
156        // The latter are all the early returns.
157
158        if fan != FanKind::Dup {
159          return self.decay_or_get_ports(node).unwrap_or_else(|(fst, snd)| Term::Fan {
160            fan,
161            tag: self.labels[fan].to_tag(lab),
162            els: vec![fst, snd],
163          });
164        }
165
166        let Some(dup_paths) = &mut self.dup_paths else {
167          return self.decay_or_get_ports(node).unwrap_or_else(|(fst, snd)| Term::Fan {
168            fan,
169            tag: self.labels[fan].to_tag(lab),
170            els: vec![fst, snd],
171          });
172        };
173
174        let stack = dup_paths.entry(lab.unwrap()).or_default();
175        let Some(slot) = stack.pop() else {
176          return self.decay_or_get_ports(node).unwrap_or_else(|(fst, snd)| Term::Fan {
177            fan,
178            tag: self.labels[fan].to_tag(lab),
179            els: vec![fst, snd],
180          });
181        };
182
183        // Found a paired Dup, so we "decay" the superposition according to the original direction we came from the Dup.
184        let term = self.read_term(self.net.enter_port(Port(node, slot)));
185        self.dup_paths.as_mut().unwrap().get_mut(&lab.unwrap()).unwrap().push(slot);
186        term
187      }
188      // If we're visiting a port 1 or 2, then it is a variable.
189      // Also, that means we found a dup, so we store it to read later.
190      1 | 2 => {
191        // If doing non-linear readback, we also store dup paths to try to resolve them later.
192        if let Some(dup_paths) = &mut self.dup_paths {
193          if fan == FanKind::Dup {
194            dup_paths.entry(lab.unwrap()).or_default().push(next.slot());
195            let term = self.read_term(self.net.enter_port(Port(node, 0)));
196            self.dup_paths.as_mut().unwrap().entry(lab.unwrap()).or_default().pop().unwrap();
197            return term;
198          }
199        }
200        // Otherwise, just store the new dup/let tup and return the variable.
201        if self.seen_fans.insert(node) {
202          self.scope.insert(node);
203        }
204        Term::Var { nam: self.namegen.var_name(next) }
205      }
206      _ => unreachable!(),
207    }
208  }
209
210  /// Reads an Opr term from an OPR node.
211  fn read_opr(&mut self, next: Port) -> Term {
212    /// Read one of the argument ports of an operation.
213    fn add_arg(
214      reader: &mut Reader,
215      port: Port,
216      args: &mut Vec<Result<hvm::hvm::Val, Term>>,
217      types: &mut Vec<hvm::hvm::Tag>,
218      ops: &mut Vec<hvm::hvm::Tag>,
219    ) {
220      if let NodeKind::Num { val } = reader.net.node(port.node_id()).kind {
221        match hvm::hvm::Numb::get_typ(&Numb(val)) {
222          // Contains an operation
223          hvm::hvm::TY_SYM => {
224            ops.push(hvm::hvm::Numb(val).get_sym());
225          }
226          // Contains a number with a type
227          typ @ hvm::hvm::TY_U24..=hvm::hvm::TY_F24 => {
228            types.push(typ);
229            args.push(Ok(val));
230          }
231          // Contains a partially applied number with operation and no type
232          op @ hvm::hvm::OP_ADD.. => {
233            ops.push(op);
234            args.push(Ok(val));
235          }
236        }
237      } else {
238        // Some other non-number argument
239        let term = reader.read_term(port);
240        args.push(Err(term));
241      }
242    }
243
244    /// Creates an Opr term from the arguments of the subnet of an OPR node.
245    fn opr_term_from_hvm_args(
246      args: &mut Vec<Result<hvm::hvm::Val, Term>>,
247      types: &mut Vec<hvm::hvm::Tag>,
248      ops: &mut Vec<hvm::hvm::Tag>,
249      is_flipped: bool,
250    ) -> Term {
251      let typ = match types.as_slice() {
252        [typ] => *typ,
253        // Use U24 as default number type
254        [] => hvm::hvm::TY_U24,
255        _ => {
256          // Too many types
257          return Term::Err;
258        }
259      };
260      match (args.as_slice(), ops.as_slice()) {
261        ([arg1, arg2], [op]) => {
262          // Correct number of arguments
263          let arg1 = match arg1 {
264            Ok(val) => num_from_bits_with_type(*val, typ as u32),
265            Err(val) => val.clone(),
266          };
267          let arg2 = match arg2 {
268            Ok(val) => num_from_bits_with_type(*val, typ as u32),
269            Err(val) => val.clone(),
270          };
271          let (arg1, arg2) = if is_flipped ^ op_is_flipped(*op) { (arg2, arg1) } else { (arg1, arg2) };
272          let Some(op) = op_from_native_tag(*op, typ) else {
273            // Invalid operator
274            return Term::Err;
275          };
276          Term::Oper { opr: op, fst: Box::new(arg1), snd: Box::new(arg2) }
277        }
278        _ => {
279          // Invalid number of arguments/types/operators
280          Term::Err
281        }
282      }
283    }
284
285    fn op_is_flipped(op: hvm::hvm::Tag) -> bool {
286      [hvm::hvm::FP_DIV, hvm::hvm::FP_REM, hvm::hvm::FP_SHL, hvm::hvm::FP_SHR, hvm::hvm::FP_SUB].contains(&op)
287    }
288
289    fn op_from_native_tag(val: hvm::hvm::Tag, typ: hvm::hvm::Tag) -> Option<Op> {
290      let op = match val {
291        hvm::hvm::OP_ADD => Op::ADD,
292        hvm::hvm::OP_SUB => Op::SUB,
293        hvm::hvm::FP_SUB => Op::SUB,
294        hvm::hvm::OP_MUL => Op::MUL,
295        hvm::hvm::OP_DIV => Op::DIV,
296        hvm::hvm::FP_DIV => Op::DIV,
297        hvm::hvm::OP_REM => Op::REM,
298        hvm::hvm::FP_REM => Op::REM,
299        hvm::hvm::OP_EQ => Op::EQ,
300        hvm::hvm::OP_NEQ => Op::NEQ,
301        hvm::hvm::OP_LT => Op::LT,
302        hvm::hvm::OP_GT => Op::GT,
303        hvm::hvm::OP_AND => {
304          if typ == hvm::hvm::TY_F24 {
305            todo!("Implement readback of atan2")
306          } else {
307            Op::AND
308          }
309        }
310        hvm::hvm::OP_OR => {
311          if typ == hvm::hvm::TY_F24 {
312            todo!("Implement readback of log")
313          } else {
314            Op::OR
315          }
316        }
317        hvm::hvm::OP_XOR => {
318          if typ == hvm::hvm::TY_F24 {
319            Op::POW
320          } else {
321            Op::XOR
322          }
323        }
324        hvm::hvm::OP_SHL => Op::SHL,
325        hvm::hvm::FP_SHL => Op::SHL,
326        hvm::hvm::OP_SHR => Op::SHR,
327        hvm::hvm::FP_SHR => Op::SHR,
328        _ => return None,
329      };
330      Some(op)
331    }
332
333    let node = next.node_id();
334    match next.slot() {
335      2 => {
336        // If port1 has a partially applied number, the operation has 1 node.
337        // Port0 has arg1 and port1 has arg2.
338        // The operation is interpreted as being pre-flipped (if its a FP_, they cancel and don't flip).
339        let port1_kind = self.net.node(self.net.enter_port(Port(node, 1)).node_id()).kind.clone();
340        if let NodeKind::Num { val } = port1_kind {
341          match hvm::hvm::Numb::get_typ(&Numb(val)) {
342            hvm::hvm::OP_ADD.. => {
343              let x1_port = self.net.enter_port(Port(node, 0));
344              let x2_port = self.net.enter_port(Port(node, 1));
345              let mut args = vec![];
346              let mut types = vec![];
347              let mut ops = vec![];
348              add_arg(self, x1_port, &mut args, &mut types, &mut ops);
349              add_arg(self, x2_port, &mut args, &mut types, &mut ops);
350              let term = opr_term_from_hvm_args(&mut args, &mut types, &mut ops, true);
351              if let Term::Err = term {
352                // Since that function doesn't have access to the reader, add the error here.
353                self.error(ReadbackError::InvalidNumericOp);
354              }
355              return term;
356            }
357            _ => {
358              // Not a partially applied number, handle it in the next case
359            }
360          }
361        }
362
363        // If port0 has a partially applied number, it also has 1 node.
364        // The operation is interpreted as not pre-flipped.
365        let port0_kind = self.net.node(self.net.enter_port(Port(node, 0)).node_id()).kind.clone();
366        if let NodeKind::Num { val } = port0_kind {
367          match hvm::hvm::Numb::get_typ(&Numb(val)) {
368            hvm::hvm::OP_ADD.. => {
369              let x1_port = self.net.enter_port(Port(node, 0));
370              let x2_port = self.net.enter_port(Port(node, 1));
371              let mut args = vec![];
372              let mut types = vec![];
373              let mut ops = vec![];
374              add_arg(self, x1_port, &mut args, &mut types, &mut ops);
375              add_arg(self, x2_port, &mut args, &mut types, &mut ops);
376              let term = opr_term_from_hvm_args(&mut args, &mut types, &mut ops, false);
377              if let Term::Err = term {
378                // Since that function doesn't have access to the reader, add the error here.
379                self.error(ReadbackError::InvalidNumericOp);
380              }
381              return term;
382            }
383            _ => {
384              // Not a partially applied number, handle it in the next case
385            }
386          }
387        }
388
389        // Otherwise, the operation has 2 nodes.
390        // Read the top node port0 and 1, bottom node port1.
391        // Args are in that order, skipping the operation.
392        let bottom_id = node;
393        let top_id = self.net.enter_port(Port(bottom_id, 0)).node_id();
394        if let NodeKind::Opr = self.net.node(top_id).kind {
395          let x1_port = self.net.enter_port(Port(top_id, 0));
396          let x2_port = self.net.enter_port(Port(top_id, 1));
397          let x3_port = self.net.enter_port(Port(bottom_id, 1));
398          let mut args = vec![];
399          let mut types = vec![];
400          let mut ops = vec![];
401          add_arg(self, x1_port, &mut args, &mut types, &mut ops);
402          add_arg(self, x2_port, &mut args, &mut types, &mut ops);
403          add_arg(self, x3_port, &mut args, &mut types, &mut ops);
404          let term = opr_term_from_hvm_args(&mut args, &mut types, &mut ops, false);
405          if let Term::Err = term {
406            self.error(ReadbackError::InvalidNumericOp);
407          }
408          term
409        } else {
410          // Port 0 was not an OPR node, invalid.
411          self.error(ReadbackError::InvalidNumericOp);
412          Term::Err
413        }
414      }
415      _ => {
416        // Entered from a port other than 2, invalid.
417        self.error(ReadbackError::InvalidNumericOp);
418        Term::Err
419      }
420    }
421  }
422
423  /// Reads a switch term from a SWI node.
424  fn read_swi(&mut self, next: Port) -> Term {
425    let node = next.node_id();
426    match next.slot() {
427      2 => {
428        // Read the matched expression
429        let arg = self.read_term(self.net.enter_port(Port(node, 0)));
430        let bnd = if let Term::Var { nam } = &arg { nam.clone() } else { self.namegen.unique() };
431
432        // Read the pattern matching node
433        let sel_node = self.net.enter_port(Port(node, 1)).node_id();
434
435        // We expect the pattern matching node to be a CON
436        let sel_kind = &self.net.node(sel_node).kind;
437        if sel_kind != &NodeKind::Ctr(CtrKind::Con(None)) {
438          // TODO: Is there any case where we expect a different node type here on readback?
439          self.error(ReadbackError::InvalidNumericMatch);
440          return Term::Err;
441        }
442
443        let zero = self.read_term(self.net.enter_port(Port(sel_node, 1)));
444        let mut succ = self.read_term(self.net.enter_port(Port(sel_node, 2)));
445        // Call expand_generated in case of succ_term be a lifted term
446        succ.expand_generated(self.book, self.recursive_defs);
447
448        // Succ term should be a lambda
449        let succ = match &mut succ {
450          Term::Lam { pat, bod, .. } => {
451            if let Pattern::Var(nam) = pat.as_ref() {
452              let mut bod = std::mem::take(bod.as_mut());
453              if let Some(nam) = nam {
454                bod.subst(nam, &Term::Var { nam: Name::new(format!("{bnd}-1")) });
455              }
456              bod
457            } else {
458              // Readback should never generate non-var patterns for lambdas.
459              self.error(ReadbackError::InvalidNumericMatch);
460              succ
461            }
462          }
463          _ => {
464            self.error(ReadbackError::InvalidNumericMatch);
465            succ
466          }
467        };
468        Term::Swt {
469          arg: Box::new(arg),
470          bnd: Some(bnd),
471          with_arg: vec![],
472          with_bnd: vec![],
473          pred: None,
474          arms: vec![zero, succ],
475        }
476      }
477      _ => {
478        self.error(ReadbackError::InvalidNumericMatch);
479        Term::Err
480      }
481    }
482  }
483
484  /// Enters both ports 1 and 2 of a node. Returns a Term if it is
485  /// possible to simplify the net, or the Terms on the two ports of the node.
486  /// The two possible outcomes are always equivalent.
487  ///
488  /// If:
489  ///  - The node Kind is CON/TUP/DUP
490  ///  - Both ports 1 and 2 are connected to the same node on slots 1 and 2 respectively
491  ///  - That node Kind is the same as the given node Kind
492  ///
493  /// Then:
494  ///   Reads the port 0 of the connected node, and returns that term.
495  ///
496  /// Otherwise:
497  ///   Returns the terms on ports 1 and 2 of the given node.
498  ///
499  /// # Example
500  ///
501  /// ```hvm
502  /// // λa let (a, b) = a; (a, b)
503  /// ([a b] [a b])
504  ///
505  /// // The node `(a, b)` is just a reconstruction of the destructuring of `a`,
506  /// // So we can skip both steps and just return the "value" unchanged:
507  ///
508  /// // λa a
509  /// (a a)
510  /// ```
511  ///
512  fn decay_or_get_ports(&mut self, node: NodeId) -> Result<Term, (Term, Term)> {
513    let fst_port = self.net.enter_port(Port(node, 1));
514    let snd_port = self.net.enter_port(Port(node, 2));
515
516    let node_kind = &self.net.node(node).kind;
517
518    // Eta-reduce the readback inet.
519    // This is not valid for all kinds of nodes, only CON/TUP/DUP, due to their interaction rules.
520    if matches!(node_kind, NodeKind::Ctr(_)) {
521      match (fst_port, snd_port) {
522        (Port(fst_node, 1), Port(snd_node, 2)) if fst_node == snd_node => {
523          if self.net.node(fst_node).kind == *node_kind {
524            self.scope.remove(&fst_node);
525
526            let port_zero = self.net.enter_port(Port(fst_node, 0));
527            let term = self.read_term(port_zero);
528            return Ok(term);
529          }
530        }
531        _ => {}
532      }
533    }
534
535    let fst = self.read_term(fst_port);
536    let snd = self.read_term(snd_port);
537    Err((fst, snd))
538  }
539
540  pub fn error(&mut self, error: ReadbackError) {
541    self.errors.push(error);
542  }
543
544  pub fn report_errors(&mut self, diagnostics: &mut Diagnostics) {
545    let mut err_counts = std::collections::HashMap::new();
546    for err in &self.errors {
547      *err_counts.entry(*err).or_insert(0) += 1;
548    }
549
550    for (err, count) in err_counts {
551      let count_msg = if count > 1 { format!(" ({count} occurrences)") } else { "".to_string() };
552      let msg = format!("{}{}", err, count_msg);
553      diagnostics.add_diagnostic(
554        msg.as_str(),
555        Severity::Warning,
556        DiagnosticOrigin::Readback,
557        Default::default(),
558      );
559    }
560  }
561
562  /// Returns whether the given port represents a tuple or some other
563  /// term (usually a lambda).
564  ///
565  /// Used heuristic: a con node is a tuple if port 1 is a closed tree and not an ERA.
566  fn is_tup(&self, node: NodeId) -> bool {
567    if !matches!(self.net.node(node).kind, NodeKind::Ctr(CtrKind::Con(_))) {
568      return false;
569    }
570    if self.net.node(self.net.enter_port(Port(node, 1)).node_id()).kind == NodeKind::Era {
571      return false;
572    }
573    let mut wires = HashSet::new();
574    let mut to_check = vec![self.net.enter_port(Port(node, 1))];
575    while let Some(port) = to_check.pop() {
576      match port.slot() {
577        0 => {
578          let node = port.node_id();
579          let lft = self.net.enter_port(Port(node, 1));
580          let rgt = self.net.enter_port(Port(node, 2));
581          to_check.push(lft);
582          to_check.push(rgt);
583        }
584        1 | 2 => {
585          // Mark as a wire. If already present, mark as visited by removing it.
586          if !(wires.insert(port) && wires.insert(self.net.enter_port(port))) {
587            wires.remove(&port);
588            wires.remove(&self.net.enter_port(port));
589          }
590        }
591        _ => unreachable!(),
592      }
593    }
594    // No hanging wires = a combinator = a tuple
595    wires.is_empty()
596  }
597}
598
599/* Utils for numbers and numeric operations */
600
601/// From an hvm number carrying the value and another carrying the type, return a Num term.
602fn num_from_bits_with_type(val: u32, typ: u32) -> Term {
603  match hvm::hvm::Numb::get_typ(&Numb(typ)) {
604    // No type information, assume u24 by default
605    hvm::hvm::TY_SYM => Term::Num { val: Num::U24(Numb::get_u24(&Numb(val))) },
606    hvm::hvm::TY_U24 => Term::Num { val: Num::U24(Numb::get_u24(&Numb(val))) },
607    hvm::hvm::TY_I24 => Term::Num { val: Num::I24(Numb::get_i24(&Numb(val))) },
608    hvm::hvm::TY_F24 => Term::Num { val: Num::F24(Numb::get_f24(&Numb(val))) },
609    _ => Term::Err,
610  }
611}
612
613/* Insertion of dups in the middle of the term */
614
615/// Represents `let #tag(fst, snd) = val` / `let #tag{fst snd} = val`
616struct Split {
617  fan: FanKind,
618  tag: Tag,
619  fst: Option<Name>,
620  snd: Option<Name>,
621  val: Term,
622}
623
624impl Default for Split {
625  fn default() -> Self {
626    Self {
627      fan: FanKind::Dup,
628      tag: Default::default(),
629      fst: Default::default(),
630      snd: Default::default(),
631      val: Default::default(),
632    }
633  }
634}
635
636impl Term {
637  /// Calculates the number of times `fst` and `snd` appear in this term. If
638  /// that is `>= threshold`, it inserts the split at this term, and returns
639  /// `None`. Otherwise, returns `Some(uses)`.
640  ///
641  /// This is only really useful when called in two passes – first, with
642  /// `threshold = usize::MAX`, to count the number of uses, and then with
643  /// `threshold = uses`.
644  ///
645  /// This has the effect of inserting the split at the lowest common ancestor
646  /// of all of the uses of `fst` and `snd`.
647  fn insert_split(&mut self, split: &mut Split, threshold: usize) -> Option<usize> {
648    maybe_grow(|| {
649      let mut n = match self {
650        Term::Var { nam } => usize::from(split.fst == *nam || split.snd == *nam),
651        _ => 0,
652      };
653      for child in self.children_mut() {
654        n += child.insert_split(split, threshold)?;
655      }
656
657      if n >= threshold {
658        let Split { fan, tag, fst, snd, val } = std::mem::take(split);
659        let nxt = Box::new(std::mem::take(self));
660        *self = Term::Let {
661          pat: Box::new(Pattern::Fan(fan, tag, vec![Pattern::Var(fst), Pattern::Var(snd)])),
662          val: Box::new(val),
663          nxt,
664        };
665        None
666      } else {
667        Some(n)
668      }
669    })
670  }
671}
672
673/* Variable name generation */
674
675#[derive(Default)]
676pub struct NameGen {
677  pub var_port_to_id: HashMap<Port, u64>,
678  pub id_counter: u64,
679}
680
681impl NameGen {
682  // Given a port, returns its name, or assigns one if it wasn't named yet.
683  fn var_name(&mut self, var_port: Port) -> Name {
684    let id = self.var_port_to_id.entry(var_port).or_insert_with(|| {
685      let id = self.id_counter;
686      self.id_counter += 1;
687      id
688    });
689    Name::from(*id)
690  }
691
692  fn decl_name(&mut self, net: &INet, var_port: Port) -> Option<Name> {
693    // If port is linked to an erase node, return an unused variable
694    let var_use = net.enter_port(var_port);
695    let var_kind = &net.node(var_use.node_id()).kind;
696    (*var_kind != NodeKind::Era).then(|| self.var_name(var_port))
697  }
698
699  pub fn unique(&mut self) -> Name {
700    let id = self.id_counter;
701    self.id_counter += 1;
702    Name::from(id)
703  }
704}
705
706/* Readback errors */
707
708#[derive(Debug, Clone, Copy)]
709pub enum ReadbackError {
710  InvalidNumericMatch,
711  InvalidNumericOp,
712  ReachedRoot,
713  Cyclic,
714}
715
716impl PartialEq for ReadbackError {
717  fn eq(&self, other: &Self) -> bool {
718    core::mem::discriminant(self) == core::mem::discriminant(other)
719  }
720}
721
722impl Eq for ReadbackError {}
723
724impl std::hash::Hash for ReadbackError {
725  fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
726    core::mem::discriminant(self).hash(state);
727  }
728}
729
730impl std::fmt::Display for ReadbackError {
731  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
732    match self {
733      ReadbackError::InvalidNumericMatch => write!(f, "Encountered an invalid 'switch'."),
734      ReadbackError::InvalidNumericOp => write!(f, "Encountered an invalid numeric operation."),
735      ReadbackError::ReachedRoot => {
736        write!(f, "Unable to interpret the HVM result as a valid Bend term. (Reached Root)")
737      }
738      ReadbackError::Cyclic => {
739        write!(f, "Unable to interpret the HVM result as a valid Bend term. (Cyclic Term)")
740      }
741    }
742  }
743}
744
745/* Recover unscoped vars */
746
747impl Term {
748  pub fn collect_unscoped(&self, unscoped: &mut HashSet<Name>, scope: &mut Vec<Name>) {
749    maybe_grow(|| match self {
750      Term::Var { nam } if !scope.contains(nam) => _ = unscoped.insert(nam.clone()),
751      Term::Swt { arg, bnd, with_bnd: _, with_arg, pred: _, arms } => {
752        arg.collect_unscoped(unscoped, scope);
753        for arg in with_arg {
754          arg.collect_unscoped(unscoped, scope);
755        }
756        arms[0].collect_unscoped(unscoped, scope);
757        if let Some(bnd) = bnd {
758          scope.push(Name::new(format!("{bnd}-1")));
759        }
760        arms[1].collect_unscoped(unscoped, scope);
761        if bnd.is_some() {
762          scope.pop();
763        }
764      }
765      _ => {
766        for (child, binds) in self.children_with_binds() {
767          let binds: Vec<_> = binds.collect();
768          for bind in binds.iter().copied().flatten() {
769            scope.push(bind.clone());
770          }
771          child.collect_unscoped(unscoped, scope);
772          for _bind in binds.into_iter().flatten() {
773            scope.pop();
774          }
775        }
776      }
777    })
778  }
779
780  /// Transform the variables that we previously found were unscoped into their unscoped variants.
781  pub fn apply_unscoped(&mut self, unscoped: &HashSet<Name>) {
782    maybe_grow(|| {
783      if let Term::Var { nam } = self {
784        if unscoped.contains(nam) {
785          *self = Term::Link { nam: std::mem::take(nam) }
786        }
787      }
788      if let Some(pat) = self.pattern_mut() {
789        pat.apply_unscoped(unscoped);
790      }
791      for child in self.children_mut() {
792        child.apply_unscoped(unscoped);
793      }
794    })
795  }
796}
797
798impl Pattern {
799  fn apply_unscoped(&mut self, unscoped: &HashSet<Name>) {
800    maybe_grow(|| {
801      if let Pattern::Var(Some(nam)) = self {
802        if unscoped.contains(nam) {
803          let nam = std::mem::take(nam);
804          *self = Pattern::Chn(nam);
805        }
806      }
807      for child in self.children_mut() {
808        child.apply_unscoped(unscoped)
809      }
810    })
811  }
812}