1use crate::bytecode::{ArgSpec, UserFunction};
2use crate::compiler::CompileError;
3use runmat_builtins::{Type, Value};
4use runmat_hir::{remapping, HirProgram, VarId};
5use runmat_runtime::RuntimeError;
6use std::collections::HashMap;
7use std::future::Future;
8
9pub struct PreparedUserCall {
10 pub func: UserFunction,
11 pub var_map: HashMap<VarId, VarId>,
12 pub func_program: HirProgram,
13 pub func_vars: Vec<Value>,
14}
15
16pub fn lookup_user_function(
17 name: &str,
18 functions: &HashMap<String, UserFunction>,
19) -> Result<UserFunction, RuntimeError> {
20 functions.get(name).cloned().ok_or_else(|| {
21 crate::interpreter::errors::mex("UndefinedFunction", &format!("Undefined function: {name}"))
22 })
23}
24
25pub fn validate_user_function_arity(
26 name: &str,
27 func: &UserFunction,
28 arg_count: usize,
29) -> Result<(), RuntimeError> {
30 if !func.has_varargin {
31 if arg_count < func.params.len() {
32 return Err(crate::interpreter::errors::mex(
33 "NotEnoughInputs",
34 &format!(
35 "Function '{name}' expects {} inputs, got {arg_count}",
36 func.params.len()
37 ),
38 ));
39 }
40 if arg_count > func.params.len() {
41 return Err(crate::interpreter::errors::mex(
42 "TooManyInputs",
43 &format!(
44 "Function '{name}' expects {} inputs, got {arg_count}",
45 func.params.len()
46 ),
47 ));
48 }
49 } else {
50 let min_args = func.params.len().saturating_sub(1);
51 if arg_count < min_args {
52 return Err(crate::interpreter::errors::mex(
53 "NotEnoughInputs",
54 &format!("Function '{name}' expects at least {min_args} inputs, got {arg_count}"),
55 ));
56 }
57 }
58 Ok(())
59}
60
61pub fn prepare_user_call(
62 func: UserFunction,
63 args: &[Value],
64 vars: &[Value],
65) -> Result<PreparedUserCall, CompileError> {
66 let var_map =
67 remapping::create_complete_function_var_map(&func.params, &func.outputs, &func.body);
68 let local_var_count = var_map.len();
69 let remapped_body = remapping::remap_function_body(&func.body, &var_map);
70 let func_vars_count = local_var_count.max(func.params.len());
71 let mut func_vars = vec![Value::Num(0.0); func_vars_count];
72
73 if func.has_varargin {
74 let fixed = func.params.len().saturating_sub(1);
75 for i in 0..fixed {
76 if i < args.len() && i < func_vars.len() {
77 func_vars[i] = args[i].clone();
78 }
79 }
80 let mut rest: Vec<Value> = if args.len() > fixed {
81 args[fixed..].to_vec()
82 } else {
83 Vec::new()
84 };
85 let cell = runmat_builtins::CellArray::new(
86 std::mem::take(&mut rest),
87 1,
88 if args.len() > fixed {
89 args.len() - fixed
90 } else {
91 0
92 },
93 )
94 .map_err(|e| CompileError::new(format!("varargin: {e}")))?;
95 if fixed < func_vars.len() {
96 func_vars[fixed] = Value::Cell(cell);
97 }
98 } else {
99 for (i, _param_id) in func.params.iter().enumerate() {
100 if i < args.len() && i < func_vars.len() {
101 func_vars[i] = args[i].clone();
102 }
103 }
104 }
105
106 for (original_var_id, local_var_id) in &var_map {
107 let local_index = local_var_id.0;
108 let global_index = original_var_id.0;
109 if local_index < func_vars.len() && global_index < vars.len() {
110 let is_parameter = func
111 .params
112 .iter()
113 .any(|param_id| param_id == original_var_id);
114 if !is_parameter {
115 func_vars[local_index] = vars[global_index].clone();
116 }
117 }
118 }
119
120 if func.has_varargout {
121 if let Some(varargout_oid) = func.outputs.last() {
122 if let Some(local_id) = var_map.get(varargout_oid) {
123 if local_id.0 < func_vars.len() {
124 let empty = runmat_builtins::CellArray::new(vec![], 1, 0)
125 .map_err(|e| CompileError::new(format!("varargout init: {e}")))?;
126 func_vars[local_id.0] = Value::Cell(empty);
127 }
128 }
129 }
130 }
131
132 let mut func_var_types = func.var_types.clone();
133 if func_var_types.len() < local_var_count {
134 func_var_types.resize(local_var_count, Type::Unknown);
135 }
136 let func_program = HirProgram {
137 body: remapped_body,
138 var_types: func_var_types,
139 };
140
141 Ok(PreparedUserCall {
142 func,
143 var_map,
144 func_program,
145 func_vars,
146 })
147}
148
149pub fn first_output_value(
150 func: &UserFunction,
151 var_map: &HashMap<VarId, VarId>,
152 func_result_vars: &[Value],
153) -> Value {
154 if func.outputs.is_empty() {
155 return Value::Num(0.0);
156 }
157 if func.has_varargout {
158 let total_named = func.outputs.len().saturating_sub(1);
159 if total_named > 0 {
160 if let Some(oid) = func.outputs.first() {
161 if let Some(local_id) = var_map.get(oid) {
162 if let Some(value) = func_result_vars.get(local_id.0) {
163 return value.clone();
164 }
165 }
166 }
167 }
168 if let Some(varargout_oid) = func.outputs.last() {
169 if let Some(local_id) = var_map.get(varargout_oid) {
170 if let Some(Value::Cell(ca)) = func_result_vars.get(local_id.0) {
171 if let Some(first) = ca.data.first() {
172 return (**first).clone();
173 }
174 }
175 }
176 }
177 return Value::Num(0.0);
178 }
179 let Some(output_id) = func.outputs.first() else {
180 return Value::Num(0.0);
181 };
182 let Some(local_id) = var_map.get(output_id) else {
183 return Value::Num(0.0);
184 };
185 func_result_vars
186 .get(local_id.0)
187 .cloned()
188 .unwrap_or(Value::Num(0.0))
189}
190
191pub fn collect_multi_outputs(
192 name: &str,
193 func: &UserFunction,
194 var_map: &HashMap<VarId, VarId>,
195 func_result_vars: &[Value],
196 out_count: usize,
197) -> Result<Vec<Value>, RuntimeError> {
198 let mut outputs = Vec::with_capacity(out_count);
199 if func.has_varargout {
200 let total_named = func.outputs.len().saturating_sub(1);
201 let mut pushed = 0usize;
202 for i in 0..total_named.min(out_count) {
203 if let Some(oid) = func.outputs.get(i) {
204 if let Some(local_id) = var_map.get(oid) {
205 let idx = local_id.0;
206 let v = func_result_vars
207 .get(idx)
208 .cloned()
209 .unwrap_or(Value::Num(0.0));
210 outputs.push(v);
211 pushed += 1;
212 }
213 }
214 }
215 if pushed < out_count {
216 if let Some(varargout_oid) = func.outputs.last() {
217 if let Some(local_id) = var_map.get(varargout_oid) {
218 if let Some(Value::Cell(ca)) = func_result_vars.get(local_id.0) {
219 let available = ca.data.len();
220 let need = out_count - pushed;
221 if need > available {
222 return Err(crate::interpreter::errors::mex(
223 "VarargoutMismatch",
224 &format!(
225 "Function '{name}' returned {available} varargout values, {need} requested"
226 ),
227 ));
228 }
229 for vi in 0..need {
230 outputs.push((*ca.data[vi]).clone());
231 }
232 }
233 }
234 }
235 }
236 } else {
237 let defined = func.outputs.len();
238 if out_count > defined {
239 return Err(crate::interpreter::errors::mex(
240 "TooManyOutputs",
241 &format!("Function '{name}' defines {defined} outputs, {out_count} requested"),
242 ));
243 }
244 for i in 0..out_count {
245 let v = func
246 .outputs
247 .get(i)
248 .and_then(|oid| var_map.get(oid))
249 .map(|lid| lid.0)
250 .and_then(|idx| func_result_vars.get(idx))
251 .cloned()
252 .unwrap_or(Value::Num(0.0));
253 outputs.push(v);
254 }
255 }
256 Ok(outputs)
257}
258
259pub fn expand_cell_indices(
260 cell: &runmat_builtins::CellArray,
261 indices: &[Value],
262) -> Result<Vec<Value>, RuntimeError> {
263 match indices.len() {
264 1 => match &indices[0] {
265 Value::Num(n) => {
266 let idx = *n as usize;
267 if idx == 0 || idx > cell.data.len() {
268 return Err(crate::interpreter::errors::mex(
269 "CellIndexOutOfBounds",
270 "Cell index out of bounds",
271 ));
272 }
273 Ok(vec![(*cell.data[idx - 1]).clone()])
274 }
275 Value::Int(i) => {
276 let idx = i.to_i64() as usize;
277 if idx == 0 || idx > cell.data.len() {
278 return Err(crate::interpreter::errors::mex(
279 "CellIndexOutOfBounds",
280 "Cell index out of bounds",
281 ));
282 }
283 Ok(vec![(*cell.data[idx - 1]).clone()])
284 }
285 Value::Tensor(t) => {
286 let mut out = Vec::with_capacity(t.data.len());
287 for &val in &t.data {
288 let idx = val as usize;
289 if idx == 0 || idx > cell.data.len() {
290 return Err(crate::interpreter::errors::mex(
291 "CellIndexOutOfBounds",
292 "Cell index out of bounds",
293 ));
294 }
295 out.push((*cell.data[idx - 1]).clone());
296 }
297 Ok(out)
298 }
299 _ => Err(crate::interpreter::errors::mex(
300 "CellIndexType",
301 "Unsupported cell index type",
302 )),
303 },
304 2 => {
305 let r: f64 = (&indices[0]).try_into()?;
306 let c: f64 = (&indices[1]).try_into()?;
307 let (ir, ic) = (r as usize, c as usize);
308 if ir == 0 || ir > cell.rows || ic == 0 || ic > cell.cols {
309 return Err(crate::interpreter::errors::mex(
310 "CellSubscriptOutOfBounds",
311 "Cell subscript out of bounds",
312 ));
313 }
314 Ok(vec![(*cell.data[(ir - 1) * cell.cols + (ic - 1)]).clone()])
315 }
316 _ => Err(crate::interpreter::errors::mex(
317 "CellIndexType",
318 "Unsupported cell index type",
319 )),
320 }
321}
322
323pub fn expand_all_cell(cell: &runmat_builtins::CellArray) -> Vec<Value> {
324 cell.data.iter().map(|p| (*(*p)).clone()).collect()
325}
326
327pub fn subsref_paren_index_cell(indices: &[Value]) -> Result<Value, RuntimeError> {
328 Ok(Value::Cell(
329 runmat_builtins::CellArray::new(indices.to_vec(), 1, indices.len())
330 .map_err(|e| CompileError::new(format!("subsref build error: {e}")))?,
331 ))
332}
333
334pub fn subsref_brace_index_cell_raw(indices: &[Value]) -> Result<Value, RuntimeError> {
335 Ok(Value::Cell(
336 runmat_builtins::CellArray::new(indices.to_vec(), 1, indices.len())
337 .map_err(|e| CompileError::new(format!("subsref build error: {e}")))?,
338 ))
339}
340
341pub fn subsref_brace_numeric_index_values(indices: &[Value]) -> Vec<Value> {
342 indices
343 .iter()
344 .map(|v| Value::Num((v).try_into().unwrap_or(0.0)))
345 .collect()
346}
347
348pub fn subsref_empty_brace_cell() -> Result<Value, RuntimeError> {
349 Ok(Value::Cell(
350 runmat_builtins::CellArray::new(vec![], 1, 0)
351 .map_err(|e| CompileError::new(format!("subsref build error: {e}")))?,
352 ))
353}
354
355pub async fn build_expanded_args_from_specs<ExpandObjectAll, ExpandObjectIndices, FutAll, FutIdx>(
356 stack: &mut Vec<Value>,
357 specs: &[ArgSpec],
358 invalid_expand_all_msg: &str,
359 invalid_expand_msg: &str,
360 mut expand_object_all: ExpandObjectAll,
361 mut expand_object_indices: ExpandObjectIndices,
362) -> Result<Vec<Value>, RuntimeError>
363where
364 ExpandObjectAll: FnMut(Value) -> FutAll,
365 ExpandObjectIndices: FnMut(Value, Vec<Value>) -> FutIdx,
366 FutAll: Future<Output = Result<Vec<Value>, RuntimeError>>,
367 FutIdx: Future<Output = Result<Vec<Value>, RuntimeError>>,
368{
369 let mut temp: Vec<Value> = Vec::new();
370 for spec in specs.iter().rev() {
371 if spec.is_expand {
372 let mut indices = Vec::with_capacity(spec.num_indices);
373 for _ in 0..spec.num_indices {
374 indices.push(stack.pop().ok_or_else(|| {
375 crate::interpreter::errors::mex("StackUnderflow", "stack underflow")
376 })?);
377 }
378 indices.reverse();
379 let base = stack.pop().ok_or_else(|| {
380 crate::interpreter::errors::mex("StackUnderflow", "stack underflow")
381 })?;
382
383 let expanded = if spec.expand_all {
384 match base {
385 Value::Cell(ca) => expand_all_cell(&ca),
386 other @ Value::Object(_) => expand_object_all(other).await?,
387 _ => {
388 return Err(crate::interpreter::errors::mex(
389 "InvalidExpandAllTarget",
390 invalid_expand_all_msg,
391 ))
392 }
393 }
394 } else {
395 match (base, indices.len()) {
396 (Value::Cell(ca), 1) | (Value::Cell(ca), 2) => {
397 expand_cell_indices(&ca, &indices)?
398 }
399 (other @ Value::Object(_), _) => expand_object_indices(other, indices).await?,
400 _ => {
401 return Err(crate::interpreter::errors::mex(
402 "ExpandError",
403 invalid_expand_msg,
404 ))
405 }
406 }
407 };
408 temp.extend(expanded);
409 } else {
410 temp.push(stack.pop().ok_or_else(|| {
411 crate::interpreter::errors::mex("StackUnderflow", "stack underflow")
412 })?);
413 }
414 }
415 temp.reverse();
416 Ok(temp)
417}