Skip to main content

sp1_hypercube/shape/
mod.rs

1//! Utilities for working with shapes
2
3// TODO: Deprecate the rest of this module.
4
5mod cluster;
6mod ordered;
7
8pub use cluster::*;
9pub use ordered::*;
10
11use itertools::Itertools;
12use slop_matrix::{dense::RowMajorMatrix, Matrix};
13
14use std::{fmt::Debug, hash::Hash, str::FromStr};
15
16use deepsize2::DeepSizeOf;
17use serde::{Deserialize, Serialize};
18use slop_algebra::PrimeField;
19use std::collections::{hash_map::IntoIter, HashMap, HashSet};
20
21use crate::air::MachineAir;
22
23/// A way to keep track of the log2 heights of some set of chips.
24#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, DeepSizeOf)]
25pub struct Shape<K: Clone + Eq + Hash> {
26    /// The nonzero log2 heights of each chip.
27    pub inner: HashMap<K, usize>,
28}
29
30// Manual `impl` to remove bound `K: Default`.
31impl<K: Clone + Eq + Hash> Default for Shape<K> {
32    fn default() -> Self {
33        Self { inner: HashMap::default() }
34    }
35}
36
37impl<K: Clone + Eq + Hash + FromStr> Shape<K> {
38    /// Create a new empty shape.
39    #[must_use]
40    pub fn new(inner: HashMap<K, usize>) -> Self {
41        Self { inner }
42    }
43
44    /// Create a shape from a list of log2 heights.
45    #[must_use]
46    pub fn from_log2_heights(log2_heights: &[(K, usize)]) -> Self {
47        Self { inner: log2_heights.iter().map(|(k, h)| (k.clone(), *h)).collect() }
48    }
49
50    /// Create a shape from a list of traces.
51    #[must_use]
52    pub fn from_traces<V: Clone + Send + Sync>(traces: &[(K, RowMajorMatrix<V>)]) -> Self {
53        Self {
54            inner: traces
55                .iter()
56                .map(|(name, trace)| (name.clone(), trace.height().ilog2() as usize))
57                .sorted_by_key(|(_, height)| *height)
58                .collect(),
59        }
60    }
61
62    /// The number of chips in the shape.
63    #[must_use]
64    pub fn len(&self) -> usize {
65        self.inner.len()
66    }
67
68    /// Whether the shape is empty.
69    #[must_use]
70    pub fn is_empty(&self) -> bool {
71        self.inner.is_empty()
72    }
73
74    /// Get the height of a given key.
75    pub fn height(&self, key: &K) -> Option<usize> {
76        self.inner.get(key).map(|height| 1 << *height)
77    }
78
79    /// Get the log2 height of a given key.
80    pub fn log2_height(&self, key: &K) -> Option<usize> {
81        self.inner.get(key).copied()
82    }
83
84    /// Whether the shape includes a given key.
85    pub fn contains(&self, key: &K) -> bool {
86        self.inner.contains_key(key)
87    }
88
89    /// Insert a key-height pair into the shape.
90    pub fn insert(&mut self, key: K, height: usize) {
91        self.inner.insert(key, height);
92    }
93
94    /// Whether the shape includes a given AIR.
95    ///
96    /// TODO: Deprecate by adding `air.id()`.
97    pub fn included<F: PrimeField, A: MachineAir<F>>(&self, air: &A) -> bool
98    where
99        <K as FromStr>::Err: std::fmt::Debug,
100    {
101        self.inner.contains_key(&K::from_str(air.name()).unwrap())
102    }
103
104    /// Get an iterator over the shape.
105    pub fn iter(&self) -> impl Iterator<Item = (&K, &usize)> {
106        self.inner.iter().sorted_by_key(|(_, v)| *v)
107    }
108
109    /// Estimate the lde size.
110    ///
111    /// WARNING: This is a heuristic, it may not be completely accurate. To be 100% sure that they
112    /// OOM, you should run the shape through the prover.
113    #[must_use]
114    pub fn estimate_lde_size(&self, costs: &HashMap<K, usize>) -> usize {
115        self.iter().map(|(k, h)| costs[k] * (1 << h)).sum()
116    }
117}
118
119impl<K: Clone + Eq + Hash> Extend<Shape<K>> for Shape<K> {
120    fn extend<T: IntoIterator<Item = Shape<K>>>(&mut self, iter: T) {
121        for shape in iter {
122            self.inner.extend(shape.inner);
123        }
124    }
125}
126
127impl<K: Clone + Eq + Hash> Extend<(K, usize)> for Shape<K> {
128    fn extend<T: IntoIterator<Item = (K, usize)>>(&mut self, iter: T) {
129        self.inner.extend(iter);
130    }
131}
132
133impl<K: Clone + Eq + Hash + FromStr> FromIterator<(K, usize)> for Shape<K> {
134    fn from_iter<T: IntoIterator<Item = (K, usize)>>(iter: T) -> Self {
135        Self { inner: iter.into_iter().collect() }
136    }
137}
138
139impl<K: Clone + Eq + Hash> IntoIterator for Shape<K> {
140    type Item = (K, usize);
141    type IntoIter = IntoIter<K, usize>;
142
143    fn into_iter(self) -> Self::IntoIter {
144        self.inner.into_iter()
145    }
146}
147
148impl<K: Clone + Eq + Hash> PartialOrd for Shape<K> {
149    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
150        let set = self.inner.keys().collect::<HashSet<_>>();
151        let other_set = other.inner.keys().collect::<HashSet<_>>();
152
153        if self == other {
154            return Some(std::cmp::Ordering::Equal);
155        }
156
157        if set.is_subset(&other_set) {
158            let mut less_seen = false;
159            let mut greater_seen = false;
160            for (name, &height) in self.inner.iter() {
161                let other_height = other.inner[name];
162                match height.cmp(&other_height) {
163                    std::cmp::Ordering::Less => less_seen = true,
164                    std::cmp::Ordering::Greater => greater_seen = true,
165                    std::cmp::Ordering::Equal => {}
166                }
167            }
168            if less_seen && greater_seen {
169                return None;
170            }
171
172            if less_seen {
173                return Some(std::cmp::Ordering::Less);
174            }
175        }
176
177        if other_set.is_subset(&set) {
178            let mut less_seen = false;
179            let mut greater_seen = false;
180            for (name, &other_height) in other.inner.iter() {
181                let height = self.inner[name];
182                match height.cmp(&other_height) {
183                    std::cmp::Ordering::Less => less_seen = true,
184                    std::cmp::Ordering::Greater => greater_seen = true,
185                    std::cmp::Ordering::Equal => {}
186                }
187            }
188            if less_seen && greater_seen {
189                return None;
190            }
191
192            if greater_seen {
193                return Some(std::cmp::Ordering::Greater);
194            }
195        }
196
197        None
198    }
199}