use triton_vm::prelude::*;
use crate::list::length::Length;
use crate::list::new::New;
use crate::list::push::Push;
use crate::prelude::*;
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
pub struct SampleIndices;
impl BasicSnippet for SampleIndices {
fn parameters(&self) -> Vec<(DataType, String)> {
vec![
(DataType::U32, "number".to_string()),
(DataType::U32, "upper_bound".to_string()),
]
}
fn return_values(&self) -> Vec<(DataType, String)> {
vec![(
DataType::List(Box::new(DataType::U32)),
"*indices".to_string(),
)]
}
fn entrypoint(&self) -> String {
"tasmlib_hashing_algebraic_hasher_sample_indices".into()
}
fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
let entrypoint = self.entrypoint();
let main_loop = format!("{entrypoint}_main_loop");
let then_reduce_and_save = format!("{entrypoint}_then_reduce_and_save");
let else_drop_tip = format!("{entrypoint}_else_drop_tip");
let new_list = library.import(Box::new(New));
let length = library.import(Box::new(Length));
let push_element = library.import(Box::new(Push::new(DataType::U32)));
let if_can_sample = triton_asm! (
dup 0 call {length} dup 3 eq push 0 eq dup 4 push -1 eq push 0 eq mul dup 0 push 0 eq swap 1 );
triton_asm! (
{entrypoint}:
call {new_list}
swap 1 push -1 add swap 1 call {main_loop}
swap 2 pop 2
return
{main_loop}:
dup 0 call {length} dup 3 eq skiz return
sponge_squeeze
dup 12 dup 12 dup 12
{&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
{&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
{&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
{&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
{&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
{&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
{&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
{&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
{&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
{&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
pop 3 recurse
{then_reduce_and_save}:
pop 1 swap 2 swap 3 split dup 2 and swap 1 pop 1
swap 1 swap 2 swap 1 dup 1 swap 1 call {push_element}
push 0
return
{else_drop_tip}:
swap 2 swap 3 pop 1 swap 1 return
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::empty_stack;
use crate::rust_shadowing_helper_functions;
use crate::test_prelude::*;
impl Procedure for SampleIndices {
fn rust_shadow(
&self,
stack: &mut Vec<BFieldElement>,
memory: &mut HashMap<BFieldElement, BFieldElement>,
_: &NonDeterminism,
_: &[BFieldElement],
sponge: &mut Option<Tip5>,
) -> Result<Vec<BFieldElement>, RustShadowError> {
let Some(sponge) = sponge.as_mut() else {
return Err(RustShadowError::SpongeUninitialized);
};
let upper_bound = stack.pop().ok_or(RustShadowError::StackUnderflow)?.value() as u32;
let number = stack.pop().ok_or(RustShadowError::StackUnderflow)?.value() as usize;
println!("sampling {number} indices between 0 and {upper_bound}");
println!("sponge before: {}", sponge.state.iter().join(","));
let indices = sponge.sample_indices(upper_bound, number);
let list_pointer =
rust_shadowing_helper_functions::dyn_malloc::dynamic_allocator(memory);
rust_shadowing_helper_functions::list::list_new(list_pointer, memory);
for index in indices.iter() {
rust_shadowing_helper_functions::list::list_push(
list_pointer,
vec![BFieldElement::new(*index as u64)],
memory,
)?;
}
println!("sponge after: {}", sponge.state.iter().join(","));
stack.push(list_pointer);
Ok(Vec::new())
}
fn pseudorandom_initial_state(
&self,
seed: [u8; 32],
bench_case: Option<BenchmarkCase>,
) -> ProcedureInitialState {
let mut rng = StdRng::from_seed(seed);
let number = if let Some(case) = bench_case {
match case {
BenchmarkCase::CommonCase => 40,
BenchmarkCase::WorstCase => 80,
}
} else {
rng.random_range(0..20)
};
let upper_bound = if let Some(case) = bench_case {
match case {
BenchmarkCase::CommonCase => 1 << 12,
BenchmarkCase::WorstCase => 1 << 23,
}
} else {
1 << rng.random_range(0..20)
};
let mut stack = empty_stack();
stack.push(BFieldElement::new(number as u64));
stack.push(BFieldElement::new(upper_bound as u64));
let public_input: Vec<BFieldElement> = vec![];
let state = Tip5 {
state: rng.random(),
};
ProcedureInitialState {
stack,
nondeterminism: NonDeterminism::default(),
public_input,
sponge: Some(state),
}
}
}
#[macro_rules_attr::apply(test)]
fn test() {
ShadowedProcedure::new(SampleIndices).test();
}
}
#[cfg(test)]
mod bench {
use super::*;
use crate::test_prelude::*;
#[macro_rules_attr::apply(test)]
fn benchmark() {
ShadowedProcedure::new(SampleIndices).bench();
}
}