use starkom_bluesky::Scalar;
use std::collections::{BTreeMap, BTreeSet, btree_map};
pub(crate) fn padded_size(n: usize) -> usize {
std::cmp::max(2, n.next_power_of_two())
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(crate) struct GateConstraint {
pub(crate) ql: Scalar,
pub(crate) qr: Scalar,
pub(crate) qo: Scalar,
pub(crate) qm: Scalar,
pub(crate) qc: Scalar,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum Wire {
LeftIn(usize),
RightIn(usize),
Out(usize),
}
impl Wire {
pub fn gate(&self) -> usize {
match *self {
Self::LeftIn(gate) => gate,
Self::RightIn(gate) => gate,
Self::Out(gate) => gate,
}
}
pub(crate) fn sigma_index(&self, n: usize) -> usize {
match self {
Wire::LeftIn(index) => *index as usize,
Wire::RightIn(index) => *index as usize + n,
Wire::Out(index) => *index as usize + n * 2,
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum WireOrUnconstrained {
Wire(Wire),
Unconstrained(Scalar),
}
impl From<Wire> for WireOrUnconstrained {
fn from(wire: Wire) -> Self {
WireOrUnconstrained::Wire(wire)
}
}
impl From<Scalar> for WireOrUnconstrained {
fn from(value: Scalar) -> Self {
WireOrUnconstrained::Unconstrained(value)
}
}
pub(crate) struct NodeIterator<'a> {
inner: btree_map::Iter<'a, usize, BTreeSet<Wire>>,
}
impl<'a> Iterator for NodeIterator<'a> {
type Item = &'a BTreeSet<Wire>;
fn next(&mut self) -> Option<Self::Item> {
self.inner.next().map(|(_, node)| node)
}
}
#[derive(Debug, Default, Clone)]
pub(crate) struct WirePartitioning {
next_id: usize,
nodes: BTreeMap<usize, BTreeSet<Wire>>,
node_by_wire: BTreeMap<Wire, usize>,
}
impl WirePartitioning {
pub(crate) fn connect(&mut self, wire1: Wire, wire2: Wire) {
if let Some(node_id1) = self.node_by_wire.get(&wire1) {
if let Some(node_id2) = self.node_by_wire.get(&wire2) {
if *node_id1 != *node_id2 {
let mut node2 = self.nodes.remove(&node_id2).unwrap();
let node1 = self.nodes.get_mut(node_id1).unwrap();
node1.append(&mut node2);
self.node_by_wire.insert(wire2, *node_id1);
}
} else {
let node = self.nodes.get_mut(node_id1).unwrap();
node.insert(wire2);
self.node_by_wire.insert(wire2, *node_id1);
}
} else {
if let Some(node_id) = self.node_by_wire.get(&wire2) {
let node = self.nodes.get_mut(node_id).unwrap();
node.insert(wire1);
self.node_by_wire.insert(wire1, *node_id);
} else {
let id = self.next_id;
self.next_id += 1;
self.nodes.insert(id, BTreeSet::from([wire1, wire2]));
self.node_by_wire.insert(wire1, id);
self.node_by_wire.insert(wire2, id);
}
}
}
pub(crate) fn iter_nodes(&self) -> NodeIterator<'_> {
NodeIterator {
inner: self.nodes.iter(),
}
}
}