zksync-gpu-prover 0.155.9

ZKsync GPU prover utilities
use super::*;
use bellman::PrimeField;
use core::ops::Range;
use std::io::{Read, Write};

pub struct AsyncVec<T, #[cfg(feature = "allocator")] A: Allocator = CudaAllocator> {
    #[cfg(feature = "allocator")]
    pub values: Option<Vec<T, A>>,
    #[cfg(not(feature = "allocator"))]
    pub values: Option<Vec<T>>,
    pub(crate) read_event: Event,
    pub(crate) write_event: Event,
}

use std::fmt;

macro_rules! impl_async_vec {
    (impl AsyncVec $inherent:tt) => {
        #[cfg(feature = "allocator")]
        impl<T, A: Allocator + Default> AsyncVec<T, A> $inherent

        #[cfg(not(feature = "allocator"))]
        impl<T> AsyncVec<T> $inherent
    };
}

impl_async_vec! {
    impl AsyncVec{
        pub fn allocate_new(length: usize) -> Self {
            #[cfg(feature = "allocator")]
            let mut values = Vec::with_capacity_in(length, A::default());
            #[cfg(not(feature = "allocator"))]
            let mut values = Vec::with_capacity(length);
            unsafe {
                values.set_len(length);
            }

            Self {
                values: Some(values),
                read_event: Event::new(),
                write_event: Event::new(),
            }
        }

        pub fn get_values(&self) -> GpuResult<&[T]> {
            self.write_event.sync()?;
            Ok(self.values.as_ref().expect("async_vec inner is none"))
        }

        pub fn get_values_mut(&mut self) -> GpuResult<&mut [T]> {
            self.read_event.sync()?;
            self.write_event.sync()?;
            Ok(self.values.as_mut().expect("async_vec inner is none"))
        }

        pub fn async_copy_to_device(
            &self,
            ctx: &mut GpuContext,
            other: &mut DeviceBuf<T>,
            this_range: Range<usize>,
            other_range: Range<usize>,
        ) -> GpuResult<()> {
            assert_eq!(this_range.len(), other_range.len());
            let length = std::mem::size_of::<T>() * this_range.len();
            set_device(ctx.device_id())?;

            ctx.h2d_stream.wait(self.write_event())?;
            ctx.h2d_stream.wait(other.read_event())?;
            ctx.h2d_stream.wait(other.write_event())?;

            let result = unsafe {
                bc_memcpy_async(
                    other.as_mut_ptr(other_range) as *mut c_void,
                    self.as_ptr(this_range) as *const c_void,
                    length as u64,
                    ctx.h2d_stream().inner,
                )
            };

            if result != 0 {
                return Err(GpuError::AsyncH2DErr(result));
            }

            self.read_event.record(ctx.h2d_stream())?;
            other.write_event.record(ctx.h2d_stream())?;

            Ok(())
        }

        pub fn async_copy_from_device(
            &mut self,
            ctx: &mut GpuContext,
            other: &mut DeviceBuf<T>,
            this_range: Range<usize>,
            other_range: Range<usize>,
        ) -> GpuResult<()> {
            assert_eq!(this_range.len(), other_range.len());
            let length = std::mem::size_of::<T>() * this_range.len();
            set_device(ctx.device_id())?;

            ctx.d2h_stream.wait(self.write_event())?;
            ctx.d2h_stream.wait(self.read_event())?;
            ctx.d2h_stream.wait(other.write_event())?;

            let result = unsafe {
                bc_memcpy_async(
                    self.as_mut_ptr(this_range) as *mut c_void,
                    other.as_ptr(other_range) as *const c_void,
                    length as u64,
                    ctx.d2h_stream().inner,
                )
            };

            if result != 0 {
                return Err(GpuError::AsyncH2DErr(result));
            }

            self.write_event.record(ctx.d2h_stream())?;
            other.read_event.record(ctx.d2h_stream())?;

            Ok(())
        }

        pub fn len(&self) -> usize {
            self.values.as_ref().expect("async_vec inner is none").len()
        }
        #[cfg(feature = "allocator")]
        pub fn into_inner(mut self) -> GpuResult<std::vec::Vec<T, A>> {
            self.read_event.sync()?;
            self.write_event.sync()?;

            Ok(self.values.take().expect("async_vec inner is none"))
        }

        #[cfg(not(feature = "allocator"))]
        pub fn into_inner(mut self) -> GpuResult<std::vec::Vec<T>> {
            self.read_event.sync()?;
            self.write_event.sync()?;

            Ok(self.values.take().expect("async_vec inner is none"))
        }

        pub fn read_event(&self) -> &Event {
            &self.read_event
        }

        pub fn write_event(&self) -> &Event {
            &self.write_event
        }

        pub fn as_ptr(&self, range: Range<usize>) -> *const T {
            self.values.as_ref().expect("async_vec inner is none")[range].as_ptr()
        }

        pub fn as_mut_ptr(&mut self, range: Range<usize>) -> *mut T {
            self.values.as_mut().expect("async_vec inner is none")[range].as_mut_ptr()
        }

        pub fn zeroize(&mut self){
            let unit_len = std::mem::size_of::<T>();
            let total_len = unit_len * self.len();
            let dst = self.as_mut_ptr(0..self.len()) as *mut u8;
            unsafe{std::ptr::write_bytes(dst, 0, total_len)};
        }
    }
}

