1use crate::builtins::common::shape::normalize_scalar_shape;
6use crate::{build_runtime_error, RuntimeError};
7use runmat_builtins::{Tensor, Value};
8
9fn indexing_error(message: impl Into<String>) -> RuntimeError {
10 build_runtime_error(message).build()
11}
12
13fn indexing_error_with_identifier(message: impl Into<String>, identifier: &str) -> RuntimeError {
14 build_runtime_error(message)
15 .with_identifier(identifier)
16 .build()
17}
18
19pub fn matrix_get_element(tensor: &Tensor, row: usize, col: usize) -> Result<f64, RuntimeError> {
21 if row == 0 || col == 0 {
22 return Err(indexing_error_with_identifier(
23 "MATLAB uses 1-based indexing",
24 "RunMat:IndexOutOfBounds",
25 ));
26 }
27 tensor
28 .get2(row - 1, col - 1)
29 .map_err(|err| indexing_error_with_identifier(err, "RunMat:IndexOutOfBounds"))
30}
31
32pub fn matrix_set_element(
34 tensor: &mut Tensor,
35 row: usize,
36 col: usize,
37 value: f64,
38) -> Result<(), RuntimeError> {
39 if row == 0 || col == 0 {
40 return Err(indexing_error_with_identifier(
41 "The MATLAB language uses 1-based indexing",
42 "RunMat:IndexOutOfBounds",
43 ));
44 }
45 tensor
46 .set2(row - 1, col - 1, value)
47 .map_err(|err| indexing_error_with_identifier(err, "RunMat:IndexOutOfBounds"))
48}
49
50pub fn matrix_get_row(tensor: &Tensor, row: usize) -> Result<Tensor, RuntimeError> {
52 if row == 0 || row > tensor.rows() {
53 return Err(indexing_error_with_identifier(
54 format!(
55 "Row index {} out of bounds for {}x{} tensor",
56 row,
57 tensor.rows(),
58 tensor.cols()
59 ),
60 "RunMat:IndexOutOfBounds",
61 ));
62 }
63
64 let mut row_data = Vec::with_capacity(tensor.cols());
66 for c in 0..tensor.cols() {
67 row_data.push(tensor.data[(row - 1) + c * tensor.rows()]);
68 }
69 Tensor::new_2d(row_data, 1, tensor.cols()).map_err(|err| indexing_error(err))
70}
71
72pub fn matrix_get_col(tensor: &Tensor, col: usize) -> Result<Tensor, RuntimeError> {
74 if col == 0 || col > tensor.cols() {
75 return Err(indexing_error_with_identifier(
76 format!(
77 "Column index {} out of bounds for {}x{} tensor",
78 col,
79 tensor.rows(),
80 tensor.cols()
81 ),
82 "RunMat:IndexOutOfBounds",
83 ));
84 }
85
86 let mut col_data = Vec::with_capacity(tensor.rows());
87 for row in 0..tensor.rows() {
88 col_data.push(tensor.data[row + (col - 1) * tensor.rows()]);
89 }
90 Tensor::new_2d(col_data, tensor.rows(), 1).map_err(|err| indexing_error(err))
91}
92
93pub async fn perform_indexing(base: &Value, indices: &[f64]) -> Result<Value, RuntimeError> {
98 match base {
99 Value::GpuTensor(h) => {
100 let provider = runmat_accelerate_api::provider().ok_or_else(|| {
101 indexing_error("Cannot index value of type GpuTensor without a provider")
102 })?;
103 if indices.is_empty() {
104 return Err(indexing_error("At least one index is required"));
105 }
106 if indices.len() == 1 {
108 let idx = indices[0] as usize;
109 let total = h.shape.iter().product();
110 if idx < 1 || idx > total {
111 return Err(indexing_error_with_identifier(
112 format!("Index {} out of bounds (1 to {})", idx, total),
113 "RunMat:IndexOutOfBounds",
114 ));
115 }
116 let lin0 = idx - 1; let val = gpu_index_scalar(provider, h, lin0).await?;
118 return Ok(Value::Num(val));
119 } else if indices.len() == 2 {
120 let row = indices[0] as usize;
121 let col = indices[1] as usize;
122 let rows = h.shape.first().copied().unwrap_or(1);
123 let cols = h.shape.get(1).copied().unwrap_or(1);
124 if row < 1 || row > rows || col < 1 || col > cols {
125 return Err(indexing_error_with_identifier(
126 format!("Index ({row}, {col}) out of bounds for {rows}x{cols} tensor"),
127 "RunMat:IndexOutOfBounds",
128 ));
129 }
130 let lin0 = (row - 1) + (col - 1) * rows;
131 let val = gpu_index_scalar(provider, h, lin0).await?;
132 return Ok(Value::Num(val));
133 }
134 Err(indexing_error_with_identifier(
135 format!("Cannot index value of type {base:?}"),
136 "RunMat:SliceNonTensor",
137 ))
138 }
139 Value::Tensor(tensor) => {
140 if indices.is_empty() {
141 return Err(indexing_error("At least one index is required"));
142 }
143
144 if indices.len() == 1 {
145 let idx = indices[0] as usize;
147 if idx < 1 || idx > tensor.data.len() {
148 return Err(indexing_error_with_identifier(
149 format!("Index {} out of bounds (1 to {})", idx, tensor.data.len()),
150 "RunMat:IndexOutOfBounds",
151 ));
152 }
153 Ok(Value::Num(tensor.data[idx - 1])) } else if indices.len() == 2 {
155 let row = indices[0] as usize;
157 let col = indices[1] as usize;
158 let shape = normalize_scalar_shape(&tensor.shape);
159 let rows = shape.first().copied().unwrap_or(1);
160 let cols = shape.get(1).copied().unwrap_or(1);
161
162 if row < 1 || row > rows {
163 return Err(indexing_error_with_identifier(
164 format!("Row index {} out of bounds (1 to {})", row, rows),
165 "RunMat:IndexOutOfBounds",
166 ));
167 }
168 if col < 1 || col > cols {
169 return Err(indexing_error_with_identifier(
170 format!("Column index {} out of bounds (1 to {})", col, cols),
171 "RunMat:IndexOutOfBounds",
172 ));
173 }
174
175 let linear_idx = (row - 1) + (col - 1) * rows; Ok(Value::Num(tensor.data[linear_idx]))
177 } else {
178 Err(indexing_error(format!(
179 "Tensors support 1 or 2 indices, got {}",
180 indices.len()
181 )))
182 }
183 }
184 Value::ComplexTensor(tensor) => {
185 if indices.is_empty() {
186 return Err(indexing_error("At least one index is required"));
187 }
188
189 if indices.len() == 1 {
190 let idx = indices[0] as usize;
191 if idx < 1 || idx > tensor.data.len() {
192 return Err(indexing_error_with_identifier(
193 format!("Index {} out of bounds (1 to {})", idx, tensor.data.len()),
194 "RunMat:IndexOutOfBounds",
195 ));
196 }
197 let (re, im) = tensor.data[idx - 1];
198 Ok(Value::Complex(re, im))
199 } else if indices.len() == 2 {
200 let row = indices[0] as usize;
201 let col = indices[1] as usize;
202 let shape = normalize_scalar_shape(&tensor.shape);
203 let rows = shape.first().copied().unwrap_or(1);
204 let cols = shape.get(1).copied().unwrap_or(1);
205
206 if row < 1 || row > rows {
207 return Err(indexing_error_with_identifier(
208 format!("Row index {} out of bounds (1 to {})", row, rows),
209 "RunMat:IndexOutOfBounds",
210 ));
211 }
212 if col < 1 || col > cols {
213 return Err(indexing_error_with_identifier(
214 format!("Column index {} out of bounds (1 to {})", col, cols),
215 "RunMat:IndexOutOfBounds",
216 ));
217 }
218
219 let linear_idx = (row - 1) + (col - 1) * rows;
220 let (re, im) = tensor.data[linear_idx];
221 Ok(Value::Complex(re, im))
222 } else {
223 Err(indexing_error(format!(
224 "Complex tensors support 1 or 2 indices, got {}",
225 indices.len()
226 )))
227 }
228 }
229 Value::StringArray(sa) => {
230 if indices.is_empty() {
231 return Err(indexing_error("At least one index is required"));
232 }
233 if indices.len() == 1 {
234 let idx = indices[0] as usize;
235 let total = sa.data.len();
236 if idx < 1 || idx > total {
237 return Err(indexing_error_with_identifier(
238 format!("Index {idx} out of bounds (1 to {total})"),
239 "RunMat:IndexOutOfBounds",
240 ));
241 }
242 Ok(Value::String(sa.data[idx - 1].clone()))
243 } else if indices.len() == 2 {
244 let row = indices[0] as usize;
245 let col = indices[1] as usize;
246 let shape = normalize_scalar_shape(&sa.shape);
247 let rows = shape.first().copied().unwrap_or(1);
248 let cols = shape.get(1).copied().unwrap_or(1);
249 if row < 1 || row > rows || col < 1 || col > cols {
250 return Err(indexing_error_with_identifier(
251 "StringArray subscript out of bounds",
252 "RunMat:IndexOutOfBounds",
253 ));
254 }
255 let idx = (row - 1) + (col - 1) * rows;
256 Ok(Value::String(sa.data[idx].clone()))
257 } else {
258 Err(indexing_error(format!(
259 "StringArray supports 1 or 2 indices, got {}",
260 indices.len()
261 )))
262 }
263 }
264 Value::Num(_) | Value::Int(_) => {
265 if indices.len() == 1 && indices[0] == 1.0 {
266 Ok(base.clone())
268 } else {
269 Err(indexing_error_with_identifier(
270 "Slicing only supported on tensors",
271 "RunMat:SliceNonTensor",
272 ))
273 }
274 }
275 Value::Cell(ca) => {
276 if indices.is_empty() {
277 return Err(indexing_error("At least one index is required"));
278 }
279 if indices.len() == 1 {
280 let idx = indices[0] as usize;
281 if idx < 1 || idx > ca.data.len() {
282 return Err(indexing_error_with_identifier(
283 format!("Cell index {} out of bounds (1 to {})", idx, ca.data.len()),
284 "RunMat:CellIndexOutOfBounds",
285 ));
286 }
287 Ok((*ca.data[idx - 1]).clone())
288 } else if indices.len() == 2 {
289 let row = indices[0] as usize;
290 let col = indices[1] as usize;
291 if row < 1 || row > ca.rows || col < 1 || col > ca.cols {
292 return Err(indexing_error_with_identifier(
293 "Cell subscript out of bounds",
294 "RunMat:CellSubscriptOutOfBounds",
295 ));
296 }
297 Ok((*ca.data[(row - 1) * ca.cols + (col - 1)]).clone())
298 } else {
299 Err(indexing_error(format!(
300 "Cell arrays support 1 or 2 indices, got {}",
301 indices.len()
302 )))
303 }
304 }
305 _ => Err(indexing_error_with_identifier(
306 format!("Cannot index value of type {base:?}"),
307 "RunMat:SliceNonTensor",
308 )),
309 }
310}
311
312async fn gpu_index_scalar(
313 provider: &dyn runmat_accelerate_api::AccelProvider,
314 handle: &runmat_accelerate_api::GpuTensorHandle,
315 lin0: usize,
316) -> Result<f64, RuntimeError> {
317 #[cfg(target_arch = "wasm32")]
318 {
319 let host = provider
320 .download(handle)
321 .await
322 .map_err(|e| indexing_error(format!("gpu index: {e}")))?;
323 if lin0 >= host.data.len() {
324 return Err(indexing_error(format!(
325 "gpu index: index {} out of bounds (len {})",
326 lin0 + 1,
327 host.data.len()
328 )));
329 }
330 Ok(host.data[lin0])
331 }
332 #[cfg(not(target_arch = "wasm32"))]
333 {
334 provider
335 .read_scalar(handle, lin0)
336 .map_err(|e| indexing_error(format!("gpu index: {e}")))
337 }
338}