use alloc::vec::Vec;
use core::marker::PhantomData;
use air::{proof::Queries, PartitionOptions, TraceInfo};
use crypto::VectorCommitment;
use tracing::info_span;
use super::{
ColMatrix, ElementHasher, EvaluationFrame, FieldElement, StarkDomain, TraceLde, TracePolyTable,
};
use crate::{RowMatrix, DEFAULT_SEGMENT_WIDTH};
#[cfg(test)]
mod tests;
pub struct DefaultTraceLde<
E: FieldElement,
H: ElementHasher<BaseField = E::BaseField>,
V: VectorCommitment<H>,
> {
main_segment_lde: RowMatrix<E::BaseField>,
main_segment_oracles: V,
aux_segment_lde: Option<RowMatrix<E>>,
aux_segment_oracles: Option<V>,
blowup: usize,
trace_info: TraceInfo,
partition_options: PartitionOptions,
_h: PhantomData<H>,
}
impl<E, H, V> DefaultTraceLde<E, H, V>
where
E: FieldElement,
H: ElementHasher<BaseField = E::BaseField>,
V: VectorCommitment<H>,
{
pub fn new(
trace_info: &TraceInfo,
main_trace: &ColMatrix<E::BaseField>,
domain: &StarkDomain<E::BaseField>,
partition_options: PartitionOptions,
) -> (Self, TracePolyTable<E>) {
let (main_segment_lde, main_segment_vector_com, main_segment_polys) =
build_trace_commitment::<E, E::BaseField, H, V>(main_trace, domain, partition_options);
let trace_poly_table = TracePolyTable::new(main_segment_polys);
let trace_lde = DefaultTraceLde {
main_segment_lde,
main_segment_oracles: main_segment_vector_com,
aux_segment_lde: None,
aux_segment_oracles: None,
blowup: domain.trace_to_lde_blowup(),
trace_info: trace_info.clone(),
partition_options,
_h: PhantomData,
};
(trace_lde, trace_poly_table)
}
#[cfg(test)]
pub fn main_segment_width(&self) -> usize {
self.main_segment_lde.num_cols()
}
#[cfg(test)]
pub fn get_main_segment(&self) -> &RowMatrix<E::BaseField> {
&self.main_segment_lde
}
#[cfg(test)]
pub fn get_main_segment_column(&self, col_idx: usize) -> Vec<E::BaseField> {
(0..self.main_segment_lde.num_rows())
.map(|row_idx| self.main_segment_lde.get(col_idx, row_idx))
.collect()
}
}
impl<E, H, V> TraceLde<E> for DefaultTraceLde<E, H, V>
where
E: FieldElement,
H: ElementHasher<BaseField = E::BaseField> + core::marker::Sync,
V: VectorCommitment<H> + core::marker::Sync,
{
type HashFn = H;
type VC = V;
fn get_main_trace_commitment(&self) -> H::Digest {
self.main_segment_oracles.commitment()
}
fn set_aux_trace(
&mut self,
aux_trace: &ColMatrix<E>,
domain: &StarkDomain<E::BaseField>,
) -> (ColMatrix<E>, H::Digest) {
let (aux_segment_lde, aux_segment_oracles, aux_segment_polys) =
build_trace_commitment::<E, E, H, Self::VC>(aux_trace, domain, self.partition_options);
assert!(
usize::from(self.aux_segment_lde.is_some()) < self.trace_info.num_aux_segments(),
"the auxiliary trace has already been added"
);
assert_eq!(
self.main_segment_lde.num_rows(),
aux_segment_lde.num_rows(),
"the number of rows in the auxiliary segment must be the same as in the main segment"
);
self.aux_segment_lde = Some(aux_segment_lde);
let commitment_string = aux_segment_oracles.commitment();
self.aux_segment_oracles = Some(aux_segment_oracles);
(aux_segment_polys, commitment_string)
}
fn read_main_trace_frame_into(
&self,
lde_step: usize,
frame: &mut EvaluationFrame<E::BaseField>,
) {
let next_lde_step = (lde_step + self.blowup()) % self.trace_len();
frame.current_mut().copy_from_slice(self.main_segment_lde.row(lde_step));
frame.next_mut().copy_from_slice(self.main_segment_lde.row(next_lde_step));
}
fn read_aux_trace_frame_into(&self, lde_step: usize, frame: &mut EvaluationFrame<E>) {
let next_lde_step = (lde_step + self.blowup()) % self.trace_len();
let segment = self.aux_segment_lde.as_ref().expect("expected aux segment to be present");
frame.current_mut().copy_from_slice(segment.row(lde_step));
frame.next_mut().copy_from_slice(segment.row(next_lde_step));
}
fn query(&self, positions: &[usize]) -> Vec<Queries> {
let mut result = vec![build_segment_queries::<E::BaseField, H, V>(
&self.main_segment_lde,
&self.main_segment_oracles,
positions,
)];
if let Some(ref segment_oracles) = self.aux_segment_oracles {
let segment_lde =
self.aux_segment_lde.as_ref().expect("expected aux segment to be present");
result.push(build_segment_queries::<E, H, V>(segment_lde, segment_oracles, positions));
}
result
}
fn trace_len(&self) -> usize {
self.main_segment_lde.num_rows()
}
fn blowup(&self) -> usize {
self.blowup
}
fn trace_info(&self) -> &TraceInfo {
&self.trace_info
}
}
fn build_trace_commitment<E, F, H, V>(
trace: &ColMatrix<F>,
domain: &StarkDomain<E::BaseField>,
partition_options: PartitionOptions,
) -> (RowMatrix<F>, V, ColMatrix<F>)
where
E: FieldElement,
F: FieldElement<BaseField = E::BaseField>,
H: ElementHasher<BaseField = E::BaseField>,
V: VectorCommitment<H>,
{
let (trace_lde, trace_polys) = {
let span = info_span!(
"extend_execution_trace",
num_cols = trace.num_cols(),
blowup = domain.trace_to_lde_blowup()
)
.entered();
let trace_polys = trace.interpolate_columns();
let trace_lde =
RowMatrix::evaluate_polys_over::<DEFAULT_SEGMENT_WIDTH>(&trace_polys, domain);
drop(span);
(trace_lde, trace_polys)
};
assert_eq!(trace_lde.num_cols(), trace.num_cols());
assert_eq!(trace_polys.num_rows(), trace.num_rows());
assert_eq!(trace_lde.num_rows(), domain.lde_domain_size());
let commitment_domain_size = trace_lde.num_rows();
let trace_vector_com = info_span!("compute_execution_trace_commitment", commitment_domain_size)
.in_scope(|| trace_lde.commit_to_rows::<H, V>(partition_options));
assert_eq!(trace_vector_com.domain_len(), commitment_domain_size);
(trace_lde, trace_vector_com, trace_polys)
}
fn build_segment_queries<E, H, V>(
segment_lde: &RowMatrix<E>,
segment_vector_com: &V,
positions: &[usize],
) -> Queries
where
E: FieldElement,
H: ElementHasher<BaseField = E::BaseField>,
V: VectorCommitment<H>,
{
let trace_states =
positions.iter().map(|&pos| segment_lde.row(pos).to_vec()).collect::<Vec<_>>();
let trace_proof = segment_vector_com
.open_many(positions)
.expect("failed to generate a batch opening proof for trace queries");
Queries::new::<H, E, V>(trace_proof.1, trace_states)
}