use crate::indexing::plan::{build_index_plan, IndexPlan};
use crate::indexing::selectors::{build_slice_selectors, SliceSelector};
use runmat_builtins::{ComplexTensor, StringArray, Tensor, Value};
use runmat_runtime::RuntimeError;
fn map_slice_shape_error(err: impl std::fmt::Display) -> RuntimeError {
crate::interpreter::errors::mex(
"ShapeMismatch",
&format!("shape mismatch for slice result: {err}"),
)
}
fn map_slice_acceleration_error(err: impl std::fmt::Display) -> RuntimeError {
crate::interpreter::errors::mex("AccelerationOperationFailed", &format!("slice: {err}"))
}
pub async fn read_tensor_slice_1d(
tensor: &Tensor,
colon_mask: u32,
end_mask: u32,
numeric: &[Value],
) -> Result<Value, RuntimeError> {
read_tensor_slice_nd(tensor, 1, colon_mask, end_mask, numeric).await
}
pub fn try_tensor_slice_2d_fast_path(
tensor: &Tensor,
dims: usize,
selectors: &[SliceSelector],
) -> Result<Option<Value>, RuntimeError> {
if dims != 2 {
return Ok(None);
}
let rows = tensor.shape.first().copied().unwrap_or(1);
let cols = tensor.shape.get(1).copied().unwrap_or(1);
match (&selectors[0], &selectors[1]) {
(SliceSelector::Colon, SliceSelector::Scalar(j)) => {
let j0 = *j - 1;
if j0 >= cols {
return Err(crate::interpreter::errors::mex(
"IndexOutOfBounds",
"Index out of bounds",
));
}
let start = j0 * rows;
let out = tensor.data[start..start + rows].to_vec();
if out.len() == 1 {
Ok(Some(Value::Num(out[0])))
} else {
let tens = Tensor::new(out, vec![rows, 1]).map_err(map_slice_shape_error)?;
Ok(Some(Value::Tensor(tens)))
}
}
(SliceSelector::Scalar(i), SliceSelector::Colon) => {
let i0 = *i - 1;
if i0 >= rows {
return Err(crate::interpreter::errors::mex(
"IndexOutOfBounds",
"Index out of bounds",
));
}
let mut out: Vec<f64> = Vec::with_capacity(cols);
for c in 0..cols {
out.push(tensor.data[i0 + c * rows]);
}
if out.len() == 1 {
Ok(Some(Value::Num(out[0])))
} else {
let tens = Tensor::new(out, vec![1, cols]).map_err(map_slice_shape_error)?;
Ok(Some(Value::Tensor(tens)))
}
}
(SliceSelector::Colon, SliceSelector::Indices(js)) => {
if js.is_empty() {
let tens = Tensor::new(Vec::new(), vec![rows, 0]).map_err(map_slice_shape_error)?;
Ok(Some(Value::Tensor(tens)))
} else {
let mut out: Vec<f64> = Vec::with_capacity(rows * js.len());
for &j in js {
let j0 = j - 1;
if j0 >= cols {
return Err(crate::interpreter::errors::mex(
"IndexOutOfBounds",
"Index out of bounds",
));
}
let start = j0 * rows;
out.extend_from_slice(&tensor.data[start..start + rows]);
}
let tens = Tensor::new(out, vec![rows, js.len()]).map_err(map_slice_shape_error)?;
Ok(Some(Value::Tensor(tens)))
}
}
(SliceSelector::Indices(is), SliceSelector::Colon) => {
if is.is_empty() {
let tens = Tensor::new(Vec::new(), vec![0, cols]).map_err(map_slice_shape_error)?;
Ok(Some(Value::Tensor(tens)))
} else {
let mut out: Vec<f64> = Vec::with_capacity(is.len() * cols);
for c in 0..cols {
for &i in is {
let i0 = i - 1;
if i0 >= rows {
return Err(crate::interpreter::errors::mex(
"IndexOutOfBounds",
"Index out of bounds",
));
}
out.push(tensor.data[i0 + c * rows]);
}
}
let tens = Tensor::new(out, vec![is.len(), cols]).map_err(map_slice_shape_error)?;
Ok(Some(Value::Tensor(tens)))
}
}
_ => Ok(None),
}
}
pub async fn read_tensor_slice_nd(
tensor: &Tensor,
dims: usize,
colon_mask: u32,
end_mask: u32,
numeric: &[Value],
) -> Result<Value, RuntimeError> {
let selectors =
build_slice_selectors(dims, colon_mask, end_mask, numeric, &tensor.shape).await?;
if let Some(value) = try_tensor_slice_2d_fast_path(tensor, dims, &selectors)? {
return Ok(value);
}
let plan = build_index_plan(&selectors, dims, &tensor.shape)?;
if plan.indices.is_empty() {
let out_tensor =
Tensor::new(Vec::new(), plan.output_shape).map_err(map_slice_shape_error)?;
return Ok(Value::Tensor(out_tensor));
}
let mut out_data: Vec<f64> = Vec::with_capacity(plan.indices.len());
for &lin in &plan.indices {
out_data.push(tensor.data[lin as usize]);
}
if out_data.len() == 1 {
Ok(Value::Num(out_data[0]))
} else {
let out_tensor = Tensor::new(out_data, plan.output_shape).map_err(map_slice_shape_error)?;
Ok(Value::Tensor(out_tensor))
}
}
pub fn read_tensor_slice_from_plan(
tensor: &Tensor,
plan: &IndexPlan,
) -> Result<Value, RuntimeError> {
if plan.indices.is_empty() {
let out_tensor =
Tensor::new(Vec::new(), plan.output_shape.clone()).map_err(map_slice_shape_error)?;
return Ok(Value::Tensor(out_tensor));
}
let mut out_data: Vec<f64> = Vec::with_capacity(plan.indices.len());
for &lin in &plan.indices {
out_data.push(tensor.data[lin as usize]);
}
if out_data.len() == 1 {
Ok(Value::Num(out_data[0]))
} else {
let out_tensor =
Tensor::new(out_data, plan.output_shape.clone()).map_err(map_slice_shape_error)?;
Ok(Value::Tensor(out_tensor))
}
}
pub async fn read_complex_slice(
tensor: &ComplexTensor,
dims: usize,
colon_mask: u32,
end_mask: u32,
numeric: &[Value],
) -> Result<Value, RuntimeError> {
let selectors =
build_slice_selectors(dims, colon_mask, end_mask, numeric, &tensor.shape).await?;
let plan = build_index_plan(&selectors, dims, &tensor.shape)?;
read_complex_slice_from_plan(tensor, &plan)
}
pub fn read_complex_slice_from_plan(
tensor: &ComplexTensor,
plan: &IndexPlan,
) -> Result<Value, RuntimeError> {
if plan.indices.is_empty() {
let empty = ComplexTensor::new(Vec::new(), plan.output_shape.clone())
.map_err(map_slice_shape_error)?;
return Ok(Value::ComplexTensor(empty));
}
if plan.indices.len() == 1 {
let lin = plan.indices[0] as usize;
let (re, im) = tensor.data.get(lin).copied().ok_or_else(|| {
crate::interpreter::errors::mex(
"IndexOutOfBounds",
"Slice error: complex index out of bounds",
)
})?;
return Ok(Value::Complex(re, im));
}
let mut out = Vec::with_capacity(plan.indices.len());
for &lin in &plan.indices {
let idx = lin as usize;
let value = tensor.data.get(idx).copied().ok_or_else(|| {
crate::interpreter::errors::mex(
"IndexOutOfBounds",
"Slice error: complex index out of bounds",
)
})?;
out.push(value);
}
let out_ct =
ComplexTensor::new(out, plan.output_shape.clone()).map_err(map_slice_shape_error)?;
Ok(Value::ComplexTensor(out_ct))
}
pub async fn read_gpu_slice(
handle: &runmat_accelerate_api::GpuTensorHandle,
dims: usize,
colon_mask: u32,
end_mask: u32,
numeric: &[Value],
) -> Result<Value, RuntimeError> {
let base_shape = handle.shape.clone();
let selectors = build_slice_selectors(dims, colon_mask, end_mask, numeric, &base_shape).await?;
let plan = build_index_plan(&selectors, dims, &base_shape)?;
read_gpu_slice_from_plan(handle, &plan)
}
pub fn read_gpu_slice_from_plan(
handle: &runmat_accelerate_api::GpuTensorHandle,
plan: &IndexPlan,
) -> Result<Value, RuntimeError> {
let provider = runmat_accelerate_api::provider().ok_or_else(|| {
crate::interpreter::errors::mex(
"AccelerationProviderUnavailable",
"No acceleration provider registered",
)
})?;
if plan.indices.is_empty() {
let zeros = provider
.zeros(&plan.output_shape)
.map_err(map_slice_acceleration_error)?;
Ok(Value::GpuTensor(zeros))
} else {
let result = provider
.gather_linear(handle, &plan.indices, &plan.output_shape)
.map_err(map_slice_acceleration_error)?;
Ok(Value::GpuTensor(result))
}
}
pub async fn read_string_slice(
sa: &StringArray,
dims: usize,
colon_mask: u32,
end_mask: u32,
numeric: &[Value],
) -> Result<Value, RuntimeError> {
let selectors = build_slice_selectors(dims, colon_mask, end_mask, numeric, &sa.shape).await?;
let plan = build_index_plan(&selectors, dims, &sa.shape)?;
gather_string_slice(sa, &plan)
}
pub fn gather_string_slice(sa: &StringArray, plan: &IndexPlan) -> Result<Value, RuntimeError> {
if plan.indices.is_empty() {
let empty = StringArray::new(Vec::new(), plan.output_shape.clone())
.map_err(map_slice_shape_error)?;
return Ok(Value::StringArray(empty));
}
if plan.indices.len() == 1 {
let lin = plan.indices[0] as usize;
let value = sa.data.get(lin).cloned().ok_or_else(|| {
crate::interpreter::errors::mex(
"IndexOutOfBounds",
"Slice error: string index out of bounds",
)
})?;
return Ok(Value::String(value));
}
let mut out = Vec::with_capacity(plan.indices.len());
for &lin in &plan.indices {
let idx = lin as usize;
let value = sa.data.get(idx).cloned().ok_or_else(|| {
crate::interpreter::errors::mex(
"IndexOutOfBounds",
"Slice error: string index out of bounds",
)
})?;
out.push(value);
}
let out_sa = StringArray::new(out, plan.output_shape.clone()).map_err(map_slice_shape_error)?;
Ok(Value::StringArray(out_sa))
}
#[cfg(test)]
mod tests {
use super::{
gather_string_slice, map_slice_acceleration_error, read_complex_slice_from_plan,
read_string_slice, read_tensor_slice_from_plan,
};
use crate::indexing::plan::IndexPlan;
use futures::executor::block_on;
use runmat_builtins::{ComplexTensor, StringArray, Tensor, Value};
#[test]
fn string_slice_linear_tensor_indices_preserve_selector_shape() {
let sa = StringArray::new(
vec![
"a".to_string(),
"b".to_string(),
"c".to_string(),
"d".to_string(),
],
vec![2, 2],
)
.expect("string array");
let selector =
Value::Tensor(Tensor::new(vec![1.0, 3.0], vec![1, 2]).expect("selector tensor"));
let result = block_on(read_string_slice(&sa, 1, 0, 0, &[selector])).expect("slice");
match result {
Value::StringArray(out) => {
assert_eq!(out.shape, vec![1, 2]);
assert_eq!(out.data, vec!["a".to_string(), "c".to_string()]);
}
other => panic!("expected string array result, got {other:?}"),
}
}
#[test]
fn string_slice_colon_then_scalar_selects_column() {
let sa = StringArray::new(
vec![
"a".to_string(),
"b".to_string(),
"c".to_string(),
"d".to_string(),
],
vec![2, 2],
)
.expect("string array");
let result =
block_on(read_string_slice(&sa, 2, 0b01, 0, &[Value::Num(2.0)])).expect("slice");
match result {
Value::StringArray(out) => {
assert_eq!(out.shape, vec![2, 1]);
assert_eq!(out.data, vec!["c".to_string(), "d".to_string()]);
}
other => panic!("expected string array result, got {other:?}"),
}
}
#[test]
fn tensor_slice_plan_shape_mismatch_reports_identifier() {
let tensor = Tensor::new(vec![10.0, 20.0], vec![1, 2]).expect("tensor");
let plan = IndexPlan::new(vec![0, 1], vec![1, 1], vec![2], 1, vec![1, 2]);
let err = read_tensor_slice_from_plan(&tensor, &plan)
.expect_err("shape-mismatch plan should fail");
assert_eq!(err.identifier(), Some("RunMat:ShapeMismatch"));
}
#[test]
fn string_slice_plan_shape_mismatch_reports_identifier() {
let sa = StringArray::new(
vec![
"a".to_string(),
"b".to_string(),
"c".to_string(),
"d".to_string(),
],
vec![2, 2],
)
.expect("string array");
let plan = IndexPlan::new(vec![0, 1], vec![1, 1], vec![2], 1, vec![2, 2]);
let err = gather_string_slice(&sa, &plan).expect_err("shape-mismatch plan should fail");
assert_eq!(err.identifier(), Some("RunMat:ShapeMismatch"));
}
#[test]
fn complex_slice_plan_shape_mismatch_reports_identifier() {
let ct = ComplexTensor::new(vec![(1.0, 0.0), (2.0, 0.0)], vec![1, 2]).expect("complex");
let plan = IndexPlan::new(vec![0, 1], vec![1, 1], vec![2], 1, vec![1, 2]);
let err =
read_complex_slice_from_plan(&ct, &plan).expect_err("shape-mismatch plan should fail");
assert_eq!(err.identifier(), Some("RunMat:ShapeMismatch"));
}
#[test]
fn slice_acceleration_error_mapping_reports_identifier() {
let err = map_slice_acceleration_error("provider failed");
assert_eq!(err.identifier(), Some("RunMat:AccelerationOperationFailed"));
}
}