sp1_hypercube/shape/
mod.rs1mod cluster;
6mod ordered;
7
8pub use cluster::*;
9pub use ordered::*;
10
11use itertools::Itertools;
12use slop_matrix::{dense::RowMajorMatrix, Matrix};
13
14use std::{fmt::Debug, hash::Hash, str::FromStr};
15
16use deepsize2::DeepSizeOf;
17use serde::{Deserialize, Serialize};
18use slop_algebra::PrimeField;
19use std::collections::{hash_map::IntoIter, HashMap, HashSet};
20
21use crate::air::MachineAir;
22
23#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, DeepSizeOf)]
25pub struct Shape<K: Clone + Eq + Hash> {
26 pub inner: HashMap<K, usize>,
28}
29
30impl<K: Clone + Eq + Hash> Default for Shape<K> {
32 fn default() -> Self {
33 Self { inner: HashMap::default() }
34 }
35}
36
37impl<K: Clone + Eq + Hash + FromStr> Shape<K> {
38 #[must_use]
40 pub fn new(inner: HashMap<K, usize>) -> Self {
41 Self { inner }
42 }
43
44 #[must_use]
46 pub fn from_log2_heights(log2_heights: &[(K, usize)]) -> Self {
47 Self { inner: log2_heights.iter().map(|(k, h)| (k.clone(), *h)).collect() }
48 }
49
50 #[must_use]
52 pub fn from_traces<V: Clone + Send + Sync>(traces: &[(K, RowMajorMatrix<V>)]) -> Self {
53 Self {
54 inner: traces
55 .iter()
56 .map(|(name, trace)| (name.clone(), trace.height().ilog2() as usize))
57 .sorted_by_key(|(_, height)| *height)
58 .collect(),
59 }
60 }
61
62 #[must_use]
64 pub fn len(&self) -> usize {
65 self.inner.len()
66 }
67
68 #[must_use]
70 pub fn is_empty(&self) -> bool {
71 self.inner.is_empty()
72 }
73
74 pub fn height(&self, key: &K) -> Option<usize> {
76 self.inner.get(key).map(|height| 1 << *height)
77 }
78
79 pub fn log2_height(&self, key: &K) -> Option<usize> {
81 self.inner.get(key).copied()
82 }
83
84 pub fn contains(&self, key: &K) -> bool {
86 self.inner.contains_key(key)
87 }
88
89 pub fn insert(&mut self, key: K, height: usize) {
91 self.inner.insert(key, height);
92 }
93
94 pub fn included<F: PrimeField, A: MachineAir<F>>(&self, air: &A) -> bool
98 where
99 <K as FromStr>::Err: std::fmt::Debug,
100 {
101 self.inner.contains_key(&K::from_str(air.name()).unwrap())
102 }
103
104 pub fn iter(&self) -> impl Iterator<Item = (&K, &usize)> {
106 self.inner.iter().sorted_by_key(|(_, v)| *v)
107 }
108
109 #[must_use]
114 pub fn estimate_lde_size(&self, costs: &HashMap<K, usize>) -> usize {
115 self.iter().map(|(k, h)| costs[k] * (1 << h)).sum()
116 }
117}
118
119impl<K: Clone + Eq + Hash> Extend<Shape<K>> for Shape<K> {
120 fn extend<T: IntoIterator<Item = Shape<K>>>(&mut self, iter: T) {
121 for shape in iter {
122 self.inner.extend(shape.inner);
123 }
124 }
125}
126
127impl<K: Clone + Eq + Hash> Extend<(K, usize)> for Shape<K> {
128 fn extend<T: IntoIterator<Item = (K, usize)>>(&mut self, iter: T) {
129 self.inner.extend(iter);
130 }
131}
132
133impl<K: Clone + Eq + Hash + FromStr> FromIterator<(K, usize)> for Shape<K> {
134 fn from_iter<T: IntoIterator<Item = (K, usize)>>(iter: T) -> Self {
135 Self { inner: iter.into_iter().collect() }
136 }
137}
138
139impl<K: Clone + Eq + Hash> IntoIterator for Shape<K> {
140 type Item = (K, usize);
141 type IntoIter = IntoIter<K, usize>;
142
143 fn into_iter(self) -> Self::IntoIter {
144 self.inner.into_iter()
145 }
146}
147
148impl<K: Clone + Eq + Hash> PartialOrd for Shape<K> {
149 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
150 let set = self.inner.keys().collect::<HashSet<_>>();
151 let other_set = other.inner.keys().collect::<HashSet<_>>();
152
153 if self == other {
154 return Some(std::cmp::Ordering::Equal);
155 }
156
157 if set.is_subset(&other_set) {
158 let mut less_seen = false;
159 let mut greater_seen = false;
160 for (name, &height) in self.inner.iter() {
161 let other_height = other.inner[name];
162 match height.cmp(&other_height) {
163 std::cmp::Ordering::Less => less_seen = true,
164 std::cmp::Ordering::Greater => greater_seen = true,
165 std::cmp::Ordering::Equal => {}
166 }
167 }
168 if less_seen && greater_seen {
169 return None;
170 }
171
172 if less_seen {
173 return Some(std::cmp::Ordering::Less);
174 }
175 }
176
177 if other_set.is_subset(&set) {
178 let mut less_seen = false;
179 let mut greater_seen = false;
180 for (name, &other_height) in other.inner.iter() {
181 let height = self.inner[name];
182 match height.cmp(&other_height) {
183 std::cmp::Ordering::Less => less_seen = true,
184 std::cmp::Ordering::Greater => greater_seen = true,
185 std::cmp::Ordering::Equal => {}
186 }
187 }
188 if less_seen && greater_seen {
189 return None;
190 }
191
192 if greater_seen {
193 return Some(std::cmp::Ordering::Greater);
194 }
195 }
196
197 None
198 }
199}