runmat_runtime/
indexing.rs1use runmat_builtins::{Tensor, Value};
6
7pub fn matrix_get_element(matrix: &Tensor, row: usize, col: usize) -> Result<f64, String> {
9 if row == 0 || col == 0 {
10 return Err("MATLAB uses 1-based indexing".to_string());
11 }
12 matrix.get2(row - 1, col - 1) }
14
15pub fn matrix_set_element(
17 matrix: &mut Tensor,
18 row: usize,
19 col: usize,
20 value: f64,
21) -> Result<(), String> {
22 if row == 0 || col == 0 {
23 return Err("The MATLAB language uses 1-based indexing".to_string());
24 }
25 matrix.set2(row - 1, col - 1, value) }
27
28pub fn matrix_get_row(matrix: &Tensor, row: usize) -> Result<Tensor, String> {
30 if row == 0 || row > matrix.rows() {
31 return Err(format!(
32 "Row index {} out of bounds for {}x{} matrix",
33 row,
34 matrix.rows(),
35 matrix.cols()
36 ));
37 }
38
39 let mut row_data = Vec::with_capacity(matrix.cols());
41 for c in 0..matrix.cols() {
42 row_data.push(matrix.data[(row - 1) + c * matrix.rows()]);
43 }
44 Tensor::new_2d(row_data, 1, matrix.cols())
45}
46
47pub fn matrix_get_col(matrix: &Tensor, col: usize) -> Result<Tensor, String> {
49 if col == 0 || col > matrix.cols() {
50 return Err(format!(
51 "Column index {} out of bounds for {}x{} matrix",
52 col,
53 matrix.rows(),
54 matrix.cols()
55 ));
56 }
57
58 let mut col_data = Vec::with_capacity(matrix.rows());
59 for row in 0..matrix.rows() {
60 col_data.push(matrix.data[row + (col - 1) * matrix.rows()]);
61 }
62 Tensor::new_2d(col_data, matrix.rows(), 1)
63}
64
65pub fn perform_indexing(base: &Value, indices: &[f64]) -> Result<Value, String> {
70 match base {
71 Value::GpuTensor(h) => {
72 let provider = runmat_accelerate_api::provider().ok_or_else(|| {
73 "Cannot index value of type GpuTensor without a provider".to_string()
74 })?;
75 if indices.is_empty() {
76 return Err("At least one index is required".to_string());
77 }
78 if indices.len() == 1 {
80 let idx = indices[0] as usize;
81 let total = h.shape.iter().product();
82 if idx < 1 || idx > total {
83 return Err(format!("Index {} out of bounds (1 to {})", idx, total));
84 }
85 let lin0 = idx - 1; let val = provider
87 .read_scalar(h, lin0)
88 .map_err(|e| format!("gpu index: {e}"))?;
89 return Ok(Value::Num(val));
90 } else if indices.len() == 2 {
91 let row = indices[0] as usize;
92 let col = indices[1] as usize;
93 let rows = h.shape.first().copied().unwrap_or(1);
94 let cols = h.shape.get(1).copied().unwrap_or(1);
95 if row < 1 || row > rows || col < 1 || col > cols {
96 return Err(format!(
97 "Index ({row}, {col}) out of bounds for {rows}x{cols} tensor"
98 ));
99 }
100 let lin0 = (row - 1) + (col - 1) * rows;
101 let val = provider
102 .read_scalar(h, lin0)
103 .map_err(|e| format!("gpu index: {e}"))?;
104 return Ok(Value::Num(val));
105 }
106 Err(format!("Cannot index value of type {:?}", base))
107 }
108 Value::Tensor(matrix) => {
109 if indices.is_empty() {
110 return Err("At least one index is required".to_string());
111 }
112
113 if indices.len() == 1 {
114 let idx = indices[0] as usize;
116 if idx < 1 || idx > matrix.data.len() {
117 return Err(format!(
118 "Index {} out of bounds (1 to {})",
119 idx,
120 matrix.data.len()
121 ));
122 }
123 Ok(Value::Num(matrix.data[idx - 1])) } else if indices.len() == 2 {
125 let row = indices[0] as usize;
127 let col = indices[1] as usize;
128
129 if row < 1 || row > matrix.rows {
130 return Err(format!(
131 "Row index {} out of bounds (1 to {})",
132 row, matrix.rows
133 ));
134 }
135 if col < 1 || col > matrix.cols {
136 return Err(format!(
137 "Column index {} out of bounds (1 to {})",
138 col, matrix.cols
139 ));
140 }
141
142 let linear_idx = (row - 1) + (col - 1) * matrix.rows; Ok(Value::Num(matrix.data[linear_idx]))
144 } else {
145 Err(format!(
146 "Matrices support 1 or 2 indices, got {}",
147 indices.len()
148 ))
149 }
150 }
151 Value::StringArray(sa) => {
152 if indices.is_empty() {
153 return Err("At least one index is required".to_string());
154 }
155 if indices.len() == 1 {
156 let idx = indices[0] as usize;
157 let total = sa.data.len();
158 if idx < 1 || idx > total {
159 return Err(format!("Index {idx} out of bounds (1 to {total})"));
160 }
161 Ok(Value::String(sa.data[idx - 1].clone()))
162 } else if indices.len() == 2 {
163 let row = indices[0] as usize;
164 let col = indices[1] as usize;
165 if row < 1 || row > sa.rows || col < 1 || col > sa.cols {
166 return Err("StringArray subscript out of bounds".to_string());
167 }
168 let idx = (row - 1) + (col - 1) * sa.rows;
169 Ok(Value::String(sa.data[idx].clone()))
170 } else {
171 Err(format!(
172 "StringArray supports 1 or 2 indices, got {}",
173 indices.len()
174 ))
175 }
176 }
177 Value::Num(_) | Value::Int(_) => {
178 if indices.len() == 1 && indices[0] == 1.0 {
179 Ok(base.clone())
181 } else {
182 Err("MATLAB:SliceNonTensor: Slicing only supported on tensors".to_string())
183 }
184 }
185 Value::Cell(ca) => {
186 if indices.is_empty() {
187 return Err("At least one index is required".to_string());
188 }
189 if indices.len() == 1 {
190 let idx = indices[0] as usize;
191 if idx < 1 || idx > ca.data.len() {
192 return Err(format!(
193 "Cell index {} out of bounds (1 to {})",
194 idx,
195 ca.data.len()
196 ));
197 }
198 Ok((*ca.data[idx - 1]).clone())
199 } else if indices.len() == 2 {
200 let row = indices[0] as usize;
201 let col = indices[1] as usize;
202 if row < 1 || row > ca.rows || col < 1 || col > ca.cols {
203 return Err("Cell subscript out of bounds".to_string());
204 }
205 Ok((*ca.data[(row - 1) * ca.cols + (col - 1)]).clone())
206 } else {
207 Err(format!(
208 "Cell arrays support 1 or 2 indices, got {}",
209 indices.len()
210 ))
211 }
212 }
213 _ => Err(format!("Cannot index value of type {base:?}")),
214 }
215}