1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
use std::cmp;
use std::marker::PhantomData;
use std::ptr;

use crate::anyhow::Result;
use crate::runtime::Block;
use crate::runtime::BlockMeta;
use crate::runtime::BlockMetaBuilder;
use crate::runtime::MessageIo;
use crate::runtime::MessageIoBuilder;
use crate::runtime::StreamIo;
use crate::runtime::StreamIoBuilder;
use crate::runtime::SyncKernel;
use crate::runtime::WorkIo;

pub struct CopyRand<T: Send + 'static> {
    max_copy: usize,
    _type: PhantomData<T>,
}

impl<T: Send + 'static> CopyRand<T> {
    pub fn new(max_copy: usize) -> Block {
        Block::new_sync(
            BlockMetaBuilder::new("CopyRand").build(),
            StreamIoBuilder::new()
                .add_input("in", std::mem::size_of::<T>())
                .add_output("out", std::mem::size_of::<T>())
                .build(),
            MessageIoBuilder::<Self>::new().build(),
            CopyRand::<T> {
                max_copy,
                _type: PhantomData,
            },
        )
    }
}

#[async_trait]
impl<T: Send + 'static> SyncKernel for CopyRand<T> {
    fn work(
        &mut self,
        io: &mut WorkIo,
        sio: &mut StreamIo,
        _mio: &mut MessageIo<Self>,
        _meta: &mut BlockMeta,
    ) -> Result<()> {
        let i = sio.input(0).slice::<u8>();
        let o = sio.output(0).slice::<u8>();
        let item_size = std::mem::size_of::<T>();

        let mut m = cmp::min(i.len(), o.len());
        m /= item_size;

        m = cmp::min(m, self.max_copy);

        if m > 0 {
            m = rand::random::<usize>() % m + 1;

            unsafe {
                ptr::copy_nonoverlapping(i.as_ptr(), o.as_mut_ptr(), m * item_size);
            }

            sio.input(0).consume(m);
            sio.output(0).produce(m);
        }

        if sio.input(0).finished() && m * item_size == i.len() {
            io.finished = true;
        }

        Ok(())
    }
}

pub struct CopyRandBuilder<T: Send + 'static> {
    max_copy: usize,
    _type: PhantomData<T>,
}

impl<T: Send + 'static> CopyRandBuilder<T> {
    pub fn new() -> Self {
        CopyRandBuilder::<T> {
            max_copy: usize::MAX,
            _type: PhantomData,
        }
    }

    #[must_use]
    pub fn max_copy(mut self, max_copy: usize) -> Self {
        self.max_copy = max_copy;
        self
    }

    pub fn build(self) -> Block {
        CopyRand::<T>::new(self.max_copy)
    }
}

impl<T: Send + 'static> Default for CopyRandBuilder<T> {
    fn default() -> Self {
        Self::new()
    }
}