sp1_stark/
opts.rs

1use std::env;
2
3use serde::{Deserialize, Serialize};
4use sysinfo::System;
5
6const MAX_SHARD_SIZE: usize = 1 << 21;
7const RECURSION_MAX_SHARD_SIZE: usize = 1 << 22;
8const MAX_SHARD_BATCH_SIZE: usize = 8;
9const DEFAULT_TRACE_GEN_WORKERS: usize = 1;
10const DEFAULT_CHECKPOINTS_CHANNEL_CAPACITY: usize = 128;
11const DEFAULT_RECORDS_AND_TRACES_CHANNEL_CAPACITY: usize = 1;
12const MAX_DEFERRED_SPLIT_THRESHOLD: usize = 1 << 15;
13
14/// Options to configure the SP1 prover for core and recursive proofs.
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16pub struct SP1ProverOpts {
17    /// Options for the core prover.
18    pub core_opts: SP1CoreOpts,
19    /// Options for the recursion prover.
20    pub recursion_opts: SP1CoreOpts,
21}
22
23impl SP1ProverOpts {
24    /// Get the default prover options.
25    #[must_use]
26    pub fn auto() -> Self {
27        let cpu_ram_gb = System::new_all().total_memory() / (1024 * 1024 * 1024);
28        SP1ProverOpts::cpu(cpu_ram_gb as usize)
29    }
30
31    /// Get the memory options (shard size, shard batch size, and divisor) for a prover on CPU based
32    /// on the amount of CPU memory.
33    #[must_use]
34    fn get_memory_opts(cpu_ram_gb: usize) -> (usize, usize, usize) {
35        match cpu_ram_gb {
36            0..33 => (19, 1, 3),
37            33..49 => (20, 1, 2),
38            49..65 => (21, 1, 3),
39            65..81 => (21, 3, 1),
40            81.. => (21, 4, 1),
41        }
42    }
43
44    /// Get the default prover options for a prover on CPU based on the amount of CPU memory.
45    ///
46    /// We use a soft heuristic based on our understanding of the memory usage in the GPU prover.
47    #[must_use]
48    pub fn cpu(cpu_ram_gb: usize) -> Self {
49        let (log2_shard_size, shard_batch_size, log2_divisor) = Self::get_memory_opts(cpu_ram_gb);
50
51        let mut opts = SP1ProverOpts::default();
52        opts.core_opts.shard_size = 1 << log2_shard_size;
53        opts.core_opts.shard_batch_size = shard_batch_size;
54
55        opts.core_opts.records_and_traces_channel_capacity = 1;
56        opts.core_opts.trace_gen_workers = 1;
57
58        let divisor = 1 << log2_divisor;
59        opts.core_opts.split_opts.deferred /= divisor;
60        opts.core_opts.split_opts.keccak /= divisor;
61        opts.core_opts.split_opts.sha_extend /= divisor;
62        opts.core_opts.split_opts.sha_compress /= divisor;
63        opts.core_opts.split_opts.memory /= divisor;
64
65        opts.recursion_opts.shard_batch_size = 2;
66        opts.recursion_opts.records_and_traces_channel_capacity = 1;
67        opts.recursion_opts.trace_gen_workers = 1;
68
69        opts
70    }
71
72    /// Get the default prover options for a prover on GPU given the amount of CPU and GPU memory.
73    #[must_use]
74    pub fn gpu(cpu_ram_gb: usize, gpu_ram_gb: usize) -> Self {
75        let mut opts = SP1ProverOpts::default();
76
77        // Set the core options.
78        if 24 <= gpu_ram_gb {
79            let log2_shard_size = 21;
80            opts.core_opts.shard_size = 1 << log2_shard_size;
81            opts.core_opts.shard_batch_size = 1;
82
83            let log2_deferred_threshold = 14;
84            opts.core_opts.split_opts = SplitOpts::new(1 << log2_deferred_threshold);
85
86            opts.core_opts.records_and_traces_channel_capacity = 4;
87            opts.core_opts.trace_gen_workers = 4;
88
89            if cpu_ram_gb <= 20 {
90                opts.core_opts.records_and_traces_channel_capacity = 1;
91                opts.core_opts.trace_gen_workers = 2;
92            }
93        } else {
94            unreachable!("not enough gpu memory");
95        }
96
97        // Set the recursion options.
98        opts.recursion_opts.shard_batch_size = 1;
99
100        opts
101    }
102}
103
104/// Options for the core prover.
105#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
106pub struct SP1CoreOpts {
107    /// The size of a shard in terms of cycles.
108    pub shard_size: usize,
109    /// The size of a batch of shards in terms of cycles.
110    pub shard_batch_size: usize,
111    /// Options for splitting deferred events.
112    pub split_opts: SplitOpts,
113    /// The number of workers to use for generating traces.
114    pub trace_gen_workers: usize,
115    /// The capacity of the channel for checkpoints.
116    pub checkpoints_channel_capacity: usize,
117    /// The capacity of the channel for records and traces.
118    pub records_and_traces_channel_capacity: usize,
119}
120
121impl Default for SP1ProverOpts {
122    fn default() -> Self {
123        Self { core_opts: SP1CoreOpts::default(), recursion_opts: SP1CoreOpts::recursion() }
124    }
125}
126
127impl Default for SP1CoreOpts {
128    fn default() -> Self {
129        let cpu_ram_gb = System::new_all().total_memory() / (1024 * 1024 * 1024);
130        let (default_log2_shard_size, default_shard_batch_size, default_log2_divisor) =
131            SP1ProverOpts::get_memory_opts(cpu_ram_gb as usize);
132
133        let mut opts = Self {
134            shard_size: env::var("SHARD_SIZE").map_or_else(
135                |_| 1 << default_log2_shard_size,
136                |s| s.parse::<usize>().unwrap_or(1 << default_log2_shard_size),
137            ),
138            shard_batch_size: env::var("SHARD_BATCH_SIZE").map_or_else(
139                |_| default_shard_batch_size,
140                |s| s.parse::<usize>().unwrap_or(default_shard_batch_size),
141            ),
142            split_opts: SplitOpts::new(MAX_DEFERRED_SPLIT_THRESHOLD),
143            trace_gen_workers: env::var("TRACE_GEN_WORKERS").map_or_else(
144                |_| DEFAULT_TRACE_GEN_WORKERS,
145                |s| s.parse::<usize>().unwrap_or(DEFAULT_TRACE_GEN_WORKERS),
146            ),
147            checkpoints_channel_capacity: env::var("CHECKPOINTS_CHANNEL_CAPACITY").map_or_else(
148                |_| DEFAULT_CHECKPOINTS_CHANNEL_CAPACITY,
149                |s| s.parse::<usize>().unwrap_or(DEFAULT_CHECKPOINTS_CHANNEL_CAPACITY),
150            ),
151            records_and_traces_channel_capacity: env::var("RECORDS_AND_TRACES_CHANNEL_CAPACITY")
152                .map_or_else(
153                    |_| DEFAULT_RECORDS_AND_TRACES_CHANNEL_CAPACITY,
154                    |s| s.parse::<usize>().unwrap_or(DEFAULT_RECORDS_AND_TRACES_CHANNEL_CAPACITY),
155                ),
156        };
157
158        let divisor = 1 << default_log2_divisor;
159        opts.split_opts.deferred /= divisor;
160        opts.split_opts.keccak /= divisor;
161        opts.split_opts.sha_extend /= divisor;
162        opts.split_opts.sha_compress /= divisor;
163        opts.split_opts.memory /= divisor;
164
165        opts
166    }
167}
168
169impl SP1CoreOpts {
170    /// Get the default options for the recursion prover.
171    #[must_use]
172    pub fn recursion() -> Self {
173        let mut opts = Self::max();
174        opts.shard_size = RECURSION_MAX_SHARD_SIZE;
175        opts.shard_batch_size = 2;
176        opts
177    }
178
179    /// Get the maximum options for the core prover.
180    #[must_use]
181    pub fn max() -> Self {
182        let split_threshold = env::var("SPLIT_THRESHOLD")
183            .map(|s| s.parse::<usize>().unwrap_or(MAX_DEFERRED_SPLIT_THRESHOLD))
184            .unwrap_or(MAX_DEFERRED_SPLIT_THRESHOLD)
185            .max(MAX_DEFERRED_SPLIT_THRESHOLD);
186
187        let shard_size = env::var("SHARD_SIZE")
188            .map_or_else(|_| MAX_SHARD_SIZE, |s| s.parse::<usize>().unwrap_or(MAX_SHARD_SIZE));
189
190        Self {
191            shard_size,
192            shard_batch_size: env::var("SHARD_BATCH_SIZE").map_or_else(
193                |_| MAX_SHARD_BATCH_SIZE,
194                |s| s.parse::<usize>().unwrap_or(MAX_SHARD_BATCH_SIZE),
195            ),
196            split_opts: SplitOpts::new(split_threshold),
197            trace_gen_workers: env::var("TRACE_GEN_WORKERS").map_or_else(
198                |_| DEFAULT_TRACE_GEN_WORKERS,
199                |s| s.parse::<usize>().unwrap_or(DEFAULT_TRACE_GEN_WORKERS),
200            ),
201            checkpoints_channel_capacity: env::var("CHECKPOINTS_CHANNEL_CAPACITY").map_or_else(
202                |_| DEFAULT_CHECKPOINTS_CHANNEL_CAPACITY,
203                |s| s.parse::<usize>().unwrap_or(DEFAULT_CHECKPOINTS_CHANNEL_CAPACITY),
204            ),
205            records_and_traces_channel_capacity: env::var("RECORDS_AND_TRACES_CHANNEL_CAPACITY")
206                .map_or_else(
207                    |_| DEFAULT_RECORDS_AND_TRACES_CHANNEL_CAPACITY,
208                    |s| s.parse::<usize>().unwrap_or(DEFAULT_RECORDS_AND_TRACES_CHANNEL_CAPACITY),
209                ),
210        }
211    }
212}
213
214/// Options for splitting deferred events.
215#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
216pub struct SplitOpts {
217    /// The threshold for combining the memory init/finalize events in to the current shard in
218    /// terms of cycles.
219    pub combine_memory_threshold: usize,
220    /// The threshold for default events.
221    pub deferred: usize,
222    /// The threshold for keccak events.
223    pub keccak: usize,
224    /// The threshold for sha extend events.
225    pub sha_extend: usize,
226    /// The threshold for sha compress events.
227    pub sha_compress: usize,
228    /// The threshold for memory events.
229    pub memory: usize,
230}
231
232impl SplitOpts {
233    /// Create a new [`SplitOpts`] with the given threshold.
234    ///
235    /// The constants here need to be chosen very carefully to prevent OOM. Consult @jtguibas on
236    /// how to change them.
237    #[must_use]
238    pub fn new(deferred_split_threshold: usize) -> Self {
239        Self {
240            combine_memory_threshold: 1 << 17,
241            deferred: deferred_split_threshold,
242            keccak: 8 * deferred_split_threshold / 24,
243            sha_extend: 32 * deferred_split_threshold / 48,
244            sha_compress: 32 * deferred_split_threshold / 80,
245            memory: 64 * deferred_split_threshold,
246        }
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    #![allow(clippy::print_stdout)]
253
254    use super::*;
255
256    #[test]
257    fn test_opts() {
258        let opts = SP1ProverOpts::cpu(8);
259        println!("8: {:?}", opts.core_opts);
260
261        let opts = SP1ProverOpts::cpu(15);
262        println!("15: {:?}", opts.core_opts);
263
264        let opts = SP1ProverOpts::cpu(16);
265        println!("16: {:?}", opts.core_opts);
266
267        let opts = SP1ProverOpts::cpu(32);
268        println!("32: {:?}", opts.core_opts);
269
270        let opts = SP1ProverOpts::cpu(36);
271        println!("36: {:?}", opts.core_opts);
272
273        let opts = SP1ProverOpts::cpu(64);
274        println!("64: {:?}", opts.core_opts);
275
276        let opts = SP1ProverOpts::cpu(128);
277        println!("128: {:?}", opts.core_opts);
278
279        let opts = SP1ProverOpts::cpu(256);
280        println!("256: {:?}", opts.core_opts);
281
282        let opts = SP1ProverOpts::cpu(512);
283        println!("512: {:?}", opts.core_opts);
284
285        let opts = SP1ProverOpts::auto();
286        println!("auto: {:?}", opts.core_opts);
287    }
288}