use std::collections::BTreeMap;
use std::ops::{Deref, DerefMut, Range};
use slop_algebra::Field;
use slop_alloc::{Backend, Buffer, CpuBackend, HasBackend};
use slop_tensor::{Dimensions, Tensor, TensorView};
use sp1_gpu_cudart::{DeviceBuffer, TaskScope};
use crate::jagged::JaggedMle;
use crate::{DenseData, DenseDataMut};
#[derive(Clone, Debug)]
pub struct TraceOffset {
pub dense_offset: Range<usize>,
pub poly_size: usize,
pub num_polys: usize,
}
#[derive(Clone)]
pub struct JaggedTraceMle<F: Field, B: Backend>(pub JaggedMle<TraceDenseData<F, B>, B>);
impl<F: Field, B: Backend> HasBackend for JaggedTraceMle<F, B> {
type Backend = B;
fn backend(&self) -> &B {
self.0.backend()
}
}
impl<F: Field, B: Backend> Deref for JaggedTraceMle<F, B> {
type Target = JaggedMle<TraceDenseData<F, B>, B>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<F: Field, B: Backend> DerefMut for JaggedTraceMle<F, B> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[derive(Clone, Debug)]
pub struct TraceDenseData<F: Field, B: Backend> {
pub dense: Buffer<F, B>,
pub preprocessed_offset: usize,
pub preprocessed_cols: usize,
pub preprocessed_padding: usize,
pub main_padding: usize,
pub preprocessed_table_index: BTreeMap<String, TraceOffset>,
pub main_table_index: BTreeMap<String, TraceOffset>,
}
impl<F: Field, B: Backend> TraceDenseData<F, B> {
pub fn main_virtual_tensor(&'_ self, log_stacking_height: u32) -> TensorView<'_, F, B> {
let ptr = unsafe { self.dense.as_ptr().add(self.preprocessed_offset) };
let sizes = Dimensions::try_from([
self.main_size() / (1 << log_stacking_height),
1 << log_stacking_height,
])
.unwrap();
unsafe { TensorView::from_raw_parts(ptr, sizes, self.backend().clone()) }
}
pub fn main_tensor(&self, log_stacking_height: u32) -> Tensor<F, B> {
let mut tensor = Tensor::with_sizes_in(
[self.main_size() / (1 << log_stacking_height), 1 << log_stacking_height],
self.backend().clone(),
);
let backend = self.dense.backend();
unsafe {
tensor.assume_init();
tensor
.as_mut_buffer()
.copy_from_slice(&self.dense[self.preprocessed_offset..], backend)
.unwrap();
}
tensor
}
pub fn preprocessed_virtual_tensor(&'_ self, log_stacking_height: u32) -> TensorView<'_, F, B> {
let ptr = self.dense.as_ptr();
let sizes = Dimensions::try_from([
self.preprocessed_offset / (1 << log_stacking_height),
1 << log_stacking_height,
])
.unwrap();
unsafe { TensorView::from_raw_parts(ptr, sizes, self.backend().clone()) }
}
pub fn preprocessed_tensor(&self, log_stacking_height: u32) -> Tensor<F, B> {
let mut tensor = Tensor::with_sizes_in(
[self.preprocessed_offset / (1 << log_stacking_height), 1 << log_stacking_height],
self.backend().clone(),
);
let backend = self.dense.backend();
unsafe {
tensor.assume_init();
tensor
.as_mut_buffer()
.copy_from_slice(&self.dense[..self.preprocessed_offset], backend)
.unwrap();
}
tensor
}
#[inline]
pub fn main_poly_height(&self, name: &str) -> Option<usize> {
self.main_table_index.get(name).map(|offset| offset.poly_size)
}
#[inline]
pub fn preprocessed_poly_height(&self, name: &str) -> Option<usize> {
self.preprocessed_table_index.get(name).map(|offset| offset.poly_size)
}
#[inline]
pub fn main_num_polys(&self, name: &str) -> Option<usize> {
self.main_table_index.get(name).map(|offset| offset.num_polys)
}
#[inline]
pub fn main_size(&self) -> usize {
self.dense.len() - self.preprocessed_offset
}
#[inline]
pub fn preprocessed_num_polys(&self, name: &str) -> Option<usize> {
self.preprocessed_table_index.get(name).map(|offset| offset.num_polys)
}
}
impl<F: Field, B: Backend> HasBackend for TraceDenseData<F, B> {
type Backend = B;
fn backend(&self) -> &B {
self.dense.backend()
}
}
impl<F: Field, B: Backend> JaggedTraceMle<F, B> {
pub fn new(
dense_data: TraceDenseData<F, B>,
col_index: Buffer<u32, B>,
start_indices: Buffer<u32, B>,
column_heights: Vec<u32>,
) -> Self {
JaggedTraceMle(JaggedMle::new(dense_data, col_index, start_indices, column_heights))
}
}
impl<F: Field> JaggedTraceMle<F, TaskScope> {
pub fn preprocessed_virtual_tensor(
&'_ self,
log_stacking_height: u32,
) -> TensorView<'_, F, TaskScope> {
self.dense_data.preprocessed_virtual_tensor(log_stacking_height)
}
pub fn main_virtual_tensor(&'_ self, log_stacking_height: u32) -> TensorView<'_, F, TaskScope> {
self.dense_data.main_virtual_tensor(log_stacking_height)
}
pub fn main_poly_height(&self, name: &str) -> Option<usize> {
self.dense_data.main_poly_height(name)
}
pub fn preprocessed_poly_height(&self, name: &str) -> Option<usize> {
self.dense_data.preprocessed_poly_height(name)
}
pub fn main_num_polys(&self, name: &str) -> Option<usize> {
self.dense_data.main_num_polys(name)
}
pub fn main_size(&self) -> usize {
self.dense_data.main_size()
}
pub fn preprocessed_num_polys(&self, name: &str) -> Option<usize> {
self.dense_data.preprocessed_num_polys(name)
}
}
#[repr(C)]
pub struct TraceDenseDataRaw<F> {
dense: *const F,
}
#[repr(C)]
pub struct TraceDenseDataMutRaw<F> {
dense: *mut F,
}
impl<F: Field, B: Backend> DenseData<B> for TraceDenseData<F, B> {
type DenseDataRaw = TraceDenseDataRaw<F>;
fn as_ptr(&self) -> TraceDenseDataRaw<F> {
TraceDenseDataRaw { dense: self.dense.as_ptr() }
}
}
impl<F: Field, B: Backend> DenseDataMut<B> for TraceDenseData<F, B> {
type DenseDataMutRaw = TraceDenseDataMutRaw<F>;
fn as_mut_ptr(&mut self) -> TraceDenseDataMutRaw<F> {
TraceDenseDataMutRaw { dense: self.dense.as_mut_ptr() }
}
}
impl<F: Field> JaggedTraceMle<F, CpuBackend> {
pub fn into_device(self, t: &TaskScope) -> JaggedTraceMle<F, TaskScope> {
let JaggedMle { col_index, start_indices, column_heights, dense_data } = self.0;
JaggedTraceMle::new(
dense_data.into_device_in(t),
DeviceBuffer::from_host(&col_index, t).unwrap().into_inner(),
DeviceBuffer::from_host(&start_indices, t).unwrap().into_inner(),
column_heights,
)
}
}
impl<F: Field> TraceDenseData<F, CpuBackend> {
pub fn into_device_in(self, t: &TaskScope) -> TraceDenseData<F, TaskScope> {
TraceDenseData {
dense: DeviceBuffer::from_host(&self.dense, t).unwrap().into_inner(),
preprocessed_offset: self.preprocessed_offset,
preprocessed_cols: self.preprocessed_cols,
preprocessed_table_index: self.preprocessed_table_index,
main_table_index: self.main_table_index,
preprocessed_padding: self.preprocessed_padding,
main_padding: self.main_padding,
}
}
}
impl<F: Field> JaggedTraceMle<F, TaskScope> {
pub fn into_host(self) -> JaggedTraceMle<F, CpuBackend> {
let JaggedMle { col_index, start_indices, column_heights, dense_data } = self.0;
let host_dense = dense_data.into_host();
let col_index_host = DeviceBuffer::from_raw(col_index).to_host().unwrap().into();
let start_indices_host = DeviceBuffer::from_raw(start_indices).to_host().unwrap().into();
JaggedTraceMle::new(host_dense, col_index_host, start_indices_host, column_heights)
}
}
impl<F: Field> TraceDenseData<F, TaskScope> {
pub fn into_host(self) -> TraceDenseData<F, CpuBackend> {
let host_dense = DeviceBuffer::from_raw(self.dense).to_host().unwrap().into();
TraceDenseData {
dense: host_dense,
preprocessed_offset: self.preprocessed_offset,
preprocessed_cols: self.preprocessed_cols,
preprocessed_table_index: self.preprocessed_table_index,
main_table_index: self.main_table_index,
preprocessed_padding: self.preprocessed_padding,
main_padding: self.main_padding,
}
}
}