use std::collections::HashMap;
use miden_crypto::field::Field;
use super::ir::{NodeId, NodeKind};
use crate::layout::InputKey;
#[derive(Debug)]
pub struct DagBuilder<EF> {
nodes: Vec<NodeKind<EF>>,
cache: HashMap<NodeKind<EF>, NodeId>,
}
impl<EF> DagBuilder<EF>
where
EF: Field,
{
pub fn new() -> Self {
Self { nodes: Vec::new(), cache: HashMap::new() }
}
pub fn into_nodes(self) -> Vec<NodeKind<EF>> {
self.nodes
}
pub fn input(&mut self, key: InputKey) -> NodeId {
self.intern(NodeKind::Input(key))
}
pub fn constant(&mut self, value: EF) -> NodeId {
self.intern(NodeKind::Constant(value))
}
pub fn add(&mut self, a: NodeId, b: NodeId) -> NodeId {
if let (Some(x), Some(y)) = (self.const_value(a), self.const_value(b)) {
return self.constant(x + y);
}
if self.is_zero(a) {
return b;
}
if self.is_zero(b) {
return a;
}
let (l, r) = if a <= b { (a, b) } else { (b, a) };
self.intern(NodeKind::Add(l, r))
}
pub fn sub(&mut self, a: NodeId, b: NodeId) -> NodeId {
if let (Some(x), Some(y)) = (self.const_value(a), self.const_value(b)) {
return self.constant(x - y);
}
if self.is_zero(b) {
return a;
}
self.intern(NodeKind::Sub(a, b))
}
pub fn mul(&mut self, a: NodeId, b: NodeId) -> NodeId {
if let (Some(x), Some(y)) = (self.const_value(a), self.const_value(b)) {
return self.constant(x * y);
}
if self.is_zero(a) || self.is_zero(b) {
return self.constant(EF::ZERO);
}
if self.is_one(a) {
return b;
}
if self.is_one(b) {
return a;
}
let (l, r) = if a <= b { (a, b) } else { (b, a) };
self.intern(NodeKind::Mul(l, r))
}
pub fn neg(&mut self, a: NodeId) -> NodeId {
if let Some(x) = self.const_value(a) {
return self.constant(-x);
}
self.intern(NodeKind::Neg(a))
}
fn const_value(&self, id: NodeId) -> Option<EF> {
match self.nodes.get(id.index())? {
NodeKind::Constant(v) => Some(*v),
_ => None,
}
}
fn is_zero(&self, id: NodeId) -> bool {
self.const_value(id).is_some_and(|v| v == EF::ZERO)
}
fn is_one(&self, id: NodeId) -> bool {
self.const_value(id).is_some_and(|v| v == EF::ONE)
}
fn intern(&mut self, node: NodeKind<EF>) -> NodeId {
if let Some(id) = self.cache.get(&node) {
return *id;
}
let id = NodeId(self.nodes.len());
self.nodes.push(node.clone());
self.cache.insert(node, id);
id
}
}
impl<EF> Default for DagBuilder<EF>
where
EF: Field,
{
fn default() -> Self {
Self::new()
}
}