use crate::array_helpers;
use crate::buffer::Buffer;
use crate::error::Result;
use crate::pipe_common::{MetadataManager, PipeState, ShapeManager};
use crate::traits::{SizedDimension, Writable};
use bytemuck::Zeroable;
use ndarray::{ArrayViewMut, Dimension, StrideShape};
use std::cell::RefCell;
use std::io::Write;
use std::marker::PhantomData;
pub struct WriteOnlyPipeBuffer<'data, A: Copy + Zeroable, D: SizedDimension + Dimension, M: Clone> {
data: RefCell<&'data mut [A]>,
nelements: usize,
shape_manager: ShapeManager<D>,
write_ptr: RefCell<usize>,
metadata_manager: MetadataManager<M>,
}
impl<'data, A: Copy + Zeroable, D: SizedDimension + Dimension, M: Clone>
WriteOnlyPipeBuffer<'data, A, D, M>
{
pub fn new<Sh: Into<StrideShape<D>>>(
data: &'data mut [A],
shape_input: Sh,
) -> Result<WriteOnlyPipeBuffer<'data, A, D, M>> {
let shape_manager = ShapeManager::new(shape_input);
let nelements = data.len() / shape_manager.element_size();
Ok(WriteOnlyPipeBuffer {
data: RefCell::new(data),
nelements,
shape_manager,
write_ptr: RefCell::new(0),
metadata_manager: MetadataManager::new(),
})
}
pub fn get_metadata(&self) -> Option<M> {
self.metadata_manager.get()
}
pub fn set_metadata(&self, m: M) {
self.metadata_manager.set(m);
}
}
impl<'data, A: Copy + Zeroable, D: SizedDimension + Dimension, M: Clone> Writable<A, D, M>
for WriteOnlyPipeBuffer<'data, A, D, M>
{
fn write<R>(
&self,
n_to_write: usize,
f: impl FnOnce(ArrayViewMut<A, D::Larger>, PipeState) -> R,
) -> Result<R>
where
D::LargerSize: Into<StrideShape<D::Larger>> + Clone,
D::CurrentSize: Clone,
{
let mut write_ptr = self.write_ptr.borrow_mut();
array_helpers::validate_bounds(*write_ptr, n_to_write, self.nelements, "Write-only pipe")?;
let mut data_ref = self.data.borrow_mut();
let data = array_helpers::create_write_view(
&mut data_ref,
*write_ptr,
n_to_write,
&self.shape_manager,
)?;
let pipe_state = PipeState {
write_ptr: *write_ptr,
read_ptr: 0, };
let result = f(data, pipe_state);
*write_ptr += n_to_write;
Ok(result)
}
fn set_metadata(&self, metadata: &M) {
self.set_metadata(metadata.clone());
}
}
pub struct WriteOnlyPipeStream<
W: Write,
A: Copy + Zeroable,
D: SizedDimension + Dimension,
M: Clone,
> {
writer: RefCell<W>,
buffer: RefCell<Buffer>,
shape_manager: ShapeManager<D>,
metadata_manager: MetadataManager<M>,
_phantom: PhantomData<A>,
}
impl<W: Write, A: Copy + Zeroable, D: SizedDimension + Dimension, M: Clone>
WriteOnlyPipeStream<W, A, D, M>
{
pub fn new<Sh: Into<StrideShape<D>>>(
writer: W,
min_buffer_size: usize,
shape_input: Sh,
) -> Result<WriteOnlyPipeStream<W, A, D, M>> {
let shape_manager = ShapeManager::new(shape_input);
let min_bytes = min_buffer_size * std::mem::size_of::<A>();
let buffer = Buffer::with_capacity(min_bytes);
Ok(WriteOnlyPipeStream {
writer: RefCell::new(writer),
buffer: RefCell::new(buffer),
shape_manager,
metadata_manager: MetadataManager::new(),
_phantom: PhantomData,
})
}
pub fn get_metadata(&self) -> Option<M> {
self.metadata_manager.get()
}
pub fn set_metadata(&self, m: M) {
self.metadata_manager.set(m);
}
pub fn into_writer(self) -> W {
self.writer.into_inner()
}
}
impl<W: Write, A: Copy + Zeroable, D: SizedDimension + Dimension, M: Clone> Writable<A, D, M>
for WriteOnlyPipeStream<W, A, D, M>
{
fn write<R>(
&self,
n_to_write: usize,
f: impl FnOnce(ArrayViewMut<A, D::Larger>, PipeState) -> R,
) -> Result<R>
where
D::LargerSize: Into<StrideShape<D::Larger>> + Clone,
D::CurrentSize: Clone,
{
let mut buffer = self.buffer.borrow_mut();
let required_elements = self.shape_manager.total_scalars(n_to_write);
let required_bytes = required_elements * std::mem::size_of::<A>();
buffer.resize_to_fit(required_bytes);
let slice: &mut [A] = buffer.view_mut()?;
let data = array_helpers::create_write_view(
slice,
0, n_to_write,
&self.shape_manager,
)?;
let pipe_state = PipeState {
write_ptr: 0, read_ptr: 0,
};
let result = f(data, pipe_state);
let bytes_to_write = &buffer.as_bytes()[..required_bytes];
let mut writer = self.writer.borrow_mut();
writer.write_all(bytes_to_write)?;
writer.flush()?;
Ok(result)
}
fn set_metadata(&self, metadata: &M) {
self.set_metadata(metadata.clone());
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::{PipeError, Result};
use ndarray::Ix0;
#[test]
fn test_writeonly_pipe_basic() -> Result<()> {
let mut data: Vec<f64> = vec![0.0; 100];
{
let pipe = WriteOnlyPipeBuffer::<f64, Ix0, ()>::new(&mut data, [])?;
pipe.write(25, |mut chunk, _state| {
for (i, value) in chunk.iter_mut().enumerate() {
*value = i as f64;
}
})?;
pipe.write(25, |mut chunk, _state| {
for (i, value) in chunk.iter_mut().enumerate() {
*value = (25 + i) as f64;
}
})?;
}
for (i, &value) in data[..25].iter().enumerate() {
assert_eq!(value, i as f64);
}
for (i, &value) in data[25..50].iter().enumerate() {
assert_eq!(value, (25 + i) as f64);
}
Ok(())
}
#[test]
fn test_writeonly_pipe_end_of_space() -> Result<()> {
let mut data: Vec<f64> = vec![0.0; 10];
let pipe = WriteOnlyPipeBuffer::<f64, Ix0, ()>::new(&mut data, [])?;
pipe.write(10, |mut chunk, _state| {
for (i, value) in chunk.iter_mut().enumerate() {
*value = i as f64;
}
})?;
let result = pipe.write(5, |_chunk, _state| {});
assert!(result.is_err());
let error = result.unwrap_err();
assert!(error.is_insufficient_data());
Ok(())
}
#[test]
fn test_writeonly_pipe_metadata() -> Result<()> {
let mut data: Vec<f64> = vec![0.0; 50];
let pipe = WriteOnlyPipeBuffer::<f64, Ix0, String>::new(&mut data, [])?;
assert_eq!(pipe.get_metadata(), None);
pipe.set_metadata("test_metadata".to_string());
assert_eq!(pipe.get_metadata(), Some("test_metadata".to_string()));
pipe.write(20, |mut chunk, _state| {
for (i, value) in chunk.iter_mut().enumerate() {
*value = i as f64;
}
})?;
assert_eq!(pipe.get_metadata(), Some("test_metadata".to_string()));
Ok(())
}
#[test]
fn test_writeonly_pipe_write_from_source() -> Result<()> {
let mut data: Vec<f64> = vec![0.0; 50];
{
let pipe = WriteOnlyPipeBuffer::<f64, Ix0, ()>::new(&mut data, [])?;
let source: Vec<f64> = (0..20).map(|i| i as f64).collect();
pipe.write(20, |mut chunk, _state| {
for (i, value) in chunk.iter_mut().enumerate() {
*value = source[i];
}
})?;
}
for (i, &value) in data[..20].iter().enumerate() {
assert_eq!(value, i as f64);
}
Ok(())
}
#[test]
fn test_writeonly_pipe_partial_writes() -> Result<()> {
let mut data: Vec<f64> = vec![0.0; 30];
let pipe = WriteOnlyPipeBuffer::<f64, Ix0, ()>::new(&mut data, [])?;
for chunk_idx in 0..3 {
pipe.write(10, |mut chunk, state| {
assert_eq!(state.write_ptr, chunk_idx * 10);
for (i, value) in chunk.iter_mut().enumerate() {
*value = (chunk_idx * 10 + i) as f64;
}
})?;
}
Ok(())
}
#[test]
fn test_writeonly_pipe_stream_basic() -> Result<()> {
let mut output = Vec::new();
{
let pipe = WriteOnlyPipeStream::<_, f64, Ix0, ()>::new(&mut output, 10, [])?;
pipe.write(5, |mut chunk, _state| {
for (i, value) in chunk.iter_mut().enumerate() {
*value = i as f64;
}
})?;
}
assert_eq!(output.len(), 5 * 8);
let mut buffer = Buffer::new(output.len());
buffer.as_bytes_mut().copy_from_slice(&output);
let f64_slice: &[f64] = buffer.view()?;
for (i, &value) in f64_slice.iter().enumerate() {
assert_eq!(value, i as f64);
}
Ok(())
}
#[test]
fn test_writeonly_pipe_stream_buffer_growth() -> Result<()> {
let mut output = Vec::new();
{
let pipe = WriteOnlyPipeStream::<_, f64, Ix0, ()>::new(&mut output, 5, [])?;
pipe.write(3, |mut chunk, _state| {
for (i, value) in chunk.iter_mut().enumerate() {
*value = i as f64;
}
})?;
pipe.write(10, |mut chunk, _state| {
for (i, value) in chunk.iter_mut().enumerate() {
*value = (100 + i) as f64;
}
})?;
}
assert_eq!(output.len(), (3 + 10) * 8);
Ok(())
}
#[test]
fn test_writeonly_pipe_stream_metadata() -> Result<()> {
let mut output = Vec::new();
let pipe = WriteOnlyPipeStream::<_, f64, Ix0, String>::new(&mut output, 10, [])?;
assert_eq!(pipe.get_metadata(), None);
pipe.set_metadata("test metadata".to_string());
assert_eq!(pipe.get_metadata(), Some("test metadata".to_string()));
pipe.write(5, |mut chunk, _state| {
for (i, value) in chunk.iter_mut().enumerate() {
*value = i as f64;
}
})?;
assert_eq!(pipe.get_metadata(), Some("test metadata".to_string()));
Ok(())
}
#[test]
fn test_writeonly_pipe_stream_write_error_handling() -> Result<()> {
struct FailingWriter {
should_fail: bool,
}
impl Write for FailingWriter {
fn write(&mut self, _buf: &[u8]) -> std::io::Result<usize> {
if self.should_fail {
Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Mock write failure",
))
} else {
Ok(0)
}
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
let failing_writer = FailingWriter { should_fail: true };
let pipe = WriteOnlyPipeStream::<_, f64, Ix0, ()>::new(failing_writer, 10, [])?;
let result = pipe.write(5, |mut chunk, _state| {
for (i, value) in chunk.iter_mut().enumerate() {
*value = i as f64;
}
});
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), PipeError::IoError(_)));
Ok(())
}
#[test]
fn test_writeonly_pipe_stream_with_cursor() -> Result<()> {
use std::io::Cursor;
let buffer = Vec::new();
let cursor = Cursor::new(buffer);
let pipe = WriteOnlyPipeStream::<_, u32, Ix0, ()>::new(cursor, 8, [])?;
pipe.write(4, |mut chunk, state| {
assert_eq!(state.write_ptr, 0);
assert_eq!(state.read_ptr, 0);
for (i, value) in chunk.iter_mut().enumerate() {
*value = (i * i) as u32; }
})?;
let final_cursor = pipe.writer.into_inner();
let final_buffer = final_cursor.into_inner();
assert_eq!(final_buffer.len(), 4 * 4);
let mut buffer = Buffer::new(final_buffer.len());
buffer.as_bytes_mut().copy_from_slice(&final_buffer);
let u32_slice: &[u32] = buffer.view()?;
assert_eq!(u32_slice, &[0, 1, 4, 9]);
Ok(())
}
#[test]
fn test_writeonly_pipe_stream_into_writer() -> Result<()> {
use std::io::Cursor;
let buffer = Vec::new();
let cursor = Cursor::new(buffer);
let pipe = WriteOnlyPipeStream::<_, u8, Ix0, ()>::new(cursor, 5, [])?;
pipe.write(4, |mut chunk, _state| {
for (i, value) in chunk.iter_mut().enumerate() {
*value = (i + 65) as u8; }
})?;
let final_cursor = pipe.into_writer();
let final_data = final_cursor.into_inner();
assert_eq!(final_data, &[65, 66, 67, 68]);
Ok(())
}
#[test]
fn test_write_bounds_error() -> Result<()> {
use crate::error::PipeError;
let mut data = vec![0.0; 5];
let pipe = WriteOnlyPipeBuffer::<f64, Ix0, ()>::new(&mut data, [])?;
match pipe.write(10, |_chunk, _state| {}) {
Err(error) if error.is_insufficient_data() => {
}
other => panic!("Expected bounds error, got: {:?}", other),
}
match pipe.write(8, |_chunk, _state| {}) {
Err(PipeError::InsufficientData {
context,
requested,
position,
available,
}) => {
assert_eq!(context, "Write-only pipe");
assert_eq!(requested, 8);
assert_eq!(position, 0);
assert_eq!(available, 5);
}
other => panic!("Expected InsufficientData error, got: {:?}", other),
}
Ok(())
}
}