shivini 0.156.0

Shvini is a library implementing a GPU-accelerated zkSync prover
use super::*;
use boojum::worker::Worker;
use era_cudart::execution::{launch_host_fn, HostFn};
pub use era_cudart::memory::memory_copy_async;
use era_cudart::stream::CudaStreamWaitEventFlags;
use std::intrinsics::copy_nonoverlapping;
use std::ops::DerefMut;
use std::slice;

pub fn h2d<T>(host: &[T], device: &mut [T]) -> CudaResult<()> {
    assert!(!host.is_empty());
    assert_eq!(host.len(), device.len());
    if_not_dry_run! {
        memory_copy_async(&mut device[..], host, get_h2d_stream())
    }
}

#[allow(clippy::assertions_on_constants)]
pub fn h2d_buffered<'a, T: Send + Sync>(
    host: &'a [T],
    device: &'a mut [T],
    chunk_size: usize,
    worker: &'a Worker,
) -> CudaResult<Vec<HostFn<'a>>> {
    assert!(!host.is_empty());
    assert_eq!(host.len(), device.len());
    assert_ne!(chunk_size, 0);
    if is_dry_run()? {
        Ok(vec![])
    } else {
        const STREAMS_COUNT: usize = 2;
        assert!(STREAMS_COUNT <= NUM_AUX_STREAMS_AND_EVENTS);
        assert!(chunk_size * STREAMS_COUNT * size_of::<T>() <= AUX_H2D_BUFFER_SIZE);
        let events = &_aux_events()[0..STREAMS_COUNT];
        let streams = &_aux_streams()[0..STREAMS_COUNT];
        let buffer: &mut [T] = unsafe { std::mem::transmute(_aux_h2d_buffer().deref_mut()) };
        let main_stream = get_h2d_stream();
        let copy = |src: &[T], dst: &mut [T]| {
            worker.scope(src.len(), |scope, chunk_size| {
                for (src_chunk, dst_chunk) in src.chunks(chunk_size).zip(dst.chunks_mut(chunk_size))
                {
                    scope.spawn(|_| unsafe {
                        copy_nonoverlapping(
                            src_chunk.as_ptr(),
                            dst_chunk.as_mut_ptr(),
                            src_chunk.len(),
                        )
                    })
                }
            });
        };
        events[0].record(main_stream)?;
        for stream in streams.iter() {
            stream.wait_event(&events[0], CudaStreamWaitEventFlags::DEFAULT)?;
        }
        let mut pending_callbacks = vec![];
        for (i, (src, dst)) in host
            .chunks(chunk_size)
            .zip(device.chunks_mut(chunk_size))
            .enumerate()
        {
            let idx = i % STREAMS_COUNT;
            let stream = &streams[idx];
            let buffer_offset = idx * chunk_size;
            let buffer = &buffer[buffer_offset..buffer_offset + src.len()];
            let callback = HostFn::new(move || {
                let dst =
                    unsafe { slice::from_raw_parts_mut(buffer.as_ptr() as *mut T, buffer.len()) };
                copy(src, dst);
            });
            launch_host_fn(stream, &callback)?;
            pending_callbacks.push(callback);
            let dst = unsafe { DeviceSlice::from_mut_slice(dst) };
            memory_copy_async(dst, buffer, stream)?;
        }

        for (event, stream) in events.iter().zip(streams.iter()) {
            event.record(stream)?;
            main_stream.wait_event(event, CudaStreamWaitEventFlags::DEFAULT)?;
        }
        Ok(pending_callbacks)
    }
}

#[allow(dead_code)]
pub fn h2d_on_stream<T>(host: &[T], device: &mut [T], stream: &CudaStream) -> CudaResult<()> {
    assert!(!host.is_empty());
    assert_eq!(host.len(), device.len());
    if_not_dry_run! {
        memory_copy_async(device, host, stream)
    }
}

pub fn d2h<T>(device: &[T], host: &mut [T]) -> CudaResult<()> {
    assert!(!host.is_empty());
    assert_eq!(host.len(), device.len());
    if_not_dry_run! {
        memory_copy_async(host, device, get_d2h_stream())
    }
}

pub fn d2d<T>(src: &[T], dst: &mut [T]) -> CudaResult<()> {
    assert!(!src.is_empty());
    assert_eq!(src.len(), dst.len());
    if_not_dry_run! {
        memory_copy_async(dst, src, get_stream())
    }
}