runmat_runtime/
indexing.rs

1//! Matrix indexing and slicing operations
2//!
3//! Implements language-style matrix indexing and access patterns.
4
5use runmat_builtins::{Tensor, Value};
6
7/// Get a single element from a matrix (1-based indexing like language)
8pub fn matrix_get_element(matrix: &Tensor, row: usize, col: usize) -> Result<f64, String> {
9    if row == 0 || col == 0 {
10        return Err("MATLAB uses 1-based indexing".to_string());
11    }
12    matrix.get2(row - 1, col - 1) // Convert to 0-based
13}
14
15/// Set a single element in a matrix (1-based indexing like language)
16pub fn matrix_set_element(
17    matrix: &mut Tensor,
18    row: usize,
19    col: usize,
20    value: f64,
21) -> Result<(), String> {
22    if row == 0 || col == 0 {
23        return Err("The MATLAB language uses 1-based indexing".to_string());
24    }
25    matrix.set2(row - 1, col - 1, value) // Convert to 0-based
26}
27
28/// Get a row from a matrix
29pub fn matrix_get_row(matrix: &Tensor, row: usize) -> Result<Tensor, String> {
30    if row == 0 || row > matrix.rows() {
31        return Err(format!(
32            "Row index {} out of bounds for {}x{} matrix",
33            row,
34            matrix.rows(),
35            matrix.cols()
36        ));
37    }
38
39    // Column-major: row slice picks every element spaced by rows across columns
40    let mut row_data = Vec::with_capacity(matrix.cols());
41    for c in 0..matrix.cols() {
42        row_data.push(matrix.data[(row - 1) + c * matrix.rows()]);
43    }
44    Tensor::new_2d(row_data, 1, matrix.cols())
45}
46
47/// Get a column from a matrix
48pub fn matrix_get_col(matrix: &Tensor, col: usize) -> Result<Tensor, String> {
49    if col == 0 || col > matrix.cols() {
50        return Err(format!(
51            "Column index {} out of bounds for {}x{} matrix",
52            col,
53            matrix.rows(),
54            matrix.cols()
55        ));
56    }
57
58    let mut col_data = Vec::with_capacity(matrix.rows());
59    for row in 0..matrix.rows() {
60        col_data.push(matrix.data[row + (col - 1) * matrix.rows()]);
61    }
62    Tensor::new_2d(col_data, matrix.rows(), 1)
63}
64
65/// Array indexing operation (used by all interpreters/compilers)
66/// In MATLAB, indexing is 1-based and supports:
67/// - Single element: A(i) for vectors, A(i,j) for matrices
68/// - Multiple indices: A(i1, i2, ..., iN)
69pub fn perform_indexing(base: &Value, indices: &[f64]) -> Result<Value, String> {
70    match base {
71        Value::GpuTensor(h) => {
72            let provider = runmat_accelerate_api::provider().ok_or_else(|| {
73                "Cannot index value of type GpuTensor without a provider".to_string()
74            })?;
75            if indices.is_empty() {
76                return Err("At least one index is required".to_string());
77            }
78            // Support scalar indexing cases mirroring Tensor branch
79            if indices.len() == 1 {
80                let idx = indices[0] as usize;
81                let total = h.shape.iter().product();
82                if idx < 1 || idx > total {
83                    return Err(format!("Index {} out of bounds (1 to {})", idx, total));
84                }
85                let lin0 = idx - 1; // 0-based
86                let val = provider
87                    .read_scalar(h, lin0)
88                    .map_err(|e| format!("gpu index: {e}"))?;
89                return Ok(Value::Num(val));
90            } else if indices.len() == 2 {
91                let row = indices[0] as usize;
92                let col = indices[1] as usize;
93                let rows = h.shape.first().copied().unwrap_or(1);
94                let cols = h.shape.get(1).copied().unwrap_or(1);
95                if row < 1 || row > rows || col < 1 || col > cols {
96                    return Err(format!(
97                        "Index ({row}, {col}) out of bounds for {rows}x{cols} tensor"
98                    ));
99                }
100                let lin0 = (row - 1) + (col - 1) * rows;
101                let val = provider
102                    .read_scalar(h, lin0)
103                    .map_err(|e| format!("gpu index: {e}"))?;
104                return Ok(Value::Num(val));
105            }
106            Err(format!("Cannot index value of type {:?}", base))
107        }
108        Value::Tensor(matrix) => {
109            if indices.is_empty() {
110                return Err("At least one index is required".to_string());
111            }
112
113            if indices.len() == 1 {
114                // Linear indexing (1-based)
115                let idx = indices[0] as usize;
116                if idx < 1 || idx > matrix.data.len() {
117                    return Err(format!(
118                        "Index {} out of bounds (1 to {})",
119                        idx,
120                        matrix.data.len()
121                    ));
122                }
123                Ok(Value::Num(matrix.data[idx - 1])) // Convert to 0-based
124            } else if indices.len() == 2 {
125                // Row-column indexing (1-based)
126                let row = indices[0] as usize;
127                let col = indices[1] as usize;
128
129                if row < 1 || row > matrix.rows {
130                    return Err(format!(
131                        "Row index {} out of bounds (1 to {})",
132                        row, matrix.rows
133                    ));
134                }
135                if col < 1 || col > matrix.cols {
136                    return Err(format!(
137                        "Column index {} out of bounds (1 to {})",
138                        col, matrix.cols
139                    ));
140                }
141
142                let linear_idx = (row - 1) + (col - 1) * matrix.rows; // Convert to 0-based, column-major
143                Ok(Value::Num(matrix.data[linear_idx]))
144            } else {
145                Err(format!(
146                    "Matrices support 1 or 2 indices, got {}",
147                    indices.len()
148                ))
149            }
150        }
151        Value::StringArray(sa) => {
152            if indices.is_empty() {
153                return Err("At least one index is required".to_string());
154            }
155            if indices.len() == 1 {
156                let idx = indices[0] as usize;
157                let total = sa.data.len();
158                if idx < 1 || idx > total {
159                    return Err(format!("Index {idx} out of bounds (1 to {total})"));
160                }
161                Ok(Value::String(sa.data[idx - 1].clone()))
162            } else if indices.len() == 2 {
163                let row = indices[0] as usize;
164                let col = indices[1] as usize;
165                if row < 1 || row > sa.rows || col < 1 || col > sa.cols {
166                    return Err("StringArray subscript out of bounds".to_string());
167                }
168                let idx = (row - 1) + (col - 1) * sa.rows;
169                Ok(Value::String(sa.data[idx].clone()))
170            } else {
171                Err(format!(
172                    "StringArray supports 1 or 2 indices, got {}",
173                    indices.len()
174                ))
175            }
176        }
177        Value::Num(_) | Value::Int(_) => {
178            if indices.len() == 1 && indices[0] == 1.0 {
179                // Scalar indexing with A(1) returns the scalar itself
180                Ok(base.clone())
181            } else {
182                Err("MATLAB:SliceNonTensor: Slicing only supported on tensors".to_string())
183            }
184        }
185        Value::Cell(ca) => {
186            if indices.is_empty() {
187                return Err("At least one index is required".to_string());
188            }
189            if indices.len() == 1 {
190                let idx = indices[0] as usize;
191                if idx < 1 || idx > ca.data.len() {
192                    return Err(format!(
193                        "Cell index {} out of bounds (1 to {})",
194                        idx,
195                        ca.data.len()
196                    ));
197                }
198                Ok((*ca.data[idx - 1]).clone())
199            } else if indices.len() == 2 {
200                let row = indices[0] as usize;
201                let col = indices[1] as usize;
202                if row < 1 || row > ca.rows || col < 1 || col > ca.cols {
203                    return Err("Cell subscript out of bounds".to_string());
204                }
205                Ok((*ca.data[(row - 1) * ca.cols + (col - 1)]).clone())
206            } else {
207                Err(format!(
208                    "Cell arrays support 1 or 2 indices, got {}",
209                    indices.len()
210                ))
211            }
212        }
213        _ => Err(format!("Cannot index value of type {base:?}")),
214    }
215}