use itertools::Itertools;
use triton_vm::prelude::triton_asm;
use crate::data_type::ArrayType;
use crate::data_type::DataType;
use crate::traits::basic_snippet::BasicSnippet;
pub struct HornerEvaluation {
pub num_coefficients: usize,
}
impl HornerEvaluation {
pub fn new(num_coefficients: usize) -> Self {
Self { num_coefficients }
}
}
impl BasicSnippet for HornerEvaluation {
fn inputs(&self) -> Vec<(crate::data_type::DataType, String)> {
vec![
(
DataType::Array(Box::new(ArrayType {
element_type: DataType::Xfe,
length: self.num_coefficients,
})),
"*coefficients".to_string(),
),
(DataType::Xfe, "indeterminate".to_string()),
]
}
fn outputs(&self) -> Vec<(crate::data_type::DataType, String)> {
vec![(DataType::Xfe, "value".to_string())]
}
fn entrypoint(&self) -> String {
format!(
"tasmlib_array_horner_evaluation_with_{}_coefficients",
self.num_coefficients
)
}
fn code(
&self,
_library: &mut crate::library::Library,
) -> Vec<triton_vm::prelude::LabelledInstruction> {
let entrypoint = self.entrypoint();
let update_running_evaluation = triton_asm! {
dup 5 dup 5 dup 5 xx_mul dup 6 read_mem 3 swap 10 pop 1 xx_add };
let update_running_evaluation_for_each_coefficient = (0..self.num_coefficients)
.flat_map(|_| update_running_evaluation.clone())
.collect_vec();
let jump_to_end = self.num_coefficients as isize * 3 - 1;
triton_asm! {
{entrypoint}:
swap 3
push {jump_to_end} add
swap 3
push 0 push 0 push 0
{&update_running_evaluation_for_each_coefficient}
swap 4 pop 1 swap 4 pop 1 swap 4 pop 1 pop 1 return
}
}
}
#[cfg(test)]
mod test {
use std::collections::HashMap;
use num::Zero;
use rand::prelude::*;
use triton_vm::twenty_first::prelude::*;
use super::*;
use crate::empty_stack;
use crate::rust_shadowing_helper_functions::array::array_get;
use crate::rust_shadowing_helper_functions::array::insert_as_array;
use crate::traits::function::Function;
use crate::traits::function::FunctionInitialState;
use crate::traits::function::ShadowedFunction;
use crate::traits::rust_shadow::RustShadow;
impl Function for HornerEvaluation {
fn rust_shadow(
&self,
stack: &mut Vec<triton_vm::prelude::BFieldElement>,
memory: &mut std::collections::HashMap<
triton_vm::prelude::BFieldElement,
triton_vm::prelude::BFieldElement,
>,
) {
let x = XFieldElement::new([
stack.pop().unwrap(),
stack.pop().unwrap(),
stack.pop().unwrap(),
]);
let pointer = stack.pop().unwrap();
let coefficients = (0..self.num_coefficients)
.map(|i| array_get(pointer, i, memory, 3))
.map(|bfes| XFieldElement::new(bfes.try_into().unwrap()))
.collect_vec();
let mut running_evaluation = XFieldElement::zero();
for c in coefficients.into_iter().rev() {
running_evaluation *= x;
running_evaluation += c;
}
let mut value = running_evaluation.coefficients.to_vec();
stack.push(value.pop().unwrap());
stack.push(value.pop().unwrap());
stack.push(value.pop().unwrap());
}
fn pseudorandom_initial_state(
&self,
seed: [u8; 32],
_bench_case: Option<crate::snippet_bencher::BenchmarkCase>,
) -> crate::traits::function::FunctionInitialState {
let mut rng: StdRng = SeedableRng::from_seed(seed);
let coefficients = (0..self.num_coefficients)
.map(|_| rng.gen::<XFieldElement>())
.collect_vec();
let address = BFieldElement::new(rng.next_u64() % (1 << 30));
println!("address: {}", address.value());
let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::new();
insert_as_array(address, &mut memory, coefficients);
let x: XFieldElement = rng.gen();
let mut stack = empty_stack();
stack.push(address);
stack.push(x.coefficients[2]);
stack.push(x.coefficients[1]);
stack.push(x.coefficients[0]);
FunctionInitialState { stack, memory }
}
}
#[test]
fn horner_evaluation() {
for n in [0, 1, 20, 587, 1000] {
let horner = HornerEvaluation {
num_coefficients: n,
};
ShadowedFunction::new(horner).test();
}
}
}
#[cfg(test)]
mod benches {
use super::*;
use crate::traits::function::ShadowedFunction;
use crate::traits::rust_shadow::RustShadow;
#[test]
fn horner_evaluation_100() {
ShadowedFunction::new(HornerEvaluation {
num_coefficients: 100,
})
.bench();
}
#[test]
fn horner_evaluation_587() {
ShadowedFunction::new(HornerEvaluation {
num_coefficients: 587,
})
.bench();
}
}