#[cfg(feature = "allocator")]
impl<T: fmt::Debug, A: Allocator + Default> fmt::Debug for AsyncVec<T, A> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("AsyncVec")
            .field("Values", &self.get_values().unwrap())
            .finish()
    }
}
#[cfg(not(feature = "allocator"))]
impl<T: fmt::Debug> fmt::Debug for AsyncVec<T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("AsyncVec")
            .field("Values", &self.get_values().unwrap())
            .finish()
    }
}

#[cfg(feature = "allocator")]
impl<T, A: Allocator> From<Vec<T, A>> for AsyncVec<T, A> {
    fn from(values: Vec<T, A>) -> Self {
        Self {
            values: Some(values),
            read_event: Event::new(),
            write_event: Event::new(),
        }
    }
}
#[cfg(not(feature = "allocator"))]
impl<T> From<Vec<T>> for AsyncVec<T> {
    fn from(values: Vec<T>) -> Self {
        Self {
            values: Some(values),
            read_event: Event::new(),
            write_event: Event::new(),
        }
    }
}

#[cfg(feature = "allocator")]
impl<T, A: Allocator + Default> From<AsyncVec<T, A>> for Vec<T, A> {
    fn from(vector: AsyncVec<T, A>) -> Self {
        vector.into_inner().unwrap()
    }
}

#[cfg(not(feature = "allocator"))]
impl<T> From<AsyncVec<T>> for Vec<T> {
    fn from(vector: AsyncVec<T>) -> Self {
        vector.into_inner().unwrap()
    }
}

#[cfg(feature = "allocator")]
impl<T, A: Allocator> Drop for AsyncVec<T, A> {
    fn drop(&mut self) {
        self.read_event.sync().unwrap();
        self.write_event.sync().unwrap();
    }
}

#[cfg(not(feature = "allocator"))]
impl<T> Drop for AsyncVec<T> {
    fn drop(&mut self) {
        self.read_event.sync().unwrap();
        self.write_event.sync().unwrap();
    }
}

macro_rules! impl_async_vec_for_field {
    (impl AsyncVec $inherent:tt) => {
        #[cfg(feature = "allocator")]
        impl<F: PrimeField, A: Allocator + Default> AsyncVec<F, A> $inherent

        #[cfg(not(feature = "allocator"))]
        impl<F: PrimeField> AsyncVec<F> $inherent
    };
}

impl_async_vec_for_field! {
    impl AsyncVec{
        pub fn write<W: Write>(&self, writer: &mut W) -> GpuResult<()> {
            let length = self.len();
            let F_SIZE = F::zero().into_raw_repr().as_ref().len() * 8;

            let mut poly_bytes: Vec<u8> = Vec::with_capacity(F_SIZE * length);
            unsafe {
                poly_bytes.set_len(F_SIZE * length);
            }

            self.to_bytes(&mut poly_bytes[..])?;
            writer
                .write_all(&poly_bytes[..])
                .expect("Can't write AsyncVec");

            Ok(())
        }

        pub fn to_bytes(&self, dst: &mut [u8]) -> GpuResult<()> {
            let length = self.len();
            let F_SIZE = F::zero().into_raw_repr().as_ref().len() * 8;
            assert_eq!(length * F_SIZE, dst.len(), "Wrong destination length");
            unsafe {
                std::ptr::copy_nonoverlapping(
                    self.as_ptr(0..self.len()) as *const u8,
                    dst.as_mut_ptr(),
                    self.len() * FIELD_ELEMENT_LEN,
                )
            };

            Ok(())
        }

        pub fn read<R: Read>(&mut self, reader: &mut R) -> GpuResult<()> {
            let length = self.len();
            let F_SIZE = F::zero().into_raw_repr().as_ref().len() * 8;

            let mut res_bytes: Vec<u8> = Vec::with_capacity(F_SIZE * length);
            unsafe {
                res_bytes.set_len(F_SIZE * length);
            }

            reader
                .read_exact(&mut res_bytes)
                .expect("Can't read AsyncVec");

            self.from_bytes(&res_bytes[..])
        }

        pub fn from_bytes(&mut self, src: &[u8]) -> GpuResult<()> {
            let length = self.len();
            let F_SIZE = F::zero().into_raw_repr().as_ref().len() * 8;
            assert_eq!(length * F_SIZE, src.len(), "Wrong source length");

            unsafe {
                std::ptr::copy_nonoverlapping(
                    src.as_ptr(),
                    self.as_ptr(0..self.len()) as *mut u8,
                    self.len() * FIELD_ELEMENT_LEN,
                )
            };

            Ok(())
        }
    }
}