use crate::shape::coord_to_flat;
use crate::{MattenError, Tensor};
const MAX_SLICE_STR_BYTES: usize = 512;
#[derive(Debug, Clone)]
pub(crate) enum SliceSpec {
All,
Index(usize),
Range {
start: Option<usize>,
end: Option<usize>,
step: usize,
},
}
fn resolve_spec(
spec: &SliceSpec,
dim: usize,
axis: usize,
operation: &'static str,
) -> Result<(Vec<usize>, bool), MattenError> {
let err = |msg: String| MattenError::Slice {
input: None,
message: msg,
};
match spec {
SliceSpec::All => Ok(((0..dim).collect(), true)),
SliceSpec::Index(i) => {
if *i >= dim {
return Err(err(format!(
"index {i} is out of range for axis {axis} with size {dim}"
)));
}
Ok((vec![*i], false))
}
SliceSpec::Range { start, end, step } => {
let s = start.unwrap_or(0);
let e = end.unwrap_or(dim);
if s > dim || e > dim {
return Err(err(format!(
"range {s}..{e} is out of range for axis {axis} with size {dim} \
in {operation}"
)));
}
if s > e {
return Err(err(format!(
"range start {s} > end {e} for axis {axis} in {operation}"
)));
}
if *step == 0 {
return Err(err(format!(
"step must be >= 1 for axis {axis} in {operation}"
)));
}
Ok(((s..e).step_by(*step).collect(), true))
}
}
}
pub(crate) fn execute_slice(
tensor: &Tensor,
specs: &[SliceSpec],
operation: &'static str,
) -> Result<Tensor, MattenError> {
#[cfg(feature = "dynamic")]
if tensor.is_dynamic() {
return Err(MattenError::Unsupported {
operation,
message: "dynamic tensors do not support the slice builder or slice_str; use get_element(&[row, col]) for element access, or call try_numeric() first".to_string(),
});
}
let ndim = tensor.ndim();
if specs.len() != ndim {
return Err(MattenError::Slice {
input: None,
message: format!(
"slice has {} specifications but tensor has rank {ndim}",
specs.len()
),
});
}
let mut per_axis: Vec<(Vec<usize>, bool)> = Vec::with_capacity(ndim);
for (axis, (spec, &dim)) in specs.iter().zip(tensor.shape()).enumerate() {
per_axis.push(resolve_spec(spec, dim, axis, operation)?);
}
let out_shape: Vec<usize> = per_axis
.iter()
.filter(|(_, keep)| *keep)
.map(|(idxs, _)| idxs.len())
.collect();
let out_len = if out_shape.is_empty() {
1
} else {
out_shape.iter().product()
};
let mut out_data = vec![0.0f64; out_len];
let counts: Vec<usize> = per_axis.iter().map(|(v, _)| v.len()).collect();
let total: usize = counts.iter().product();
for out_flat in 0..total {
let mut rem = out_flat;
let mut src_coord = vec![0usize; ndim];
let mut out_coord_kept = Vec::with_capacity(out_shape.len());
let mut stride = total;
for (ax, (sel, keep)) in per_axis.iter().enumerate() {
stride /= counts[ax];
let local = rem / stride;
rem %= stride;
src_coord[ax] = sel[local];
if *keep {
out_coord_kept.push(local);
}
}
let src_flat = coord_to_flat(&src_coord, tensor.shape())
.expect("constructed coordinate is always valid");
let dst_flat = if out_shape.is_empty() {
0
} else {
coord_to_flat(&out_coord_kept, &out_shape).expect("kept coordinate is always valid")
};
out_data[dst_flat] = tensor.data[src_flat];
}
Ok(Tensor {
data: out_data,
shape: out_shape,
#[cfg(feature = "dynamic")]
dynamic: None,
})
}
#[doc(hidden)]
pub struct SliceSpecRepr(pub(crate) SliceSpec);
pub trait SliceConvert {
#[doc(hidden)]
fn into_repr(self) -> SliceSpecRepr;
}
pub trait IntoSliceRange: SliceConvert {}
impl IntoSliceRange for std::ops::Range<usize> {}
impl IntoSliceRange for std::ops::RangeFrom<usize> {}
impl IntoSliceRange for std::ops::RangeTo<usize> {}
impl IntoSliceRange for std::ops::RangeFull {}
impl IntoSliceRange for std::ops::RangeInclusive<usize> {}
impl SliceConvert for std::ops::Range<usize> {
fn into_repr(self) -> SliceSpecRepr {
SliceSpecRepr(SliceSpec::Range {
start: Some(self.start),
end: Some(self.end),
step: 1,
})
}
}
impl SliceConvert for std::ops::RangeFrom<usize> {
fn into_repr(self) -> SliceSpecRepr {
SliceSpecRepr(SliceSpec::Range {
start: Some(self.start),
end: None,
step: 1,
})
}
}
impl SliceConvert for std::ops::RangeTo<usize> {
fn into_repr(self) -> SliceSpecRepr {
SliceSpecRepr(SliceSpec::Range {
start: None,
end: Some(self.end),
step: 1,
})
}
}
impl SliceConvert for std::ops::RangeFull {
fn into_repr(self) -> SliceSpecRepr {
SliceSpecRepr(SliceSpec::All)
}
}
impl SliceConvert for std::ops::RangeInclusive<usize> {
fn into_repr(self) -> SliceSpecRepr {
SliceSpecRepr(SliceSpec::Range {
start: Some(*self.start()),
end: Some(self.end() + 1),
step: 1,
})
}
}
pub struct SliceBuilder<'a> {
tensor: &'a Tensor,
specs: Vec<SliceSpec>,
}
impl<'a> SliceBuilder<'a> {
pub(crate) fn new(tensor: &'a Tensor) -> Self {
Self {
tensor,
specs: Vec::with_capacity(tensor.ndim()),
}
}
pub fn all(mut self) -> Self {
self.specs.push(SliceSpec::All);
self
}
pub fn index(mut self, index: usize) -> Self {
self.specs.push(SliceSpec::Index(index));
self
}
pub fn range<R: IntoSliceRange>(mut self, range: R) -> Self {
self.specs.push(range.into_repr().0);
self
}
pub fn build(self) -> Result<Tensor, MattenError> {
execute_slice(self.tensor, &self.specs, "slice_builder")
}
}
pub(crate) fn parse_slice_str(spec: &str) -> Result<Vec<SliceSpec>, MattenError> {
let err = |msg: String| MattenError::Slice {
input: Some(spec.to_string()),
message: msg,
};
if spec.len() > MAX_SLICE_STR_BYTES {
return Err(err(format!(
"slice spec exceeds the maximum length of {MAX_SLICE_STR_BYTES} bytes"
)));
}
spec.split(',')
.map(|part| parse_axis_spec(part.trim(), spec))
.collect()
}
fn parse_axis_spec(part: &str, full: &str) -> Result<SliceSpec, MattenError> {
let err = |msg: String| MattenError::Slice {
input: Some(full.to_string()),
message: msg,
};
if let Ok(n) = part.parse::<usize>() {
return Ok(SliceSpec::Index(n));
}
if !part.contains(':') {
return Err(err(format!("unrecognised slice component {part:?}")));
}
let segments: Vec<&str> = part.splitn(3, ':').collect();
let parse_opt = |s: &str| -> Result<Option<usize>, MattenError> {
let s = s.trim();
if s.is_empty() {
Ok(None)
} else {
s.parse::<usize>()
.map(Some)
.map_err(|_| err(format!("expected integer, got {s:?}")))
}
};
let start = parse_opt(segments[0])?;
let end = parse_opt(segments[1])?;
let step = if segments.len() == 3 {
let s = segments[2].trim();
if s.is_empty() {
return Err(err(format!(
"trailing colon in {part:?} is not valid; write {part:?} without the trailing colon, or supply a step value (e.g. \"0:10:2\")"
)));
} else {
s.parse::<usize>()
.map_err(|_| err(format!("step must be a positive integer, got {s:?}")))?
}
} else {
1
};
if start.is_none() && end.is_none() && step == 1 {
return Ok(SliceSpec::All);
}
Ok(SliceSpec::Range { start, end, step })
}