use crate::error::{PipeError, Result};
use crate::pipe_common::ShapeManager;
use crate::traits::SizedDimension;
use ndarray::{ArrayView, ArrayViewMut, Dimension, StrideShape};
pub fn create_read_view<'a, A, D: SizedDimension + Dimension>(
data: &'a [A],
start_element: usize,
n_elements: usize,
shape_manager: &ShapeManager<D>,
) -> Result<ArrayView<'a, A, D::Larger>>
where
D::LargerSize: Into<StrideShape<D::Larger>> + Clone,
D::CurrentSize: Clone,
{
let (start_idx, end_idx) =
calculate_slice_bounds(start_element, n_elements, shape_manager.element_size());
if end_idx > data.len() {
return Err(PipeError::array_bounds_error(
start_idx,
end_idx,
data.len(),
));
}
let slice = &data[start_idx..end_idx];
let array_view = ArrayView::<A, D::Larger>::from_shape(
shape_manager.get_larger_array_size(n_elements),
slice,
)?;
Ok(array_view)
}
pub fn create_write_view<'a, A, D: SizedDimension + Dimension>(
data: &'a mut [A],
start_element: usize,
n_elements: usize,
shape_manager: &ShapeManager<D>,
) -> Result<ArrayViewMut<'a, A, D::Larger>>
where
D::LargerSize: Into<StrideShape<D::Larger>> + Clone,
D::CurrentSize: Clone,
{
let (start_idx, end_idx) =
calculate_slice_bounds(start_element, n_elements, shape_manager.element_size());
if end_idx > data.len() {
return Err(PipeError::array_bounds_error(
start_idx,
end_idx,
data.len(),
));
}
let slice = &mut data[start_idx..end_idx];
let array_view = ArrayViewMut::<A, D::Larger>::from_shape(
shape_manager.get_larger_array_size(n_elements),
slice,
)?;
Ok(array_view)
}
pub fn calculate_slice_bounds(
start_element: usize,
n_elements: usize,
element_size: usize,
) -> (usize, usize) {
let start_idx = start_element * element_size;
let end_idx = start_idx + n_elements * element_size;
(start_idx, end_idx)
}
pub fn validate_bounds(
start_element: usize,
n_elements: usize,
total_elements: usize,
context: &str,
) -> Result<()> {
if start_element + n_elements > total_elements {
return Err(PipeError::insufficient_data(
context,
n_elements,
start_element,
total_elements,
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::{PipeError, Result};
use crate::pipe_common::ShapeManager;
use ndarray::{Ix0, Ix1};
#[test]
fn test_calculate_slice_bounds() {
assert_eq!(calculate_slice_bounds(5, 10, 1), (5, 15));
assert_eq!(calculate_slice_bounds(2, 4, 3), (6, 18));
}
#[test]
fn test_create_read_view_ix0() -> Result<()> {
let data: Vec<f64> = (0..20).map(|i| i as f64).collect();
let shape_manager = ShapeManager::<Ix0>::new([]);
let view = create_read_view(
&data,
5, 10, &shape_manager,
)?;
assert_eq!(view.len(), 10);
assert_eq!(view[0], 5.0);
assert_eq!(view[9], 14.0);
Ok(())
}
#[test]
fn test_create_write_view_ix0() -> Result<()> {
let mut data: Vec<f64> = vec![0.0; 20];
let shape_manager = ShapeManager::<Ix0>::new([]);
{
let mut view = create_write_view(
&mut data,
5, 10, &shape_manager,
)?;
for (i, val) in view.iter_mut().enumerate() {
*val = (i + 100) as f64;
}
}
for i in 0..5 {
assert_eq!(data[i], 0.0); }
for i in 5..15 {
assert_eq!(data[i], (i - 5 + 100) as f64); }
for i in 15..20 {
assert_eq!(data[i], 0.0); }
Ok(())
}
#[test]
fn test_create_read_view_ix1() -> Result<()> {
let data: Vec<f64> = (0..30).map(|i| i as f64).collect();
let shape_manager = ShapeManager::<Ix1>::new([3]);
let view = create_read_view(
&data,
2, 4, &shape_manager,
)?;
assert_eq!(view.shape(), &[4, 3]); assert_eq!(view[(0, 0)], 6.0); assert_eq!(view[(0, 2)], 8.0); assert_eq!(view[(3, 2)], 17.0);
Ok(())
}
#[test]
fn test_bounds_validation() {
assert!(validate_bounds(5, 10, 20, "test").is_ok());
let result = validate_bounds(15, 10, 20, "test");
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
PipeError::InsufficientData { .. }
));
}
#[test]
fn test_bounds_error_on_read() {
let data: Vec<f64> = vec![0.0; 10];
let shape_manager = ShapeManager::<Ix0>::new([]);
let result = create_read_view(
&data,
8, 5, &shape_manager,
);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
PipeError::ArrayBoundsError { .. }
));
}
#[test]
fn test_bounds_error_on_write() {
let mut data: Vec<f64> = vec![0.0; 10];
let shape_manager = ShapeManager::<Ix0>::new([]);
let result = create_write_view(
&mut data,
7, 5, &shape_manager,
);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
PipeError::ArrayBoundsError { .. }
));
}
}