Skip to main content

runmat_vm/call/
shared.rs

1use crate::bytecode::{ArgSpec, UserFunction};
2use crate::compiler::CompileError;
3use runmat_builtins::{Type, Value};
4use runmat_hir::{remapping, HirProgram, VarId};
5use runmat_runtime::RuntimeError;
6use std::collections::HashMap;
7use std::future::Future;
8
9pub struct PreparedUserCall {
10    pub func: UserFunction,
11    pub var_map: HashMap<VarId, VarId>,
12    pub func_program: HirProgram,
13    pub func_vars: Vec<Value>,
14}
15
16pub fn lookup_user_function(
17    name: &str,
18    functions: &HashMap<String, UserFunction>,
19) -> Result<UserFunction, RuntimeError> {
20    functions.get(name).cloned().ok_or_else(|| {
21        crate::interpreter::errors::mex("UndefinedFunction", &format!("Undefined function: {name}"))
22    })
23}
24
25pub fn validate_user_function_arity(
26    name: &str,
27    func: &UserFunction,
28    arg_count: usize,
29) -> Result<(), RuntimeError> {
30    if !func.has_varargin {
31        if arg_count < func.params.len() {
32            return Err(crate::interpreter::errors::mex(
33                "NotEnoughInputs",
34                &format!(
35                    "Function '{name}' expects {} inputs, got {arg_count}",
36                    func.params.len()
37                ),
38            ));
39        }
40        if arg_count > func.params.len() {
41            return Err(crate::interpreter::errors::mex(
42                "TooManyInputs",
43                &format!(
44                    "Function '{name}' expects {} inputs, got {arg_count}",
45                    func.params.len()
46                ),
47            ));
48        }
49    } else {
50        let min_args = func.params.len().saturating_sub(1);
51        if arg_count < min_args {
52            return Err(crate::interpreter::errors::mex(
53                "NotEnoughInputs",
54                &format!("Function '{name}' expects at least {min_args} inputs, got {arg_count}"),
55            ));
56        }
57    }
58    Ok(())
59}
60
61pub fn prepare_user_call(
62    func: UserFunction,
63    args: &[Value],
64    vars: &[Value],
65) -> Result<PreparedUserCall, CompileError> {
66    let var_map =
67        remapping::create_complete_function_var_map(&func.params, &func.outputs, &func.body);
68    let local_var_count = var_map.len();
69    let remapped_body = remapping::remap_function_body(&func.body, &var_map);
70    let func_vars_count = local_var_count.max(func.params.len());
71    let mut func_vars = vec![Value::Num(0.0); func_vars_count];
72
73    if func.has_varargin {
74        let fixed = func.params.len().saturating_sub(1);
75        for i in 0..fixed {
76            if i < args.len() && i < func_vars.len() {
77                func_vars[i] = args[i].clone();
78            }
79        }
80        let mut rest: Vec<Value> = if args.len() > fixed {
81            args[fixed..].to_vec()
82        } else {
83            Vec::new()
84        };
85        let cell = runmat_builtins::CellArray::new(
86            std::mem::take(&mut rest),
87            1,
88            if args.len() > fixed {
89                args.len() - fixed
90            } else {
91                0
92            },
93        )
94        .map_err(|e| CompileError::new(format!("varargin: {e}")))?;
95        if fixed < func_vars.len() {
96            func_vars[fixed] = Value::Cell(cell);
97        }
98    } else {
99        for (i, _param_id) in func.params.iter().enumerate() {
100            if i < args.len() && i < func_vars.len() {
101                func_vars[i] = args[i].clone();
102            }
103        }
104    }
105
106    for (original_var_id, local_var_id) in &var_map {
107        let local_index = local_var_id.0;
108        let global_index = original_var_id.0;
109        if local_index < func_vars.len() && global_index < vars.len() {
110            let is_parameter = func
111                .params
112                .iter()
113                .any(|param_id| param_id == original_var_id);
114            if !is_parameter {
115                func_vars[local_index] = vars[global_index].clone();
116            }
117        }
118    }
119
120    if func.has_varargout {
121        if let Some(varargout_oid) = func.outputs.last() {
122            if let Some(local_id) = var_map.get(varargout_oid) {
123                if local_id.0 < func_vars.len() {
124                    let empty = runmat_builtins::CellArray::new(vec![], 1, 0)
125                        .map_err(|e| CompileError::new(format!("varargout init: {e}")))?;
126                    func_vars[local_id.0] = Value::Cell(empty);
127                }
128            }
129        }
130    }
131
132    let mut func_var_types = func.var_types.clone();
133    if func_var_types.len() < local_var_count {
134        func_var_types.resize(local_var_count, Type::Unknown);
135    }
136    let func_program = HirProgram {
137        body: remapped_body,
138        var_types: func_var_types,
139    };
140
141    Ok(PreparedUserCall {
142        func,
143        var_map,
144        func_program,
145        func_vars,
146    })
147}
148
149pub fn first_output_value(
150    func: &UserFunction,
151    var_map: &HashMap<VarId, VarId>,
152    func_result_vars: &[Value],
153) -> Value {
154    if func.outputs.is_empty() {
155        return Value::Num(0.0);
156    }
157    if func.has_varargout {
158        let total_named = func.outputs.len().saturating_sub(1);
159        if total_named > 0 {
160            if let Some(oid) = func.outputs.first() {
161                if let Some(local_id) = var_map.get(oid) {
162                    if let Some(value) = func_result_vars.get(local_id.0) {
163                        return value.clone();
164                    }
165                }
166            }
167        }
168        if let Some(varargout_oid) = func.outputs.last() {
169            if let Some(local_id) = var_map.get(varargout_oid) {
170                if let Some(Value::Cell(ca)) = func_result_vars.get(local_id.0) {
171                    if let Some(first) = ca.data.first() {
172                        return (**first).clone();
173                    }
174                }
175            }
176        }
177        return Value::Num(0.0);
178    }
179    let Some(output_id) = func.outputs.first() else {
180        return Value::Num(0.0);
181    };
182    let Some(local_id) = var_map.get(output_id) else {
183        return Value::Num(0.0);
184    };
185    func_result_vars
186        .get(local_id.0)
187        .cloned()
188        .unwrap_or(Value::Num(0.0))
189}
190
191pub fn collect_multi_outputs(
192    name: &str,
193    func: &UserFunction,
194    var_map: &HashMap<VarId, VarId>,
195    func_result_vars: &[Value],
196    out_count: usize,
197) -> Result<Vec<Value>, RuntimeError> {
198    let mut outputs = Vec::with_capacity(out_count);
199    if func.has_varargout {
200        let total_named = func.outputs.len().saturating_sub(1);
201        let mut pushed = 0usize;
202        for i in 0..total_named.min(out_count) {
203            if let Some(oid) = func.outputs.get(i) {
204                if let Some(local_id) = var_map.get(oid) {
205                    let idx = local_id.0;
206                    let v = func_result_vars
207                        .get(idx)
208                        .cloned()
209                        .unwrap_or(Value::Num(0.0));
210                    outputs.push(v);
211                    pushed += 1;
212                }
213            }
214        }
215        if pushed < out_count {
216            if let Some(varargout_oid) = func.outputs.last() {
217                if let Some(local_id) = var_map.get(varargout_oid) {
218                    if let Some(Value::Cell(ca)) = func_result_vars.get(local_id.0) {
219                        let available = ca.data.len();
220                        let need = out_count - pushed;
221                        if need > available {
222                            return Err(crate::interpreter::errors::mex(
223                                "VarargoutMismatch",
224                                &format!(
225                                    "Function '{name}' returned {available} varargout values, {need} requested"
226                                ),
227                            ));
228                        }
229                        for vi in 0..need {
230                            outputs.push((*ca.data[vi]).clone());
231                        }
232                    }
233                }
234            }
235        }
236    } else {
237        let defined = func.outputs.len();
238        if out_count > defined {
239            return Err(crate::interpreter::errors::mex(
240                "TooManyOutputs",
241                &format!("Function '{name}' defines {defined} outputs, {out_count} requested"),
242            ));
243        }
244        for i in 0..out_count {
245            let v = func
246                .outputs
247                .get(i)
248                .and_then(|oid| var_map.get(oid))
249                .map(|lid| lid.0)
250                .and_then(|idx| func_result_vars.get(idx))
251                .cloned()
252                .unwrap_or(Value::Num(0.0));
253            outputs.push(v);
254        }
255    }
256    Ok(outputs)
257}
258
259pub fn expand_cell_indices(
260    cell: &runmat_builtins::CellArray,
261    indices: &[Value],
262) -> Result<Vec<Value>, RuntimeError> {
263    match indices.len() {
264        1 => match &indices[0] {
265            Value::Num(n) => {
266                let idx = *n as usize;
267                if idx == 0 || idx > cell.data.len() {
268                    return Err(crate::interpreter::errors::mex(
269                        "CellIndexOutOfBounds",
270                        "Cell index out of bounds",
271                    ));
272                }
273                Ok(vec![(*cell.data[idx - 1]).clone()])
274            }
275            Value::Int(i) => {
276                let idx = i.to_i64() as usize;
277                if idx == 0 || idx > cell.data.len() {
278                    return Err(crate::interpreter::errors::mex(
279                        "CellIndexOutOfBounds",
280                        "Cell index out of bounds",
281                    ));
282                }
283                Ok(vec![(*cell.data[idx - 1]).clone()])
284            }
285            Value::Tensor(t) => {
286                let mut out = Vec::with_capacity(t.data.len());
287                for &val in &t.data {
288                    let idx = val as usize;
289                    if idx == 0 || idx > cell.data.len() {
290                        return Err(crate::interpreter::errors::mex(
291                            "CellIndexOutOfBounds",
292                            "Cell index out of bounds",
293                        ));
294                    }
295                    out.push((*cell.data[idx - 1]).clone());
296                }
297                Ok(out)
298            }
299            _ => Err(crate::interpreter::errors::mex(
300                "CellIndexType",
301                "Unsupported cell index type",
302            )),
303        },
304        2 => {
305            let r: f64 = (&indices[0]).try_into()?;
306            let c: f64 = (&indices[1]).try_into()?;
307            let (ir, ic) = (r as usize, c as usize);
308            if ir == 0 || ir > cell.rows || ic == 0 || ic > cell.cols {
309                return Err(crate::interpreter::errors::mex(
310                    "CellSubscriptOutOfBounds",
311                    "Cell subscript out of bounds",
312                ));
313            }
314            Ok(vec![(*cell.data[(ir - 1) * cell.cols + (ic - 1)]).clone()])
315        }
316        _ => Err(crate::interpreter::errors::mex(
317            "CellIndexType",
318            "Unsupported cell index type",
319        )),
320    }
321}
322
323pub fn expand_all_cell(cell: &runmat_builtins::CellArray) -> Vec<Value> {
324    cell.data.iter().map(|p| (*(*p)).clone()).collect()
325}
326
327pub fn subsref_paren_index_cell(indices: &[Value]) -> Result<Value, RuntimeError> {
328    Ok(Value::Cell(
329        runmat_builtins::CellArray::new(indices.to_vec(), 1, indices.len())
330            .map_err(|e| CompileError::new(format!("subsref build error: {e}")))?,
331    ))
332}
333
334pub fn subsref_brace_index_cell_raw(indices: &[Value]) -> Result<Value, RuntimeError> {
335    Ok(Value::Cell(
336        runmat_builtins::CellArray::new(indices.to_vec(), 1, indices.len())
337            .map_err(|e| CompileError::new(format!("subsref build error: {e}")))?,
338    ))
339}
340
341pub fn subsref_brace_numeric_index_values(indices: &[Value]) -> Vec<Value> {
342    indices
343        .iter()
344        .map(|v| Value::Num((v).try_into().unwrap_or(0.0)))
345        .collect()
346}
347
348pub fn subsref_empty_brace_cell() -> Result<Value, RuntimeError> {
349    Ok(Value::Cell(
350        runmat_builtins::CellArray::new(vec![], 1, 0)
351            .map_err(|e| CompileError::new(format!("subsref build error: {e}")))?,
352    ))
353}
354
355pub async fn build_expanded_args_from_specs<ExpandObjectAll, ExpandObjectIndices, FutAll, FutIdx>(
356    stack: &mut Vec<Value>,
357    specs: &[ArgSpec],
358    invalid_expand_all_msg: &str,
359    invalid_expand_msg: &str,
360    mut expand_object_all: ExpandObjectAll,
361    mut expand_object_indices: ExpandObjectIndices,
362) -> Result<Vec<Value>, RuntimeError>
363where
364    ExpandObjectAll: FnMut(Value) -> FutAll,
365    ExpandObjectIndices: FnMut(Value, Vec<Value>) -> FutIdx,
366    FutAll: Future<Output = Result<Vec<Value>, RuntimeError>>,
367    FutIdx: Future<Output = Result<Vec<Value>, RuntimeError>>,
368{
369    let mut temp: Vec<Value> = Vec::new();
370    for spec in specs.iter().rev() {
371        if spec.is_expand {
372            let mut indices = Vec::with_capacity(spec.num_indices);
373            for _ in 0..spec.num_indices {
374                indices.push(stack.pop().ok_or_else(|| {
375                    crate::interpreter::errors::mex("StackUnderflow", "stack underflow")
376                })?);
377            }
378            indices.reverse();
379            let base = stack.pop().ok_or_else(|| {
380                crate::interpreter::errors::mex("StackUnderflow", "stack underflow")
381            })?;
382
383            let expanded = if spec.expand_all {
384                match base {
385                    Value::Cell(ca) => expand_all_cell(&ca),
386                    other @ Value::Object(_) => expand_object_all(other).await?,
387                    _ => {
388                        return Err(crate::interpreter::errors::mex(
389                            "InvalidExpandAllTarget",
390                            invalid_expand_all_msg,
391                        ))
392                    }
393                }
394            } else {
395                match (base, indices.len()) {
396                    (Value::Cell(ca), 1) | (Value::Cell(ca), 2) => {
397                        expand_cell_indices(&ca, &indices)?
398                    }
399                    (other @ Value::Object(_), _) => expand_object_indices(other, indices).await?,
400                    _ => {
401                        return Err(crate::interpreter::errors::mex(
402                            "ExpandError",
403                            invalid_expand_msg,
404                        ))
405                    }
406                }
407            };
408            temp.extend(expanded);
409        } else {
410            temp.push(stack.pop().ok_or_else(|| {
411                crate::interpreter::errors::mex("StackUnderflow", "stack underflow")
412            })?);
413        }
414    }
415    temp.reverse();
416    Ok(temp)
417}