use alloc::vec::Vec;
use air::PartitionOptions;
use crypto::{hashers::Blake3_256, ElementHasher, MerkleTree};
use math::{
fields::f128::BaseElement, get_power_series, get_power_series_with_offset, polynom,
FieldElement, StarkField,
};
use crate::{
tests::{build_fib_trace, MockAir},
DefaultTraceLde, StarkDomain, Trace, TraceLde,
};
type Blake3 = Blake3_256<BaseElement>;
#[test]
fn extend_trace_table() {
let trace_length = 8;
let air = MockAir::with_trace_length(trace_length);
let trace = build_fib_trace(trace_length * 2);
let domain = StarkDomain::new(&air);
let partition_option = PartitionOptions::default();
let (trace_lde, trace_polys) = DefaultTraceLde::<BaseElement, Blake3, MerkleTree<Blake3>>::new(
trace.info(),
trace.main_segment(),
&domain,
partition_option,
);
assert_eq!(2, trace_lde.main_segment_width());
assert_eq!(64, trace_lde.trace_len());
let trace_root = BaseElement::get_root_of_unity(trace_length.ilog2());
let trace_domain = get_power_series(trace_root, trace_length);
assert_eq!(2, trace_polys.num_main_trace_polys());
assert_eq!(
vec![1u32, 2, 5, 13, 34, 89, 233, 610]
.into_iter()
.map(BaseElement::from)
.collect::<Vec<BaseElement>>(),
polynom::eval_many(trace_polys.get_main_trace_poly(0), &trace_domain)
);
assert_eq!(
vec![1u32, 3, 8, 21, 55, 144, 377, 987]
.into_iter()
.map(BaseElement::from)
.collect::<Vec<BaseElement>>(),
polynom::eval_many(trace_polys.get_main_trace_poly(1), &trace_domain)
);
let lde_domain = build_lde_domain(domain.lde_domain_size());
assert_eq!(
trace_polys.get_main_trace_poly(0),
polynom::interpolate(&lde_domain, &trace_lde.get_main_segment_column(0), true)
);
assert_eq!(
trace_polys.get_main_trace_poly(1),
polynom::interpolate(&lde_domain, &trace_lde.get_main_segment_column(1), true)
);
}
#[test]
fn commit_trace_table() {
let trace_length = 8;
let air = MockAir::with_trace_length(trace_length);
let trace = build_fib_trace(trace_length * 2);
let domain = StarkDomain::new(&air);
let partition_option = PartitionOptions::default();
let (trace_lde, _) = DefaultTraceLde::<BaseElement, Blake3, MerkleTree<Blake3>>::new(
trace.info(),
trace.main_segment(),
&domain,
partition_option,
);
let mut hashed_states = Vec::new();
let mut trace_state = vec![BaseElement::ZERO; trace_lde.main_segment_width()];
#[allow(clippy::needless_range_loop)]
for i in 0..trace_lde.trace_len() {
for j in 0..trace_lde.main_segment_width() {
trace_state[j] = trace_lde.get_main_segment().get(j, i);
}
let buf = Blake3::hash_elements(&trace_state);
hashed_states.push(buf);
}
let expected_tree = MerkleTree::<Blake3>::new(hashed_states).unwrap();
assert_eq!(*expected_tree.root(), trace_lde.get_main_trace_commitment())
}
fn build_lde_domain<B: StarkField>(domain_size: usize) -> Vec<B> {
let g = B::get_root_of_unity(domain_size.ilog2());
get_power_series_with_offset(g, B::GENERATOR, domain_size)
}