Skip to main content

runmat_runtime/
indexing.rs

1//! Matrix indexing and slicing operations
2//!
3//! Implements language-style tensor indexing and access patterns.
4
5use crate::builtins::common::shape::normalize_scalar_shape;
6use crate::{build_runtime_error, RuntimeError};
7use runmat_builtins::{Tensor, Value};
8
9fn indexing_error(message: impl Into<String>) -> RuntimeError {
10    build_runtime_error(message).build()
11}
12
13fn indexing_error_with_identifier(message: impl Into<String>, identifier: &str) -> RuntimeError {
14    build_runtime_error(message)
15        .with_identifier(identifier)
16        .build()
17}
18
19/// Get a single element from a tensor (1-based indexing like language)
20pub fn matrix_get_element(tensor: &Tensor, row: usize, col: usize) -> Result<f64, RuntimeError> {
21    if row == 0 || col == 0 {
22        return Err(indexing_error_with_identifier(
23            "MATLAB uses 1-based indexing",
24            "RunMat:IndexOutOfBounds",
25        ));
26    }
27    tensor
28        .get2(row - 1, col - 1)
29        .map_err(|err| indexing_error_with_identifier(err, "RunMat:IndexOutOfBounds"))
30}
31
32/// Set a single element in a tensor (1-based indexing like language)
33pub fn matrix_set_element(
34    tensor: &mut Tensor,
35    row: usize,
36    col: usize,
37    value: f64,
38) -> Result<(), RuntimeError> {
39    if row == 0 || col == 0 {
40        return Err(indexing_error_with_identifier(
41            "The MATLAB language uses 1-based indexing",
42            "RunMat:IndexOutOfBounds",
43        ));
44    }
45    tensor
46        .set2(row - 1, col - 1, value)
47        .map_err(|err| indexing_error_with_identifier(err, "RunMat:IndexOutOfBounds"))
48}
49
50/// Get a row from a tensor
51pub fn matrix_get_row(tensor: &Tensor, row: usize) -> Result<Tensor, RuntimeError> {
52    if row == 0 || row > tensor.rows() {
53        return Err(indexing_error_with_identifier(
54            format!(
55                "Row index {} out of bounds for {}x{} tensor",
56                row,
57                tensor.rows(),
58                tensor.cols()
59            ),
60            "RunMat:IndexOutOfBounds",
61        ));
62    }
63
64    // Column-major: row slice picks every element spaced by rows across columns
65    let mut row_data = Vec::with_capacity(tensor.cols());
66    for c in 0..tensor.cols() {
67        row_data.push(tensor.data[(row - 1) + c * tensor.rows()]);
68    }
69    Tensor::new_2d(row_data, 1, tensor.cols()).map_err(|err| indexing_error(err))
70}
71
72/// Get a column from a tensor
73pub fn matrix_get_col(tensor: &Tensor, col: usize) -> Result<Tensor, RuntimeError> {
74    if col == 0 || col > tensor.cols() {
75        return Err(indexing_error_with_identifier(
76            format!(
77                "Column index {} out of bounds for {}x{} tensor",
78                col,
79                tensor.rows(),
80                tensor.cols()
81            ),
82            "RunMat:IndexOutOfBounds",
83        ));
84    }
85
86    let mut col_data = Vec::with_capacity(tensor.rows());
87    for row in 0..tensor.rows() {
88        col_data.push(tensor.data[row + (col - 1) * tensor.rows()]);
89    }
90    Tensor::new_2d(col_data, tensor.rows(), 1).map_err(|err| indexing_error(err))
91}
92
93/// Array indexing operation (used by all interpreters/compilers)
94/// In MATLAB, indexing is 1-based and supports:
95/// - Single element: A(i) for vectors, A(i,j) for tensors
96/// - Multiple indices: A(i1, i2, ..., iN)
97pub async fn perform_indexing(base: &Value, indices: &[f64]) -> Result<Value, RuntimeError> {
98    match base {
99        Value::GpuTensor(h) => {
100            let provider = runmat_accelerate_api::provider().ok_or_else(|| {
101                indexing_error("Cannot index value of type GpuTensor without a provider")
102            })?;
103            if indices.is_empty() {
104                return Err(indexing_error("At least one index is required"));
105            }
106            // Support scalar indexing cases mirroring Tensor branch
107            if indices.len() == 1 {
108                let idx = indices[0] as usize;
109                let total = h.shape.iter().product();
110                if idx < 1 || idx > total {
111                    return Err(indexing_error_with_identifier(
112                        format!("Index {} out of bounds (1 to {})", idx, total),
113                        "RunMat:IndexOutOfBounds",
114                    ));
115                }
116                let lin0 = idx - 1; // 0-based
117                let val = gpu_index_scalar(provider, h, lin0).await?;
118                return Ok(Value::Num(val));
119            } else if indices.len() == 2 {
120                let row = indices[0] as usize;
121                let col = indices[1] as usize;
122                let rows = h.shape.first().copied().unwrap_or(1);
123                let cols = h.shape.get(1).copied().unwrap_or(1);
124                if row < 1 || row > rows || col < 1 || col > cols {
125                    return Err(indexing_error_with_identifier(
126                        format!("Index ({row}, {col}) out of bounds for {rows}x{cols} tensor"),
127                        "RunMat:IndexOutOfBounds",
128                    ));
129                }
130                let lin0 = (row - 1) + (col - 1) * rows;
131                let val = gpu_index_scalar(provider, h, lin0).await?;
132                return Ok(Value::Num(val));
133            }
134            Err(indexing_error_with_identifier(
135                format!("Cannot index value of type {base:?}"),
136                "RunMat:SliceNonTensor",
137            ))
138        }
139        Value::Tensor(tensor) => {
140            if indices.is_empty() {
141                return Err(indexing_error("At least one index is required"));
142            }
143
144            if indices.len() == 1 {
145                // Linear indexing (1-based)
146                let idx = indices[0] as usize;
147                if idx < 1 || idx > tensor.data.len() {
148                    return Err(indexing_error_with_identifier(
149                        format!("Index {} out of bounds (1 to {})", idx, tensor.data.len()),
150                        "RunMat:IndexOutOfBounds",
151                    ));
152                }
153                Ok(Value::Num(tensor.data[idx - 1])) // Convert to 0-based
154            } else if indices.len() == 2 {
155                // Row-column indexing (1-based)
156                let row = indices[0] as usize;
157                let col = indices[1] as usize;
158                let shape = normalize_scalar_shape(&tensor.shape);
159                let rows = shape.first().copied().unwrap_or(1);
160                let cols = shape.get(1).copied().unwrap_or(1);
161
162                if row < 1 || row > rows {
163                    return Err(indexing_error_with_identifier(
164                        format!("Row index {} out of bounds (1 to {})", row, rows),
165                        "RunMat:IndexOutOfBounds",
166                    ));
167                }
168                if col < 1 || col > cols {
169                    return Err(indexing_error_with_identifier(
170                        format!("Column index {} out of bounds (1 to {})", col, cols),
171                        "RunMat:IndexOutOfBounds",
172                    ));
173                }
174
175                let linear_idx = (row - 1) + (col - 1) * rows; // Convert to 0-based, column-major
176                Ok(Value::Num(tensor.data[linear_idx]))
177            } else {
178                Err(indexing_error(format!(
179                    "Tensors support 1 or 2 indices, got {}",
180                    indices.len()
181                )))
182            }
183        }
184        Value::ComplexTensor(tensor) => {
185            if indices.is_empty() {
186                return Err(indexing_error("At least one index is required"));
187            }
188
189            if indices.len() == 1 {
190                let idx = indices[0] as usize;
191                if idx < 1 || idx > tensor.data.len() {
192                    return Err(indexing_error_with_identifier(
193                        format!("Index {} out of bounds (1 to {})", idx, tensor.data.len()),
194                        "RunMat:IndexOutOfBounds",
195                    ));
196                }
197                let (re, im) = tensor.data[idx - 1];
198                Ok(Value::Complex(re, im))
199            } else if indices.len() == 2 {
200                let row = indices[0] as usize;
201                let col = indices[1] as usize;
202                let shape = normalize_scalar_shape(&tensor.shape);
203                let rows = shape.first().copied().unwrap_or(1);
204                let cols = shape.get(1).copied().unwrap_or(1);
205
206                if row < 1 || row > rows {
207                    return Err(indexing_error_with_identifier(
208                        format!("Row index {} out of bounds (1 to {})", row, rows),
209                        "RunMat:IndexOutOfBounds",
210                    ));
211                }
212                if col < 1 || col > cols {
213                    return Err(indexing_error_with_identifier(
214                        format!("Column index {} out of bounds (1 to {})", col, cols),
215                        "RunMat:IndexOutOfBounds",
216                    ));
217                }
218
219                let linear_idx = (row - 1) + (col - 1) * rows;
220                let (re, im) = tensor.data[linear_idx];
221                Ok(Value::Complex(re, im))
222            } else {
223                Err(indexing_error(format!(
224                    "Complex tensors support 1 or 2 indices, got {}",
225                    indices.len()
226                )))
227            }
228        }
229        Value::StringArray(sa) => {
230            if indices.is_empty() {
231                return Err(indexing_error("At least one index is required"));
232            }
233            if indices.len() == 1 {
234                let idx = indices[0] as usize;
235                let total = sa.data.len();
236                if idx < 1 || idx > total {
237                    return Err(indexing_error_with_identifier(
238                        format!("Index {idx} out of bounds (1 to {total})"),
239                        "RunMat:IndexOutOfBounds",
240                    ));
241                }
242                Ok(Value::String(sa.data[idx - 1].clone()))
243            } else if indices.len() == 2 {
244                let row = indices[0] as usize;
245                let col = indices[1] as usize;
246                let shape = normalize_scalar_shape(&sa.shape);
247                let rows = shape.first().copied().unwrap_or(1);
248                let cols = shape.get(1).copied().unwrap_or(1);
249                if row < 1 || row > rows || col < 1 || col > cols {
250                    return Err(indexing_error_with_identifier(
251                        "StringArray subscript out of bounds",
252                        "RunMat:IndexOutOfBounds",
253                    ));
254                }
255                let idx = (row - 1) + (col - 1) * rows;
256                Ok(Value::String(sa.data[idx].clone()))
257            } else {
258                Err(indexing_error(format!(
259                    "StringArray supports 1 or 2 indices, got {}",
260                    indices.len()
261                )))
262            }
263        }
264        Value::Num(_) | Value::Int(_) => {
265            if indices.len() == 1 && indices[0] == 1.0 {
266                // Scalar indexing with A(1) returns the scalar itself
267                Ok(base.clone())
268            } else {
269                Err(indexing_error_with_identifier(
270                    "Slicing only supported on tensors",
271                    "RunMat:SliceNonTensor",
272                ))
273            }
274        }
275        Value::Cell(ca) => {
276            if indices.is_empty() {
277                return Err(indexing_error("At least one index is required"));
278            }
279            if indices.len() == 1 {
280                let idx = indices[0] as usize;
281                if idx < 1 || idx > ca.data.len() {
282                    return Err(indexing_error_with_identifier(
283                        format!("Cell index {} out of bounds (1 to {})", idx, ca.data.len()),
284                        "RunMat:CellIndexOutOfBounds",
285                    ));
286                }
287                Ok((*ca.data[idx - 1]).clone())
288            } else if indices.len() == 2 {
289                let row = indices[0] as usize;
290                let col = indices[1] as usize;
291                if row < 1 || row > ca.rows || col < 1 || col > ca.cols {
292                    return Err(indexing_error_with_identifier(
293                        "Cell subscript out of bounds",
294                        "RunMat:CellSubscriptOutOfBounds",
295                    ));
296                }
297                Ok((*ca.data[(row - 1) * ca.cols + (col - 1)]).clone())
298            } else {
299                Err(indexing_error(format!(
300                    "Cell arrays support 1 or 2 indices, got {}",
301                    indices.len()
302                )))
303            }
304        }
305        _ => Err(indexing_error_with_identifier(
306            format!("Cannot index value of type {base:?}"),
307            "RunMat:SliceNonTensor",
308        )),
309    }
310}
311
312async fn gpu_index_scalar(
313    provider: &dyn runmat_accelerate_api::AccelProvider,
314    handle: &runmat_accelerate_api::GpuTensorHandle,
315    lin0: usize,
316) -> Result<f64, RuntimeError> {
317    #[cfg(target_arch = "wasm32")]
318    {
319        let host = provider
320            .download(handle)
321            .await
322            .map_err(|e| indexing_error(format!("gpu index: {e}")))?;
323        if lin0 >= host.data.len() {
324            return Err(indexing_error(format!(
325                "gpu index: index {} out of bounds (len {})",
326                lin0 + 1,
327                host.data.len()
328            )));
329        }
330        Ok(host.data[lin0])
331    }
332    #[cfg(not(target_arch = "wasm32"))]
333    {
334        provider
335            .read_scalar(handle, lin0)
336            .map_err(|e| indexing_error(format!("gpu index: {e}")))
337    }
338}