use crate::{
cost_and_height_per_syscall, rv64im_costs, utils::trunc_32, RetainedEventsPreset, RiscvAirId,
SyscallCode, BYTE_NUM_ROWS, RANGE_NUM_ROWS,
};
use enum_map::EnumMap;
use serde::{Deserialize, Serialize};
use std::{collections::HashSet, env};
const MAX_SHARD_SIZE: usize = 1 << 24;
pub const ELEMENT_THRESHOLD: u64 = (1 << 28) + (1 << 27);
pub const HEIGHT_THRESHOLD: u64 = 1 << 22;
pub const MINIMAL_TRACE_CHUNK_THRESHOLD: u64 =
2147483648 / std::mem::size_of::<sp1_jit::MemValue>() as u64;
pub const DEFAULT_TRACE_CHUNK_SLOTS: usize = 5;
pub const DEFAULT_MEMORY_LIMIT: u64 = 24 * 1024 * 1024 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct ShardingThreshold {
pub element_threshold: u64,
pub height_threshold: u64,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SP1CoreOpts {
pub minimal_trace_chunk_threshold: u64,
pub trace_chunk_slots: usize,
pub memory_limit: u64,
pub shard_size: usize,
pub sharding_threshold: ShardingThreshold,
pub retained_events_presets: HashSet<RetainedEventsPreset>,
pub global_dependencies_opt: bool,
}
impl Default for SP1CoreOpts {
fn default() -> Self {
let minimal_trace_chunk_threshold = env::var("MINIMAL_TRACE_CHUNK_THRESHOLD").map_or_else(
|_| MINIMAL_TRACE_CHUNK_THRESHOLD,
|s| s.parse::<u64>().unwrap_or(MINIMAL_TRACE_CHUNK_THRESHOLD),
);
let trace_chunk_slots = env::var("TRACE_CHUNK_SLOTS").map_or_else(
|_| DEFAULT_TRACE_CHUNK_SLOTS,
|s| s.parse::<usize>().unwrap_or(DEFAULT_TRACE_CHUNK_SLOTS),
);
let memory_limit = env::var("MEMORY_LIMIT").map_or_else(
|_| DEFAULT_MEMORY_LIMIT,
|s| s.parse::<u64>().unwrap_or(DEFAULT_MEMORY_LIMIT),
);
let shard_size = env::var("SHARD_SIZE")
.map_or_else(|_| MAX_SHARD_SIZE, |s| s.parse::<usize>().unwrap_or(MAX_SHARD_SIZE));
let element_threshold = env::var("ELEMENT_THRESHOLD")
.map_or_else(|_| ELEMENT_THRESHOLD, |s| s.parse::<u64>().unwrap_or(ELEMENT_THRESHOLD));
let height_threshold = env::var("HEIGHT_THRESHOLD")
.map_or_else(|_| HEIGHT_THRESHOLD, |s| s.parse::<u64>().unwrap_or(HEIGHT_THRESHOLD));
let sharding_threshold = ShardingThreshold { element_threshold, height_threshold };
let mut retained_events_presets = HashSet::new();
retained_events_presets.insert(RetainedEventsPreset::Bls12381Field);
retained_events_presets.insert(RetainedEventsPreset::Bn254Field);
retained_events_presets.insert(RetainedEventsPreset::Sha256);
retained_events_presets.insert(RetainedEventsPreset::Poseidon2);
retained_events_presets.insert(RetainedEventsPreset::U256Ops);
Self {
minimal_trace_chunk_threshold,
trace_chunk_slots,
memory_limit,
shard_size,
sharding_threshold,
retained_events_presets,
global_dependencies_opt: false,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct SplitOpts {
pub pack_trace_threshold: u64,
pub combine_memory_threshold: usize,
pub combine_page_prot_threshold: usize,
pub syscall_threshold: EnumMap<SyscallCode, usize>,
pub memory: usize,
pub page_prot: usize,
}
impl SplitOpts {
#[must_use]
pub fn new(opts: &SP1CoreOpts, program_size: usize, page_protect_allowed: bool) -> Self {
assert!(!page_protect_allowed, "page protection is turned off");
let costs = rv64im_costs();
let mut available_trace_area = opts.sharding_threshold.element_threshold;
let mut fixed_trace_area = 0;
fixed_trace_area += program_size.next_multiple_of(32) * costs[&RiscvAirId::Program];
fixed_trace_area += BYTE_NUM_ROWS as usize * costs[&RiscvAirId::Byte];
fixed_trace_area += RANGE_NUM_ROWS as usize * costs[&RiscvAirId::Range];
assert!(
available_trace_area >= fixed_trace_area as u64,
"SP1CoreOpts's element threshold is too low"
);
available_trace_area -= fixed_trace_area as u64;
let max_height = opts.sharding_threshold.height_threshold;
let syscall_threshold = EnumMap::from_fn(|syscall_code: SyscallCode| {
if syscall_code.should_send() == 0 || syscall_code.as_air_id().is_none() {
return 0;
}
let (cost_per_syscall, max_height_per_syscall) =
cost_and_height_per_syscall(syscall_code, &costs, page_protect_allowed);
let element_threshold = trunc_32(available_trace_area as usize / cost_per_syscall);
let height_threshold = trunc_32(max_height as usize / max_height_per_syscall);
element_threshold.min(height_threshold)
});
let cost_per_memory = costs[&RiscvAirId::MemoryGlobalInit] + costs[&RiscvAirId::Global];
let memory = trunc_32(
(available_trace_area as usize / cost_per_memory).min(max_height as usize) / 2,
);
let pack_trace_threshold = 2 * opts.sharding_threshold.element_threshold / 3;
let combine_memory_threshold =
trunc_32(3 * opts.sharding_threshold.element_threshold as usize / cost_per_memory / 20);
Self {
pack_trace_threshold,
combine_memory_threshold,
combine_page_prot_threshold: 0,
syscall_threshold,
memory,
page_prot: 0,
}
}
}