use std::cell::{Cell, RefCell};
use std::collections::{HashMap, HashSet};
use indexmap::IndexMap;
use ndarray::Array2;
use num_complex::Complex;
use rand::{Rng, SeedableRng, rngs::SmallRng};
use crate::env::{Env, LambdaFn, Value};
use crate::io::IoContext;
pub type FnCallHook = fn(
name: &str,
func: &Value,
args: &[Value],
caller_env: &Env,
io: &mut IoContext,
) -> Result<Value, String>;
thread_local! {
static FN_CALL_HOOK: Cell<Option<FnCallHook>> = const { Cell::new(None) };
}
pub fn set_fn_call_hook(f: FnCallHook) {
FN_CALL_HOOK.with(|c| c.set(Some(f)));
}
pub type AutoloadHook = fn(name: &str) -> bool;
thread_local! {
static AUTOLOAD_HOOK: Cell<Option<AutoloadHook>> = const { Cell::new(None) };
static AUTOLOAD_CACHE: RefCell<Env> = RefCell::new(Env::new());
static AUTOLOAD_MISS_CACHE: RefCell<HashSet<String>> = RefCell::new(HashSet::new());
}
pub fn set_autoload_hook(f: AutoloadHook) {
AUTOLOAD_HOOK.with(|c| c.set(Some(f)));
}
pub fn autoload_cache_insert(name: String, val: Value) {
AUTOLOAD_CACHE.with(|c| c.borrow_mut().insert(name, val));
}
pub fn clear_autoload_miss_cache() {
AUTOLOAD_MISS_CACHE.with(|c| c.borrow_mut().clear());
}
pub fn resolve_autoloaded(name: &str) -> Option<Value> {
let cached = AUTOLOAD_CACHE.with(|c| c.borrow().get(name).cloned());
if cached.is_some() {
return cached;
}
if AUTOLOAD_MISS_CACHE.with(|c| c.borrow().contains(name)) {
return None;
}
let hook = AUTOLOAD_HOOK.with(|c| c.get());
if let Some(f) = hook {
f(name);
}
let found = AUTOLOAD_CACHE.with(|c| c.borrow().get(name).cloned());
if found.is_none() {
AUTOLOAD_MISS_CACHE.with(|c| c.borrow_mut().insert(name.to_string()));
}
found
}
pub type EvalStrHook = fn(code: &str, env: &Env) -> Result<Value, String>;
thread_local! {
static EVAL_STR_HOOK: Cell<Option<EvalStrHook>> = const { Cell::new(None) };
}
pub fn set_eval_str_hook(f: EvalStrHook) {
EVAL_STR_HOOK.with(|c| c.set(Some(f)));
}
fn call_eval_str_hook(code: &str, env: &Env) -> Result<Value, String> {
match EVAL_STR_HOOK.with(|c| c.get()) {
Some(hook) => hook(code, env),
None => Err("eval: exec::init() not called".to_string()),
}
}
thread_local! {
static TIC_TIME: Cell<Option<std::time::Instant>> = const { Cell::new(None) };
}
thread_local! {
static LAST_ERR: RefCell<String> = const { RefCell::new(String::new()) };
}
pub fn set_last_err(msg: &str) {
LAST_ERR.with(|e| *e.borrow_mut() = msg.to_string());
}
pub fn get_last_err() -> String {
LAST_ERR.with(|e| e.borrow().clone())
}
thread_local! {
static NARGOUT: Cell<usize> = const { Cell::new(1) };
}
pub fn set_nargout(n: usize) {
NARGOUT.with(|c| c.set(n));
}
fn get_nargout() -> usize {
NARGOUT.with(|c| c.get())
}
thread_local! {
static DISPLAY_FMT: RefCell<FormatMode> = const { RefCell::new(FormatMode::Short) };
static DISPLAY_BASE: Cell<Base> = const { Cell::new(Base::Dec) };
static DISPLAY_COMPACT: Cell<bool> = const { Cell::new(false) };
}
pub fn set_display_ctx(fmt: &FormatMode, base: Base, compact: bool) {
DISPLAY_FMT.with(|f| *f.borrow_mut() = fmt.clone());
DISPLAY_BASE.with(|b| b.set(base));
DISPLAY_COMPACT.with(|c| c.set(compact));
}
pub fn get_display_fmt() -> FormatMode {
DISPLAY_FMT.with(|f| f.borrow().clone())
}
pub fn get_display_base() -> Base {
DISPLAY_BASE.with(|b| b.get())
}
pub fn get_display_compact() -> bool {
DISPLAY_COMPACT.with(|c| c.get())
}
thread_local! {
static GLOBAL_ENV: RefCell<Env> = RefCell::new(Env::new());
static GLOBAL_NAMES_STACK: RefCell<Vec<HashSet<String>>> =
RefCell::new(vec![HashSet::new()]);
}
pub fn global_frame_push() {
GLOBAL_NAMES_STACK.with(|s| s.borrow_mut().push(HashSet::new()));
}
pub fn global_frame_pop() {
GLOBAL_NAMES_STACK.with(|s| {
s.borrow_mut().pop();
});
}
pub fn global_declare(name: &str) {
GLOBAL_NAMES_STACK.with(|s| {
if let Some(frame) = s.borrow_mut().last_mut() {
frame.insert(name.to_string());
}
});
}
pub fn is_global(name: &str) -> bool {
GLOBAL_NAMES_STACK.with(|s| s.borrow().last().is_some_and(|f| f.contains(name)))
}
pub fn global_get(name: &str) -> Option<Value> {
GLOBAL_ENV.with(|e| e.borrow().get(name).cloned())
}
pub fn global_set(name: &str, val: Value) {
GLOBAL_ENV.with(|e| e.borrow_mut().insert(name.to_string(), val));
}
pub fn global_init_if_absent(name: &str) {
GLOBAL_ENV.with(|e| {
e.borrow_mut()
.entry(name.to_string())
.or_insert(Value::Scalar(0.0));
});
}
pub fn global_refresh_into_env(env: &mut crate::env::Env) {
GLOBAL_NAMES_STACK.with(|s| {
GLOBAL_ENV.with(|ge| {
if let Some(frame) = s.borrow().last() {
let store = ge.borrow();
for name in frame {
if let Some(val) = store.get(name) {
env.insert(name.clone(), val.clone());
}
}
}
});
});
}
thread_local! {
static PERSISTENT_STORE: RefCell<HashMap<String, Value>> =
RefCell::new(HashMap::new());
static FUNC_NAME_STACK: RefCell<Vec<String>> =
RefCell::new(vec![String::new()]);
static PERSISTENT_NAMES_STACK: RefCell<Vec<HashSet<String>>> =
RefCell::new(vec![HashSet::new()]);
}
pub fn persistent_frame_push(func_name: &str) {
FUNC_NAME_STACK.with(|s| s.borrow_mut().push(func_name.to_string()));
PERSISTENT_NAMES_STACK.with(|s| s.borrow_mut().push(HashSet::new()));
}
pub fn persistent_frame_pop() -> (String, HashSet<String>) {
let func_name = FUNC_NAME_STACK.with(|s| s.borrow_mut().pop().unwrap_or_default());
let names = PERSISTENT_NAMES_STACK.with(|s| s.borrow_mut().pop().unwrap_or_default());
(func_name, names)
}
pub fn persistent_declare(name: &str) {
PERSISTENT_NAMES_STACK.with(|s| {
if let Some(frame) = s.borrow_mut().last_mut() {
frame.insert(name.to_string());
}
});
}
pub fn persistent_load(func_name: &str, var_name: &str) -> Option<Value> {
let key = format!("{func_name}\x00{var_name}");
PERSISTENT_STORE.with(|s| s.borrow().get(&key).cloned())
}
pub fn persistent_save(func_name: &str, var_name: &str, val: Value) {
let key = format!("{func_name}\x00{var_name}");
PERSISTENT_STORE.with(|s| s.borrow_mut().insert(key, val));
}
pub fn current_func_name() -> String {
FUNC_NAME_STACK.with(|s| s.borrow().last().cloned().unwrap_or_default())
}
pub fn is_persistent(name: &str) -> bool {
PERSISTENT_NAMES_STACK.with(|s| s.borrow().last().is_some_and(|frame| frame.contains(name)))
}
thread_local! {
static RNG: RefCell<SmallRng> = RefCell::new(SmallRng::from_entropy());
}
pub fn rng_seed(seed: u64) {
RNG.with(|r| *r.borrow_mut() = SmallRng::seed_from_u64(seed));
}
pub fn rng_shuffle() {
RNG.with(|r| *r.borrow_mut() = SmallRng::from_entropy());
}
fn rand_uniform() -> f64 {
RNG.with(|r| r.borrow_mut().gen_range(0.0_f64..1.0))
}
fn rand_normal() -> f64 {
let u1 = rand_uniform().max(f64::EPSILON);
let u2 = rand_uniform();
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
}
#[derive(Debug, Clone)]
pub enum Expr {
Number(f64),
Var(String),
UnaryMinus(Box<Expr>),
UnaryNot(Box<Expr>),
BinOp(Box<Expr>, Op, Box<Expr>),
Call(String, Vec<Expr>),
Matrix(Vec<Vec<Expr>>),
Transpose(Box<Expr>),
Range(Box<Expr>, Option<Box<Expr>>, Box<Expr>),
Colon,
StrLiteral(String),
StringObjLiteral(String),
Lambda {
params: Vec<String>,
body: Box<Expr>,
source: String,
},
PlainTranspose(Box<Expr>),
CellLiteral(Vec<Expr>),
CellIndex(Box<Expr>, Box<Expr>),
FuncHandle(String),
FieldGet(Box<Expr>, String),
DynFieldGet(Box<Expr>, Box<Expr>),
DotCall(Vec<String>, Vec<Expr>),
NaT,
}
#[derive(Debug, Clone)]
pub enum Op {
Add,
Sub,
Mul,
Div,
Pow,
ElemMul,
ElemDiv,
ElemPow,
Eq,
NotEq,
Lt,
Gt,
LtEq,
GtEq,
And,
Or,
ElemAnd,
ElemOr,
LDiv,
}
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum Base {
#[default]
Dec,
Hex,
Bin,
Oct,
}
#[derive(Debug, Clone, PartialEq)]
pub enum FormatMode {
Short,
Long,
ShortE,
LongE,
ShortG,
LongG,
Bank,
Rat,
Hex,
Plus,
Custom(usize),
}
impl Default for FormatMode {
fn default() -> Self {
FormatMode::Custom(10)
}
}
impl FormatMode {
pub fn name(&self) -> String {
match self {
FormatMode::Short => "short".to_string(),
FormatMode::Long => "long".to_string(),
FormatMode::ShortE => "shortE".to_string(),
FormatMode::LongE => "longE".to_string(),
FormatMode::ShortG => "shortG".to_string(),
FormatMode::LongG => "longG".to_string(),
FormatMode::Bank => "bank".to_string(),
FormatMode::Rat => "rat".to_string(),
FormatMode::Hex => "hex".to_string(),
FormatMode::Plus => "+".to_string(),
FormatMode::Custom(n) => format!("custom({n})"),
}
}
}
pub fn eval(expr: &Expr, env: &Env) -> Result<Value, String> {
eval_inner(expr, env, None)
}
pub fn eval_with_io(expr: &Expr, env: &Env, io: &mut IoContext) -> Result<Value, String> {
eval_inner(expr, env, Some(io))
}
fn eval_inner(expr: &Expr, env: &Env, mut io: Option<&mut IoContext>) -> Result<Value, String> {
match expr {
Expr::Number(n) => Ok(Value::Scalar(*n)),
Expr::Var(name) => env.get(name).cloned().ok_or(()).or_else(|_| {
if is_global(name)
&& let Some(val) = global_get(name)
{
return Ok(val);
}
if name == "e" {
return Ok(Value::Scalar(std::f64::consts::E));
}
if let Ok(val) = call_builtin(name, &[], env, io.as_deref_mut()) {
return Ok(val);
}
let hint = suggest_similar(name, env);
match hint {
Some(s) => Err(format!("Undefined variable '{name}'; did you mean '{s}'?")),
None => Err(format!("Undefined variable: '{name}'")),
}
}),
Expr::UnaryMinus(e) => match eval_inner(e, env, io)? {
Value::Void => Err("Unary minus is not applicable to void".to_string()),
Value::Scalar(n) => Ok(Value::Scalar(-n)),
Value::Matrix(m) => Ok(Value::Matrix(Box::new(m.mapv(|x| -x)))),
Value::Complex(re, im) => Ok(Value::Complex(-re, -im)),
Value::ComplexMatrix(m) => Ok(Value::ComplexMatrix(Box::new(m.mapv(|c| -c)))),
Value::Str(s) => match str_to_numeric(&s) {
Value::Scalar(n) => Ok(Value::Scalar(-n)),
Value::Matrix(m) => Ok(Value::Matrix(Box::new(m.mapv(|x| -x)))),
_ => unreachable!(),
},
Value::StringObj(_) => {
Err("Unary minus is not applicable to string objects".to_string())
}
Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err("Unary minus is not applicable to this type".to_string()),
},
Expr::UnaryNot(e) => match eval_inner(e, env, io)? {
Value::Void => Err("Logical NOT is not applicable to void".to_string()),
Value::Scalar(n) => Ok(Value::Scalar(if n == 0.0 { 1.0 } else { 0.0 })),
Value::Matrix(m) => Ok(Value::Matrix(Box::new(
m.mapv(|x| if x == 0.0 { 1.0 } else { 0.0 }),
))),
Value::Complex(re, im) => Ok(Value::Scalar(if re == 0.0 && im == 0.0 {
1.0
} else {
0.0
})),
Value::ComplexMatrix(m) => {
Ok(Value::Matrix(Box::new(m.mapv(|c| {
if c.re == 0.0 && c.im == 0.0 { 1.0 } else { 0.0 }
}))))
}
Value::Str(s) => match str_to_numeric(&s) {
Value::Scalar(n) => Ok(Value::Scalar(if n == 0.0 { 1.0 } else { 0.0 })),
Value::Matrix(m) => Ok(Value::Matrix(Box::new(
m.mapv(|x| if x == 0.0 { 1.0 } else { 0.0 }),
))),
_ => unreachable!(),
},
Value::StringObj(_) => {
Err("Logical NOT is not applicable to string objects".to_string())
}
Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err("Logical NOT is not applicable to this type".to_string()),
},
Expr::BinOp(left, op, right) => {
let l = eval_inner(left, env, io.as_deref_mut())?;
let r = eval_inner(right, env, io)?;
eval_binop(l, op, r)
}
Expr::Call(name, args) => {
if name == "try" && args.len() == 2 {
return match eval_inner(&args[0], env, io.as_deref_mut()) {
Ok(v) => Ok(v),
Err(msg) => {
set_last_err(&msg);
eval_inner(&args[1], env, io.as_deref_mut())
}
};
}
if let Some(env_val) = env.get(name) {
if !matches!(env_val, Value::Lambda(_) | Value::Function(_)) {
return eval_index(env_val, args, env);
}
let val = env_val.clone();
match &val {
Value::Lambda(f) => {
let mut evaled = Vec::with_capacity(args.len().max(1));
for a in args {
evaled.push(eval_inner(a, env, io.as_deref_mut())?);
}
if evaled.is_empty() {
evaled.push(env.get("ans").cloned().unwrap_or(Value::Scalar(0.0)));
}
let f = f.clone();
return f.0(&evaled, io);
}
Value::Function(_) => {
let mut evaled = Vec::with_capacity(args.len());
for a in args {
evaled.push(eval_inner(a, env, io.as_deref_mut())?);
}
return match io.as_deref_mut() {
Some(io_ref) => FN_CALL_HOOK.with(|c| match c.get() {
Some(hook) => hook(name, &val, &evaled, env, io_ref),
None => Err(format!(
"'{name}': user function execution not initialized \
(call exec::init() first)"
)),
}),
None => {
let mut tmp_io = IoContext::new();
FN_CALL_HOOK.with(|c| match c.get() {
Some(hook) => hook(name, &val, &evaled, env, &mut tmp_io),
None => Err(format!(
"'{name}': user function execution not initialized"
)),
})
}
};
}
_ => unreachable!(),
}
}
let autoloaded_val = AUTOLOAD_CACHE
.with(|c| c.borrow().get(name).cloned())
.or_else(|| {
if AUTOLOAD_MISS_CACHE.with(|c| c.borrow().contains(name.as_str())) {
return None;
}
let loaded = AUTOLOAD_HOOK
.with(|c| c.get())
.is_some_and(|hook| hook(name));
if loaded {
AUTOLOAD_CACHE.with(|c| c.borrow().get(name).cloned())
} else {
AUTOLOAD_MISS_CACHE.with(|c| c.borrow_mut().insert(name.to_string()));
None
}
});
if let Some(val) = autoloaded_val {
let mut evaled = Vec::with_capacity(args.len());
for a in args {
evaled.push(eval_inner(a, env, io.as_deref_mut())?);
}
return match io.as_deref_mut() {
Some(io_ref) => FN_CALL_HOOK.with(|c| match c.get() {
Some(hook) => hook(name, &val, &evaled, env, io_ref),
None => Err(format!("'{name}': exec::init() not called")),
}),
None => {
let mut tmp_io = IoContext::new();
FN_CALL_HOOK.with(|c| match c.get() {
Some(hook) => hook(name, &val, &evaled, env, &mut tmp_io),
None => Err(format!("'{name}': exec::init() not called")),
})
}
};
}
let mut evaled = Vec::with_capacity(args.len().max(1));
for a in args {
evaled.push(eval_inner(a, env, io.as_deref_mut())?);
}
let no_ans_inject = matches!(
name.as_str(),
"struct"
| "fieldnames"
| "isfield"
| "rmfield"
| "isstruct"
| "cell"
| "iscell"
| "isempty"
| "cellfun"
| "error"
| "warning"
| "lasterr"
| "pcall"
| "rand"
| "randn"
| "rng"
| "tic"
| "toc"
);
if evaled.is_empty() && !no_ans_inject {
evaled.push(env.get("ans").cloned().unwrap_or(Value::Scalar(0.0)));
}
call_builtin(name, &evaled, env, io)
}
Expr::Lambda {
params,
body,
source,
} => {
let captured_env = env.clone();
let captured_params = params.clone();
let captured_body = *body.clone();
let src = source.clone();
let lambda = LambdaFn(
std::rc::Rc::new(move |args: &[Value], io: Option<&mut IoContext>| {
let effective = if args.len() > captured_params.len() {
if args.len() > captured_params.len() + 1 {
return Err(format!(
"Lambda: too many arguments (expected at most {}, got {})",
captured_params.len(),
args.len()
));
}
&args[..captured_params.len()]
} else {
args
};
let mut local_env = captured_env.clone();
for (p, a) in captured_params.iter().zip(effective.iter()) {
local_env.insert(p.clone(), a.clone());
}
local_env.insert("nargin".to_string(), Value::Scalar(effective.len() as f64));
eval_inner(&captured_body, &local_env, io)
}),
src,
);
Ok(Value::Lambda(Box::new(lambda)))
}
Expr::CellLiteral(elems) => {
let mut vals = Vec::with_capacity(elems.len());
for e in elems {
vals.push(eval_inner(e, env, io.as_deref_mut())?);
}
Ok(Value::Cell(Box::new(vals)))
}
Expr::CellIndex(cell_expr, idx_expr) => {
let cell = eval_inner(cell_expr, env, io.as_deref_mut())?;
let idx = eval_inner(idx_expr, env, io)?;
match (cell, idx) {
(Value::Cell(v), Value::Scalar(i)) => {
let i = i as isize;
if i < 1 || i as usize > v.len() {
Err(format!("Cell index {} out of range (1..{})", i, v.len()))
} else {
Ok(v[(i - 1) as usize].clone())
}
}
(Value::Cell(_), _) => Err("Cell index must be a scalar integer".to_string()),
_ => Err("Brace indexing '{}' is only valid on cell arrays".to_string()),
}
}
Expr::DynFieldGet(base_expr, field_expr) => {
let base_val = eval_inner(base_expr, env, io.as_deref_mut())?;
let field_val = eval_inner(field_expr, env, io)?;
let field = match &field_val {
Value::Str(s) | Value::StringObj(s) => s.clone(),
_ => return Err("Dynamic field name must be a string".to_string()),
};
match base_val {
Value::Struct(map) => map
.get(&field)
.cloned()
.ok_or_else(|| format!("No field '{field}' in struct")),
_ => Err(format!(
"Cannot access field '{field}' on a non-struct value"
)),
}
}
Expr::FieldGet(base_expr, field) => {
let base_val = eval_inner(base_expr, env, io)?;
match base_val {
Value::Map(ref map) if field == "Count" => Ok(Value::Scalar(map.len() as f64)),
Value::Map(_) => Err(format!(
"Map has no property '{field}'; use 'Count', isKey(), keys(), values()"
)),
Value::Struct(map) => map
.get(field)
.cloned()
.ok_or_else(|| format!("No field '{field}' in struct")),
Value::StructArray(arr) => {
let mut values: Vec<Value> = Vec::with_capacity(arr.len());
for (idx, elem) in arr.iter().enumerate() {
let v = elem.get(field).cloned().ok_or_else(|| {
format!("No field '{field}' in struct array element {}", idx + 1)
})?;
values.push(v);
}
let all_scalar = values.iter().all(|v| matches!(v, Value::Scalar(_)));
if all_scalar {
let nums: Vec<f64> = values
.into_iter()
.map(|v| {
if let Value::Scalar(n) = v {
n
} else {
unreachable!()
}
})
.collect();
let n = nums.len();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, n), nums).unwrap(),
)))
} else {
Ok(Value::Cell(Box::new(values)))
}
}
_ => Err(format!(
"Cannot access field '{field}' on a non-struct value"
)),
}
}
Expr::DotCall(segs, args) => {
let qualified = segs.join(".");
if segs == &["containers", "Map"] {
return make_containers_map(args, env, io);
}
if let Some(head_val) = env.get(&segs[0]).cloned() {
let mut val = head_val;
for field in &segs[1..] {
val = match val {
Value::Struct(ref map) => map
.get(field)
.cloned()
.ok_or_else(|| format!("No field '{field}' in struct"))?,
_ => {
return Err(format!(
"Cannot access field '{field}' on a non-struct value"
));
}
};
}
let mut evaled = Vec::with_capacity(args.len());
for a in args {
evaled.push(eval_inner(a, env, io.as_deref_mut())?);
}
return match val {
Value::Lambda(f) => {
if evaled.is_empty() {
evaled.push(env.get("ans").cloned().unwrap_or(Value::Scalar(0.0)));
}
f.0(&evaled, io)
}
Value::Function(_) => match io.as_deref_mut() {
Some(io_ref) => FN_CALL_HOOK.with(|c| match c.get() {
Some(hook) => hook(&qualified, &val, &evaled, env, io_ref),
None => Err(format!("'{qualified}': exec::init() not called")),
}),
None => {
let mut tmp_io = IoContext::new();
FN_CALL_HOOK.with(|c| match c.get() {
Some(hook) => hook(&qualified, &val, &evaled, env, &mut tmp_io),
None => Err(format!("'{qualified}': exec::init() not called")),
})
}
},
_ => Err(format!("'{qualified}': not a callable")),
};
}
let cached = AUTOLOAD_CACHE.with(|c| c.borrow().get(&qualified).cloned());
let autoloaded_val = cached.or_else(|| {
let loaded = AUTOLOAD_HOOK
.with(|c| c.get())
.is_some_and(|hook| hook(&qualified));
if loaded {
AUTOLOAD_CACHE.with(|c| c.borrow().get(&qualified).cloned())
} else {
None
}
});
if let Some(val) = autoloaded_val {
let mut evaled = Vec::with_capacity(args.len());
for a in args {
evaled.push(eval_inner(a, env, io.as_deref_mut())?);
}
return match io.as_deref_mut() {
Some(io_ref) => FN_CALL_HOOK.with(|c| match c.get() {
Some(hook) => hook(&qualified, &val, &evaled, env, io_ref),
None => Err(format!("'{qualified}': exec::init() not called")),
}),
None => {
let mut tmp_io = IoContext::new();
FN_CALL_HOOK.with(|c| match c.get() {
Some(hook) => hook(&qualified, &val, &evaled, env, &mut tmp_io),
None => Err(format!("'{qualified}': exec::init() not called")),
})
}
};
}
Err(format!("Unknown package function: '{qualified}'"))
}
Expr::FuncHandle(name) => {
let name = name.clone();
let captured_env = env.clone();
let src = format!("@{name}");
let lambda = LambdaFn(
std::rc::Rc::new(move |args: &[Value], io: Option<&mut IoContext>| {
if let Some(f) = captured_env.get(&name) {
let f = f.clone();
call_function_value(&f, args, io)
} else {
call_builtin(&name, args, &captured_env, io)
}
}),
src,
);
Ok(Value::Lambda(Box::new(lambda)))
}
Expr::PlainTranspose(e) => match eval_inner(e, env, io)? {
Value::Void => Err("Transpose is not applicable to void".to_string()),
Value::Scalar(n) => Ok(Value::Scalar(n)),
Value::Matrix(m) => Ok(Value::Matrix(Box::new(m.t().to_owned()))),
Value::Complex(re, im) => Ok(Value::Complex(re, im)),
Value::ComplexMatrix(m) => Ok(Value::ComplexMatrix(Box::new(m.t().to_owned()))),
Value::Str(s) => Ok(Value::Str(s)),
Value::StringObj(s) => Ok(Value::StringObj(s)),
v @ (Value::DateTimeArray(_) | Value::DurationArray(_)) => Ok(v),
Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::Map(_) => Err("Transpose is not applicable to this type".to_string()),
},
Expr::Colon => Err("':' is only valid inside index expressions".to_string()),
Expr::NaT => Ok(Value::DateTime(f64::NAN)),
Expr::Matrix(rows) => {
if rows.is_empty() {
return Ok(Value::Matrix(Box::new(Array2::<f64>::zeros((0, 0)))));
}
let mut evaluated: Vec<Vec<Value>> = Vec::with_capacity(rows.len());
for row in rows {
if row.is_empty() {
continue;
}
let mut ev_row: Vec<Value> = Vec::with_capacity(row.len());
for elem_expr in row {
ev_row.push(eval_inner(elem_expr, env, io.as_deref_mut())?);
}
evaluated.push(ev_row);
}
if evaluated.is_empty() {
return Ok(Value::Matrix(Box::new(Array2::<f64>::zeros((0, 0)))));
}
let has_complex = evaluated
.iter()
.flat_map(|row| row.iter())
.any(|v| matches!(v, Value::Complex(_, _) | Value::ComplexMatrix(_)));
enum MatKind {
ComplexNumeric,
Numeric,
DateTime,
Duration,
Str,
}
let kind = if has_complex {
MatKind::ComplexNumeric
} else {
match &evaluated[0][0] {
Value::Scalar(_) | Value::Matrix(_) => MatKind::Numeric,
Value::DateTime(_) | Value::DateTimeArray(_) => MatKind::DateTime,
Value::Duration(_) | Value::DurationArray(_) => MatKind::Duration,
Value::Str(_) | Value::StringObj(_) => MatKind::Str,
Value::Void => {
return Err("Void value cannot be used in matrix literal".to_string());
}
Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::Map(_) => {
return Err("This type cannot be used in matrix literals".to_string());
}
Value::Complex(_, _) | Value::ComplexMatrix(_) => unreachable!(),
}
};
match kind {
MatKind::ComplexNumeric => {
let mut row_blocks: Vec<Array2<Complex<f64>>> =
Vec::with_capacity(evaluated.len());
for ev_row in &evaluated {
let mut elem_mats: Vec<Array2<Complex<f64>>> =
Vec::with_capacity(ev_row.len());
for val in ev_row {
let block: Array2<Complex<f64>> = match val {
Value::Scalar(n) => {
Array2::from_elem((1, 1), Complex::new(*n, 0.0))
}
Value::Complex(re, im) => {
Array2::from_elem((1, 1), Complex::new(*re, *im))
}
Value::Matrix(m) => cm_from_real(m),
Value::ComplexMatrix(m) => (**m).clone(),
_ => {
return Err(
"This type cannot be used in a complex matrix literal"
.to_string(),
);
}
};
elem_mats.push(block);
}
let nrows = elem_mats[0].nrows();
for (i, m) in elem_mats.iter().enumerate().skip(1) {
if m.nrows() != nrows {
return Err(format!(
"Matrix row height mismatch: expected {} rows, element {} has {} rows",
nrows,
i + 1,
m.nrows()
));
}
}
let ncols: usize = elem_mats.iter().map(|m| m.ncols()).sum();
let mut flat: Vec<Complex<f64>> = Vec::with_capacity(nrows * ncols);
for r in 0..nrows {
for m in &elem_mats {
flat.extend(m.row(r).iter().copied());
}
}
row_blocks.push(
Array2::from_shape_vec((nrows, ncols), flat)
.map_err(|e| format!("Matrix shape error: {e}"))?,
);
}
if row_blocks.is_empty() {
return Ok(Value::ComplexMatrix(Box::new(Array2::zeros((0, 0)))));
}
let ncols = row_blocks[0].ncols();
for (i, blk) in row_blocks.iter().enumerate().skip(1) {
if blk.ncols() != ncols {
return Err(format!(
"Matrix column count mismatch: expected {} columns, row {} has {} columns",
ncols,
i + 1,
blk.ncols()
));
}
}
let total_rows: usize = row_blocks.iter().map(|b| b.nrows()).sum();
let mut flat: Vec<Complex<f64>> = Vec::with_capacity(total_rows * ncols);
for blk in &row_blocks {
flat.extend(blk.iter().copied());
}
let m = Array2::from_shape_vec((total_rows, ncols), flat)
.map_err(|e| format!("Matrix shape error: {e}"))?;
Ok(Value::ComplexMatrix(Box::new(m)))
}
MatKind::DateTime => {
let mut ts: Vec<f64> = Vec::new();
for ev_row in &evaluated {
for val in ev_row {
match val {
Value::DateTime(t) => ts.push(*t),
Value::DateTimeArray(v) => ts.extend_from_slice(v),
_ => {
return Err(
"Matrix literal: cannot mix datetime with other types"
.to_string(),
);
}
}
}
}
Ok(Value::DateTimeArray(ts))
}
MatKind::Duration => {
let mut sv: Vec<f64> = Vec::new();
for ev_row in &evaluated {
for val in ev_row {
match val {
Value::Duration(s) => sv.push(*s),
Value::DurationArray(v) => sv.extend_from_slice(v),
_ => {
return Err(
"Matrix literal: cannot mix duration with other types"
.to_string(),
);
}
}
}
}
Ok(Value::DurationArray(sv))
}
MatKind::Numeric => {
let mut row_blocks: Vec<Array2<f64>> = Vec::with_capacity(evaluated.len());
for ev_row in &evaluated {
let mut elem_mats: Vec<Array2<f64>> = Vec::with_capacity(ev_row.len());
for val in ev_row {
match val {
Value::Scalar(n) => {
elem_mats.push(Array2::from_elem((1, 1), *n));
}
Value::Matrix(m) => elem_mats.push((**m).clone()),
Value::Void => {
return Err(
"Void value cannot be used in matrix literal".to_string()
);
}
Value::Str(s) | Value::StringObj(s) => {
let codes: Vec<f64> =
s.chars().map(|c| c as u32 as f64).collect();
let mat = if codes.is_empty() {
Array2::<f64>::zeros((1, 0))
} else {
Array2::from_shape_vec((1, codes.len()), codes)
.map_err(|e| format!("Matrix shape error: {e}"))?
};
elem_mats.push(mat);
}
_ => {
return Err(
"This type cannot be used in matrix literals".to_string()
);
}
}
}
let nrows = elem_mats[0].nrows();
for (i, m) in elem_mats.iter().enumerate().skip(1) {
if m.nrows() != nrows {
return Err(format!(
"Matrix row height mismatch: expected {} rows, element {} has {} rows",
nrows,
i + 1,
m.nrows()
));
}
}
let ncols: usize = elem_mats.iter().map(|m| m.ncols()).sum();
let mut flat: Vec<f64> = Vec::with_capacity(nrows * ncols);
for r in 0..nrows {
for m in &elem_mats {
flat.extend(m.row(r).iter().copied());
}
}
row_blocks.push(
Array2::from_shape_vec((nrows, ncols), flat)
.map_err(|e| format!("Matrix shape error: {e}"))?,
);
}
if row_blocks.is_empty() {
return Ok(Value::Matrix(Box::new(Array2::<f64>::zeros((0, 0)))));
}
let ncols = row_blocks[0].ncols();
if ncols == 0 {
let total_rows: usize = row_blocks.iter().map(|b| b.nrows()).sum();
return Ok(Value::Matrix(Box::new(Array2::zeros((total_rows, 0)))));
}
for (i, blk) in row_blocks.iter().enumerate().skip(1) {
if blk.ncols() != ncols {
return Err(format!(
"Matrix column count mismatch: expected {} columns, row {} has {} columns",
ncols,
i + 1,
blk.ncols()
));
}
}
let total_rows: usize = row_blocks.iter().map(|b| b.nrows()).sum();
let mut flat: Vec<f64> = Vec::with_capacity(total_rows * ncols);
for blk in &row_blocks {
flat.extend(blk.iter().copied());
}
let m = Array2::from_shape_vec((total_rows, ncols), flat)
.map_err(|e| format!("Matrix shape error: {e}"))?;
Ok(Value::Matrix(Box::new(m)))
}
MatKind::Str => {
if evaluated.len() > 1 {
return Err("Multi-row char-array literals are not supported".to_string());
}
let mut out = String::new();
for val in &evaluated[0] {
match val {
Value::Str(s) | Value::StringObj(s) => out.push_str(s),
Value::Scalar(n) => {
let code = n.round();
out.push(
char::from_u32(code as u32)
.ok_or_else(|| format!("char: invalid code {n}"))?,
);
}
Value::Matrix(m) => {
for &n in m.iter() {
out.push(
char::from_u32(n.round() as u32)
.ok_or_else(|| format!("char: invalid code {n}"))?,
);
}
}
_ => {
return Err(
"This type cannot be used in a char-array literal".to_string()
);
}
}
}
Ok(Value::Str(out))
}
}
}
Expr::Transpose(e) => match eval_inner(e, env, io)? {
Value::Void => Err("Transpose is not applicable to void".to_string()),
Value::Scalar(n) => Ok(Value::Scalar(n)),
Value::Matrix(m) => Ok(Value::Matrix(Box::new(m.t().to_owned()))),
Value::Complex(re, im) => Ok(Value::Complex(re, -im)),
Value::ComplexMatrix(m) => Ok(Value::ComplexMatrix(Box::new(m.t().mapv(|c| c.conj())))),
Value::Str(s) => Ok(Value::Str(s)),
Value::StringObj(s) => Ok(Value::StringObj(s)),
v @ (Value::DateTimeArray(_) | Value::DurationArray(_)) => Ok(v),
Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::Map(_) => Err("Transpose is not applicable to this type".to_string()),
},
Expr::StrLiteral(s) => Ok(Value::Str(s.clone())),
Expr::StringObjLiteral(s) => Ok(Value::StringObj(s.clone())),
Expr::Range(start_expr, step_expr, stop_expr) => {
let start = match eval_inner(start_expr, env, io.as_deref_mut())? {
Value::Scalar(n) => n,
_ => return Err("Range bounds must be real scalars".to_string()),
};
let stop = match eval_inner(stop_expr, env, io.as_deref_mut())? {
Value::Scalar(n) => n,
_ => return Err("Range bounds must be real scalars".to_string()),
};
let step = match step_expr {
None => 1.0,
Some(s) => match eval_inner(s, env, io)? {
Value::Scalar(n) => n,
_ => return Err("Range step must be a real scalar".to_string()),
},
};
if step == 0.0 {
return Err("Range step cannot be zero".to_string());
}
let n_float = (stop - start) / step;
if n_float < -1e-10 {
return Ok(Value::Matrix(Box::new(Array2::zeros((1, 0)))));
}
let n = (n_float + 1e-10).floor() as usize + 1;
let vals: Vec<f64> = (0..n).map(|i| start + i as f64 * step).collect();
let m =
Array2::from_shape_vec((1, n), vals).map_err(|e| format!("Range error: {e}"))?;
Ok(Value::Matrix(Box::new(m)))
}
}
}
fn eval_binop(l: Value, op: &Op, r: Value) -> Result<Value, String> {
match (l, r) {
(Value::Void, _) | (_, Value::Void) => {
Err("Cannot apply operator to void value".to_string())
}
(Value::StringObj(a), Value::StringObj(b)) => match op {
Op::Add => Ok(Value::StringObj(a + &b)),
Op::Eq => Ok(Value::Scalar(bool_to_f64(a == b))),
Op::NotEq => Ok(Value::Scalar(bool_to_f64(a != b))),
_ => Err("Operator not supported on string objects".to_string()),
},
(Value::Str(s), r) => eval_binop(str_to_numeric(&s), op, r),
(l, Value::Str(s)) => eval_binop(l, op, str_to_numeric(&s)),
(Value::StringObj(_), _) | (_, Value::StringObj(_)) => {
Err("String object cannot be combined with non-string values".to_string())
}
(Value::Lambda(_), _)
| (_, Value::Lambda(_))
| (Value::Function(_), _)
| (_, Value::Function(_))
| (Value::Tuple(_), _)
| (_, Value::Tuple(_))
| (Value::Cell(_), _)
| (_, Value::Cell(_))
| (Value::Struct(_), _)
| (_, Value::Struct(_))
| (Value::StructArray(_), _)
| (_, Value::StructArray(_))
| (Value::Map(_), _)
| (_, Value::Map(_)) => Err("Cannot apply operator to a Map value".to_string()),
(Value::DateTime(t), Value::Duration(d)) => match op {
Op::Add => Ok(Value::DateTime(t + d)),
Op::Sub => Ok(Value::DateTime(t - d)),
_ => Err("Unsupported operator between datetime and duration".to_string()),
},
(Value::Duration(d), Value::DateTime(t)) => match op {
Op::Add => Ok(Value::DateTime(t + d)),
_ => Err("Unsupported operator between duration and datetime".to_string()),
},
(Value::DateTime(t1), Value::DateTime(t2)) => match op {
Op::Sub => Ok(Value::Duration(t1 - t2)),
Op::Eq => Ok(Value::Scalar(bool_to_f64(
(t1 - t2).abs() < 1e-9 || (t1.is_nan() && t2.is_nan()),
))),
Op::NotEq => Ok(Value::Scalar(bool_to_f64(
(t1 - t2).abs() >= 1e-9 && !(t1.is_nan() && t2.is_nan()),
))),
Op::Lt => Ok(Value::Scalar(bool_to_f64(t1 < t2))),
Op::Gt => Ok(Value::Scalar(bool_to_f64(t1 > t2))),
Op::LtEq => Ok(Value::Scalar(bool_to_f64(t1 <= t2))),
Op::GtEq => Ok(Value::Scalar(bool_to_f64(t1 >= t2))),
_ => Err("Unsupported operator between two datetimes".to_string()),
},
(Value::Duration(d1), Value::Duration(d2)) => match op {
Op::Add => Ok(Value::Duration(d1 + d2)),
Op::Sub => Ok(Value::Duration(d1 - d2)),
Op::Div | Op::ElemDiv => Ok(Value::Scalar(d1 / d2)),
Op::Eq => Ok(Value::Scalar(bool_to_f64((d1 - d2).abs() < 1e-9))),
Op::NotEq => Ok(Value::Scalar(bool_to_f64((d1 - d2).abs() >= 1e-9))),
Op::Lt => Ok(Value::Scalar(bool_to_f64(d1 < d2))),
Op::Gt => Ok(Value::Scalar(bool_to_f64(d1 > d2))),
Op::LtEq => Ok(Value::Scalar(bool_to_f64(d1 <= d2))),
Op::GtEq => Ok(Value::Scalar(bool_to_f64(d1 >= d2))),
_ => Err("Unsupported operator between two durations".to_string()),
},
(Value::Duration(d), Value::Scalar(s)) => match op {
Op::Mul | Op::ElemMul => Ok(Value::Duration(d * s)),
Op::Div | Op::ElemDiv => Ok(Value::Duration(d / s)),
_ => Err("Unsupported operator between duration and scalar".to_string()),
},
(Value::Scalar(s), Value::Duration(d)) => match op {
Op::Mul | Op::ElemMul => Ok(Value::Duration(s * d)),
_ => Err("Unsupported operator between scalar and duration".to_string()),
},
(Value::DateTime(t), Value::DurationArray(dv)) => match op {
Op::Add => Ok(Value::DateTimeArray(dv.iter().map(|d| t + d).collect())),
Op::Sub => Ok(Value::DateTimeArray(dv.iter().map(|d| t - d).collect())),
_ => Err("Unsupported operator between datetime and duration array".to_string()),
},
(Value::DurationArray(dv), Value::DateTime(t)) => match op {
Op::Add => Ok(Value::DateTimeArray(dv.iter().map(|d| t + d).collect())),
_ => Err("Unsupported operator between duration array and datetime".to_string()),
},
(Value::DateTimeArray(tv), Value::Duration(d)) => match op {
Op::Add => Ok(Value::DateTimeArray(tv.iter().map(|t| t + d).collect())),
Op::Sub => Ok(Value::DateTimeArray(tv.iter().map(|t| t - d).collect())),
_ => Err("Unsupported operator between datetime array and duration".to_string()),
},
(Value::DateTimeArray(tv), Value::DurationArray(dv)) => match op {
Op::Add if tv.len() == dv.len() => Ok(Value::DateTimeArray(
tv.iter().zip(&dv).map(|(t, d)| t + d).collect(),
)),
Op::Sub if tv.len() == dv.len() => Ok(Value::DateTimeArray(
tv.iter().zip(&dv).map(|(t, d)| t - d).collect(),
)),
_ => Err("Unsupported or mismatched datetime/duration array operation".to_string()),
},
(Value::DateTimeArray(tv1), Value::DateTimeArray(tv2)) => match op {
Op::Sub if tv1.len() == tv2.len() => Ok(Value::DurationArray(
tv1.iter().zip(&tv2).map(|(a, b)| a - b).collect(),
)),
_ => Err("Unsupported operator between two datetime arrays".to_string()),
},
(Value::DurationArray(dv), Value::Scalar(s)) => match op {
Op::Mul | Op::ElemMul => Ok(Value::DurationArray(dv.iter().map(|d| d * s).collect())),
Op::Div | Op::ElemDiv => Ok(Value::DurationArray(dv.iter().map(|d| d / s).collect())),
_ => Err("Unsupported operator between duration array and scalar".to_string()),
},
(Value::Scalar(s), Value::DurationArray(dv)) => match op {
Op::Mul | Op::ElemMul => Ok(Value::DurationArray(dv.iter().map(|d| s * d).collect())),
_ => Err("Unsupported operator between scalar and duration array".to_string()),
},
(Value::DateTime(_), _)
| (_, Value::DateTime(_))
| (Value::Duration(_), _)
| (_, Value::Duration(_))
| (Value::DateTimeArray(_), _)
| (_, Value::DateTimeArray(_))
| (Value::DurationArray(_), _)
| (_, Value::DurationArray(_)) => {
Err("Unsupported operation on datetime or duration value".to_string())
}
(Value::Complex(re1, im1), Value::Complex(re2, im2)) => {
complex_binop(re1, im1, op, re2, im2)
}
(Value::Complex(re, im), Value::Scalar(s)) => complex_binop(re, im, op, s, 0.0),
(Value::Scalar(s), Value::Complex(re, im)) => complex_binop(s, 0.0, op, re, im),
(Value::Complex(re, im), Value::Matrix(m)) => {
complex_binop_cm(re, im, op, cm_from_real(&m))
}
(Value::Matrix(m), Value::Complex(re, im)) => {
cm_binop_complex(cm_from_real(&m), op, re, im)
}
(Value::ComplexMatrix(a), Value::ComplexMatrix(b)) => complex_matrix_binop(*a, op, *b),
(Value::ComplexMatrix(cm), Value::Matrix(m)) => {
complex_matrix_binop(*cm, op, cm_from_real(&m))
}
(Value::Matrix(m), Value::ComplexMatrix(cm)) => {
complex_matrix_binop(cm_from_real(&m), op, *cm)
}
(Value::ComplexMatrix(cm), Value::Scalar(s)) => cm_binop_scalar(*cm, op, s),
(Value::Scalar(s), Value::ComplexMatrix(cm)) => scalar_binop_cm(s, op, *cm),
(Value::ComplexMatrix(cm), Value::Complex(re, im)) => cm_binop_complex(*cm, op, re, im),
(Value::Complex(re, im), Value::ComplexMatrix(cm)) => complex_binop_cm(re, im, op, *cm),
(Value::Scalar(lv), Value::Scalar(rv)) => {
let result = match op {
Op::Add => lv + rv,
Op::Sub => lv - rv,
Op::Mul | Op::ElemMul => lv * rv,
Op::Div | Op::ElemDiv => lv / rv,
Op::LDiv => rv / lv,
Op::Pow | Op::ElemPow => lv.powf(rv),
Op::Eq => bool_to_f64(lv == rv),
Op::NotEq => bool_to_f64(lv != rv),
Op::Lt => bool_to_f64(lv < rv),
Op::Gt => bool_to_f64(lv > rv),
Op::LtEq => bool_to_f64(lv <= rv),
Op::GtEq => bool_to_f64(lv >= rv),
Op::And | Op::ElemAnd => bool_to_f64(lv != 0.0 && rv != 0.0),
Op::Or | Op::ElemOr => bool_to_f64(lv != 0.0 || rv != 0.0),
};
Ok(Value::Scalar(result))
}
(Value::Matrix(lm), Value::Matrix(rm)) => match op {
Op::Add => {
check_same_shape(&lm, &rm)?;
Ok(Value::Matrix(Box::new(&*lm + &*rm)))
}
Op::Sub => {
check_same_shape(&lm, &rm)?;
Ok(Value::Matrix(Box::new(&*lm - &*rm)))
}
Op::Mul => {
if lm.ncols() != rm.nrows() {
return Err(format!(
"Inner dimensions must agree: {}x{} * {}x{}",
lm.nrows(),
lm.ncols(),
rm.nrows(),
rm.ncols()
));
}
Ok(Value::Matrix(Box::new(lm.dot(&*rm))))
}
Op::ElemMul => {
check_same_shape(&lm, &rm)?;
Ok(Value::Matrix(Box::new(&*lm * &*rm)))
}
Op::ElemDiv => {
check_same_shape(&lm, &rm)?;
Ok(Value::Matrix(Box::new(&*lm / &*rm)))
}
Op::ElemPow => {
check_same_shape(&lm, &rm)?;
Ok(Value::Matrix(Box::new(
ndarray::Zip::from(&*lm)
.and(&*rm)
.map_collect(|a, b| a.powf(*b)),
)))
}
Op::Eq | Op::NotEq | Op::Lt | Op::Gt | Op::LtEq | Op::GtEq => {
check_same_shape(&lm, &rm)?;
Ok(Value::Matrix(Box::new(
ndarray::Zip::from(&*lm)
.and(&*rm)
.map_collect(|a, b| bool_to_f64(cmp_op(op, *a, *b))),
)))
}
Op::And | Op::Or | Op::ElemAnd | Op::ElemOr => {
check_same_shape(&lm, &rm)?;
Ok(Value::Matrix(Box::new(
ndarray::Zip::from(&*lm)
.and(&*rm)
.map_collect(|a, b| bool_to_f64(cmp_op(op, *a, *b))),
)))
}
Op::Div => Err("Matrix / Matrix: use inv(B)*A or A*inv(B)".to_string()),
Op::LDiv => Ok(Value::Matrix(Box::new(solve_linear(&lm, &rm)?))),
Op::Pow => Err("Matrix ^ Matrix: not supported".to_string()),
},
(Value::Scalar(s), Value::Matrix(m)) => match op {
Op::Add => Ok(Value::Matrix(Box::new(s + &*m))),
Op::Sub => Ok(Value::Matrix(Box::new(m.mapv(|x| s - x)))),
Op::Mul | Op::ElemMul => Ok(Value::Matrix(Box::new(s * &*m))),
Op::Div => Err("Scalar / Matrix: not supported".to_string()),
Op::ElemDiv => Err("Scalar ./ Matrix: not supported".to_string()),
Op::LDiv => {
if s == 0.0 {
return Err("Left division by zero (a \\ B requires a ≠ 0)".to_string());
}
Ok(Value::Matrix(Box::new(m.mapv(|x| x / s))))
}
Op::Pow | Op::ElemPow => Ok(Value::Matrix(Box::new(m.mapv(|x| s.powf(x))))),
Op::Eq
| Op::NotEq
| Op::Lt
| Op::Gt
| Op::LtEq
| Op::GtEq
| Op::And
| Op::Or
| Op::ElemAnd
| Op::ElemOr => Ok(Value::Matrix(Box::new(
m.mapv(|x| bool_to_f64(cmp_op(op, s, x))),
))),
},
(Value::Matrix(m), Value::Scalar(s)) => match op {
Op::Add => Ok(Value::Matrix(Box::new(&*m + s))),
Op::Sub => Ok(Value::Matrix(Box::new(&*m - s))),
Op::Mul | Op::ElemMul => Ok(Value::Matrix(Box::new(&*m * s))),
Op::Div | Op::ElemDiv => Ok(Value::Matrix(Box::new(m.mapv(|x| x / s)))),
Op::LDiv => {
let b = Array2::from_elem((m.nrows(), 1), s);
Ok(Value::Matrix(Box::new(solve_linear(&m, &b)?)))
}
Op::Pow | Op::ElemPow => Ok(Value::Matrix(Box::new(m.mapv(|x| x.powf(s))))),
Op::Eq
| Op::NotEq
| Op::Lt
| Op::Gt
| Op::LtEq
| Op::GtEq
| Op::And
| Op::Or
| Op::ElemAnd
| Op::ElemOr => Ok(Value::Matrix(Box::new(
m.mapv(|x| bool_to_f64(cmp_op(op, x, s))),
))),
},
}
}
#[inline]
fn bool_to_f64(b: bool) -> f64 {
if b { 1.0 } else { 0.0 }
}
fn cmp_op(op: &Op, a: f64, b: f64) -> bool {
match op {
Op::Eq => a == b,
Op::NotEq => a != b,
Op::Lt => a < b,
Op::Gt => a > b,
Op::LtEq => a <= b,
Op::GtEq => a >= b,
Op::And | Op::ElemAnd => a != 0.0 && b != 0.0,
Op::Or | Op::ElemOr => a != 0.0 || b != 0.0,
_ => unreachable!(),
}
}
fn complex_binop(re1: f64, im1: f64, op: &Op, re2: f64, im2: f64) -> Result<Value, String> {
match op {
Op::Add => Ok(make_complex(re1 + re2, im1 + im2)),
Op::Sub => Ok(make_complex(re1 - re2, im1 - im2)),
Op::Mul | Op::ElemMul => {
Ok(make_complex(re1 * re2 - im1 * im2, re1 * im2 + im1 * re2))
}
Op::Div | Op::ElemDiv => {
let denom = re2 * re2 + im2 * im2;
if denom == 0.0 {
return Ok(make_complex(re1 / 0.0_f64, im1 / 0.0_f64));
}
Ok(make_complex(
(re1 * re2 + im1 * im2) / denom,
(im1 * re2 - re1 * im2) / denom,
))
}
Op::Pow | Op::ElemPow => {
let r1 = (re1 * re1 + im1 * im1).sqrt();
if r1 == 0.0 {
if re2 > 0.0 {
return Ok(Value::Scalar(0.0));
}
return Ok(Value::Complex(f64::NAN, f64::NAN));
}
if im2 == 0.0 && re2.fract() == 0.0 && re2.abs() < 1_000_000.0 {
let n = re2 as i64;
if n == 0 {
return Ok(Value::Scalar(1.0));
}
let abs_n = n.unsigned_abs();
let (mut rr, mut ri) = (1.0_f64, 0.0_f64);
let (mut br, mut bi) = (re1, im1);
let mut exp = abs_n;
while exp > 0 {
if exp & 1 == 1 {
let nr = rr * br - ri * bi;
let ni = rr * bi + ri * br;
rr = nr;
ri = ni;
}
let nr = br * br - bi * bi;
let ni = 2.0 * br * bi;
br = nr;
bi = ni;
exp >>= 1;
}
if n < 0 {
let denom = rr * rr + ri * ri;
return Ok(make_complex(rr / denom, -ri / denom));
}
return Ok(make_complex(rr, ri));
}
let theta1 = im1.atan2(re1);
let ln_r1 = r1.ln();
let exp_re = re2 * ln_r1 - im2 * theta1;
let exp_im = im2 * ln_r1 + re2 * theta1;
let mag = exp_re.exp();
Ok(make_complex(mag * exp_im.cos(), mag * exp_im.sin()))
}
Op::Eq => Ok(Value::Scalar(bool_to_f64(re1 == re2 && im1 == im2))),
Op::NotEq => Ok(Value::Scalar(bool_to_f64(re1 != re2 || im1 != im2))),
Op::Lt | Op::Gt | Op::LtEq | Op::GtEq => {
Err("Ordering is not defined for complex numbers".to_string())
}
Op::And | Op::ElemAnd => Ok(Value::Scalar(bool_to_f64(
(re1 != 0.0 || im1 != 0.0) && (re2 != 0.0 || im2 != 0.0),
))),
Op::Or | Op::ElemOr => Ok(Value::Scalar(bool_to_f64(
re1 != 0.0 || im1 != 0.0 || re2 != 0.0 || im2 != 0.0,
))),
Op::LDiv => Err("Left division (\\) is not supported for complex numbers".to_string()),
}
}
#[inline]
fn make_complex(re: f64, im: f64) -> Value {
if im == 0.0 {
Value::Scalar(re)
} else {
Value::Complex(re, im)
}
}
#[inline]
fn cm_from_real(m: &Array2<f64>) -> Array2<Complex<f64>> {
m.mapv(|x| Complex::new(x, 0.0))
}
fn complex_matrix_binop(
a: Array2<Complex<f64>>,
op: &Op,
b: Array2<Complex<f64>>,
) -> Result<Value, String> {
let same_shape = || {
if a.shape() != b.shape() {
Err(format!(
"Matrix dimensions must agree: {}×{} vs {}×{}",
a.nrows(),
a.ncols(),
b.nrows(),
b.ncols()
))
} else {
Ok(())
}
};
match op {
Op::Add => {
same_shape()?;
Ok(Value::ComplexMatrix(Box::new(a + b)))
}
Op::Sub => {
same_shape()?;
Ok(Value::ComplexMatrix(Box::new(a - b)))
}
Op::Mul => {
if a.ncols() != b.nrows() {
return Err(format!(
"Inner dimensions must agree: {}×{} * {}×{}",
a.nrows(),
a.ncols(),
b.nrows(),
b.ncols()
));
}
Ok(Value::ComplexMatrix(Box::new(a.dot(&b))))
}
Op::ElemMul => {
same_shape()?;
Ok(Value::ComplexMatrix(Box::new(a * b)))
}
Op::ElemDiv => {
same_shape()?;
Ok(Value::ComplexMatrix(Box::new(a / b)))
}
Op::ElemPow => {
same_shape()?;
Ok(Value::ComplexMatrix(Box::new(
ndarray::Zip::from(&a)
.and(&b)
.map_collect(|x, y| x.powc(*y)),
)))
}
Op::Pow => Err(
"ComplexMatrix ^ ComplexMatrix: not supported; use .^ for element-wise power"
.to_string(),
),
Op::Div | Op::LDiv => {
Err("Complex matrix / and \\ not supported; use inv(A)*B".to_string())
}
Op::Eq => {
same_shape()?;
Ok(Value::Matrix(Box::new(
ndarray::Zip::from(&a)
.and(&b)
.map_collect(|x, y| bool_to_f64(x == y)),
)))
}
Op::NotEq => {
same_shape()?;
Ok(Value::Matrix(Box::new(
ndarray::Zip::from(&a)
.and(&b)
.map_collect(|x, y| bool_to_f64(x != y)),
)))
}
Op::Lt | Op::Gt | Op::LtEq | Op::GtEq => {
Err("Ordering comparison not defined for complex matrices".to_string())
}
Op::And | Op::ElemAnd => {
same_shape()?;
Ok(Value::Matrix(Box::new(
ndarray::Zip::from(&a).and(&b).map_collect(|x, y| {
bool_to_f64((x.re != 0.0 || x.im != 0.0) && (y.re != 0.0 || y.im != 0.0))
}),
)))
}
Op::Or | Op::ElemOr => {
same_shape()?;
Ok(Value::Matrix(Box::new(
ndarray::Zip::from(&a).and(&b).map_collect(|x, y| {
bool_to_f64(x.re != 0.0 || x.im != 0.0 || y.re != 0.0 || y.im != 0.0)
}),
)))
}
}
}
fn cm_binop_scalar(cm: Array2<Complex<f64>>, op: &Op, s: f64) -> Result<Value, String> {
let c = Complex::new(s, 0.0);
match op {
Op::Add => Ok(Value::ComplexMatrix(Box::new(cm.mapv(|x| x + c)))),
Op::Sub => Ok(Value::ComplexMatrix(Box::new(cm.mapv(|x| x - c)))),
Op::Mul | Op::ElemMul => Ok(Value::ComplexMatrix(Box::new(cm.mapv(|x| x * c)))),
Op::Div | Op::ElemDiv => Ok(Value::ComplexMatrix(Box::new(cm.mapv(|x| x / c)))),
Op::Pow | Op::ElemPow => Ok(Value::ComplexMatrix(Box::new(cm.mapv(|x| x.powf(s))))),
Op::Eq => Ok(Value::Matrix(Box::new(cm.mapv(|x| bool_to_f64(x == c))))),
Op::NotEq => Ok(Value::Matrix(Box::new(cm.mapv(|x| bool_to_f64(x != c))))),
_ => Err("Unsupported operator between complex matrix and scalar".to_string()),
}
}
fn scalar_binop_cm(s: f64, op: &Op, cm: Array2<Complex<f64>>) -> Result<Value, String> {
let c = Complex::new(s, 0.0);
match op {
Op::Add => Ok(Value::ComplexMatrix(Box::new(cm.mapv(|x| c + x)))),
Op::Sub => Ok(Value::ComplexMatrix(Box::new(cm.mapv(|x| c - x)))),
Op::Mul | Op::ElemMul => Ok(Value::ComplexMatrix(Box::new(cm.mapv(|x| c * x)))),
Op::Pow | Op::ElemPow => Ok(Value::ComplexMatrix(Box::new(cm.mapv(|x| c.powc(x))))),
Op::Eq => Ok(Value::Matrix(Box::new(cm.mapv(|x| bool_to_f64(c == x))))),
Op::NotEq => Ok(Value::Matrix(Box::new(cm.mapv(|x| bool_to_f64(c != x))))),
_ => Err("Unsupported operator between scalar and complex matrix".to_string()),
}
}
fn cm_binop_complex(cm: Array2<Complex<f64>>, op: &Op, re: f64, im: f64) -> Result<Value, String> {
let c = Complex::new(re, im);
match op {
Op::Add => Ok(Value::ComplexMatrix(Box::new(cm.mapv(|x| x + c)))),
Op::Sub => Ok(Value::ComplexMatrix(Box::new(cm.mapv(|x| x - c)))),
Op::Mul | Op::ElemMul => Ok(Value::ComplexMatrix(Box::new(cm.mapv(|x| x * c)))),
Op::Div | Op::ElemDiv => Ok(Value::ComplexMatrix(Box::new(cm.mapv(|x| x / c)))),
Op::Pow | Op::ElemPow => Ok(Value::ComplexMatrix(Box::new(cm.mapv(|x| x.powc(c))))),
Op::Eq => Ok(Value::Matrix(Box::new(cm.mapv(|x| bool_to_f64(x == c))))),
Op::NotEq => Ok(Value::Matrix(Box::new(cm.mapv(|x| bool_to_f64(x != c))))),
_ => Err("Unsupported operator between complex matrix and complex scalar".to_string()),
}
}
fn complex_binop_cm(re: f64, im: f64, op: &Op, cm: Array2<Complex<f64>>) -> Result<Value, String> {
let c = Complex::new(re, im);
match op {
Op::Add => Ok(Value::ComplexMatrix(Box::new(cm.mapv(|x| c + x)))),
Op::Sub => Ok(Value::ComplexMatrix(Box::new(cm.mapv(|x| c - x)))),
Op::Mul | Op::ElemMul => Ok(Value::ComplexMatrix(Box::new(cm.mapv(|x| c * x)))),
Op::Pow | Op::ElemPow => Ok(Value::ComplexMatrix(Box::new(cm.mapv(|x| c.powc(x))))),
Op::Eq => Ok(Value::Matrix(Box::new(cm.mapv(|x| bool_to_f64(c == x))))),
Op::NotEq => Ok(Value::Matrix(Box::new(cm.mapv(|x| bool_to_f64(c != x))))),
_ => Err("Unsupported operator between complex scalar and complex matrix".to_string()),
}
}
fn str_to_numeric(s: &str) -> Value {
let codes: Vec<f64> = s.chars().map(|c| c as u32 as f64).collect();
match codes.len() {
0 => Value::Matrix(Box::new(Array2::zeros((1, 0)))),
1 => Value::Scalar(codes[0]),
n => Value::Matrix(Box::new(Array2::from_shape_vec((1, n), codes).unwrap())),
}
}
fn string_arg<'a>(v: &'a Value, fname: &str, pos: usize) -> Result<&'a str, String> {
match v {
Value::Str(s) | Value::StringObj(s) => Ok(s.as_str()),
_ => Err(format!(
"Function '{fname}' argument {pos} must be a string"
)),
}
}
fn check_same_shape(lm: &Array2<f64>, rm: &Array2<f64>) -> Result<(), String> {
if lm.shape() != rm.shape() {
return Err(format!(
"Matrix size mismatch: {}x{} vs {}x{}",
lm.nrows(),
lm.ncols(),
rm.nrows(),
rm.ncols()
));
}
Ok(())
}
fn scalar_arg(v: &Value, fname: &str, pos: usize) -> Result<f64, String> {
match v {
Value::Void => Err(format!(
"Function '{fname}' argument {pos} must be a scalar, got void"
)),
Value::Scalar(n) => Ok(*n),
Value::Complex(re, im) if *im == 0.0 => Ok(*re),
Value::Complex(_, _) => Err(format!(
"Function '{fname}' argument {pos} must be real, got a complex number"
)),
Value::Matrix(_) => Err(format!(
"Function '{fname}' argument {pos} must be a scalar, got a matrix"
)),
Value::ComplexMatrix(_) => Err(format!(
"Function '{fname}' argument {pos} must be a scalar, got a complex matrix"
)),
Value::Str(s) if s.chars().count() == 1 => Ok(s.chars().next().unwrap() as u32 as f64),
Value::Str(_) | Value::StringObj(_) => Err(format!(
"Function '{fname}' argument {pos} must be a scalar, got a string"
)),
Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err(format!(
"Function '{fname}' argument {pos} must be a scalar, got a non-numeric value"
)),
}
}
fn size_arg(v: &Value, fname: &str) -> Result<(usize, usize), String> {
match v {
Value::Scalar(n) => Ok((*n as usize, *n as usize)),
Value::Matrix(m) => {
let elems: Vec<f64> = m.iter().copied().collect();
match elems.as_slice() {
[n] => Ok((*n as usize, *n as usize)),
[r, c] => Ok((*r as usize, *c as usize)),
_ => Err(format!(
"{fname}: size argument must be a scalar or a 1×2 vector, \
got a {}×{} matrix",
m.nrows(),
m.ncols()
)),
}
}
_ => Err(format!(
"{fname}: size argument must be a scalar or a [rows cols] vector"
)),
}
}
fn randi_range(v: &Value) -> Result<(i64, i64), String> {
match v {
Value::Scalar(n) => {
let hi = *n as i64;
if hi < 1 {
return Err("randi: max must be a positive integer".to_string());
}
Ok((1, hi))
}
Value::Matrix(m) if m.len() == 2 => {
let vals: Vec<f64> = m.iter().copied().collect();
let lo = vals[0] as i64;
let hi = vals[1] as i64;
if lo > hi {
return Err("randi: [min, max] range is empty".to_string());
}
Ok((lo, hi))
}
_ => Err("randi: first argument must be a scalar max or a [min, max] vector".to_string()),
}
}
fn numeric_vec(v: &Value, fname: &str) -> Result<Vec<f64>, String> {
match v {
Value::Scalar(n) => Ok(vec![*n]),
Value::Matrix(m) => Ok(m.iter().copied().collect()),
_ => Err(format!("{fname}: argument must be numeric")),
}
}
fn stat_var_vec(vals: &[f64], population: bool) -> f64 {
let n = vals.len();
if n == 0 {
return f64::NAN;
}
if n == 1 {
return 0.0;
}
let mean = vals.iter().sum::<f64>() / n as f64;
let ss: f64 = vals.iter().map(|&x| (x - mean).powi(2)).sum();
let denom = if population { n as f64 } else { (n - 1) as f64 };
ss / denom
}
fn apply_stat<F>(v: &Value, mut f: F, fname: &str) -> Result<Value, String>
where
F: FnMut(&[f64]) -> f64,
{
match v {
Value::Scalar(n) => Ok(Value::Scalar(f(&[*n]))),
Value::Matrix(m) => {
if m.nrows() == 1 || m.ncols() == 1 {
let vals: Vec<f64> = m.iter().copied().collect();
Ok(Value::Scalar(f(&vals)))
} else {
let ncols = m.ncols();
let result: Vec<f64> = (0..ncols)
.map(|c| {
let col: Vec<f64> = m.column(c).iter().copied().collect();
f(&col)
})
.collect();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, ncols), result).unwrap(),
)))
}
}
_ => Err(format!("{fname}: argument must be numeric")),
}
}
fn percentile_sorted(sorted: &[f64], p: f64) -> f64 {
let n = sorted.len();
if n == 0 {
return f64::NAN;
}
if n == 1 {
return sorted[0];
}
let p = p.clamp(0.0, 100.0);
let idx = (p / 100.0 * n as f64 - 0.5).max(0.0).min((n - 1) as f64);
let lo = idx.floor() as usize;
let hi = idx.ceil() as usize;
let frac = idx - lo as f64;
sorted[lo] * (1.0 - frac) + sorted[hi] * frac
}
fn apply_elem<F: Fn(f64) -> f64>(v: &Value, f: F) -> Result<Value, String> {
match v {
Value::Void => Err("Element-wise function not applicable to void".to_string()),
Value::Scalar(n) => Ok(Value::Scalar(f(*n))),
Value::Matrix(m) => Ok(Value::Matrix(Box::new(m.mapv(f)))),
Value::Complex(re, im) if *im == 0.0 => Ok(Value::Scalar(f(*re))),
Value::Complex(_, _) => {
Err("Element-wise real function not applicable to complex values".to_string())
}
Value::ComplexMatrix(_) => {
Err("Element-wise real function not applicable to complex matrices".to_string())
}
Value::Str(_) | Value::StringObj(_) => {
Err("Element-wise function not applicable to strings".to_string())
}
Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err("Element-wise function not applicable to this type".to_string()),
}
}
fn apply_reduction<F>(v: &Value, f: F) -> Result<Value, String>
where
F: Fn(&[f64]) -> f64,
{
match v {
Value::Void => Err("Reduction not applicable to void".to_string()),
Value::Scalar(n) => Ok(Value::Scalar(f(&[*n]))),
Value::Complex(_, _) => Err("Reduction not applicable to complex values".to_string()),
Value::ComplexMatrix(_) => Err("Reduction not applicable to complex matrices".to_string()),
Value::Str(_) | Value::StringObj(_) => {
Err("Reduction not applicable to strings".to_string())
}
Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err("Reduction not applicable to this type".to_string()),
Value::Matrix(m) => {
if m.nrows() == 1 || m.ncols() == 1 {
let vals: Vec<f64> = m.iter().copied().collect();
Ok(Value::Scalar(f(&vals)))
} else {
let ncols = m.ncols();
let result: Vec<f64> = (0..ncols)
.map(|c| {
let col: Vec<f64> = m.column(c).iter().copied().collect();
f(&col)
})
.collect();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, ncols), result).unwrap(),
)))
}
}
}
}
fn apply_cm_reduction<F>(v: &Value, f: F) -> Result<Value, String>
where
F: Fn(&[Complex<f64>]) -> Complex<f64>,
{
let make_scalar = |c: Complex<f64>| -> Value {
if c.im == 0.0 {
Value::Scalar(c.re)
} else {
Value::Complex(c.re, c.im)
}
};
match v {
Value::Scalar(n) => Ok(make_scalar(f(&[Complex::new(*n, 0.0)]))),
Value::Complex(re, im) => Ok(make_scalar(f(&[Complex::new(*re, *im)]))),
Value::Matrix(m) => {
if m.nrows() == 1 || m.ncols() == 1 {
let vals: Vec<Complex<f64>> = m.iter().map(|&x| Complex::new(x, 0.0)).collect();
Ok(make_scalar(f(&vals)))
} else {
let ncols = m.ncols();
let result: Vec<Complex<f64>> = (0..ncols)
.map(|c| {
let col: Vec<Complex<f64>> =
m.column(c).iter().map(|&x| Complex::new(x, 0.0)).collect();
f(&col)
})
.collect();
if result.iter().all(|c| c.im == 0.0) {
let reals: Vec<f64> = result.iter().map(|c| c.re).collect();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, ncols), reals).unwrap(),
)))
} else {
Ok(Value::ComplexMatrix(Box::new(
Array2::from_shape_vec((1, ncols), result).unwrap(),
)))
}
}
}
Value::ComplexMatrix(m) => {
if m.nrows() == 1 || m.ncols() == 1 {
let vals: Vec<Complex<f64>> = m.iter().copied().collect();
Ok(make_scalar(f(&vals)))
} else {
let ncols = m.ncols();
let result: Vec<Complex<f64>> = (0..ncols)
.map(|c| {
let col: Vec<Complex<f64>> = m.column(c).iter().copied().collect();
f(&col)
})
.collect();
if result.iter().all(|c| c.im == 0.0) {
let reals: Vec<f64> = result.iter().map(|c| c.re).collect();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, ncols), reals).unwrap(),
)))
} else {
Ok(Value::ComplexMatrix(Box::new(
Array2::from_shape_vec((1, ncols), result).unwrap(),
)))
}
}
}
_ => Err("Reduction not applicable to this type".to_string()),
}
}
fn apply_cumulative<F>(v: &Value, combine: F) -> Result<Value, String>
where
F: Fn(f64, f64) -> f64,
{
match v {
Value::Void => Err("Cumulative reduction not applicable to void".to_string()),
Value::Scalar(n) => Ok(Value::Scalar(*n)),
Value::Complex(_, _) => {
Err("Cumulative reduction not applicable to complex values".to_string())
}
Value::ComplexMatrix(_) => {
Err("Cumulative reduction not applicable to complex matrices".to_string())
}
Value::Str(_) | Value::StringObj(_) => {
Err("Cumulative reduction not applicable to strings".to_string())
}
Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err("Cumulative reduction not applicable to this type".to_string()),
Value::Matrix(m) => {
let initial = combine(0.0, 0.0); let identity = if (combine(1.0, 1.0) - 1.0).abs() < 1e-15 && initial == 0.0 {
1.0 } else {
0.0 };
let (nrows, ncols) = (m.nrows(), m.ncols());
let mut result = m.clone();
if nrows == 1 || ncols == 1 {
let mut acc = identity;
for v in result.iter_mut() {
acc = combine(acc, *v);
*v = acc;
}
} else {
for c in 0..ncols {
let mut acc = identity;
for r in 0..nrows {
acc = combine(acc, result[[r, c]]);
result[[r, c]] = acc;
}
}
}
Ok(Value::Matrix(result))
}
}
}
fn find_nonzero(v: &Value, max_k: usize) -> Result<Value, String> {
match v {
Value::Void => Err("find: not applicable to void".to_string()),
Value::ComplexMatrix(_) => Err("find: not applicable to complex matrices".to_string()),
Value::Str(_) | Value::StringObj(_) => Err("find: not applicable to strings".to_string()),
Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err("find: not applicable to this type".to_string()),
Value::Complex(re, im) => {
if (*re != 0.0 || *im != 0.0) && max_k >= 1 {
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, 1), vec![1.0]).unwrap(),
)))
} else {
Ok(Value::Matrix(Box::new(Array2::zeros((1, 0)))))
}
}
Value::Scalar(n) => {
if *n != 0.0 && max_k >= 1 {
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, 1), vec![1.0]).unwrap(),
)))
} else {
Ok(Value::Matrix(Box::new(Array2::zeros((1, 0)))))
}
}
Value::Matrix(m) => {
let nrows = m.nrows();
let total = m.len();
let mut idxs: Vec<f64> = Vec::new();
for i in 0..total {
if idxs.len() >= max_k {
break;
}
let row = i % nrows;
let col = i / nrows;
if m[[row, col]] != 0.0 {
idxs.push((i + 1) as f64);
}
}
let n = idxs.len();
if n == 0 {
Ok(Value::Matrix(Box::new(Array2::zeros((1, 0)))))
} else {
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, n), idxs).unwrap(),
)))
}
}
}
}
pub fn format_printf(fmt: &str, args: &[Value]) -> Result<String, String> {
let mut result = String::new();
let mut arg_idx = 0;
loop {
let consumed_before = arg_idx;
let mut chars = fmt.chars().peekable();
while let Some(c) = chars.next() {
if c == '\\' {
match chars.next() {
Some('n') => result.push('\n'),
Some('t') => result.push('\t'),
Some('\\') => result.push('\\'),
Some('\'') => result.push('\''),
Some('"') => result.push('"'),
Some(other) => {
result.push('\\');
result.push(other);
}
None => result.push('\\'),
}
continue;
}
if c != '%' {
result.push(c);
continue;
}
if chars.peek() == Some(&'%') {
chars.next();
result.push('%');
continue;
}
let mut flag_minus = false;
let mut flag_plus = false;
let mut flag_zero = false;
let mut flag_space = false;
loop {
match chars.peek() {
Some('-') => {
flag_minus = true;
chars.next();
}
Some('+') => {
flag_plus = true;
chars.next();
}
Some('0') => {
flag_zero = true;
chars.next();
}
Some(' ') => {
flag_space = true;
chars.next();
}
_ => break,
}
}
let mut width_str = String::new();
while let Some(&d) = chars.peek() {
if d.is_ascii_digit() {
width_str.push(d);
chars.next();
} else {
break;
}
}
let width: usize = width_str.parse().unwrap_or(0);
let mut precision: Option<usize> = None;
if chars.peek() == Some(&'.') {
chars.next();
let mut p = String::new();
while let Some(&d) = chars.peek() {
if d.is_ascii_digit() {
p.push(d);
chars.next();
} else {
break;
}
}
precision = Some(p.parse().unwrap_or(0));
}
let spec = match chars.next() {
Some(s) => s,
None => {
return Err("fprintf: incomplete format specifier at end of string".to_string());
}
};
if arg_idx >= args.len() {
continue;
}
let arg = &args[arg_idx];
arg_idx += 1;
let formatted = match spec {
'd' | 'i' => {
let n = printf_scalar(arg, spec)?;
let i = n.trunc() as i64;
let s = printf_sign_str(i >= 0, flag_plus, flag_space, format!("{}", i.abs()));
printf_pad(s, width, flag_minus, flag_zero)
}
'f' => {
let n = printf_scalar(arg, spec)?;
let prec = precision.unwrap_or(6);
let s = printf_sign_str(
n >= 0.0,
flag_plus,
flag_space,
format!("{:.prec$}", n.abs(), prec = prec),
);
printf_pad(s, width, flag_minus, flag_zero)
}
'e' | 'E' => {
let n = printf_scalar(arg, spec)?;
let prec = precision.unwrap_or(6);
let s = printf_format_sci(n, prec, flag_plus, flag_space, spec == 'E');
printf_pad(s, width, flag_minus, flag_zero)
}
'g' | 'G' => {
let n = printf_scalar(arg, spec)?;
let prec = precision.unwrap_or(6).max(1);
let s = printf_format_g(n, prec, flag_plus, flag_space, spec == 'G');
printf_pad(s, width, flag_minus, flag_zero)
}
'x' | 'X' => {
let n = printf_scalar(arg, spec)?;
let i = n.trunc() as u64;
let hex = if spec == 'X' {
format!("{:X}", i)
} else {
format!("{:x}", i)
};
printf_pad(hex, width, flag_minus, flag_zero)
}
's' => {
let s = printf_string(arg)?;
let s = if let Some(max_len) = precision {
s.chars().take(max_len).collect::<String>()
} else {
s
};
printf_pad(s, width, flag_minus, false)
}
other => return Err(format!("fprintf: unknown format specifier '%{other}'")),
};
result.push_str(&formatted);
}
if arg_idx >= args.len() || arg_idx == consumed_before {
break;
}
}
Ok(result)
}
fn printf_scalar(v: &Value, spec: char) -> Result<f64, String> {
match v {
Value::Scalar(n) => Ok(*n),
Value::Complex(re, im) if *im == 0.0 => Ok(*re),
Value::Str(s) if s.chars().count() == 1 => Ok(s.chars().next().unwrap() as u32 as f64),
_ => Err(format!(
"fprintf: expected numeric argument for '%{spec}', got {:?}",
std::mem::discriminant(v)
)),
}
}
fn printf_string(v: &Value) -> Result<String, String> {
match v {
Value::Str(s) | Value::StringObj(s) => Ok(s.clone()),
Value::Scalar(n) => Ok(format_number(*n)),
Value::Complex(re, im) => Ok(format_complex(*re, *im, &FormatMode::Custom(6))),
Value::Void => Err("fprintf: cannot format void as string".to_string()),
Value::Matrix(_) => Err("fprintf: cannot format matrix as string".to_string()),
Value::ComplexMatrix(_) => {
Err("fprintf: cannot format complex matrix as string".to_string())
}
Value::DateTime(ts) => Ok(crate::datetime::format_datetime(*ts)),
Value::Duration(s) => Ok(crate::datetime::format_duration(*s)),
Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err("fprintf: cannot format this type as string".to_string()),
}
}
fn printf_sign_str(positive: bool, flag_plus: bool, flag_space: bool, digits: String) -> String {
if positive {
if flag_plus {
format!("+{digits}")
} else if flag_space {
format!(" {digits}")
} else {
digits
}
} else {
format!("-{digits}")
}
}
fn printf_pad(s: String, width: usize, left_align: bool, zero_pad: bool) -> String {
if s.len() >= width {
return s;
}
let pad_len = width - s.len();
if left_align {
format!("{s}{}", " ".repeat(pad_len))
} else if zero_pad {
let (prefix, rest) = if s.starts_with(['+', '-', ' ']) {
s.split_at(1)
} else {
("", s.as_str())
};
format!("{prefix}{}{rest}", "0".repeat(pad_len))
} else {
format!("{}{s}", " ".repeat(pad_len))
}
}
fn printf_format_sci(
n: f64,
prec: usize,
flag_plus: bool,
flag_space: bool,
upper: bool,
) -> String {
if n == 0.0 {
let zeros = "0".repeat(prec);
let sep = if prec > 0 {
format!(".{zeros}")
} else {
String::new()
};
let e_char = if upper { 'E' } else { 'e' };
let sign = if flag_plus {
"+"
} else if flag_space {
" "
} else {
""
};
return format!("{sign}0{sep}{e_char}+00");
}
let neg = n < 0.0;
let abs_n = n.abs();
let exp = abs_n.log10().floor() as i32;
let mantissa = abs_n / 10f64.powi(exp);
let man_str = format!("{:.prec$}", mantissa, prec = prec);
let e_char = if upper { 'E' } else { 'e' };
let exp_sign = if exp >= 0 { '+' } else { '-' };
let exp_abs = exp.unsigned_abs();
let exp_str = if exp_abs < 10 {
format!("{e_char}{exp_sign}0{exp_abs}")
} else {
format!("{e_char}{exp_sign}{exp_abs}")
};
let sign_str = if neg {
"-"
} else if flag_plus {
"+"
} else if flag_space {
" "
} else {
""
};
format!("{sign_str}{man_str}{exp_str}")
}
fn printf_format_g(n: f64, prec: usize, flag_plus: bool, flag_space: bool, upper: bool) -> String {
if n == 0.0 {
let sign = if flag_plus {
"+"
} else if flag_space {
" "
} else {
""
};
return format!("{sign}0");
}
let abs_n = n.abs();
let exp = abs_n.log10().floor() as i32;
if exp < -4 || exp >= prec as i32 {
let s = printf_format_sci(n, prec.saturating_sub(1), flag_plus, flag_space, upper);
trim_g_sci(s, upper)
} else {
let decimal_places = (prec as i32 - 1 - exp).max(0) as usize;
let neg = n < 0.0;
let s = format!("{:.prec$}", abs_n, prec = decimal_places);
let s = if s.contains('.') {
s.trim_end_matches('0').trim_end_matches('.').to_string()
} else {
s
};
let sign = if neg {
"-"
} else if flag_plus {
"+"
} else if flag_space {
" "
} else {
""
};
format!("{sign}{s}")
}
}
fn trim_g_sci(s: String, upper: bool) -> String {
let e_char = if upper { 'E' } else { 'e' };
if let Some(e_pos) = s.find(e_char) {
let mantissa = &s[..e_pos];
let exp_part = &s[e_pos..];
let trimmed = if mantissa.contains('.') {
mantissa.trim_end_matches('0').trim_end_matches('.')
} else {
mantissa
};
format!("{trimmed}{exp_part}")
} else {
s
}
}
fn call_function_value(
f: &Value,
args: &[Value],
io: Option<&mut IoContext>,
) -> Result<Value, String> {
match f {
Value::Lambda(lf) => {
let lf = lf.clone();
lf.0(args, io)
}
Value::Function(_) => {
let empty_env = Env::new();
match io {
Some(io_ref) => FN_CALL_HOOK.with(|c| match c.get() {
Some(hook) => hook("<anonymous>", f, args, &empty_env, io_ref),
None => Err("User function execution not initialized".to_string()),
}),
None => {
let mut tmp_io = IoContext::new();
FN_CALL_HOOK.with(|c| match c.get() {
Some(hook) => hook("<anonymous>", f, args, &empty_env, &mut tmp_io),
None => Err("User function execution not initialized".to_string()),
})
}
}
}
_ => Err("cellfun/arrayfun: first argument must be a function or lambda (@fn)".to_string()),
}
}
pub fn builtin_names() -> &'static [&'static str] {
&[
"abs",
"acos",
"all",
"angle",
"any",
"arrayfun",
"asin",
"assert",
"atan",
"atan2",
"bitand",
"bitnot",
"bitor",
"bitshift",
"bitxor",
"ceil",
"cell",
"cellfun",
"chol",
"complex",
"cond",
"cross",
"conj",
"contains",
"conv",
"cos",
"cov",
"cumprod",
"cumsum",
"datenum",
"datestr",
"datevec",
"datetime",
"day",
"days",
"deconv",
"det",
"diag",
"diff",
"dir",
"dot",
"disp",
"dlmread",
"dlmwrite",
"eig",
"endsWith",
"erf",
"eval",
"erfc",
"exist",
"exp",
"eye",
"fclose",
"fft",
"fftfreq",
"fftshift",
"fgetl",
"fgets",
"fieldnames",
"ifft",
"ifftshift",
"figure",
"find",
"fliplr",
"flipud",
"floor",
"fopen",
"fprintf",
"genpath",
"histc",
"hour",
"hours",
"hypot",
"imag",
"ind2sub",
"int2str",
"interp1",
"intersect",
"inv",
"iqr",
"iscell",
"ismember",
"ischar",
"isdatetime",
"isduration",
"isempty",
"isfield",
"isfile",
"isfinite",
"isfolder",
"isinf",
"isnan",
"isnat",
"isreal",
"isKey",
"isstring",
"isstruct",
"jsonencode",
"jsondecode",
"keys",
"kron",
"kurtosis",
"lasterr",
"length",
"linspace",
"load",
"log",
"log10",
"log2",
"lower",
"lu",
"mat2str",
"max",
"mean",
"median",
"meshgrid",
"milliseconds",
"min",
"minute",
"minutes",
"mod",
"mode",
"month",
"nan",
"norm",
"normcdf",
"normpdf",
"not",
"null",
"num2str",
"numel",
"ones",
"orth",
"pinv",
"poly",
"polyfit",
"polyval",
"posixtime",
"prctile",
"prod",
"qr",
"rand",
"randi",
"randn",
"rank",
"readmatrix",
"readtable",
"real",
"regexp",
"regexpi",
"regexprep",
"rem",
"repelem",
"remove",
"repmat",
"reshape",
"rmfield",
"rng",
"roots",
"round",
"setdiff",
"second",
"seconds",
"sign",
"sin",
"size",
"skewness",
"sort",
"sprintf",
"sqrt",
"startsWith",
"std",
"str2double",
"str2num",
"strcmp",
"strcmpi",
"strjoin",
"strrep",
"strsplit",
"strtrim",
"sub2ind",
"sum",
"svd",
"tan",
"tic",
"toc",
"trace",
"tril",
"triu",
"union",
"unique",
"upper",
"values",
"var",
"writetable",
"xor",
"year",
"years",
"zeros",
"zscore",
]
}
fn levenshtein(a: &str, b: &str) -> usize {
let a: Vec<char> = a.chars().collect();
let b: Vec<char> = b.chars().collect();
let (m, n) = (a.len(), b.len());
let mut row: Vec<usize> = (0..=n).collect();
for i in 1..=m {
let mut prev = row[0];
row[0] = i;
for j in 1..=n {
let next = if a[i - 1] == b[j - 1] {
prev
} else {
1 + prev.min(row[j]).min(row[j - 1])
};
prev = row[j];
row[j] = next;
}
}
row[n]
}
fn suggest_similar(name: &str, env: &Env) -> Option<String> {
const MAX_DIST: usize = 2;
let mut best: Option<(String, usize)> = None;
let mut update = |candidate: &str| {
let d = levenshtein(name, candidate);
if d <= MAX_DIST && best.as_ref().is_none_or(|(_, bd)| d < *bd) {
best = Some((candidate.to_string(), d));
}
};
for key in env.keys() {
update(key);
}
for &bname in builtin_names() {
update(bname);
}
best.map(|(s, _)| s)
}
fn assert_values_equal(a: &Value, b: &Value, tol: Option<f64>) -> Result<Value, String> {
match (a, b) {
(Value::Scalar(x), Value::Scalar(y)) => {
let ok = match tol {
None => x == y,
Some(t) => (x - y).abs() <= t,
};
if ok {
Ok(Value::Void)
} else if let Some(t) = tol {
Err(format!(
"assert: |{x} - {y}| = {} exceeds tolerance {t}",
(x - y).abs()
))
} else {
Err(format!("assert: {x} ~= {y}"))
}
}
(Value::Matrix(ma), Value::Matrix(mb)) => {
if ma.shape() != mb.shape() {
return Err(format!(
"assert: size mismatch [{}×{}] vs [{}×{}]",
ma.nrows(),
ma.ncols(),
mb.nrows(),
mb.ncols()
));
}
for (x, y) in ma.iter().zip(mb.iter()) {
let ok = match tol {
None => x == y,
Some(t) => (x - y).abs() <= t,
};
if !ok {
if let Some(t) = tol {
return Err(format!(
"assert: difference {} exceeds tolerance {t}",
(x - y).abs()
));
} else {
return Err(format!("assert: {x} ~= {y}"));
}
}
}
Ok(Value::Void)
}
_ => {
if tol.is_some() {
return Err("assert: tolerance requires numeric arguments".to_string());
}
if a == b {
Ok(Value::Void)
} else {
Err("assert: values not equal".to_string())
}
}
}
}
pub(crate) fn call_builtin(
name: &str,
args: &[Value],
env: &Env,
mut io: Option<&mut IoContext>,
) -> Result<Value, String> {
if let Some(result) = crate::plugin::call_plugin(name, args, env) {
return result;
}
match (name, args.len()) {
("sqrt", 1) => match &args[0] {
Value::Scalar(x) if *x < 0.0 => Ok(make_complex(0.0, (-x).sqrt())),
Value::Complex(re, im) => {
let mag = (*re * *re + *im * *im).sqrt();
let sqrt_mag = mag.sqrt();
let arg = (*im).atan2(*re) / 2.0;
Ok(make_complex(sqrt_mag * arg.cos(), sqrt_mag * arg.sin()))
}
_ => apply_elem(&args[0], |x| x.sqrt()),
},
("floor", 1) => apply_elem(&args[0], |x| x.floor()),
("ceil", 1) => apply_elem(&args[0], |x| x.ceil()),
("round", 1) => apply_elem(&args[0], |x| x.round()),
("sign", 1) => apply_elem(&args[0], |x| x.signum()),
("log", 1) => apply_elem(&args[0], |x| x.ln()),
("log2", 1) => apply_elem(&args[0], |x| x.log2()),
("log10", 1) => apply_elem(&args[0], |x| x.log10()),
("exp", 1) => match &args[0] {
Value::Complex(re, im) => {
let e = re.exp();
Ok(make_complex(e * im.cos(), e * im.sin()))
}
_ => apply_elem(&args[0], |x| x.exp()),
},
("sin", 1) => apply_elem(&args[0], |x| x.sin()),
("cos", 1) => apply_elem(&args[0], |x| x.cos()),
("tan", 1) => apply_elem(&args[0], |x| x.tan()),
("asin", 1) => apply_elem(&args[0], |x| x.asin()),
("acos", 1) => apply_elem(&args[0], |x| x.acos()),
("atan", 1) => apply_elem(&args[0], |x| x.atan()),
("erf", 1) => apply_elem(&args[0], libm::erf),
("erfc", 1) => apply_elem(&args[0], libm::erfc),
("normcdf", 1) => apply_elem(&args[0], |x| {
0.5 * (1.0 + libm::erf(x / std::f64::consts::SQRT_2))
}),
("normcdf", 3) => {
let mu = scalar_arg(&args[1], name, 2)?;
let s = scalar_arg(&args[2], name, 3)?;
if s <= 0.0 {
return Err("normcdf: sigma must be positive".to_string());
}
apply_elem(&args[0], move |x| {
0.5 * (1.0 + libm::erf((x - mu) / (s * std::f64::consts::SQRT_2)))
})
}
("normpdf", 1) => apply_elem(&args[0], |x| {
(-0.5 * x * x).exp() / (2.0 * std::f64::consts::PI).sqrt()
}),
("normpdf", 3) => {
let mu = scalar_arg(&args[1], name, 2)?;
let s = scalar_arg(&args[2], name, 3)?;
if s <= 0.0 {
return Err("normpdf: sigma must be positive".to_string());
}
apply_elem(&args[0], move |x| {
let z = (x - mu) / s;
(-0.5 * z * z).exp() / (s * (2.0 * std::f64::consts::PI).sqrt())
})
}
("atan2", 2) => Ok(Value::Scalar(
scalar_arg(&args[0], name, 1)?.atan2(scalar_arg(&args[1], name, 2)?),
)),
("mod", 2) => {
let a = scalar_arg(&args[0], name, 1)?;
let b = scalar_arg(&args[1], name, 2)?;
Ok(Value::Scalar(a - b * (a / b).floor()))
}
("rem", 2) => {
let a = scalar_arg(&args[0], name, 1)?;
let b = scalar_arg(&args[1], name, 2)?;
Ok(Value::Scalar(a - b * (a / b).trunc()))
}
("max", 2) => Ok(Value::Scalar(
scalar_arg(&args[0], name, 1)?.max(scalar_arg(&args[1], name, 2)?),
)),
("min", 2) => Ok(Value::Scalar(
scalar_arg(&args[0], name, 1)?.min(scalar_arg(&args[1], name, 2)?),
)),
("hypot", 2) => Ok(Value::Scalar(
scalar_arg(&args[0], name, 1)?.hypot(scalar_arg(&args[1], name, 2)?),
)),
("log", 2) => Ok(Value::Scalar(
scalar_arg(&args[0], name, 1)?.log(scalar_arg(&args[1], name, 2)?),
)),
("zeros", 1) => {
let (r, c) = size_arg(&args[0], name)?;
Ok(Value::Matrix(Box::new(Array2::zeros((r, c)))))
}
("zeros", 2) => {
let r = scalar_arg(&args[0], name, 1)? as usize;
let c = scalar_arg(&args[1], name, 2)? as usize;
Ok(Value::Matrix(Box::new(Array2::zeros((r, c)))))
}
("ones", 1) => {
let (r, c) = size_arg(&args[0], name)?;
Ok(Value::Matrix(Box::new(Array2::ones((r, c)))))
}
("ones", 2) => {
let r = scalar_arg(&args[0], name, 1)? as usize;
let c = scalar_arg(&args[1], name, 2)? as usize;
Ok(Value::Matrix(Box::new(Array2::ones((r, c)))))
}
("eye", 1) => {
let n = scalar_arg(&args[0], name, 1)? as usize;
let mut m = Array2::<f64>::zeros((n, n));
for i in 0..n {
m[[i, i]] = 1.0;
}
Ok(Value::Matrix(Box::new(m)))
}
("size", 1) => match &args[0] {
Value::Void => Err("size: not applicable to void".to_string()),
Value::Scalar(_) | Value::Complex(_, _) | Value::Struct(_) => Ok(Value::Matrix(
Box::new(Array2::from_shape_vec((1, 2), vec![1.0, 1.0]).unwrap()),
)),
Value::Matrix(m) => Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, 2), vec![m.nrows() as f64, m.ncols() as f64]).unwrap(),
))),
Value::ComplexMatrix(m) => Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, 2), vec![m.nrows() as f64, m.ncols() as f64]).unwrap(),
))),
Value::Str(s) => Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, 2), vec![1.0, s.chars().count() as f64]).unwrap(),
))),
Value::StringObj(s) => Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, 2), vec![1.0, s.chars().count() as f64]).unwrap(),
))),
Value::Cell(v) => Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, 2), vec![1.0, v.len() as f64]).unwrap(),
))),
Value::StructArray(arr) => Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, 2), vec![1.0, arr.len() as f64]).unwrap(),
))),
Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err("size: not applicable to this type".to_string()),
},
("size", 2) => {
let dim = scalar_arg(&args[1], name, 2)? as usize;
match &args[0] {
Value::Void => Err("size: not applicable to void".to_string()),
Value::Scalar(_) | Value::Complex(_, _) | Value::Struct(_) => {
Ok(Value::Scalar(1.0))
}
Value::Matrix(m) => match dim {
1 => Ok(Value::Scalar(m.nrows() as f64)),
2 => Ok(Value::Scalar(m.ncols() as f64)),
_ => Err(format!("size: invalid dimension {dim}, must be 1 or 2")),
},
Value::ComplexMatrix(m) => match dim {
1 => Ok(Value::Scalar(m.nrows() as f64)),
2 => Ok(Value::Scalar(m.ncols() as f64)),
_ => Err(format!("size: invalid dimension {dim}, must be 1 or 2")),
},
Value::Str(s) => match dim {
1 => Ok(Value::Scalar(1.0)),
2 => Ok(Value::Scalar(s.chars().count() as f64)),
_ => Err(format!("size: invalid dimension {dim}")),
},
Value::StringObj(s) => match dim {
1 => Ok(Value::Scalar(1.0)),
2 => Ok(Value::Scalar(s.chars().count() as f64)),
_ => Err(format!("size: invalid dimension {dim}")),
},
Value::Cell(v) => match dim {
1 => Ok(Value::Scalar(1.0)),
2 => Ok(Value::Scalar(v.len() as f64)),
_ => Err(format!("size: invalid dimension {dim}")),
},
Value::StructArray(arr) => match dim {
1 => Ok(Value::Scalar(1.0)),
2 => Ok(Value::Scalar(arr.len() as f64)),
_ => Err(format!("size: invalid dimension {dim}")),
},
Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err("size: not applicable to this type".to_string()),
}
}
("length", 1) => match &args[0] {
Value::Void => Err("length: not applicable to void".to_string()),
Value::Scalar(_) | Value::Complex(_, _) | Value::Struct(_) => Ok(Value::Scalar(1.0)),
Value::Matrix(m) => Ok(Value::Scalar(m.nrows().max(m.ncols()) as f64)),
Value::ComplexMatrix(m) => Ok(Value::Scalar(m.nrows().max(m.ncols()) as f64)),
Value::Str(s) => Ok(Value::Scalar(s.chars().count() as f64)),
Value::StringObj(s) => Ok(Value::Scalar(s.chars().count() as f64)),
Value::Cell(v) => Ok(Value::Scalar(v.len() as f64)),
Value::StructArray(arr) => Ok(Value::Scalar(arr.len() as f64)),
Value::DateTimeArray(v) | Value::DurationArray(v) => Ok(Value::Scalar(v.len() as f64)),
Value::DateTime(_) | Value::Duration(_) => Ok(Value::Scalar(1.0)),
Value::Map(m) => Ok(Value::Scalar(m.len() as f64)),
Value::Lambda(_) | Value::Function(_) | Value::Tuple(_) => {
Err("length: not applicable to function values".to_string())
}
},
("numel", 1) => match &args[0] {
Value::Void => Err("numel: not applicable to void".to_string()),
Value::Scalar(_) | Value::Complex(_, _) | Value::Struct(_) => Ok(Value::Scalar(1.0)),
Value::Matrix(m) => Ok(Value::Scalar(m.len() as f64)),
Value::ComplexMatrix(m) => Ok(Value::Scalar(m.len() as f64)),
Value::Str(s) => Ok(Value::Scalar(s.chars().count() as f64)),
Value::StringObj(s) => Ok(Value::Scalar(s.chars().count() as f64)),
Value::Cell(v) => Ok(Value::Scalar(v.len() as f64)),
Value::StructArray(arr) => Ok(Value::Scalar(arr.len() as f64)),
Value::DateTimeArray(v) | Value::DurationArray(v) => Ok(Value::Scalar(v.len() as f64)),
Value::DateTime(_) | Value::Duration(_) => Ok(Value::Scalar(1.0)),
Value::Map(m) => Ok(Value::Scalar(m.len() as f64)),
Value::Lambda(_) | Value::Function(_) | Value::Tuple(_) => {
Err("numel: not applicable to function values".to_string())
}
},
("trace", 1) => match &args[0] {
Value::Void => Err("trace: not applicable to void".to_string()),
Value::Scalar(n) => Ok(Value::Scalar(*n)),
Value::Complex(re, _) => Ok(Value::Scalar(*re)),
Value::Matrix(m) => {
let n = m.nrows().min(m.ncols());
Ok(Value::Scalar((0..n).map(|i| m[[i, i]]).sum()))
}
Value::ComplexMatrix(m) => {
let n = m.nrows().min(m.ncols());
let s: Complex<f64> = (0..n).map(|i| m[[i, i]]).sum();
Ok(if s.im == 0.0 {
Value::Scalar(s.re)
} else {
Value::Complex(s.re, s.im)
})
}
Value::Str(_)
| Value::StringObj(_)
| Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err("trace: not applicable to non-numeric values".to_string()),
},
("det", 1) => match &args[0] {
Value::Void => Err("det: not applicable to void".to_string()),
Value::Scalar(n) => Ok(Value::Scalar(*n)),
Value::Complex(_, _) => Err("det: not applicable to complex scalars".to_string()),
Value::ComplexMatrix(_) => Err("det: not supported for complex matrices".to_string()),
Value::Matrix(m) => Ok(Value::Scalar(det_matrix(m)?)),
Value::Str(_)
| Value::StringObj(_)
| Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err("det: not applicable to non-numeric values".to_string()),
},
("inv", 1) => match &args[0] {
Value::Void => Err("inv: not applicable to void".to_string()),
Value::Scalar(n) => {
if *n == 0.0 {
Err("inv: singular (zero scalar)".to_string())
} else {
Ok(Value::Scalar(1.0 / n))
}
}
Value::Complex(re, im) => {
let denom = re * re + im * im;
if denom == 0.0 {
Err("inv: singular (zero complex)".to_string())
} else {
Ok(make_complex(re / denom, -im / denom))
}
}
Value::Matrix(m) => Ok(Value::Matrix(Box::new(inv_matrix(m)?))),
Value::ComplexMatrix(_) => Err("inv: not supported for complex matrices".to_string()),
Value::Str(_)
| Value::StringObj(_)
| Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err("inv: not applicable to non-numeric values".to_string()),
},
("linspace", 3) => {
let a = scalar_arg(&args[0], name, 1)?;
let b = scalar_arg(&args[1], name, 2)?;
let n = scalar_arg(&args[2], name, 3)? as usize;
if n == 0 {
return Ok(Value::Matrix(Box::new(Array2::zeros((1, 0)))));
}
if n == 1 {
return Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, 1), vec![b]).unwrap(),
)));
}
let vals: Vec<f64> = (0..n)
.map(|i| a + (b - a) * i as f64 / (n - 1) as f64)
.collect();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, n), vals).unwrap(),
)))
}
("bitand", 2) => {
let a = to_bits(scalar_arg(&args[0], name, 1)?, name, 1)?;
let b = to_bits(scalar_arg(&args[1], name, 2)?, name, 2)?;
Ok(Value::Scalar((a & b) as f64))
}
("bitor", 2) => {
let a = to_bits(scalar_arg(&args[0], name, 1)?, name, 1)?;
let b = to_bits(scalar_arg(&args[1], name, 2)?, name, 2)?;
Ok(Value::Scalar((a | b) as f64))
}
("bitxor", 2) => {
let a = to_bits(scalar_arg(&args[0], name, 1)?, name, 1)?;
let b = to_bits(scalar_arg(&args[1], name, 2)?, name, 2)?;
Ok(Value::Scalar((a ^ b) as f64))
}
("bitshift", 2) => {
let a = to_bits(scalar_arg(&args[0], name, 1)?, name, 1)?;
let n = scalar_arg(&args[1], name, 2)?;
if n.fract() != 0.0 {
return Err("bitshift: shift amount must be an integer".to_string());
}
let n = n as i64;
let result: u64 = if n >= 64 || n <= -64 {
0
} else if n >= 0 {
a.wrapping_shl(n as u32)
} else {
a.wrapping_shr((-n) as u32)
};
Ok(Value::Scalar(result as f64))
}
("bitnot", 1) => {
let a = to_bits(scalar_arg(&args[0], name, 1)?, name, 1)?;
let mask: u64 = 0xFFFF_FFFF;
Ok(Value::Scalar(((a ^ mask) & mask) as f64))
}
("bitnot", 2) => {
let a = to_bits(scalar_arg(&args[0], name, 1)?, name, 1)?;
let bits = scalar_arg(&args[1], name, 2)?;
if bits.fract() != 0.0 || !(1.0..=53.0).contains(&bits) {
return Err(format!(
"bitnot: bit-width must be an integer in [1, 53], got {bits}"
));
}
let mask: u64 = (1u64 << bits as u32) - 1;
Ok(Value::Scalar(((a ^ mask) & mask) as f64))
}
("isnan", 1) => apply_elem(&args[0], |x| if x.is_nan() { 1.0 } else { 0.0 }),
("isinf", 1) => apply_elem(&args[0], |x| if x.is_infinite() { 1.0 } else { 0.0 }),
("isfinite", 1) => apply_elem(&args[0], |x| if x.is_finite() { 1.0 } else { 0.0 }),
("nan", 1) => {
let (r, c) = size_arg(&args[0], name)?;
Ok(Value::Matrix(Box::new(Array2::from_elem((r, c), f64::NAN))))
}
("nan", 2) => {
let r = scalar_arg(&args[0], name, 1)? as usize;
let c = scalar_arg(&args[1], name, 2)? as usize;
Ok(Value::Matrix(Box::new(Array2::from_elem((r, c), f64::NAN))))
}
("rand", 0) => Ok(Value::Scalar(rand_uniform())),
("rand", 1) => {
let (r, c) = size_arg(&args[0], name)?;
let data: Vec<f64> = (0..r * c).map(|_| rand_uniform()).collect();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((r, c), data).unwrap(),
)))
}
("rand", 2) => {
let r = scalar_arg(&args[0], name, 1)? as usize;
let c = scalar_arg(&args[1], name, 2)? as usize;
let data: Vec<f64> = (0..r * c).map(|_| rand_uniform()).collect();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((r, c), data).unwrap(),
)))
}
("randn", 0) => Ok(Value::Scalar(rand_normal())),
("randn", 1) => {
let (r, c) = size_arg(&args[0], name)?;
let data: Vec<f64> = (0..r * c).map(|_| rand_normal()).collect();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((r, c), data).unwrap(),
)))
}
("randn", 2) => {
let r = scalar_arg(&args[0], name, 1)? as usize;
let c = scalar_arg(&args[1], name, 2)? as usize;
let data: Vec<f64> = (0..r * c).map(|_| rand_normal()).collect();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((r, c), data).unwrap(),
)))
}
("randi", 1) => {
let (lo, hi) = randi_range(&args[0])?;
let v = RNG.with(|r| r.borrow_mut().gen_range(lo..=hi)) as f64;
Ok(Value::Scalar(v))
}
("randi", 2) => {
let (lo, hi) = randi_range(&args[0])?;
let n = scalar_arg(&args[1], name, 2)? as usize;
let data: Vec<f64> = (0..n * n)
.map(|_| RNG.with(|r| r.borrow_mut().gen_range(lo..=hi)) as f64)
.collect();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((n, n), data).unwrap(),
)))
}
("randi", 3) => {
let (lo, hi) = randi_range(&args[0])?;
let r = scalar_arg(&args[1], name, 2)? as usize;
let c = scalar_arg(&args[2], name, 3)? as usize;
let data: Vec<f64> = (0..r * c)
.map(|_| RNG.with(|rng| rng.borrow_mut().gen_range(lo..=hi)) as f64)
.collect();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((r, c), data).unwrap(),
)))
}
("rng", 1) => match &args[0] {
Value::Scalar(n) => {
rng_seed(*n as u64);
Ok(Value::Void)
}
Value::Str(s) | Value::StringObj(s) if s == "shuffle" => {
rng_shuffle();
Ok(Value::Void)
}
_ => Err("rng: argument must be a numeric seed or 'shuffle'".to_string()),
},
("sum", 1) => {
if matches!(&args[0], Value::Complex(_, _) | Value::ComplexMatrix(_)) {
apply_cm_reduction(&args[0], |v| v.iter().copied().sum())
} else {
apply_reduction(&args[0], |v| v.iter().copied().sum())
}
}
("prod", 1) => {
if matches!(&args[0], Value::Complex(_, _) | Value::ComplexMatrix(_)) {
apply_cm_reduction(&args[0], |v| v.iter().copied().product())
} else {
apply_reduction(&args[0], |v| v.iter().copied().product())
}
}
("any", 1) => apply_reduction(&args[0], |v| {
if v.iter().any(|&x| x != 0.0) {
1.0
} else {
0.0
}
}),
("all", 1) => apply_reduction(&args[0], |v| {
if v.iter().all(|&x| x != 0.0) {
1.0
} else {
0.0
}
}),
("mean", 1) => {
if matches!(&args[0], Value::Complex(_, _) | Value::ComplexMatrix(_)) {
apply_cm_reduction(&args[0], |v| {
if v.is_empty() {
Complex::new(f64::NAN, 0.0)
} else {
v.iter().copied().sum::<Complex<f64>>() / v.len() as f64
}
})
} else {
apply_reduction(&args[0], |v| {
if v.is_empty() {
f64::NAN
} else {
v.iter().copied().sum::<f64>() / v.len() as f64
}
})
}
}
("min", 1) => apply_reduction(&args[0], |v| {
v.iter().copied().fold(f64::INFINITY, f64::min)
}),
("max", 1) => apply_reduction(&args[0], |v| {
v.iter().copied().fold(f64::NEG_INFINITY, f64::max)
}),
("norm", 1) => match &args[0] {
Value::Void => Err("norm: not applicable to void".to_string()),
Value::Scalar(n) => Ok(Value::Scalar(n.abs())),
Value::Complex(re, im) => Ok(Value::Scalar((re * re + im * im).sqrt())),
Value::Matrix(m) => {
if m.nrows() <= 1 || m.ncols() <= 1 {
Ok(Value::Scalar(m.iter().map(|x| x * x).sum::<f64>().sqrt()))
} else {
let (_, s, _) = svd_compute(m)?;
Ok(Value::Scalar(s.first().copied().unwrap_or(0.0)))
}
}
Value::ComplexMatrix(m) => Ok(Value::Scalar(
m.iter().map(|c| c.norm_sqr()).sum::<f64>().sqrt(),
)),
Value::Str(_)
| Value::StringObj(_)
| Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err("norm: not applicable to non-numeric values".to_string()),
},
("norm", 2) => match &args[1] {
Value::Str(s) | Value::StringObj(s) => match s.as_str() {
"fro" => match &args[0] {
Value::Scalar(n) => Ok(Value::Scalar(n.abs())),
Value::Matrix(m) => {
Ok(Value::Scalar(m.iter().map(|x| x * x).sum::<f64>().sqrt()))
}
_ => Err("norm: first argument must be numeric".to_string()),
},
other => Err(format!("norm: unknown norm type '{other}'")),
},
_ => {
let p = scalar_arg(&args[1], name, 2)?;
match &args[0] {
Value::Void => Err("norm: not applicable to void".to_string()),
Value::Scalar(n) => Ok(Value::Scalar(n.abs())),
Value::Complex(re, im) => Ok(Value::Scalar((re * re + im * im).sqrt().powf(p))),
Value::Matrix(m) => {
if m.nrows() > 1 && m.ncols() > 1 {
if (p - 2.0).abs() < 1e-15 {
let (_, s, _) = svd_compute(m)?;
return Ok(Value::Scalar(s.first().copied().unwrap_or(0.0)));
} else if (p - 1.0).abs() < 1e-15 {
let v = (0..m.ncols())
.map(|j| m.column(j).iter().map(|&x| x.abs()).sum::<f64>())
.fold(0.0_f64, f64::max);
return Ok(Value::Scalar(v));
} else if p == f64::INFINITY {
let v = (0..m.nrows())
.map(|i| m.row(i).iter().map(|&x| x.abs()).sum::<f64>())
.fold(0.0_f64, f64::max);
return Ok(Value::Scalar(v));
}
}
if p == f64::INFINITY {
Ok(Value::Scalar(
m.iter().copied().fold(0.0_f64, |acc, x| acc.max(x.abs())),
))
} else {
Ok(Value::Scalar(
m.iter().map(|x| x.abs().powf(p)).sum::<f64>().powf(1.0 / p),
))
}
}
Value::ComplexMatrix(m) => Ok(Value::Scalar(
m.iter().map(|c| c.norm_sqr()).sum::<f64>().sqrt().powf(p),
)),
Value::Str(_)
| Value::StringObj(_)
| Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => {
Err("norm: not applicable to non-numeric values".to_string())
}
}
}
},
("cumsum", 1) => apply_cumulative(&args[0], |acc, x| acc + x),
("cumprod", 1) => apply_cumulative(&args[0], |acc, x| acc * x),
("sort", 1) => match &args[0] {
Value::Void => Err("sort: not applicable to void".to_string()),
Value::Scalar(n) => Ok(Value::Scalar(*n)),
Value::Complex(_, _) => Err("sort: not applicable to complex values".to_string()),
Value::ComplexMatrix(_) => Err("sort: not applicable to complex values".to_string()),
Value::Str(_)
| Value::StringObj(_)
| Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err("sort: not applicable to non-numeric values".to_string()),
Value::Matrix(m) => {
if m.nrows() > 1 && m.ncols() > 1 {
return Err("sort: input must be a vector".to_string());
}
let mut vals: Vec<f64> = m.iter().copied().collect();
vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec(m.raw_dim(), vals).unwrap(),
)))
}
},
("reshape", 3) => {
let r = scalar_arg(&args[1], name, 2)? as usize;
let c = scalar_arg(&args[2], name, 3)? as usize;
match &args[0] {
Value::Void => Err("reshape: not applicable to void".to_string()),
Value::Scalar(n) => {
if r * c != 1 {
return Err(format!("reshape: cannot reshape 1 element into {r}x{c}"));
}
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, 1), vec![*n]).unwrap(),
)))
}
Value::Complex(_, _) => {
Err("reshape: not applicable to complex values".to_string())
}
Value::ComplexMatrix(_) => {
Err("reshape: not supported for complex matrices".to_string())
}
Value::Str(_)
| Value::StringObj(_)
| Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err("reshape: not applicable to non-numeric values".to_string()),
Value::Matrix(m) => {
let total = m.len();
if r * c != total {
return Err(format!(
"reshape: cannot reshape {total} elements into {r}x{c}"
));
}
let flat: Vec<f64> = (0..m.ncols())
.flat_map(|col| (0..m.nrows()).map(move |row| m[[row, col]]))
.collect();
let mut result = Array2::<f64>::zeros((r, c));
for (i, &v) in flat.iter().enumerate() {
result[[i % r, i / r]] = v;
}
Ok(Value::Matrix(Box::new(result)))
}
}
}
("fliplr", 1) => match &args[0] {
Value::Void => Err(format!("{name}: not applicable to void")),
Value::Scalar(n) => Ok(Value::Scalar(*n)),
Value::Complex(re, im) => Ok(Value::Complex(*re, *im)),
Value::ComplexMatrix(_) => Err(format!("{name}: not supported for complex matrices")),
Value::Str(_)
| Value::StringObj(_)
| Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err(format!("{name}: not applicable to non-numeric values")),
Value::Matrix(m) => {
let (nrows, ncols) = (m.nrows(), m.ncols());
let mut result = m.clone();
for r in 0..nrows {
for c in 0..ncols / 2 {
let tmp = result[[r, c]];
result[[r, c]] = result[[r, ncols - 1 - c]];
result[[r, ncols - 1 - c]] = tmp;
}
}
Ok(Value::Matrix(result))
}
},
("flipud", 1) => match &args[0] {
Value::Void => Err(format!("{name}: not applicable to void")),
Value::Scalar(n) => Ok(Value::Scalar(*n)),
Value::Complex(re, im) => Ok(Value::Complex(*re, *im)),
Value::ComplexMatrix(_) => Err(format!("{name}: not supported for complex matrices")),
Value::Str(_)
| Value::StringObj(_)
| Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err(format!("{name}: not applicable to non-numeric values")),
Value::Matrix(m) => {
let (nrows, ncols) = (m.nrows(), m.ncols());
let mut result = m.clone();
for c in 0..ncols {
for r in 0..nrows / 2 {
let tmp = result[[r, c]];
result[[r, c]] = result[[nrows - 1 - r, c]];
result[[nrows - 1 - r, c]] = tmp;
}
}
Ok(Value::Matrix(result))
}
},
("find", 1) => find_nonzero(&args[0], usize::MAX),
("find", 2) => {
let k = scalar_arg(&args[1], name, 2)?;
if k < 0.0 {
return Err("find: k must be non-negative".to_string());
}
find_nonzero(&args[0], k as usize)
}
("unique", 1) => match &args[0] {
Value::Void => Err("unique: not applicable to void".to_string()),
Value::Scalar(n) => Ok(Value::Scalar(*n)),
Value::Matrix(m) => {
let mut vals: Vec<f64> = m.iter().copied().collect();
vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let mut unique: Vec<f64> = Vec::new();
for v in vals {
if unique.last().is_none_or(|&last| last != v) {
unique.push(v);
}
}
let n = unique.len();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, n), unique).unwrap(),
)))
}
Value::Complex(_, _) => Err("unique: not applicable to complex values".to_string()),
Value::ComplexMatrix(_) => Err("unique: not applicable to complex values".to_string()),
Value::Str(_)
| Value::StringObj(_)
| Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err("unique: not applicable to non-numeric values".to_string()),
},
("std", 1) => apply_stat(&args[0], |s| stat_var_vec(s, false).sqrt(), "std"),
("std", 2) => {
let w = scalar_arg(&args[1], name, 2)?;
let population = w != 0.0;
apply_stat(&args[0], |s| stat_var_vec(s, population).sqrt(), "std")
}
("var", 1) => apply_stat(&args[0], |s| stat_var_vec(s, false), "var"),
("var", 2) => {
let w = scalar_arg(&args[1], name, 2)?;
let population = w != 0.0;
apply_stat(&args[0], |s| stat_var_vec(s, population), "var")
}
("cov", 1) => match &args[0] {
Value::Scalar(_) => Ok(Value::Scalar(0.0)),
Value::Matrix(m) => {
if m.nrows() == 1 || m.ncols() == 1 {
let vals: Vec<f64> = m.iter().copied().collect();
Ok(Value::Scalar(stat_var_vec(&vals, false)))
} else {
let (nobs, nvars) = (m.nrows(), m.ncols());
if nobs < 2 {
return Err("cov: need at least 2 observations".to_string());
}
let mut centered = m.clone();
for c in 0..nvars {
let col_mean: f64 = m.column(c).iter().sum::<f64>() / nobs as f64;
for r in 0..nobs {
centered[[r, c]] -= col_mean;
}
}
let denom = (nobs - 1) as f64;
let mut cov_mat = Array2::<f64>::zeros((nvars, nvars));
for i in 0..nvars {
for j in 0..nvars {
let dot: f64 =
(0..nobs).map(|r| centered[[r, i]] * centered[[r, j]]).sum();
cov_mat[[i, j]] = dot / denom;
}
}
Ok(Value::Matrix(Box::new(cov_mat)))
}
}
_ => Err("cov: argument must be numeric".to_string()),
},
("median", 1) => apply_stat(
&args[0],
|s| {
if s.is_empty() {
return f64::NAN;
}
let mut v = s.to_vec();
v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n = v.len();
if n % 2 == 0 {
(v[n / 2 - 1] + v[n / 2]) / 2.0
} else {
v[n / 2]
}
},
"median",
),
("mode", 1) => apply_stat(
&args[0],
|s| {
if s.is_empty() {
return f64::NAN;
}
let mut counts: std::collections::HashMap<u64, usize> =
std::collections::HashMap::new();
for &x in s {
*counts.entry(x.to_bits()).or_insert(0) += 1;
}
let max_count = counts.values().copied().max().unwrap_or(0);
let mut candidates: Vec<f64> = counts
.iter()
.filter(|&(_, &c)| c == max_count)
.map(|(&bits, _)| f64::from_bits(bits))
.collect();
candidates.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
candidates[0]
},
"mode",
),
("skewness", 1) => apply_stat(
&args[0],
|s| {
let n = s.len();
if n == 0 {
return f64::NAN;
}
if n == 1 {
return 0.0;
}
let mean = s.iter().sum::<f64>() / n as f64;
let m2 = s.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
if m2 == 0.0 {
return f64::NAN;
}
let m3 = s.iter().map(|&x| (x - mean).powi(3)).sum::<f64>() / n as f64;
m3 / m2.powf(1.5)
},
"skewness",
),
("kurtosis", 1) => apply_stat(
&args[0],
|s| {
let n = s.len();
if n < 2 {
return f64::NAN;
}
let mean = s.iter().sum::<f64>() / n as f64;
let m2 = s.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
if m2 == 0.0 {
return f64::NAN;
}
let m4 = s.iter().map(|&x| (x - mean).powi(4)).sum::<f64>() / n as f64;
m4 / m2.powi(2)
},
"kurtosis",
),
("histc", 2) => {
let vals = numeric_vec(&args[0], name)?;
let edges = numeric_vec(&args[1], name)?;
if edges.is_empty() {
return Err("histc: edges must not be empty".to_string());
}
let n_edges = edges.len();
let mut counts = vec![0.0f64; n_edges];
for &v in &vals {
let last = n_edges - 1;
if v == edges[last] {
counts[last] += 1.0;
} else {
for i in 0..last {
if v >= edges[i] && v < edges[i + 1] {
counts[i] += 1.0;
break;
}
}
}
}
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, n_edges), counts).unwrap(),
)))
}
("prctile", 2) => {
let p_vals = numeric_vec(&args[1], name)?;
let n_p = p_vals.len();
let compute_col = |vals: &[f64]| -> Vec<f64> {
let mut s = vals.to_vec();
s.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
p_vals.iter().map(|&p| percentile_sorted(&s, p)).collect()
};
match &args[0] {
Value::Scalar(n) => {
let pr = compute_col(&[*n]);
if n_p == 1 {
Ok(Value::Scalar(pr[0]))
} else {
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, n_p), pr).unwrap(),
)))
}
}
Value::Matrix(m) if m.nrows() == 1 || m.ncols() == 1 => {
let vals: Vec<f64> = m.iter().copied().collect();
let pr = compute_col(&vals);
if n_p == 1 {
Ok(Value::Scalar(pr[0]))
} else {
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, n_p), pr).unwrap(),
)))
}
}
Value::Matrix(m) => {
let ncols = m.ncols();
let mut result = Array2::<f64>::zeros((n_p, ncols));
for j in 0..ncols {
let col: Vec<f64> = m.column(j).iter().copied().collect();
let pr = compute_col(&col);
for (i, &v) in pr.iter().enumerate() {
result[[i, j]] = v;
}
}
if n_p == 1 {
let row: Vec<f64> = result.row(0).iter().copied().collect();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, ncols), row).unwrap(),
)))
} else {
Ok(Value::Matrix(Box::new(result)))
}
}
_ => Err("prctile: first argument must be numeric".to_string()),
}
}
("iqr", 1) => apply_stat(
&args[0],
|s| {
let mut sorted = s.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
percentile_sorted(&sorted, 75.0) - percentile_sorted(&sorted, 25.0)
},
"iqr",
),
("zscore", 1) => match &args[0] {
Value::Scalar(_) => Ok(Value::Scalar(0.0)),
Value::Matrix(m) => {
if m.nrows() == 1 || m.ncols() == 1 {
let vals: Vec<f64> = m.iter().copied().collect();
let n = vals.len() as f64;
let mean = vals.iter().sum::<f64>() / n;
let s = stat_var_vec(&vals, false).sqrt();
let result: Vec<f64> = vals
.iter()
.map(|&x| if s == 0.0 { 0.0 } else { (x - mean) / s })
.collect();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec(m.raw_dim(), result).unwrap(),
)))
} else {
let (nrows, ncols) = (m.nrows(), m.ncols());
let mut result = m.clone();
for j in 0..ncols {
let col: Vec<f64> = m.column(j).iter().copied().collect();
let mean = col.iter().sum::<f64>() / col.len() as f64;
let s = stat_var_vec(&col, false).sqrt();
for i in 0..nrows {
result[[i, j]] = if s == 0.0 {
0.0
} else {
(m[[i, j]] - mean) / s
};
}
}
Ok(Value::Matrix(result))
}
}
_ => Err("zscore: argument must be numeric".to_string()),
},
("diag", 1) => match &args[0] {
Value::Scalar(n) => Ok(Value::Matrix(Box::new(Array2::from_elem((1, 1), *n)))),
Value::Matrix(m) => {
let (rows, cols) = (m.nrows(), m.ncols());
if rows == 1 || cols == 1 {
let v: Vec<f64> = m.iter().copied().collect();
let n = v.len();
let mut result = Array2::<f64>::zeros((n, n));
for (i, &val) in v.iter().enumerate() {
result[[i, i]] = val;
}
Ok(Value::Matrix(Box::new(result)))
} else {
let n = rows.min(cols);
let d: Vec<f64> = (0..n).map(|i| m[[i, i]]).collect();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((n, 1), d).unwrap(),
)))
}
}
Value::Void => Err("diag: not applicable to void".to_string()),
Value::Complex(re, im) => {
let mut result = Array2::<Complex<f64>>::zeros((1, 1));
result[[0, 0]] = Complex::new(*re, *im);
Ok(Value::ComplexMatrix(Box::new(result)))
}
Value::ComplexMatrix(m) => {
let (rows, cols) = (m.nrows(), m.ncols());
if rows == 1 || cols == 1 {
let v: Vec<Complex<f64>> = m.iter().copied().collect();
let n = v.len();
let mut result = Array2::<Complex<f64>>::zeros((n, n));
for (i, &val) in v.iter().enumerate() {
result[[i, i]] = val;
}
Ok(Value::ComplexMatrix(Box::new(result)))
} else {
let n = rows.min(cols);
let d: Vec<Complex<f64>> = (0..n).map(|i| m[[i, i]]).collect();
Ok(Value::ComplexMatrix(Box::new(
Array2::from_shape_vec((n, 1), d).unwrap(),
)))
}
}
Value::Str(_)
| Value::StringObj(_)
| Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err("diag: not applicable to non-numeric values".to_string()),
},
("real", 1) => match &args[0] {
Value::Void => Err("real: not applicable to void".to_string()),
Value::Scalar(n) => Ok(Value::Scalar(*n)),
Value::Complex(re, _) => Ok(Value::Scalar(*re)),
Value::Matrix(m) => Ok(Value::Matrix(m.clone())),
Value::ComplexMatrix(m) => Ok(Value::Matrix(Box::new(m.mapv(|c| c.re)))),
Value::Str(_)
| Value::StringObj(_)
| Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err("real: not applicable to non-numeric values".to_string()),
},
("imag", 1) => match &args[0] {
Value::Void => Err("imag: not applicable to void".to_string()),
Value::Scalar(_) => Ok(Value::Scalar(0.0)),
Value::Complex(_, im) => Ok(Value::Scalar(*im)),
Value::Matrix(m) => Ok(Value::Matrix(Box::new(Array2::zeros(m.raw_dim())))),
Value::ComplexMatrix(m) => Ok(Value::Matrix(Box::new(m.mapv(|c| c.im)))),
Value::Str(_)
| Value::StringObj(_)
| Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err("imag: not applicable to non-numeric values".to_string()),
},
("abs", 1) => match &args[0] {
Value::Void => Err("abs: not applicable to void".to_string()),
Value::Scalar(n) => Ok(Value::Scalar(n.abs())),
Value::Complex(re, im) => Ok(Value::Scalar((re * re + im * im).sqrt())),
Value::Matrix(m) => Ok(Value::Matrix(Box::new(m.mapv(|x| x.abs())))),
Value::ComplexMatrix(m) => Ok(Value::Matrix(Box::new(m.mapv(|c| c.norm())))),
Value::Str(_)
| Value::StringObj(_)
| Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err("abs: not applicable to non-numeric values".to_string()),
},
("angle", 1) => match &args[0] {
Value::Void => Err("angle: not applicable to void".to_string()),
Value::Scalar(n) => Ok(Value::Scalar(if *n >= 0.0 {
0.0
} else {
std::f64::consts::PI
})),
Value::Complex(re, im) => Ok(Value::Scalar(im.atan2(*re))),
Value::Matrix(m) => {
Ok(Value::Matrix(Box::new(m.mapv(|x| {
if x >= 0.0 { 0.0 } else { std::f64::consts::PI }
}))))
}
Value::ComplexMatrix(m) => Ok(Value::Matrix(Box::new(m.mapv(|c| c.im.atan2(c.re))))),
Value::Str(_)
| Value::StringObj(_)
| Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err("angle: not applicable to non-numeric values".to_string()),
},
("conj", 1) => match &args[0] {
Value::Void => Err("conj: not applicable to void".to_string()),
Value::Scalar(n) => Ok(Value::Scalar(*n)),
Value::Complex(re, im) => Ok(make_complex(*re, -*im)),
Value::Matrix(m) => Ok(Value::Matrix(m.clone())),
Value::ComplexMatrix(m) => Ok(Value::ComplexMatrix(Box::new(m.mapv(|c| c.conj())))),
Value::Str(_)
| Value::StringObj(_)
| Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err("conj: not applicable to non-numeric values".to_string()),
},
("complex", 2) => {
let re = scalar_arg(&args[0], name, 1)?;
let im = scalar_arg(&args[1], name, 2)?;
Ok(make_complex(re, im))
}
("isreal", 1) => match &args[0] {
Value::Void => Ok(Value::Scalar(0.0)),
Value::Scalar(_) => Ok(Value::Scalar(1.0)),
Value::Complex(_, im) => Ok(Value::Scalar(if *im == 0.0 { 1.0 } else { 0.0 })),
Value::Matrix(_) => Ok(Value::Scalar(1.0)),
Value::ComplexMatrix(_) => Ok(Value::Scalar(0.0)),
Value::Str(_) | Value::StringObj(_) => Ok(Value::Scalar(0.0)),
Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Ok(Value::Scalar(0.0)),
},
("num2str", 1) => match &args[0] {
Value::Void => Err("num2str: not applicable to void".to_string()),
Value::Str(s) => Ok(Value::Str(s.clone())),
Value::StringObj(s) => Ok(Value::Str(s.clone())),
Value::Scalar(n) => Ok(Value::Str(fmt_auto_sig(*n, 5))),
Value::Complex(re, im) => Ok(Value::Str(format_complex(*re, *im, &FormatMode::Short))),
Value::Matrix(m) => {
let s = m
.iter()
.map(|x| fmt_auto_sig(*x, 5))
.collect::<Vec<_>>()
.join(" ");
Ok(Value::Str(s))
}
Value::ComplexMatrix(_) => {
Err("num2str: not supported for complex matrices".to_string())
}
Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err("num2str: not applicable to this type".to_string()),
},
("num2str", 2) => {
let n = scalar_arg(&args[1], name, 2)? as usize;
match &args[0] {
Value::Void => Err("num2str: not applicable to void".to_string()),
Value::Str(s) => Ok(Value::Str(s.clone())),
Value::StringObj(s) => Ok(Value::Str(s.clone())),
Value::Scalar(v) => Ok(Value::Str(fmt_auto_sig(*v, n))),
Value::Complex(re, im) => {
Ok(Value::Str(format_complex(*re, *im, &FormatMode::Custom(n))))
}
Value::Matrix(m) => {
let s = m
.iter()
.map(|x| fmt_auto_sig(*x, n))
.collect::<Vec<_>>()
.join(" ");
Ok(Value::Str(s))
}
Value::ComplexMatrix(_) => {
Err("num2str: not supported for complex matrices".to_string())
}
Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => Err("num2str: not applicable to this type".to_string()),
}
}
("str2double", 1) => {
let s = string_arg(&args[0], name, 1)?;
match s.trim().parse::<f64>() {
Ok(n) => Ok(Value::Scalar(n)),
Err(_) => Ok(Value::Scalar(f64::NAN)),
}
}
("str2num", 1) => {
let s = string_arg(&args[0], name, 1)?;
s.trim()
.parse::<f64>()
.map(Value::Scalar)
.map_err(|_| format!("str2num: cannot convert '{}' to number", s.trim()))
}
("strcat", n) if n >= 2 => {
let mut result = String::new();
let mut any_obj = false;
for (i, arg) in args.iter().enumerate() {
match arg {
Value::Str(s) => result.push_str(s.trim_end()),
Value::StringObj(s) => {
result.push_str(s);
any_obj = true;
}
_ => return Err(format!("strcat: argument {} must be a string", i + 1)),
}
}
if any_obj {
Ok(Value::StringObj(result))
} else {
Ok(Value::Str(result))
}
}
("ischar", 1) => Ok(Value::Scalar(if matches!(&args[0], Value::Str(_)) {
1.0
} else {
0.0
})),
("isstring", 1) => Ok(Value::Scalar(if matches!(&args[0], Value::StringObj(_)) {
1.0
} else {
0.0
})),
("struct", _) => {
if !args.len().is_multiple_of(2) {
return Err(
"struct: requires an even number of arguments (name, value, ...)".to_string(),
);
}
let mut map = IndexMap::new();
for pair in args.chunks(2) {
let key = match &pair[0] {
Value::Str(s) | Value::StringObj(s) => s.clone(),
_ => return Err("struct: field names must be strings".to_string()),
};
map.insert(key, pair[1].clone());
}
Ok(Value::Struct(Box::new(map)))
}
("fieldnames", 1) => match &args[0] {
Value::Struct(map) => {
let names: Vec<Value> = map.keys().map(|k| Value::Str(k.clone())).collect();
Ok(Value::Cell(Box::new(names)))
}
Value::StructArray(arr) => {
let names: Vec<Value> = arr
.first()
.map(|m| m.keys().map(|k| Value::Str(k.clone())).collect())
.unwrap_or_default();
Ok(Value::Cell(Box::new(names)))
}
_ => Err("fieldnames: argument must be a struct".to_string()),
},
("isfield", 2) => {
let field = match &args[1] {
Value::Str(s) | Value::StringObj(s) => s.clone(),
_ => return Err("isfield: second argument must be a string".to_string()),
};
Ok(Value::Scalar(match &args[0] {
Value::Struct(map) if map.contains_key(&field) => 1.0,
Value::StructArray(arr) if arr.first().is_some_and(|m| m.contains_key(&field)) => {
1.0
}
_ => 0.0,
}))
}
("isKey", 2) => {
let key = match &args[1] {
Value::Str(s) | Value::StringObj(s) => s.clone(),
_ => return Err("isKey: second argument must be a string key".to_string()),
};
match &args[0] {
Value::Map(map) => Ok(Value::Scalar(if map.contains_key(&key) {
1.0
} else {
0.0
})),
_ => Err("isKey: first argument must be a containers.Map".to_string()),
}
}
("keys", 1) => match &args[0] {
Value::Map(map) => {
let mut sorted_keys: Vec<&String> = map.keys().collect();
sorted_keys.sort();
Ok(Value::Cell(Box::new(
sorted_keys
.into_iter()
.map(|k| Value::Str(k.clone()))
.collect(),
)))
}
_ => Err("keys: argument must be a containers.Map".to_string()),
},
("values", 1) => match &args[0] {
Value::Map(map) => {
let mut pairs: Vec<(&String, &Value)> = map.iter().collect();
pairs.sort_by_key(|(k, _)| *k);
Ok(Value::Cell(Box::new(
pairs.into_iter().map(|(_, v)| v.clone()).collect(),
)))
}
_ => Err("values: argument must be a containers.Map".to_string()),
},
("rmfield", 2) => {
let field = match &args[1] {
Value::Str(s) | Value::StringObj(s) => s.clone(),
_ => return Err("rmfield: second argument must be a string".to_string()),
};
match &args[0] {
Value::Struct(map) => {
if !map.contains_key(&field) {
return Err(format!("rmfield: field '{field}' does not exist"));
}
let mut updated = map.clone();
updated.shift_remove(&field);
Ok(Value::Struct(updated))
}
Value::StructArray(arr) => {
let updated: Result<Vec<_>, _> = arr
.iter()
.map(|m| {
if !m.contains_key(&field) {
return Err(format!("rmfield: field '{field}' does not exist"));
}
let mut m2 = m.clone();
m2.shift_remove(&field);
Ok(m2)
})
.collect();
Ok(Value::StructArray(Box::new(updated?)))
}
_ => Err("rmfield: first argument must be a struct".to_string()),
}
}
("isstruct", 1) => Ok(Value::Scalar(
if matches!(&args[0], Value::Struct(_) | Value::StructArray(_)) {
1.0
} else {
0.0
},
)),
("isempty", 1) => {
let empty = match &args[0] {
Value::Matrix(m) => m.is_empty(),
Value::Str(s) | Value::StringObj(s) => s.is_empty(),
Value::Cell(v) => v.is_empty(),
Value::Void => true,
_ => false,
};
Ok(Value::Scalar(if empty { 1.0 } else { 0.0 }))
}
("iscell", 1) => Ok(Value::Scalar(if matches!(&args[0], Value::Cell(_)) {
1.0
} else {
0.0
})),
("cell", 1) => {
let n = scalar_arg(&args[0], name, 1)? as usize;
Ok(Value::Cell(Box::new(vec![Value::Scalar(0.0); n])))
}
("cell", 2) => {
let m = scalar_arg(&args[0], name, 1)? as usize;
let n = scalar_arg(&args[1], name, 2)? as usize;
Ok(Value::Cell(Box::new(vec![Value::Scalar(0.0); m * n])))
}
("cellfun", 2) => {
let f = args[0].clone();
match &args[1] {
Value::Cell(elems) => {
let elems: Vec<Value> = (**elems).clone();
let mut results = Vec::with_capacity(elems.len());
for elem in &elems {
let result =
call_function_value(&f, std::slice::from_ref(elem), io.as_deref_mut())?;
results.push(result);
}
let all_scalar = results.iter().all(|v| matches!(v, Value::Scalar(_)));
if all_scalar {
let vals: Vec<f64> = results
.iter()
.map(|v| {
if let Value::Scalar(n) = v {
*n
} else {
unreachable!()
}
})
.collect();
let n = vals.len();
if n == 0 {
Ok(Value::Matrix(Box::new(Array2::zeros((1, 0)))))
} else {
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, n), vals).unwrap(),
)))
}
} else {
Ok(Value::Cell(Box::new(results)))
}
}
_ => Err("cellfun: second argument must be a cell array".to_string()),
}
}
("arrayfun", 2) => {
let f = args[0].clone();
match &args[1] {
Value::Matrix(m) => {
let m = m.clone();
let mut flat = Vec::with_capacity(m.len());
for col in 0..m.ncols() {
for row in 0..m.nrows() {
let elem = Value::Scalar(m[[row, col]]);
let result = call_function_value(&f, &[elem], io.as_deref_mut())?;
match result {
Value::Scalar(n) => flat.push(n),
_ => {
return Err(
"arrayfun: function must return a scalar".to_string()
);
}
}
}
}
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((m.nrows(), m.ncols()), flat).unwrap(),
)))
}
Value::Scalar(n) => {
let elem = Value::Scalar(*n);
let result = call_function_value(&f, &[elem], io.as_deref_mut())?;
Ok(result)
}
_ => {
Err("arrayfun: second argument must be a numeric matrix or scalar".to_string())
}
}
}
("lower", 1) => match &args[0] {
Value::Str(s) => Ok(Value::Str(s.to_lowercase())),
Value::StringObj(s) => Ok(Value::StringObj(s.to_lowercase())),
_ => Err("lower: argument must be a string".to_string()),
},
("upper", 1) => match &args[0] {
Value::Str(s) => Ok(Value::Str(s.to_uppercase())),
Value::StringObj(s) => Ok(Value::StringObj(s.to_uppercase())),
_ => Err("upper: argument must be a string".to_string()),
},
("strtrim", 1) => match &args[0] {
Value::Str(s) => Ok(Value::Str(s.trim().to_string())),
Value::StringObj(s) => Ok(Value::StringObj(s.trim().to_string())),
_ => Err("strtrim: argument must be a string".to_string()),
},
("strrep", 3) => {
let s = string_arg(&args[0], name, 1)?.to_string();
let old = string_arg(&args[1], name, 2)?;
let new = string_arg(&args[2], name, 3)?;
let result = s.replace(old, new);
match &args[0] {
Value::StringObj(_) => Ok(Value::StringObj(result)),
_ => Ok(Value::Str(result)),
}
}
("strcmp", 2) => {
let a = string_arg(&args[0], name, 1)?;
let b = string_arg(&args[1], name, 2)?;
Ok(Value::Scalar(bool_to_f64(a == b)))
}
("strcmpi", 2) => {
let a = string_arg(&args[0], name, 1)?.to_lowercase();
let b = string_arg(&args[1], name, 2)?.to_lowercase();
Ok(Value::Scalar(bool_to_f64(a == b)))
}
("disp", 1) => {
use std::io::Write;
let mode = get_display_fmt();
let output = match &args[0] {
Value::Str(s) | Value::StringObj(s) => format!("{s}\n"),
v => match format_value_full(v, &mode) {
Some(block) => format!("{block}\n\n"),
None => format!("{}\n", format_value(v, get_display_base(), &mode)),
},
};
match io {
Some(ctx) => ctx.write_to_fd(1, &output)?,
None => {
print!("{output}");
if output.contains('\n') {
std::io::stdout().flush().ok();
}
}
}
Ok(Value::Void)
}
("sprintf", n) if n >= 1 => {
let fmt = string_arg(&args[0], name, 1)?.to_string();
let result = format_printf(&fmt, &args[1..])?;
Ok(Value::Str(result))
}
("fprintf", n) if n >= 1 => {
let (fd, fmt_idx) = match &args[0] {
Value::Scalar(n) => (*n as i32, 1),
_ => (1, 0),
};
if fmt_idx >= args.len() {
return Err("fprintf: missing format string".to_string());
}
let fmt = string_arg(&args[fmt_idx], name, fmt_idx + 1)?.to_string();
let output = format_printf(&fmt, &args[fmt_idx + 1..])?;
match io {
Some(ctx) => ctx.write_to_fd(fd, &output)?,
None => {
if fd == 1 {
use std::io::Write;
print!("{output}");
if output.contains('\n') {
std::io::stdout().flush().ok();
}
} else {
return Err("fprintf: file I/O not available in this context".to_string());
}
}
}
Ok(Value::Void)
}
("fopen", 2) => {
let path = string_arg(&args[0], name, 1)?;
let mode = string_arg(&args[1], name, 2)?;
match io {
Some(ctx) => Ok(Value::Scalar(ctx.fopen(path, mode) as f64)),
None => Err("fopen: file I/O not available in this context".to_string()),
}
}
("fclose", 1) => match &args[0] {
Value::Str(s) if s == "all" => {
if let Some(ctx) = io {
ctx.fclose_all();
}
Ok(Value::Scalar(0.0))
}
_ => {
let fd = scalar_arg(&args[0], name, 1)? as i32;
match io {
Some(ctx) => Ok(Value::Scalar(ctx.fclose(fd) as f64)),
None => Err("fclose: file I/O not available in this context".to_string()),
}
}
},
("fgetl", 1) => {
let fd = scalar_arg(&args[0], name, 1)? as i32;
match io {
Some(ctx) => match ctx.fgetl(fd) {
Some(line) => Ok(Value::Str(line)),
None => Ok(Value::Scalar(-1.0)),
},
None => Err("fgetl: file I/O not available in this context".to_string()),
}
}
("fgets", 1) => {
let fd = scalar_arg(&args[0], name, 1)? as i32;
match io {
Some(ctx) => match ctx.fgets(fd) {
Some(line) => Ok(Value::Str(line)),
None => Ok(Value::Scalar(-1.0)),
},
None => Err("fgets: file I/O not available in this context".to_string()),
}
}
("isfile", 1) => {
let path = string_arg(&args[0], name, 1)?;
let is_file = std::fs::metadata(path)
.map(|m| m.is_file())
.unwrap_or(false);
Ok(Value::Scalar(bool_to_f64(is_file)))
}
("isfolder", 1) => {
let path = string_arg(&args[0], name, 1)?;
let is_dir = std::fs::metadata(path).map(|m| m.is_dir()).unwrap_or(false);
Ok(Value::Scalar(bool_to_f64(is_dir)))
}
("dir", _) => {
let path = if args.is_empty() {
"."
} else {
string_arg(&args[0], "dir", 1)?
};
Ok(dir_impl(path))
}
("genpath", 1) => {
let root = string_arg(&args[0], name, 1)?;
let sep = if cfg!(windows) { ';' } else { ':' };
let mut dirs: Vec<String> = Vec::new();
let mut stack = vec![std::path::PathBuf::from(root)];
while let Some(dir) = stack.pop() {
if !dir.is_dir() {
continue;
}
dirs.push(dir.to_string_lossy().into_owned());
if let Ok(entries) = std::fs::read_dir(&dir) {
let mut children: Vec<std::path::PathBuf> = entries
.filter_map(|e| e.ok())
.map(|e| e.path())
.filter(|p| p.is_dir())
.collect();
children.sort();
children.reverse();
stack.extend(children);
}
}
Ok(Value::Str(dirs.join(&sep.to_string())))
}
("pwd", _) => {
let cwd = std::env::current_dir()
.map(|p| p.to_string_lossy().into_owned())
.unwrap_or_default();
Ok(Value::Str(cwd))
}
("exist", 1) => {
let name_arg = string_arg(&args[0], name, 1)?;
if env.contains_key(name_arg) {
Ok(Value::Scalar(1.0))
} else if std::path::Path::new(name_arg).is_file() {
Ok(Value::Scalar(2.0))
} else {
Ok(Value::Scalar(0.0))
}
}
("exist", 2) => {
let name_arg = string_arg(&args[0], name, 1)?;
let kind = string_arg(&args[1], name, 2)?;
match kind {
"var" => Ok(Value::Scalar(if env.contains_key(name_arg) {
1.0
} else {
0.0
})),
"file" => Ok(Value::Scalar(if std::path::Path::new(name_arg).is_file() {
2.0
} else {
0.0
})),
other => Err(format!(
"exist: unknown type '{other}', expected 'var' or 'file'"
)),
}
}
("dlmread", 1) => {
let path = string_arg(&args[0], name, 1)?.to_string();
dlmread_impl(&path, None)
}
("dlmread", 2) => {
let path = string_arg(&args[0], name, 1)?.to_string();
let delim = interpret_delim(string_arg(&args[1], name, 2)?);
dlmread_impl(&path, Some(delim))
}
("dlmwrite", 2) => {
let path = string_arg(&args[0], name, 1)?.to_string();
dlmwrite_impl(&path, &args[1], None)
}
("dlmwrite", 3) => {
let path = string_arg(&args[0], name, 1)?.to_string();
let delim = interpret_delim(string_arg(&args[2], name, 3)?);
dlmwrite_impl(&path, &args[1], Some(delim))
}
("readmatrix", n) if n == 1 || n == 3 => {
let path = string_arg(&args[0], name, 1)?.to_string();
let delim = parse_delimiter_opt(name, args, 1)?;
readmatrix_impl(&path, delim)
}
("readtable", n) if n == 1 || n == 3 => {
let path = string_arg(&args[0], name, 1)?.to_string();
let delim = parse_delimiter_opt(name, args, 1)?;
readtable_impl(&path, delim)
}
("writetable", n) if n == 2 || n == 4 => {
let path = string_arg(&args[1], name, 2)?.to_string();
let delim = parse_delimiter_opt(name, args, 2)?;
writetable_impl(&args[0], &path, delim)
}
("xor", 2) => {
let a = &args[0];
let b = &args[1];
match (a, b) {
(Value::Scalar(x), Value::Scalar(y)) => {
Ok(Value::Scalar(bool_to_f64((*x != 0.0) ^ (*y != 0.0))))
}
(Value::Matrix(mx), Value::Matrix(my)) => {
if mx.shape() != my.shape() {
return Err("xor: matrices must have the same dimensions".to_string());
}
Ok(Value::Matrix(Box::new(
ndarray::Zip::from(&**mx)
.and(&**my)
.map_collect(|a, b| bool_to_f64((*a != 0.0) ^ (*b != 0.0))),
)))
}
(Value::Scalar(s), Value::Matrix(m)) => {
let sv = *s != 0.0;
Ok(Value::Matrix(Box::new(
m.mapv(|x| bool_to_f64(sv ^ (x != 0.0))),
)))
}
(Value::Matrix(m), Value::Scalar(s)) => {
let sv = *s != 0.0;
Ok(Value::Matrix(Box::new(
m.mapv(|x| bool_to_f64((x != 0.0) ^ sv)),
)))
}
_ => Err("xor: arguments must be numeric".to_string()),
}
}
("not", 1) => apply_elem(&args[0], |x| if x == 0.0 { 1.0 } else { 0.0 }),
("int2str", 1) => match &args[0] {
Value::Scalar(n) => Ok(Value::Str(format!("{}", n.round() as i64))),
Value::Matrix(m) => {
let parts: Vec<String> =
m.iter().map(|x| format!("{}", x.round() as i64)).collect();
Ok(Value::Str(parts.join(" ")))
}
_ => Err("int2str: argument must be numeric".to_string()),
},
("mat2str", 1) => match &args[0] {
Value::Scalar(n) => Ok(Value::Str(format!("{n}"))),
Value::Matrix(m) => {
if m.nrows() == 0 || m.ncols() == 0 {
return Ok(Value::Str("[]".to_string()));
}
let mut s = String::from("[");
for (r, row) in m.rows().into_iter().enumerate() {
if r > 0 {
s.push(';');
}
for (c, val) in row.iter().enumerate() {
if c > 0 {
s.push(' ');
}
s.push_str(&format!("{val}"));
}
}
s.push(']');
Ok(Value::Str(s))
}
_ => Err("mat2str: argument must be numeric".to_string()),
},
("strsplit", 2) => {
let s = string_arg(&args[0], name, 1)?.to_string();
let delim = string_arg(&args[1], name, 2)?.to_string();
let parts: Vec<Value> = s
.split(delim.as_str())
.map(|p| Value::Str(p.to_string()))
.collect();
Ok(Value::Cell(Box::new(parts)))
}
("strsplit", 1) => {
let s = string_arg(&args[0], name, 1)?.to_string();
let parts: Vec<Value> = s
.split_whitespace()
.map(|p| Value::Str(p.to_string()))
.collect();
Ok(Value::Cell(Box::new(parts)))
}
("strjoin", n) if n == 1 || n == 2 => {
let cells = match &args[0] {
Value::Cell(v) => v,
_ => {
return Err(
"strjoin: first argument must be a cell array of strings".to_string()
);
}
};
let delim = if n == 2 {
string_arg(&args[1], name, 2)?.to_string()
} else {
" ".to_string()
};
let mut parts: Vec<String> = Vec::with_capacity(cells.len());
for (i, v) in cells.iter().enumerate() {
match v {
Value::Str(s) | Value::StringObj(s) => parts.push(s.clone()),
_ => return Err(format!("strjoin: element {} must be a string", i + 1)),
}
}
Ok(Value::Str(parts.join(&delim)))
}
("contains", 2) => {
let s = string_arg(&args[0], name, 1)?;
let pat = string_arg(&args[1], name, 2)?;
Ok(Value::Scalar(bool_to_f64(s.contains(pat))))
}
("contains", 4) => {
let s = string_arg(&args[0], name, 1)?;
let pat = string_arg(&args[1], name, 2)?;
let key = string_arg(&args[2], name, 3)?;
if key != "IgnoreCase" {
return Err(format!(
"contains: unknown option '{key}'; expected 'IgnoreCase'"
));
}
let ignore = match &args[3] {
Value::Scalar(n) => *n != 0.0,
_ => return Err("contains: 'IgnoreCase' value must be a scalar".to_string()),
};
if ignore {
Ok(Value::Scalar(bool_to_f64(
s.to_lowercase().contains(&pat.to_lowercase()),
)))
} else {
Ok(Value::Scalar(bool_to_f64(s.contains(pat))))
}
}
("startsWith", 2) => {
let s = string_arg(&args[0], name, 1)?;
let pat = string_arg(&args[1], name, 2)?;
Ok(Value::Scalar(bool_to_f64(s.starts_with(pat))))
}
("endsWith", 2) => {
let s = string_arg(&args[0], name, 1)?;
let pat = string_arg(&args[1], name, 2)?;
Ok(Value::Scalar(bool_to_f64(s.ends_with(pat))))
}
("regexp", 2) => {
let s = string_arg(&args[0], name, 1)?.to_string();
let pat = string_arg(&args[1], name, 2)?.to_string();
regexp_impl("regexp", &s, &pat, false, false)
}
("regexp", 3) => {
let s = string_arg(&args[0], name, 1)?.to_string();
let pat = string_arg(&args[1], name, 2)?.to_string();
let opt = string_arg(&args[2], name, 3)?;
if opt != "match" {
return Err(format!("regexp: unknown option '{opt}'; expected 'match'"));
}
regexp_impl("regexp", &s, &pat, false, true)
}
("regexpi", 2) => {
let s = string_arg(&args[0], name, 1)?.to_string();
let pat = string_arg(&args[1], name, 2)?.to_string();
regexp_impl("regexpi", &s, &pat, true, false)
}
("regexpi", 3) => {
let s = string_arg(&args[0], name, 1)?.to_string();
let pat = string_arg(&args[1], name, 2)?.to_string();
let opt = string_arg(&args[2], name, 3)?;
if opt != "match" {
return Err(format!("regexpi: unknown option '{opt}'; expected 'match'"));
}
regexp_impl("regexpi", &s, &pat, true, true)
}
("regexprep", 3) => {
let s = string_arg(&args[0], name, 1)?.to_string();
let pat = string_arg(&args[1], name, 2)?.to_string();
let rep = string_arg(&args[2], name, 3)?.to_string();
regexprep_impl(&s, &pat, &rep)
}
("error", _) if !args.is_empty() => {
let fmt_str = match &args[0] {
Value::Str(s) | Value::StringObj(s) => s.clone(),
_ => return Err("error: first argument must be a format string".to_string()),
};
let msg = format_printf(&fmt_str, &args[1..])?;
Err(msg)
}
("warning", _) if !args.is_empty() => {
let fmt_str = match &args[0] {
Value::Str(s) | Value::StringObj(s) => s.clone(),
_ => return Err("warning: first argument must be a format string".to_string()),
};
let msg = format_printf(&fmt_str, &args[1..])?;
eprintln!("warning: {msg}");
Ok(Value::Void)
}
("lasterr", 0) => Ok(Value::Str(get_last_err())),
("lasterr", 1) => {
let prev = get_last_err();
let new_msg = match &args[0] {
Value::Str(s) | Value::StringObj(s) => s.clone(),
_ => return Err("lasterr: argument must be a string".to_string()),
};
set_last_err(&new_msg);
Ok(Value::Str(prev))
}
("pcall", _) if !args.is_empty() => {
let callable = args[0].clone();
let call_args = &args[1..];
let result = match &callable {
Value::Lambda(f) => {
let f = f.clone();
f.0(call_args, io)
}
Value::Function(_) => match io {
Some(io_ref) => FN_CALL_HOOK.with(|c| match c.get() {
Some(hook) => hook("<pcall>", &callable, call_args, env, io_ref),
None => Err("pcall: function execution not initialized".to_string()),
}),
None => {
let mut tmp_io = IoContext::new();
FN_CALL_HOOK.with(|c| match c.get() {
Some(hook) => hook("<pcall>", &callable, call_args, env, &mut tmp_io),
None => Err("pcall: function execution not initialized".to_string()),
})
}
},
_ => {
return Err(
"pcall: first argument must be a function handle (@func)".to_string()
);
}
};
match result {
Ok(v) => Ok(Value::Tuple(vec![Value::Scalar(1.0), v])),
Err(msg) => {
set_last_err(&msg);
Ok(Value::Tuple(vec![Value::Scalar(0.0), Value::Str(msg)]))
}
}
}
("eig", 1) => match &args[0] {
Value::Scalar(n) => {
if get_nargout() <= 1 {
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, 1), vec![*n]).unwrap(),
)))
} else {
Ok(Value::Tuple(vec![
Value::Matrix(Box::new(Array2::eye(1))),
Value::Matrix(Box::new(Array2::from_elem((1, 1), *n))),
]))
}
}
Value::Matrix(m) => {
let (evals, evecs) = eig_compute(m)?;
let nn = evals.len();
let has_imag = evals.iter().any(|c| c.im.abs() > 1e-14);
if get_nargout() <= 1 {
if has_imag {
Ok(Value::ComplexMatrix(Box::new(
Array2::from_shape_vec((nn, 1), evals).unwrap(),
)))
} else {
let reals: Vec<f64> = evals.iter().map(|c| c.re).collect();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((nn, 1), reals).unwrap(),
)))
}
} else if has_imag {
Err("eig: [V,D] form not supported when eigenvalues are complex".to_string())
} else {
let reals: Vec<f64> = evals.iter().map(|c| c.re).collect();
let mut d = Array2::<f64>::zeros((nn, nn));
for (i, &e) in reals.iter().enumerate() {
d[[i, i]] = e;
}
Ok(Value::Tuple(vec![
Value::Matrix(Box::new(evecs)),
Value::Matrix(Box::new(d)),
]))
}
}
_ => Err("eig: argument must be a real numeric matrix".to_string()),
},
("svd", 1) => match &args[0] {
Value::Scalar(n) => {
let sv = n.abs();
if get_nargout() <= 1 {
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, 1), vec![sv]).unwrap(),
)))
} else {
Ok(Value::Tuple(vec![
Value::Matrix(Box::new(Array2::eye(1))),
Value::Matrix(Box::new(Array2::from_elem((1, 1), sv))),
Value::Matrix(Box::new(Array2::eye(1))),
]))
}
}
Value::Matrix(m) => {
let mm = m.nrows();
let nn = m.ncols();
let (u_c, s_v, v_c) = svd_compute(m)?;
let k = s_v.len();
if get_nargout() <= 1 {
let col: Vec<f64> = s_v;
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((k, 1), col).unwrap(),
)))
} else {
let u_full = complete_orthonormal_basis(&u_c);
let mut s_mat = Array2::<f64>::zeros((mm, nn));
for (i, &sv) in s_v.iter().enumerate() {
s_mat[[i, i]] = sv;
}
Ok(Value::Tuple(vec![
Value::Matrix(Box::new(u_full)),
Value::Matrix(Box::new(s_mat)),
Value::Matrix(Box::new(v_c)),
]))
}
}
_ => Err("svd: argument must be a real numeric matrix".to_string()),
},
("svd", 2) => match (&args[0], &args[1]) {
(Value::Matrix(m), Value::Str(opt) | Value::StringObj(opt)) if opt == "econ" => {
let (u_c, s_v, v_c) = svd_compute(m)?;
let k = s_v.len();
let mut s_mat = Array2::<f64>::zeros((k, k));
for (i, &sv) in s_v.iter().enumerate() {
s_mat[[i, i]] = sv;
}
Ok(Value::Tuple(vec![
Value::Matrix(Box::new(u_c)),
Value::Matrix(Box::new(s_mat)),
Value::Matrix(Box::new(v_c)),
]))
}
_ => Err("svd: expected svd(A, 'econ')".to_string()),
},
("lu", 1) => match &args[0] {
Value::Scalar(n) => {
if get_nargout() <= 1 {
Ok(Value::Scalar(*n))
} else {
Ok(Value::Tuple(vec![
Value::Matrix(Box::new(Array2::eye(1))),
Value::Matrix(Box::new(Array2::from_elem((1, 1), *n))),
Value::Matrix(Box::new(Array2::eye(1))),
]))
}
}
Value::Matrix(m) => {
let (l, u, p) = lu_decompose(m)?;
if get_nargout() <= 1 {
Ok(Value::Matrix(Box::new(u)))
} else {
Ok(Value::Tuple(vec![
Value::Matrix(Box::new(l)),
Value::Matrix(Box::new(u)),
Value::Matrix(Box::new(p)),
]))
}
}
_ => Err("lu: argument must be a real numeric matrix".to_string()),
},
("qr", 1) => match &args[0] {
Value::Scalar(n) => {
if get_nargout() <= 1 {
Ok(Value::Scalar(*n))
} else {
Ok(Value::Tuple(vec![
Value::Matrix(Box::new(Array2::from_elem(
(1, 1),
if *n >= 0.0 { 1.0 } else { -1.0 },
))),
Value::Matrix(Box::new(Array2::from_elem((1, 1), n.abs()))),
]))
}
}
Value::Matrix(m) => {
let (q, r) = qr_decompose(m)?;
if get_nargout() <= 1 {
Ok(Value::Matrix(Box::new(r)))
} else {
Ok(Value::Tuple(vec![
Value::Matrix(Box::new(q)),
Value::Matrix(Box::new(r)),
]))
}
}
_ => Err("qr: argument must be a real numeric matrix".to_string()),
},
("chol", 1) => match &args[0] {
Value::Scalar(n) => {
if *n < 0.0 {
Err("chol: value is not positive definite".to_string())
} else {
Ok(Value::Scalar(n.sqrt()))
}
}
Value::Matrix(m) => Ok(Value::Matrix(Box::new(chol_decompose(m)?))),
_ => Err("chol: argument must be a real numeric matrix".to_string()),
},
("rank", 1) => match &args[0] {
Value::Scalar(x) => Ok(Value::Scalar(if x.abs() > 1e-15 { 1.0 } else { 0.0 })),
Value::Matrix(m) => {
let (_, s_v, _) = svd_compute(m)?;
let tol = (m.nrows().max(m.ncols())) as f64
* s_v.first().copied().unwrap_or(0.0)
* f64::EPSILON
* 2.0;
let r = s_v.iter().filter(|&&s| s > tol).count();
Ok(Value::Scalar(r as f64))
}
_ => Err("rank: argument must be a real numeric matrix".to_string()),
},
("null", 1) => match &args[0] {
Value::Scalar(_) => Ok(Value::Matrix(Box::new(Array2::zeros((1, 0))))),
Value::Matrix(m) => {
let nn = m.ncols();
let (_, s_v, v_c) = svd_compute(m)?;
let tol = (m.nrows().max(nn)) as f64
* s_v.first().copied().unwrap_or(0.0)
* f64::EPSILON
* 2.0;
let r = s_v.iter().filter(|&&s| s > tol).count();
let null_k = nn.saturating_sub(r);
if null_k == 0 {
return Ok(Value::Matrix(Box::new(Array2::zeros((nn, 0)))));
}
let mut result = Array2::<f64>::zeros((nn, null_k));
for j in 0..null_k {
let col_idx = r + j;
if col_idx < v_c.ncols() {
for i in 0..nn {
result[[i, j]] = v_c[[i, col_idx]];
}
}
}
Ok(Value::Matrix(Box::new(result)))
}
_ => Err("null: argument must be a real numeric matrix".to_string()),
},
("orth", 1) => match &args[0] {
Value::Scalar(x) => {
if x.abs() > 1e-15 {
Ok(Value::Matrix(Box::new(Array2::from_elem((1, 1), 1.0))))
} else {
Ok(Value::Matrix(Box::new(Array2::zeros((1, 0)))))
}
}
Value::Matrix(m) => {
let mm = m.nrows();
let (u_c, s_v, _) = svd_compute(m)?;
let tol = (mm.max(m.ncols())) as f64
* s_v.first().copied().unwrap_or(0.0)
* f64::EPSILON
* 2.0;
let r = s_v.iter().filter(|&&s| s > tol).count();
if r == 0 {
return Ok(Value::Matrix(Box::new(Array2::zeros((mm, 0)))));
}
let mut result = Array2::<f64>::zeros((mm, r));
for j in 0..r {
if j < u_c.ncols() {
for i in 0..mm {
result[[i, j]] = u_c[[i, j]];
}
}
}
Ok(Value::Matrix(Box::new(result)))
}
_ => Err("orth: argument must be a real numeric matrix".to_string()),
},
("cond", 1) => match &args[0] {
Value::Scalar(x) => {
if x.abs() < 1e-15 {
Ok(Value::Scalar(f64::INFINITY))
} else {
Ok(Value::Scalar(1.0))
}
}
Value::Matrix(m) => {
let (_, s_v, _) = svd_compute(m)?;
if s_v.is_empty() {
return Ok(Value::Scalar(1.0));
}
let s_max = s_v[0];
let s_min = *s_v.last().unwrap();
Ok(Value::Scalar(if s_min < 1e-15 {
f64::INFINITY
} else {
s_max / s_min
}))
}
_ => Err("cond: argument must be a real numeric matrix".to_string()),
},
("pinv", 1) => match &args[0] {
Value::Scalar(x) => Ok(Value::Scalar(if x.abs() < 1e-15 { 0.0 } else { 1.0 / x })),
Value::Matrix(m) => {
let mm = m.nrows();
let nn = m.ncols();
let (u_c, s_v, v_c) = svd_compute(m)?;
let k = s_v.len();
let tol =
(mm.max(nn)) as f64 * s_v.first().copied().unwrap_or(0.0) * f64::EPSILON * 2.0;
let mut result = Array2::<f64>::zeros((nn, mm));
for j in 0..k {
if s_v[j] > tol {
let inv_s = 1.0 / s_v[j];
for r in 0..nn {
for c in 0..mm {
result[[r, c]] += v_c[[r, j]] * inv_s * u_c[[c, j]];
}
}
}
}
Ok(Value::Matrix(Box::new(result)))
}
_ => Err("pinv: argument must be a real numeric matrix".to_string()),
},
("fft", 1) => fft_call(&args[0], None),
("fft", 2) => {
let n = scalar_arg(&args[1], "fft", 2)?;
let n = n as usize;
if n == 0 {
return Err("fft: length must be positive".to_string());
}
fft_call(&args[0], Some(n))
}
("ifft", 1) => ifft_call(&args[0]),
("fftshift", 1) => match &args[0] {
Value::Scalar(s) => Ok(Value::Scalar(*s)),
Value::Matrix(m) => {
let (nrows, ncols) = (m.nrows(), m.ncols());
if nrows == 1 {
let n = ncols;
let shift = n / 2;
let data: Vec<f64> = m.iter().copied().collect();
let mut out = vec![0.0f64; n];
for (i, &x) in data.iter().enumerate() {
out[(i + shift) % n] = x;
}
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, n), out).unwrap(),
)))
} else if ncols == 1 {
let n = nrows;
let shift = n / 2;
let data: Vec<f64> = m.iter().copied().collect();
let mut out = vec![0.0f64; n];
for (i, &x) in data.iter().enumerate() {
out[(i + shift) % n] = x;
}
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((n, 1), out).unwrap(),
)))
} else {
let row_shift = nrows / 2;
let col_shift = ncols / 2;
let mut out = Array2::<f64>::zeros((nrows, ncols));
for i in 0..nrows {
for j in 0..ncols {
out[[(i + row_shift) % nrows, (j + col_shift) % ncols]] = m[[i, j]];
}
}
Ok(Value::Matrix(Box::new(out)))
}
}
_ => Err("fftshift: argument must be a numeric matrix".to_string()),
},
("ifftshift", 1) => match &args[0] {
Value::Scalar(s) => Ok(Value::Scalar(*s)),
Value::Matrix(m) => {
let (nrows, ncols) = (m.nrows(), m.ncols());
if nrows == 1 {
let n = ncols;
let shift = n.div_ceil(2);
let data: Vec<f64> = m.iter().copied().collect();
let mut out = vec![0.0f64; n];
for (i, &x) in data.iter().enumerate() {
out[(i + shift) % n] = x;
}
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, n), out).unwrap(),
)))
} else if ncols == 1 {
let n = nrows;
let shift = n.div_ceil(2);
let data: Vec<f64> = m.iter().copied().collect();
let mut out = vec![0.0f64; n];
for (i, &x) in data.iter().enumerate() {
out[(i + shift) % n] = x;
}
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((n, 1), out).unwrap(),
)))
} else {
let row_shift = nrows.div_ceil(2);
let col_shift = ncols.div_ceil(2);
let mut out = Array2::<f64>::zeros((nrows, ncols));
for i in 0..nrows {
for j in 0..ncols {
out[[(i + row_shift) % nrows, (j + col_shift) % ncols]] = m[[i, j]];
}
}
Ok(Value::Matrix(Box::new(out)))
}
}
_ => Err("ifftshift: argument must be a numeric matrix".to_string()),
},
("fftfreq", 2) => {
let n = match &args[0] {
Value::Scalar(s) => {
let n = *s as usize;
if *s < 1.0 || (*s - n as f64).abs() > 1e-9 {
return Err("fftfreq: n must be a positive integer".to_string());
}
n
}
_ => return Err("fftfreq: first argument must be a scalar integer".to_string()),
};
let d = scalar_arg(&args[1], "fftfreq", 2)?;
if d == 0.0 {
return Err("fftfreq: sample spacing d must be nonzero".to_string());
}
let pos_count = (n - 1) / 2 + 1;
let neg_count = n / 2;
let factor = 1.0 / (n as f64 * d);
let mut freqs = Vec::with_capacity(n);
for k in 0..pos_count as i64 {
freqs.push(k as f64 * factor);
}
let neg_start = -(neg_count as i64);
for k in neg_start..0 {
freqs.push(k as f64 * factor);
}
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, n), freqs).unwrap(),
)))
}
("jsondecode", 1) => jsondecode_impl(&args[0]),
("jsonencode", 1) => jsonencode_impl(&args[0]),
("load", 1) => {
let path = match &args[0] {
Value::Str(s) | Value::StringObj(s) => s.clone(),
_ => return Err("load: argument must be a string path".to_string()),
};
if !path.ends_with(".mat") {
return Err("load: use bare 'load path' syntax for non-.mat files".to_string());
}
load_mat_file(&path)
}
("assert", 1) => {
let truthy = match &args[0] {
Value::Scalar(n) => *n != 0.0 && !n.is_nan(),
Value::Matrix(m) => m.iter().all(|&x| x != 0.0 && !x.is_nan()),
Value::Complex(re, im) => *re != 0.0 || *im != 0.0,
Value::Str(s) | Value::StringObj(s) => !s.is_empty(),
_ => false,
};
if truthy {
Ok(Value::Void)
} else {
Err("assert: condition is false".to_string())
}
}
("assert", 2) => assert_values_equal(&args[0], &args[1], None),
("assert", 3) => {
let tol = match &args[2] {
Value::Scalar(t) => *t,
_ => return Err("assert: tolerance must be a scalar".to_string()),
};
assert_values_equal(&args[0], &args[1], Some(tol))
}
("datetime", 1) => match &args[0] {
Value::Str(s) | Value::StringObj(s) => {
let s = s.as_str();
if s == "now" {
return Ok(Value::DateTime(crate::datetime::now_timestamp()));
}
if s == "today" {
return Ok(Value::DateTime(crate::datetime::today_timestamp()));
}
crate::datetime::parse_iso8601(s).map(Value::DateTime)
}
_ => Err("datetime: expected a string or numeric constructor arguments".to_string()),
},
("datetime", 3) if matches!(&args[1], Value::Str(_) | Value::StringObj(_)) => {
let ts = scalar_arg(&args[0], "datetime", 1)?;
match (&args[1], &args[2]) {
(Value::Str(k) | Value::StringObj(k), Value::Str(v) | Value::StringObj(v))
if k.eq_ignore_ascii_case("convertfrom")
&& v.eq_ignore_ascii_case("posixtime") =>
{
Ok(Value::DateTime(ts))
}
_ => Err("datetime: unsupported arguments".to_string()),
}
}
("datetime", 3) => {
let y = scalar_arg(&args[0], "datetime", 1)? as i64;
let mo = scalar_arg(&args[1], "datetime", 2)? as u32;
let d = scalar_arg(&args[2], "datetime", 3)? as u32;
Ok(Value::DateTime(crate::datetime::civil_to_timestamp(
y, mo, d, 0, 0, 0.0,
)))
}
("datetime", 6) => {
let y = scalar_arg(&args[0], "datetime", 1)? as i64;
let mo = scalar_arg(&args[1], "datetime", 2)? as u32;
let d = scalar_arg(&args[2], "datetime", 3)? as u32;
let h = scalar_arg(&args[3], "datetime", 4)? as u32;
let mi = scalar_arg(&args[4], "datetime", 5)? as u32;
let s = scalar_arg(&args[5], "datetime", 6)?;
Ok(Value::DateTime(crate::datetime::civil_to_timestamp(
y, mo, d, h, mi, s,
)))
}
("year", 1) => match &args[0] {
Value::DateTime(ts) => {
let (y, ..) = crate::datetime::timestamp_to_civil(*ts);
Ok(Value::Scalar(y as f64))
}
Value::DateTimeArray(v) => {
let rows: Vec<f64> = v
.iter()
.map(|ts| {
let (y, ..) = crate::datetime::timestamp_to_civil(*ts);
y as f64
})
.collect();
Ok(Value::Matrix(Box::new(
ndarray::Array2::from_shape_vec((rows.len(), 1), rows)
.map_err(|e| e.to_string())?,
)))
}
_ => Err("year: argument must be a datetime".to_string()),
},
("month", 1) => match &args[0] {
Value::DateTime(ts) => {
let (_, mo, ..) = crate::datetime::timestamp_to_civil(*ts);
Ok(Value::Scalar(mo as f64))
}
Value::DateTimeArray(v) => {
let rows: Vec<f64> = v
.iter()
.map(|ts| {
let (_, mo, ..) = crate::datetime::timestamp_to_civil(*ts);
mo as f64
})
.collect();
Ok(Value::Matrix(Box::new(
ndarray::Array2::from_shape_vec((rows.len(), 1), rows)
.map_err(|e| e.to_string())?,
)))
}
_ => Err("month: argument must be a datetime".to_string()),
},
("day", 1) => match &args[0] {
Value::DateTime(ts) => {
let (_, _, d, ..) = crate::datetime::timestamp_to_civil(*ts);
Ok(Value::Scalar(d as f64))
}
Value::DateTimeArray(v) => {
let rows: Vec<f64> = v
.iter()
.map(|ts| {
let (_, _, d, ..) = crate::datetime::timestamp_to_civil(*ts);
d as f64
})
.collect();
Ok(Value::Matrix(Box::new(
ndarray::Array2::from_shape_vec((rows.len(), 1), rows)
.map_err(|e| e.to_string())?,
)))
}
_ => Err("day: argument must be a datetime".to_string()),
},
("hour", 1) => match &args[0] {
Value::DateTime(ts) => {
let (_, _, _, h, ..) = crate::datetime::timestamp_to_civil(*ts);
Ok(Value::Scalar(h as f64))
}
Value::DateTimeArray(v) => {
let rows: Vec<f64> = v
.iter()
.map(|ts| {
let (_, _, _, h, ..) = crate::datetime::timestamp_to_civil(*ts);
h as f64
})
.collect();
Ok(Value::Matrix(Box::new(
ndarray::Array2::from_shape_vec((rows.len(), 1), rows)
.map_err(|e| e.to_string())?,
)))
}
_ => Err("hour: argument must be a datetime or duration".to_string()),
},
("minute", 1) => match &args[0] {
Value::DateTime(ts) => {
let (_, _, _, _, mi, ..) = crate::datetime::timestamp_to_civil(*ts);
Ok(Value::Scalar(mi as f64))
}
Value::DateTimeArray(v) => {
let rows: Vec<f64> = v
.iter()
.map(|ts| {
let (_, _, _, _, mi, ..) = crate::datetime::timestamp_to_civil(*ts);
mi as f64
})
.collect();
Ok(Value::Matrix(Box::new(
ndarray::Array2::from_shape_vec((rows.len(), 1), rows)
.map_err(|e| e.to_string())?,
)))
}
_ => Err("minute: argument must be a datetime or duration".to_string()),
},
("second", 1) => match &args[0] {
Value::DateTime(ts) => {
let (_, _, _, _, _, s) = crate::datetime::timestamp_to_civil(*ts);
Ok(Value::Scalar(s))
}
Value::DateTimeArray(v) => {
let rows: Vec<f64> = v
.iter()
.map(|ts| {
let (_, _, _, _, _, s) = crate::datetime::timestamp_to_civil(*ts);
s
})
.collect();
Ok(Value::Matrix(Box::new(
ndarray::Array2::from_shape_vec((rows.len(), 1), rows)
.map_err(|e| e.to_string())?,
)))
}
_ => Err("second: argument must be a datetime or duration".to_string()),
},
("isdatetime", 1) => Ok(Value::Scalar(bool_to_f64(matches!(
&args[0],
Value::DateTime(_) | Value::DateTimeArray(_)
)))),
("isduration", 1) => Ok(Value::Scalar(bool_to_f64(matches!(
&args[0],
Value::Duration(_) | Value::DurationArray(_)
)))),
("isnat", 1) => match &args[0] {
Value::DateTime(ts) => Ok(Value::Scalar(bool_to_f64(ts.is_nan()))),
Value::DateTimeArray(v) => {
let rows: Vec<f64> = v
.iter()
.map(|ts| if ts.is_nan() { 1.0 } else { 0.0 })
.collect();
Ok(Value::Matrix(Box::new(
ndarray::Array2::from_shape_vec((rows.len(), 1), rows)
.map_err(|e| e.to_string())?,
)))
}
_ => Ok(Value::Scalar(0.0)),
},
("hours", 1) => match &args[0] {
Value::Duration(s) => Ok(Value::Scalar(*s / 3600.0)),
Value::DurationArray(v) => {
let rows: Vec<f64> = v.iter().map(|s| s / 3600.0).collect();
Ok(Value::Matrix(Box::new(
ndarray::Array2::from_shape_vec((rows.len(), 1), rows)
.map_err(|e| e.to_string())?,
)))
}
_ => {
let s = scalar_arg(&args[0], "hours", 1)?;
Ok(Value::Duration(s * 3600.0))
}
},
("minutes", 1) => match &args[0] {
Value::Duration(s) => Ok(Value::Scalar(*s / 60.0)),
Value::DurationArray(v) => {
let rows: Vec<f64> = v.iter().map(|s| s / 60.0).collect();
Ok(Value::Matrix(Box::new(
ndarray::Array2::from_shape_vec((rows.len(), 1), rows)
.map_err(|e| e.to_string())?,
)))
}
_ => {
let s = scalar_arg(&args[0], "minutes", 1)?;
Ok(Value::Duration(s * 60.0))
}
},
("seconds", 1) => match &args[0] {
Value::Duration(s) => Ok(Value::Scalar(*s)),
Value::DurationArray(v) => {
let rows = v.to_vec();
Ok(Value::Matrix(Box::new(
ndarray::Array2::from_shape_vec((rows.len(), 1), rows)
.map_err(|e| e.to_string())?,
)))
}
_ => {
let s = scalar_arg(&args[0], "seconds", 1)?;
Ok(Value::Duration(s))
}
},
("days", 1) => match &args[0] {
Value::Duration(s) => Ok(Value::Scalar(*s / 86400.0)),
Value::DurationArray(v) => {
let rows: Vec<f64> = v.iter().map(|s| s / 86400.0).collect();
Ok(Value::Matrix(Box::new(
ndarray::Array2::from_shape_vec((rows.len(), 1), rows)
.map_err(|e| e.to_string())?,
)))
}
_ => {
let s = scalar_arg(&args[0], "days", 1)?;
Ok(Value::Duration(s * 86400.0))
}
},
("milliseconds", 1) => match &args[0] {
Value::Duration(s) => Ok(Value::Scalar(*s * 1000.0)),
Value::DurationArray(v) => {
let rows: Vec<f64> = v.iter().map(|s| s * 1000.0).collect();
Ok(Value::Matrix(Box::new(
ndarray::Array2::from_shape_vec((rows.len(), 1), rows)
.map_err(|e| e.to_string())?,
)))
}
_ => {
let s = scalar_arg(&args[0], "milliseconds", 1)?;
Ok(Value::Duration(s / 1000.0))
}
},
("years", 1) => match &args[0] {
Value::Duration(s) => Ok(Value::Scalar(*s / (365.2425 * 86400.0))),
Value::DurationArray(v) => {
let rows: Vec<f64> = v.iter().map(|s| s / (365.2425 * 86400.0)).collect();
Ok(Value::Matrix(Box::new(
ndarray::Array2::from_shape_vec((rows.len(), 1), rows)
.map_err(|e| e.to_string())?,
)))
}
_ => {
let s = scalar_arg(&args[0], "years", 1)?;
Ok(Value::Duration(s * 365.2425 * 86400.0))
}
},
("duration", 3) => {
let h = scalar_arg(&args[0], "duration", 1)?;
let m = scalar_arg(&args[1], "duration", 2)?;
let s = scalar_arg(&args[2], "duration", 3)?;
Ok(Value::Duration(h * 3600.0 + m * 60.0 + s))
}
("datestr", 1) => match &args[0] {
Value::DateTime(ts) => {
let s = crate::datetime::format_datestr(*ts, "dd-MMM-yyyy HH:mm:ss");
Ok(Value::Str(s))
}
Value::DateTimeArray(v) => Ok(Value::Cell(Box::new(
v.iter()
.map(|ts| {
Value::Str(crate::datetime::format_datestr(*ts, "dd-MMM-yyyy HH:mm:ss"))
})
.collect(),
))),
_ => Err("datestr: argument must be a datetime".to_string()),
},
("datestr", 2) => {
let fmt_str = match &args[1] {
Value::Str(s) | Value::StringObj(s) => s.clone(),
_ => return Err("datestr: second argument must be a format string".to_string()),
};
match &args[0] {
Value::DateTime(ts) => {
Ok(Value::Str(crate::datetime::format_datestr(*ts, &fmt_str)))
}
Value::DateTimeArray(v) => Ok(Value::Cell(Box::new(
v.iter()
.map(|ts| Value::Str(crate::datetime::format_datestr(*ts, &fmt_str)))
.collect(),
))),
_ => Err("datestr: first argument must be a datetime".to_string()),
}
}
("datevec", 1) => match &args[0] {
Value::DateTime(ts) => {
let (y, mo, d, h, mi, s) = crate::datetime::timestamp_to_civil(*ts);
let sec_i = s.floor() as u32;
let data = vec![
y as f64,
mo as f64,
d as f64,
h as f64,
mi as f64,
sec_i as f64,
];
Ok(Value::Matrix(Box::new(
ndarray::Array2::from_shape_vec((1, 6), data).map_err(|e| e.to_string())?,
)))
}
_ => Err("datevec: argument must be a datetime".to_string()),
},
("datenum", 1) => match &args[0] {
Value::DateTime(ts) => Ok(Value::Scalar(crate::datetime::to_datenum(*ts))),
_ => Err("datenum: argument must be a datetime".to_string()),
},
("datenum", 3) => {
let y = scalar_arg(&args[0], "datenum", 1)? as i64;
let mo = scalar_arg(&args[1], "datenum", 2)? as u32;
let d = scalar_arg(&args[2], "datenum", 3)? as u32;
let ts = crate::datetime::civil_to_timestamp(y, mo, d, 0, 0, 0.0);
Ok(Value::Scalar(crate::datetime::to_datenum(ts)))
}
("posixtime", 1) => match &args[0] {
Value::DateTime(ts) => Ok(Value::Scalar(*ts)),
_ => Err("posixtime: argument must be a datetime".to_string()),
},
("diff", 1) => match &args[0] {
Value::DateTimeArray(v) if v.len() >= 2 => {
let diffs: Vec<f64> = v.windows(2).map(|w| w[1] - w[0]).collect();
Ok(Value::DurationArray(diffs))
}
Value::DurationArray(v) if v.len() >= 2 => {
let diffs: Vec<f64> = v.windows(2).map(|w| w[1] - w[0]).collect();
Ok(Value::DurationArray(diffs))
}
Value::Matrix(m) => {
let (nrows, ncols) = (m.nrows(), m.ncols());
if ncols > 1 && nrows == 1 {
let data: Vec<f64> =
(0..ncols - 1).map(|j| m[[0, j + 1]] - m[[0, j]]).collect();
Ok(Value::Matrix(Box::new(
ndarray::Array2::from_shape_vec((1, data.len()), data)
.map_err(|e| e.to_string())?,
)))
} else if nrows > 1 {
let data: Vec<f64> = (0..nrows - 1)
.flat_map(|i| (0..ncols).map(move |j| m[[i + 1, j]] - m[[i, j]]))
.collect();
Ok(Value::Matrix(Box::new(
ndarray::Array2::from_shape_vec((nrows - 1, ncols), data)
.map_err(|e| e.to_string())?,
)))
} else {
Err("diff: input must have at least 2 elements".to_string())
}
}
_ => Err("diff: unsupported argument type".to_string()),
},
("triu", 1) => match &args[0] {
Value::Matrix(m) => {
let mut r = m.clone();
for i in 0..m.nrows() {
for j in 0..m.ncols() {
if (j as isize) < (i as isize) {
r[[i, j]] = 0.0;
}
}
}
Ok(Value::Matrix(r))
}
Value::Scalar(n) => Ok(Value::Scalar(*n)),
_ => Err("triu: argument must be a numeric matrix".to_string()),
},
("triu", 2) => match (&args[0], &args[1]) {
(Value::Matrix(m), Value::Scalar(k)) => {
let k = *k as isize;
let mut r = m.clone();
for i in 0..m.nrows() {
for j in 0..m.ncols() {
if (j as isize) - (i as isize) < k {
r[[i, j]] = 0.0;
}
}
}
Ok(Value::Matrix(r))
}
_ => Err("triu: expects (matrix, scalar)".to_string()),
},
("tril", 1) => match &args[0] {
Value::Matrix(m) => {
let mut r = m.clone();
for i in 0..m.nrows() {
for j in 0..m.ncols() {
if (j as isize) > (i as isize) {
r[[i, j]] = 0.0;
}
}
}
Ok(Value::Matrix(r))
}
Value::Scalar(n) => Ok(Value::Scalar(*n)),
_ => Err("tril: argument must be a numeric matrix".to_string()),
},
("tril", 2) => match (&args[0], &args[1]) {
(Value::Matrix(m), Value::Scalar(k)) => {
let k = *k as isize;
let mut r = m.clone();
for i in 0..m.nrows() {
for j in 0..m.ncols() {
if (j as isize) - (i as isize) > k {
r[[i, j]] = 0.0;
}
}
}
Ok(Value::Matrix(r))
}
_ => Err("tril: expects (matrix, scalar)".to_string()),
},
("repmat", 3) => match (&args[0], &args[1], &args[2]) {
(Value::Matrix(a), Value::Scalar(rm), Value::Scalar(cn)) => {
let rm = *rm as usize;
let cn = *cn as usize;
if rm == 0 || cn == 0 {
return Ok(Value::Matrix(Box::new(Array2::zeros((0, 0)))));
}
let row_tile: Vec<Array2<f64>> = std::iter::repeat_n(a.view(), cn)
.map(|v| v.to_owned())
.collect();
let row_block = ndarray::concatenate(
ndarray::Axis(1),
&row_tile.iter().map(|m| m.view()).collect::<Vec<_>>(),
)
.map_err(|e| e.to_string())?;
let col_tiles: Vec<Array2<f64>> = std::iter::repeat_n(row_block.view(), rm)
.map(|v| v.to_owned())
.collect();
let result = ndarray::concatenate(
ndarray::Axis(0),
&col_tiles.iter().map(|m| m.view()).collect::<Vec<_>>(),
)
.map_err(|e| e.to_string())?;
Ok(Value::Matrix(Box::new(result)))
}
(Value::Scalar(s), Value::Scalar(rm), Value::Scalar(cn)) => {
let rm = *rm as usize;
let cn = *cn as usize;
Ok(Value::Matrix(Box::new(Array2::from_elem((rm, cn), *s))))
}
_ => Err("repmat: expects (matrix, m, n)".to_string()),
},
("kron", 2) => match (&args[0], &args[1]) {
(Value::Matrix(a), Value::Matrix(b)) => {
let (ra, ca) = (a.nrows(), a.ncols());
let (rb, cb) = (b.nrows(), b.ncols());
let mut result = Array2::<f64>::zeros((ra * rb, ca * cb));
for i in 0..ra {
for j in 0..ca {
let aij = a[[i, j]];
for p in 0..rb {
for q in 0..cb {
result[[i * rb + p, j * cb + q]] = aij * b[[p, q]];
}
}
}
}
Ok(Value::Matrix(Box::new(result)))
}
(Value::Scalar(s), Value::Matrix(b)) => Ok(Value::Matrix(Box::new(b.mapv(|x| x * s)))),
(Value::Matrix(a), Value::Scalar(s)) => Ok(Value::Matrix(Box::new(a.mapv(|x| x * s)))),
(Value::Scalar(a), Value::Scalar(b)) => Ok(Value::Scalar(a * b)),
_ => Err("kron: arguments must be numeric matrices".to_string()),
},
("meshgrid", 1) => {
let xv = numeric_vec(&args[0], "meshgrid")?;
let n = xv.len();
let x_mat = Array2::from_shape_fn((n, n), |(_r, c)| xv[c]);
let y_mat = Array2::from_shape_fn((n, n), |(r, _c)| xv[r]);
if get_nargout() >= 2 {
Ok(Value::Tuple(vec![
Value::Matrix(Box::new(x_mat)),
Value::Matrix(Box::new(y_mat)),
]))
} else {
Ok(Value::Matrix(Box::new(x_mat)))
}
}
("meshgrid", 2) => {
let xv = numeric_vec(&args[0], "meshgrid")?;
let yv = numeric_vec(&args[1], "meshgrid")?;
let n_rows = yv.len();
let n_cols = xv.len();
let x_mat = Array2::from_shape_fn((n_rows, n_cols), |(_r, c)| xv[c]);
let y_mat = Array2::from_shape_fn((n_rows, n_cols), |(r, _c)| yv[r]);
if get_nargout() >= 2 {
Ok(Value::Tuple(vec![
Value::Matrix(Box::new(x_mat)),
Value::Matrix(Box::new(y_mat)),
]))
} else {
Ok(Value::Matrix(Box::new(x_mat)))
}
}
("cross", 2) => {
fn to_vec3(v: &Value, argn: usize) -> Result<[f64; 3], String> {
match v {
Value::Matrix(m) => {
let flat: Vec<f64> = m.iter().copied().collect();
if flat.len() != 3 {
Err(format!(
"cross: argument {} must have exactly 3 elements",
argn
))
} else {
Ok([flat[0], flat[1], flat[2]])
}
}
_ => Err(format!(
"cross: argument {} must be a 3-element vector",
argn
)),
}
}
let a = to_vec3(&args[0], 1)?;
let b = to_vec3(&args[1], 2)?;
let cx = a[1] * b[2] - a[2] * b[1];
let cy = a[2] * b[0] - a[0] * b[2];
let cz = a[0] * b[1] - a[1] * b[0];
let result = match &args[0] {
Value::Matrix(m) if m.nrows() == 1 => {
Array2::from_shape_vec((1, 3), vec![cx, cy, cz]).unwrap()
}
_ => Array2::from_shape_vec((3, 1), vec![cx, cy, cz]).unwrap(),
};
Ok(Value::Matrix(Box::new(result)))
}
("dot", 2) => {
fn to_flat(v: &Value, argn: usize) -> Result<Vec<f64>, String> {
match v {
Value::Matrix(m) => Ok(m.iter().copied().collect()),
Value::Scalar(s) => Ok(vec![*s]),
_ => Err(format!("dot: argument {} must be a numeric vector", argn)),
}
}
let a = to_flat(&args[0], 1)?;
let b = to_flat(&args[1], 2)?;
if a.len() != b.len() {
return Err(format!(
"dot: vectors must have the same length ({} vs {})",
a.len(),
b.len()
));
}
let s: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
Ok(Value::Scalar(s))
}
("intersect", 2) => {
fn to_sorted_vec(v: &Value, fname: &str) -> Result<Vec<f64>, String> {
match v {
Value::Matrix(m) => {
let mut vals: Vec<f64> = m.iter().copied().collect();
vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
Ok(vals)
}
Value::Scalar(s) => Ok(vec![*s]),
_ => Err(format!("{fname}: arguments must be numeric vectors")),
}
}
let a = to_sorted_vec(&args[0], "intersect")?;
let b = to_sorted_vec(&args[1], "intersect")?;
let b_set: std::collections::HashSet<u64> = b
.iter()
.filter(|x| !x.is_nan())
.map(|x| x.to_bits())
.collect();
let mut result: Vec<f64> = Vec::new();
for x in &a {
if !x.is_nan()
&& b_set.contains(&x.to_bits())
&& result.last().is_none_or(|&last| last != *x)
{
result.push(*x);
}
}
let n = result.len();
if n == 0 {
Ok(Value::Matrix(Box::new(Array2::zeros((1, 0)))))
} else {
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, n), result).unwrap(),
)))
}
}
("union", 2) => {
fn collect_vals(v: &Value, fname: &str) -> Result<Vec<f64>, String> {
match v {
Value::Matrix(m) => Ok(m.iter().copied().collect()),
Value::Scalar(s) => Ok(vec![*s]),
_ => Err(format!("{fname}: arguments must be numeric vectors")),
}
}
let mut combined = collect_vals(&args[0], "union")?;
combined.extend(collect_vals(&args[1], "union")?);
combined.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let mut result: Vec<f64> = Vec::new();
for x in combined {
if result.last().is_none_or(|&last| last != x) {
result.push(x);
}
}
let n = result.len();
if n == 0 {
Ok(Value::Matrix(Box::new(Array2::zeros((1, 0)))))
} else {
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, n), result).unwrap(),
)))
}
}
("setdiff", 2) => {
fn collect_vals2(v: &Value, fname: &str) -> Result<Vec<f64>, String> {
match v {
Value::Matrix(m) => Ok(m.iter().copied().collect()),
Value::Scalar(s) => Ok(vec![*s]),
_ => Err(format!("{fname}: arguments must be numeric vectors")),
}
}
let a = collect_vals2(&args[0], "setdiff")?;
let b = collect_vals2(&args[1], "setdiff")?;
let b_set: std::collections::HashSet<u64> = b
.iter()
.filter(|x| !x.is_nan())
.map(|x| x.to_bits())
.collect();
let mut a_sorted = a.clone();
a_sorted.sort_by(|x, y| x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Equal));
let mut result: Vec<f64> = Vec::new();
for x in a_sorted {
if !x.is_nan()
&& !b_set.contains(&x.to_bits())
&& result.last().is_none_or(|&last| last != x)
{
result.push(x);
}
}
let n = result.len();
if n == 0 {
Ok(Value::Matrix(Box::new(Array2::zeros((1, 0)))))
} else {
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, n), result).unwrap(),
)))
}
}
("ismember", 2) => {
fn collect_vals3(v: &Value, fname: &str) -> Result<Vec<f64>, String> {
match v {
Value::Matrix(m) => Ok(m.iter().copied().collect()),
Value::Scalar(s) => Ok(vec![*s]),
_ => Err(format!("{fname}: arguments must be numeric")),
}
}
let set: std::collections::HashSet<u64> = collect_vals3(&args[1], "ismember")?
.into_iter()
.filter(|x| !x.is_nan())
.map(|x| x.to_bits())
.collect();
match &args[0] {
Value::Scalar(s) => {
let found = !s.is_nan() && set.contains(&s.to_bits());
Ok(Value::Scalar(if found { 1.0 } else { 0.0 }))
}
Value::Matrix(m) => {
let result: Vec<f64> = m
.iter()
.map(|x| {
if !x.is_nan() && set.contains(&x.to_bits()) {
1.0
} else {
0.0
}
})
.collect();
let shape = m.raw_dim();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec(shape, result).unwrap(),
)))
}
_ => Err("ismember: first argument must be numeric".to_string()),
}
}
("sub2ind", 3) => {
let sz = match &args[0] {
Value::Matrix(m) if m.len() == 2 => (m[[0, 0]] as usize, m[[0, 1]] as usize),
_ => return Err("sub2ind: first argument must be [rows cols]".to_string()),
};
let rows = sz.0;
fn idx_vals(v: &Value, argn: usize) -> Result<Vec<f64>, String> {
match v {
Value::Scalar(s) => Ok(vec![*s]),
Value::Matrix(m) => Ok(m.iter().copied().collect()),
_ => Err(format!("sub2ind: argument {} must be numeric", argn)),
}
}
let r = idx_vals(&args[1], 2)?;
let c = idx_vals(&args[2], 3)?;
if r.len() != c.len() {
return Err(
"sub2ind: row and column index vectors must have the same length".to_string(),
);
}
if r.len() == 1 {
let idx = (c[0] as usize - 1) * rows + r[0] as usize;
Ok(Value::Scalar(idx as f64))
} else {
let vals: Vec<f64> = r
.iter()
.zip(c.iter())
.map(|(&ri, &ci)| ((ci as usize - 1) * rows + ri as usize) as f64)
.collect();
let n = vals.len();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, n), vals).unwrap(),
)))
}
}
("ind2sub", 2) => {
let sz = match &args[0] {
Value::Matrix(m) if m.len() == 2 => (m[[0, 0]] as usize, m[[0, 1]] as usize),
_ => return Err("ind2sub: first argument must be [rows cols]".to_string()),
};
let rows = sz.0;
fn idx_vals2(v: &Value, argn: usize) -> Result<Vec<f64>, String> {
match v {
Value::Scalar(s) => Ok(vec![*s]),
Value::Matrix(m) => Ok(m.iter().copied().collect()),
_ => Err(format!("ind2sub: argument {} must be numeric", argn)),
}
}
let indices = idx_vals2(&args[1], 2)?;
if indices.len() == 1 {
let idx = indices[0] as usize;
let r = ((idx - 1) % rows + 1) as f64;
let c = ((idx - 1) / rows + 1) as f64;
Ok(Value::Tuple(vec![Value::Scalar(r), Value::Scalar(c)]))
} else {
let n = indices.len();
let rs: Vec<f64> = indices
.iter()
.map(|&idx| ((idx as usize - 1) % rows + 1) as f64)
.collect();
let cs: Vec<f64> = indices
.iter()
.map(|&idx| ((idx as usize - 1) / rows + 1) as f64)
.collect();
let rm = Value::Matrix(Box::new(Array2::from_shape_vec((1, n), rs).unwrap()));
let cm = Value::Matrix(Box::new(Array2::from_shape_vec((1, n), cs).unwrap()));
Ok(Value::Tuple(vec![rm, cm]))
}
}
("repelem", 2) => match (&args[0], &args[1]) {
(Value::Matrix(a), Value::Scalar(n)) => {
let n = *n as usize;
let flat: Vec<f64> = a.iter().flat_map(|&x| std::iter::repeat_n(x, n)).collect();
let total = flat.len();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, total), flat).unwrap(),
)))
}
(Value::Matrix(a), Value::Matrix(ns)) => {
let av: Vec<f64> = a.iter().copied().collect();
let nv: Vec<f64> = ns.iter().copied().collect();
if av.len() != nv.len() {
return Err(
"repelem: element count vector must match source vector length".to_string(),
);
}
let flat: Vec<f64> = av
.iter()
.zip(nv.iter())
.flat_map(|(&x, &n)| std::iter::repeat_n(x, n as usize))
.collect();
let total = flat.len();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, total), flat).unwrap(),
)))
}
(Value::Scalar(s), Value::Scalar(n)) => {
let n = *n as usize;
Ok(Value::Matrix(Box::new(Array2::from_elem((1, n), *s))))
}
_ => Err("repelem: unsupported argument types".to_string()),
},
("repelem", 3) => match (&args[0], &args[1], &args[2]) {
(Value::Matrix(a), Value::Scalar(rm), Value::Scalar(cn)) => {
let rm = *rm as usize;
let cn = *cn as usize;
let (nrows, ncols) = (a.nrows(), a.ncols());
let mut result = Array2::<f64>::zeros((nrows * rm, ncols * cn));
for i in 0..nrows {
for j in 0..ncols {
let v = a[[i, j]];
for di in 0..rm {
for dj in 0..cn {
result[[i * rm + di, j * cn + dj]] = v;
}
}
}
}
Ok(Value::Matrix(Box::new(result)))
}
(Value::Scalar(s), Value::Scalar(rm), Value::Scalar(cn)) => Ok(Value::Matrix(
Box::new(Array2::from_elem((*rm as usize, *cn as usize), *s)),
)),
_ => Err("repelem: expects (matrix, m, n) for 2D repetition".to_string()),
},
("polyval", 2) => {
let coeffs = poly_coeffs(&args[0], "polyval")?;
if coeffs.is_empty() {
return Err("polyval: polynomial vector is empty".to_string());
}
match &args[1] {
Value::Scalar(x) => Ok(Value::Scalar(horner(&coeffs, *x))),
Value::Matrix(m) => Ok(Value::Matrix(Box::new(m.mapv(|x| horner(&coeffs, x))))),
_ => Err("polyval: second argument must be a real numeric value".to_string()),
}
}
("polyfit", 3) => {
let xv = poly_coeffs(&args[0], "polyfit")?;
let yv = poly_coeffs(&args[1], "polyfit")?;
let deg = match &args[2] {
Value::Scalar(n) => {
let d = *n as usize;
if *n < 0.0 || (*n - d as f64).abs() > 1e-9 {
return Err("polyfit: degree must be a non-negative integer".to_string());
}
d
}
_ => return Err("polyfit: degree must be a scalar".to_string()),
};
if xv.len() != yv.len() {
return Err("polyfit: x and y must have the same length".to_string());
}
let m = xv.len();
let ncols = deg + 1;
if ncols > m {
return Err(format!(
"polyfit: not enough data points ({m}) for degree-{deg} fit"
));
}
let mut vander = Array2::<f64>::zeros((m, ncols));
for (i, &xi) in xv.iter().enumerate() {
for j in 0..ncols {
vander[[i, j]] = xi.powi((deg - j) as i32);
}
}
let (q, r) = qr_decompose(&vander)?;
let qty: Vec<f64> = (0..ncols)
.map(|i| (0..m).map(|k| q[[k, i]] * yv[k]).sum())
.collect();
let mut r_sq = Array2::<f64>::zeros((ncols, ncols));
for i in 0..ncols {
for j in 0..ncols {
r_sq[[i, j]] = r[[i, j]];
}
}
let coeffs = poly_back_sub(&r_sq, &qty)?;
let result = Array2::from_shape_vec((1, ncols), coeffs)
.map_err(|e| format!("polyfit: internal error: {e}"))?;
Ok(Value::Matrix(Box::new(result)))
}
("roots", 1) => {
let raw = poly_coeffs(&args[0], "roots")?;
let start = raw.iter().position(|&c| c != 0.0).unwrap_or(raw.len());
let coeffs = &raw[start..];
if coeffs.len() <= 1 {
return Ok(Value::Matrix(Box::new(Array2::zeros((0, 1)))));
}
let roots = durand_kerner(coeffs)?;
Ok(roots_to_value(&roots))
}
("poly", 1) => match &args[0] {
Value::Scalar(r) => {
let data = vec![1.0, -*r];
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, 2), data).unwrap(),
)))
}
Value::Matrix(m) => {
if m.nrows() == 1 || m.ncols() == 1 {
let roots: Vec<f64> = if m.nrows() == 1 {
m.row(0).iter().copied().collect()
} else {
m.column(0).iter().copied().collect()
};
let mut p = vec![1.0_f64];
for &r in &roots {
p = poly_conv(&p, &[1.0, -r]);
}
let ncols = p.len();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, ncols), p).unwrap(),
)))
} else {
let coeffs = characteristic_poly(m)?;
let ncols = coeffs.len();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, ncols), coeffs).unwrap(),
)))
}
}
_ => Err("poly: argument must be a numeric vector or square matrix".to_string()),
},
("conv", 2) => {
let a = poly_coeffs(&args[0], "conv")?;
let b = poly_coeffs(&args[1], "conv")?;
if a.is_empty() || b.is_empty() {
return Ok(Value::Matrix(Box::new(Array2::zeros((1, 0)))));
}
let c = poly_conv(&a, &b);
let len = c.len();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, len), c).unwrap(),
)))
}
("deconv", 2) => {
let c = poly_coeffs(&args[0], "deconv")?;
let b = poly_coeffs(&args[1], "deconv")?;
let (q, r) = poly_deconv(&c, &b)?;
let qn = q.len();
let rn = r.len();
let q_val = Value::Matrix(Box::new(Array2::from_shape_vec((1, qn), q).unwrap()));
let r_val = Value::Matrix(Box::new(Array2::from_shape_vec((1, rn), r).unwrap()));
Ok(Value::Tuple(vec![q_val, r_val]))
}
("interp1", 3) => {
let xv = poly_coeffs(&args[0], "interp1")?;
let yv = poly_coeffs(&args[1], "interp1")?;
if xv.len() != yv.len() {
return Err("interp1: x and y must have the same length".to_string());
}
if xv.len() < 2 {
return Err("interp1: requires at least two knot points".to_string());
}
match &args[2] {
Value::Scalar(xi) => Ok(Value::Scalar(interp1_at(&xv, &yv, *xi, "linear"))),
Value::Matrix(xi_m) => Ok(Value::Matrix(Box::new(
xi_m.mapv(|xi| interp1_at(&xv, &yv, xi, "linear")),
))),
_ => Err("interp1: query points must be numeric".to_string()),
}
}
("interp1", 4) => {
let xv = poly_coeffs(&args[0], "interp1")?;
let yv = poly_coeffs(&args[1], "interp1")?;
let method = match &args[3] {
Value::Str(s) | Value::StringObj(s) => s.clone(),
_ => return Err("interp1: method argument must be a string".to_string()),
};
if !matches!(method.as_str(), "linear" | "nearest" | "previous" | "next") {
return Err(format!(
"interp1: unknown method '{method}'; supported: linear nearest previous next"
));
}
if xv.len() != yv.len() {
return Err("interp1: x and y must have the same length".to_string());
}
if xv.len() < 2 {
return Err("interp1: requires at least two knot points".to_string());
}
match &args[2] {
Value::Scalar(xi) => Ok(Value::Scalar(interp1_at(&xv, &yv, *xi, &method))),
Value::Matrix(xi_m) => {
let m_str = method.as_str();
Ok(Value::Matrix(Box::new(
xi_m.mapv(|xi| interp1_at(&xv, &yv, xi, m_str)),
)))
}
_ => Err("interp1: query points must be numeric".to_string()),
}
}
("tic", 0) => {
TIC_TIME.with(|t| t.set(Some(std::time::Instant::now())));
Ok(Value::Void)
}
("toc", 0) => {
let elapsed = TIC_TIME.with(|t| t.get().map(|s| s.elapsed().as_secs_f64()));
match elapsed {
Some(t) => Ok(Value::Scalar(t)),
None => Err("toc: tic must be called before toc".to_string()),
}
}
("eval", 1) => {
let code = match &args[0] {
Value::Str(s) | Value::StringObj(s) => s.clone(),
_ => return Err("eval: argument must be a string".to_string()),
};
call_eval_str_hook(&code, env)
}
("eval", 2) => {
let code = match &args[0] {
Value::Str(s) | Value::StringObj(s) => s.clone(),
_ => return Err("eval: argument must be a string".to_string()),
};
match call_eval_str_hook(&code, env) {
Err(e) => {
set_last_err(&e);
let catch = match &args[1] {
Value::Str(s) | Value::StringObj(s) => s.clone(),
_ => return Err("eval: catch argument must be a string".to_string()),
};
call_eval_str_hook(&catch, env)
}
ok => ok,
}
}
_ => {
let hint = suggest_similar(name, env);
match hint {
Some(s) => Err(format!("Unknown function '{name}'; did you mean '{s}'?")),
None => Err(format!("Unknown function: '{name}'")),
}
}
}
}
fn interpret_delim(s: &str) -> String {
match s {
r"\t" => "\t".to_string(),
r"\n" => "\n".to_string(),
other => other.to_string(),
}
}
fn delim_consistent(lines: &[&str], delim: char) -> bool {
let counts: Vec<usize> = lines.iter().map(|l| l.split(delim).count()).collect();
counts.iter().all(|&c| c > 1) && counts.windows(2).all(|w| w[0] == w[1])
}
fn dlmread_impl(path: &str, explicit_delim: Option<String>) -> Result<Value, String> {
let content =
std::fs::read_to_string(path).map_err(|e| format!("dlmread: cannot read '{path}': {e}"))?;
let lines: Vec<&str> = content.lines().filter(|l| !l.trim().is_empty()).collect();
if lines.is_empty() {
return Ok(Value::Matrix(Box::new(Array2::zeros((0, 0)))));
}
let delim: Option<String> = match explicit_delim {
Some(d) => Some(d),
None => {
if delim_consistent(&lines, ',') {
Some(",".to_string())
} else if delim_consistent(&lines, '\t') {
Some("\t".to_string())
} else {
None }
}
};
let mut rows: Vec<Vec<f64>> = Vec::new();
for (line_num, line) in lines.iter().enumerate() {
let fields: Vec<&str> = match &delim {
Some(d) => line.split(d.as_str()).collect(),
None => line.split_whitespace().collect(),
};
let mut row_vals: Vec<f64> = Vec::with_capacity(fields.len());
for field in &fields {
let trimmed = field.trim();
if trimmed.is_empty() {
row_vals.push(0.0);
} else {
row_vals.push(trimmed.parse::<f64>().map_err(|_| {
format!(
"dlmread: non-numeric value '{trimmed}' on line {}",
line_num + 1
)
})?);
}
}
if !row_vals.is_empty() {
rows.push(row_vals);
}
}
if rows.is_empty() {
return Ok(Value::Matrix(Box::new(Array2::zeros((0, 0)))));
}
let ncols = rows[0].len();
for (i, row) in rows.iter().enumerate() {
if row.len() != ncols {
return Err(format!(
"dlmread: row {} has {} fields, expected {ncols}",
i + 1,
row.len()
));
}
}
let nrows = rows.len();
let flat: Vec<f64> = rows.into_iter().flatten().collect();
Array2::from_shape_vec((nrows, ncols), flat)
.map_err(|e| format!("dlmread: shape error: {e}"))
.map(|m| Value::Matrix(Box::new(m)))
}
fn fmt_dlm_number(n: f64) -> String {
if n.is_finite() && n == n.trunc() && n.abs() < 1e15 {
format!("{}", n as i64)
} else {
format!("{n}")
}
}
fn dlmwrite_impl(path: &str, val: &Value, explicit_delim: Option<String>) -> Result<Value, String> {
let delim = explicit_delim.unwrap_or_else(|| ",".to_string());
let content = match val {
Value::Scalar(n) => format!("{}\n", fmt_dlm_number(*n)),
Value::Matrix(m) => {
let mut out = String::new();
for row in m.rows() {
let parts: Vec<String> = row.iter().map(|n| fmt_dlm_number(*n)).collect();
out.push_str(&parts.join(&delim));
out.push('\n');
}
out
}
_ => {
return Err("dlmwrite: second argument must be a numeric scalar or matrix".to_string());
}
};
std::fs::write(path, content).map_err(|e| format!("dlmwrite: cannot write '{path}': {e}"))?;
Ok(Value::Void)
}
fn auto_detect_delim(lines: &[&str]) -> Option<String> {
let comma_counts: Vec<usize> = lines.iter().map(|l| split_csv_row(l, ",").len()).collect();
if comma_counts.iter().all(|&c| c > 1) && comma_counts.windows(2).all(|w| w[0] == w[1]) {
return Some(",".to_string());
}
if delim_consistent(lines, '\t') {
Some("\t".to_string())
} else {
None
}
}
fn split_csv_row(line: &str, delim: &str) -> Vec<String> {
if delim.chars().count() != 1 {
return line.split(delim).map(str::to_string).collect();
}
let delim_char = delim.chars().next().unwrap();
let chars: Vec<char> = line.chars().collect();
let mut fields: Vec<String> = Vec::new();
let mut field = String::new();
let mut i = 0;
let mut in_quotes = false;
while i < chars.len() {
let c = chars[i];
if in_quotes {
if c == '"' && i + 1 < chars.len() && chars[i + 1] == '"' {
field.push('"');
i += 2;
continue;
} else if c == '"' {
in_quotes = false;
} else {
field.push(c);
}
} else if c == '"' {
in_quotes = true;
} else if c == delim_char {
fields.push(std::mem::take(&mut field));
} else {
field.push(c);
}
i += 1;
}
fields.push(field);
fields
}
fn split_csv_row_opt(line: &str, delim: &Option<String>) -> Vec<String> {
match delim {
None => line.split_whitespace().map(str::to_string).collect(),
Some(d) => split_csv_row(line, d),
}
}
fn row_is_header(fields: &[String]) -> bool {
fields
.iter()
.any(|f| !f.trim().is_empty() && f.trim().parse::<f64>().is_err())
}
fn sanitize_header(s: &str, col_1based: usize) -> String {
let s = s.trim();
if s.is_empty() {
return format!("x{col_1based}");
}
let mut out = String::new();
for c in s.chars() {
if c.is_alphanumeric() || c == '_' {
out.push(c);
} else if !out.ends_with('_') {
out.push('_');
}
}
let out = out.trim_end_matches('_').to_string();
if out.is_empty() {
return format!("x{col_1based}");
}
if out.chars().next().unwrap().is_ascii_digit() {
format!("x{out}")
} else {
out
}
}
fn deduplicate_headers(headers: Vec<String>) -> Vec<String> {
let mut count: HashMap<String, usize> = HashMap::new();
for h in &headers {
*count.entry(h.clone()).or_insert(0) += 1;
}
let mut seen: HashMap<String, usize> = HashMap::new();
headers
.into_iter()
.map(|h| {
if *count.get(&h).unwrap() == 1 {
h
} else {
let idx = seen.entry(h.clone()).or_insert(0);
*idx += 1;
format!("{h}_{idx}")
}
})
.collect()
}
fn parse_delimiter_opt(
fn_name: &str,
args: &[Value],
start: usize,
) -> Result<Option<String>, String> {
if args.len() <= start {
return Ok(None);
}
let key = string_arg(&args[start], fn_name, start + 1)?;
if !key.eq_ignore_ascii_case("delimiter") {
return Err(format!(
"{fn_name}: expected 'Delimiter' option at argument {}, got '{key}'",
start + 1
));
}
if args.len() <= start + 1 {
return Err(format!("{fn_name}: 'Delimiter' option requires a value"));
}
let val = interpret_delim(string_arg(&args[start + 1], fn_name, start + 2)?);
Ok(Some(val))
}
fn readmatrix_impl(path: &str, explicit_delim: Option<String>) -> Result<Value, String> {
let content = std::fs::read_to_string(path)
.map_err(|e| format!("readmatrix: cannot read '{path}': {e}"))?;
let lines: Vec<&str> = content.lines().filter(|l| !l.trim().is_empty()).collect();
if lines.is_empty() {
return Ok(Value::Matrix(Box::new(Array2::<f64>::zeros((0, 0)))));
}
let delim = match explicit_delim {
Some(d) => Some(d),
None => auto_detect_delim(&lines),
};
let first_fields = split_csv_row_opt(lines[0], &delim);
let skip_header = row_is_header(&first_fields);
let data_lines = if skip_header { &lines[1..] } else { &lines[..] };
if data_lines.is_empty() {
return Ok(Value::Matrix(Box::new(Array2::<f64>::zeros((0, 0)))));
}
let mut rows: Vec<Vec<f64>> = Vec::new();
for (i, line) in data_lines.iter().enumerate() {
let fields = split_csv_row_opt(line, &delim);
let mut row: Vec<f64> = Vec::with_capacity(fields.len());
for f in &fields {
let t = f.trim();
if t.is_empty() {
row.push(f64::NAN);
} else {
row.push(t.parse::<f64>().map_err(|_| {
format!(
"readmatrix: non-numeric value '{t}' on line {}",
i + 1 + usize::from(skip_header)
)
})?);
}
}
rows.push(row);
}
if rows.is_empty() {
return Ok(Value::Matrix(Box::new(Array2::<f64>::zeros((0, 0)))));
}
let ncols = rows[0].len();
for (i, row) in rows.iter().enumerate() {
if row.len() != ncols {
return Err(format!(
"readmatrix: row {} has {} fields, expected {ncols}",
i + 1,
row.len()
));
}
}
let nrows = rows.len();
let flat: Vec<f64> = rows.into_iter().flatten().collect();
Array2::from_shape_vec((nrows, ncols), flat)
.map_err(|e| format!("readmatrix: shape error: {e}"))
.map(|m| Value::Matrix(Box::new(m)))
}
fn readtable_impl(path: &str, explicit_delim: Option<String>) -> Result<Value, String> {
let content = std::fs::read_to_string(path)
.map_err(|e| format!("readtable: cannot read '{path}': {e}"))?;
let lines: Vec<&str> = content.lines().filter(|l| !l.trim().is_empty()).collect();
if lines.is_empty() {
return Ok(Value::Struct(Box::new(IndexMap::new())));
}
let delim = match explicit_delim {
Some(d) => Some(d),
None => auto_detect_delim(&lines),
};
let raw_headers = split_csv_row_opt(lines[0], &delim);
let ncols = raw_headers.len();
let headers: Vec<String> = deduplicate_headers(
raw_headers
.iter()
.enumerate()
.map(|(i, h)| sanitize_header(h.trim(), i + 1))
.collect(),
);
let data_lines = &lines[1..];
if data_lines.is_empty() {
let mut s: IndexMap<String, Value> = IndexMap::new();
for h in &headers {
s.insert(
h.clone(),
Value::Matrix(Box::new(Array2::<f64>::zeros((0, 1)))),
);
}
return Ok(Value::Struct(Box::new(s)));
}
let mut all_rows: Vec<Vec<String>> = Vec::new();
for (i, line) in data_lines.iter().enumerate() {
let fields = split_csv_row_opt(line, &delim);
if fields.len() != ncols {
return Err(format!(
"readtable: row {} has {} fields, expected {ncols}",
i + 2,
fields.len()
));
}
all_rows.push(fields.into_iter().map(|f| f.trim().to_string()).collect());
}
let nrows = all_rows.len();
let mut s: IndexMap<String, Value> = IndexMap::new();
for col in 0..ncols {
let all_numeric = all_rows.iter().all(|row| {
let t = row[col].as_str();
t.is_empty() || t.parse::<f64>().is_ok()
});
if all_numeric {
let vals: Vec<f64> = all_rows
.iter()
.map(|row| {
let t = row[col].as_str();
if t.is_empty() {
f64::NAN
} else {
t.parse::<f64>().unwrap()
}
})
.collect();
let col_mat = Array2::from_shape_vec((nrows, 1), vals)
.map_err(|e| format!("readtable: shape error: {e}"))?;
s.insert(headers[col].clone(), Value::Matrix(Box::new(col_mat)));
} else {
let vals: Vec<Value> = all_rows
.iter()
.map(|row| Value::Str(row[col].clone()))
.collect();
s.insert(headers[col].clone(), Value::Cell(Box::new(vals)));
}
}
Ok(Value::Struct(Box::new(s)))
}
fn csv_quote_cell(s: &str, delim: &str) -> String {
if s.contains('"') || s.contains('\n') || s.contains(delim) {
let escaped = s.replace('"', "\"\"");
format!("\"{escaped}\"")
} else {
s.to_string()
}
}
fn col_nrows(v: &Value) -> Option<usize> {
match v {
Value::Matrix(m) if m.ncols() == 1 || m.nrows() == 0 => Some(m.nrows()),
Value::Cell(c) => Some(c.len()),
Value::Scalar(_) => Some(1),
Value::Str(_) | Value::StringObj(_) => Some(1),
_ => None,
}
}
fn col_cell_str(v: &Value, row: usize, delim: &str) -> Result<String, String> {
match v {
Value::Matrix(m) => Ok(csv_quote_cell(&fmt_dlm_number(m[[row, 0]]), delim)),
Value::Cell(c) => match &c[row] {
Value::Str(s) | Value::StringObj(s) => Ok(csv_quote_cell(s, delim)),
Value::Scalar(n) => Ok(csv_quote_cell(&fmt_dlm_number(*n), delim)),
_ => Err(format!(
"writetable: cell element at row {} has unsupported type",
row + 1
)),
},
Value::Scalar(n) => Ok(csv_quote_cell(&fmt_dlm_number(*n), delim)),
Value::Str(s) | Value::StringObj(s) => Ok(csv_quote_cell(s, delim)),
_ => Err(format!(
"writetable: unsupported column type at row {}",
row + 1
)),
}
}
fn writetable_impl(
tbl: &Value,
path: &str,
explicit_delim: Option<String>,
) -> Result<Value, String> {
let delim = explicit_delim.unwrap_or_else(|| ",".to_string());
let fields = match tbl {
Value::Struct(m) => m,
_ => return Err("writetable: first argument must be a struct".to_string()),
};
if fields.is_empty() {
std::fs::write(path, "").map_err(|e| format!("writetable: cannot write '{path}': {e}"))?;
return Ok(Value::Void);
}
let nrows = {
let (first_name, first_val) = fields.iter().next().unwrap();
col_nrows(first_val).ok_or_else(|| {
format!("writetable: column '{first_name}' must be a Matrix (N×1), Cell, or scalar")
})?
};
for (cname, cval) in fields.iter() {
let n = col_nrows(cval).ok_or_else(|| {
format!("writetable: column '{cname}' must be a Matrix (N×1), Cell, or scalar")
})?;
if n != nrows {
return Err(format!(
"writetable: column '{cname}' has {n} rows, expected {nrows}"
));
}
}
let mut out = String::new();
let header_parts: Vec<String> = fields.keys().map(|k| csv_quote_cell(k, &delim)).collect();
out.push_str(&header_parts.join(&delim));
out.push('\n');
for row in 0..nrows {
let mut parts: Vec<String> = Vec::with_capacity(fields.len());
for cval in fields.values() {
parts.push(col_cell_str(cval, row, &delim)?);
}
out.push_str(&parts.join(&delim));
out.push('\n');
}
std::fs::write(path, out).map_err(|e| format!("writetable: cannot write '{path}': {e}"))?;
Ok(Value::Void)
}
fn glob_match_inner(pat: &[u8], name: &[u8]) -> bool {
match (pat.first(), name.first()) {
(None, None) => true,
(Some(&b'*'), _) => {
glob_match_inner(&pat[1..], name)
|| (!name.is_empty() && glob_match_inner(pat, &name[1..]))
}
(Some(&b'?'), Some(_)) => glob_match_inner(&pat[1..], &name[1..]),
(Some(p), Some(n)) if p == n => glob_match_inner(&pat[1..], &name[1..]),
_ => false,
}
}
fn glob_match(pattern: &str, name: &str) -> bool {
#[cfg(windows)]
let (p, n) = (pattern.to_lowercase(), name.to_lowercase());
#[cfg(not(windows))]
let (p, n) = (pattern.to_string(), name.to_string());
glob_match_inner(p.as_bytes(), n.as_bytes())
}
fn dir_impl(path_arg: &str) -> Value {
let has_glob = path_arg.contains('*') || path_arg.contains('?');
let p = std::path::Path::new(path_arg);
let (dir_path, pattern): (std::path::PathBuf, String) = if has_glob {
let parent = p
.parent()
.filter(|d| *d != std::path::Path::new(""))
.unwrap_or(std::path::Path::new("."));
let pat = p
.file_name()
.map(|f| f.to_string_lossy().into_owned())
.unwrap_or_default();
(parent.to_path_buf(), pat)
} else {
(p.to_path_buf(), String::new())
};
let abs = if dir_path.is_absolute() {
dir_path.to_string_lossy().into_owned()
} else {
std::env::current_dir()
.unwrap_or_else(|_| ".".into())
.join(&dir_path)
.to_string_lossy()
.into_owned()
};
#[cfg(windows)]
let abs = abs.replace('/', "\\");
let folder_str = if abs.len() > 1 && (abs.ends_with('/') || abs.ends_with('\\')) {
abs[..abs.len() - 1].to_string()
} else {
abs
};
let mut entries: Vec<IndexMap<String, Value>> = Vec::new();
if !has_glob {
for dot in &[".", ".."] {
let mut row = IndexMap::new();
row.insert("name".to_string(), Value::Str(dot.to_string()));
row.insert("folder".to_string(), Value::Str(folder_str.clone()));
row.insert("isdir".to_string(), Value::Scalar(1.0));
row.insert("bytes".to_string(), Value::Scalar(0.0));
entries.push(row);
}
}
let Ok(rd) = std::fs::read_dir(&dir_path) else {
return Value::StructArray(Box::default());
};
let mut file_rows: Vec<(String, IndexMap<String, Value>)> = rd
.filter_map(|e| e.ok())
.filter_map(|e| {
let file_name = e.file_name().to_string_lossy().into_owned();
if has_glob && !glob_match(&pattern, &file_name) {
return None;
}
let meta = e.metadata().ok()?;
let is_dir = if meta.is_dir() { 1.0 } else { 0.0 };
let bytes = if meta.is_file() {
meta.len() as f64
} else {
0.0
};
let mut row = IndexMap::new();
row.insert("name".to_string(), Value::Str(file_name.clone()));
row.insert("folder".to_string(), Value::Str(folder_str.clone()));
row.insert("isdir".to_string(), Value::Scalar(is_dir));
row.insert("bytes".to_string(), Value::Scalar(bytes));
Some((file_name, row))
})
.collect();
file_rows.sort_by(|a, b| a.0.cmp(&b.0));
entries.extend(file_rows.into_iter().map(|(_, row)| row));
Value::StructArray(Box::new(entries))
}
fn to_bits(v: f64, fname: &str, pos: usize) -> Result<u64, String> {
if v < 0.0 {
return Err(format!(
"{fname}: argument {pos} must be non-negative, got {v}"
));
}
if v.fract() != 0.0 {
return Err(format!(
"{fname}: argument {pos} must be an integer, got {v}"
));
}
if v > u64::MAX as f64 {
return Err(format!(
"{fname}: argument {pos} is too large for bitwise operations"
));
}
Ok(v as u64)
}
fn det_matrix(m: &Array2<f64>) -> Result<f64, String> {
let n = m.nrows();
if m.ncols() != n {
return Err("det: matrix must be square".to_string());
}
if n == 0 {
return Ok(1.0);
}
let mut a = m.clone();
let mut sign: f64 = 1.0;
for col in 0..n {
let pivot = (col..n)
.max_by(|&r1, &r2| a[[r1, col]].abs().partial_cmp(&a[[r2, col]].abs()).unwrap())
.unwrap();
if a[[pivot, col]].abs() < 1e-15 {
return Ok(0.0); }
if pivot != col {
for j in 0..n {
let tmp = a[[pivot, j]];
a[[pivot, j]] = a[[col, j]];
a[[col, j]] = tmp;
}
sign = -sign;
}
let pv = a[[col, col]];
for row in (col + 1)..n {
let factor = a[[row, col]] / pv;
for j in col..n {
let val = a[[col, j]] * factor;
a[[row, j]] -= val;
}
}
}
Ok(sign * (0..n).map(|i| a[[i, i]]).product::<f64>())
}
fn inv_matrix(m: &Array2<f64>) -> Result<Array2<f64>, String> {
let n = m.nrows();
if m.ncols() != n {
return Err("inv: matrix must be square".to_string());
}
let cols = 2 * n;
let mut aug = vec![0.0f64; n * cols];
for i in 0..n {
for j in 0..n {
aug[i * cols + j] = m[[i, j]];
}
aug[i * cols + n + i] = 1.0;
}
for col in 0..n {
let pivot = (col..n)
.max_by(|&r1, &r2| {
aug[r1 * cols + col]
.abs()
.partial_cmp(&aug[r2 * cols + col].abs())
.unwrap()
})
.filter(|&r| aug[r * cols + col].abs() > 1e-12)
.ok_or_else(|| "inv: matrix is singular".to_string())?;
if pivot != col {
for j in 0..cols {
aug.swap(col * cols + j, pivot * cols + j);
}
}
let pv = aug[col * cols + col];
for j in 0..cols {
aug[col * cols + j] /= pv;
}
for row in 0..n {
if row == col {
continue;
}
let factor = aug[row * cols + col];
for j in 0..cols {
let val = aug[col * cols + j] * factor;
aug[row * cols + j] -= val;
}
}
}
let mut result = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..n {
result[[i, j]] = aug[i * cols + n + j];
}
}
Ok(result)
}
fn solve_linear(a: &Array2<f64>, b: &Array2<f64>) -> Result<Array2<f64>, String> {
let n = a.nrows();
if a.ncols() != n {
return Err(format!(
"\\: coefficient matrix must be square, got {}×{}",
n,
a.ncols()
));
}
let k = b.ncols();
if b.nrows() != n {
return Err(format!(
"\\: size mismatch — A is {}×{} but b has {} rows",
n,
n,
b.nrows()
));
}
if n == 0 {
return Ok(Array2::zeros((0, k)));
}
let cols = n + k;
let mut aug = vec![0.0f64; n * cols];
for i in 0..n {
for j in 0..n {
aug[i * cols + j] = a[[i, j]];
}
for j in 0..k {
aug[i * cols + n + j] = b[[i, j]];
}
}
for col in 0..n {
let pivot = (col..n)
.max_by(|&r1, &r2| {
aug[r1 * cols + col]
.abs()
.partial_cmp(&aug[r2 * cols + col].abs())
.unwrap()
})
.filter(|&r| aug[r * cols + col].abs() > 1e-12)
.ok_or_else(|| "\\: matrix is singular or nearly singular".to_string())?;
if pivot != col {
for j in 0..cols {
aug.swap(col * cols + j, pivot * cols + j);
}
}
let pv = aug[col * cols + col];
for j in col..cols {
aug[col * cols + j] /= pv;
}
for row in 0..n {
if row == col {
continue;
}
let factor = aug[row * cols + col];
if factor == 0.0 {
continue;
}
for j in col..cols {
let val = aug[col * cols + j] * factor;
aug[row * cols + j] -= val;
}
}
}
let mut result = Array2::<f64>::zeros((n, k));
for i in 0..n {
for j in 0..k {
result[[i, j]] = aug[i * cols + n + j];
}
}
Ok(result)
}
fn qr_decompose(a: &Array2<f64>) -> Result<(Array2<f64>, Array2<f64>), String> {
let m = a.nrows();
let n = a.ncols();
let k = m.min(n);
let mut r = a.clone();
let mut q = Array2::<f64>::eye(m);
for j in 0..k {
let col_len = m - j;
let mut v: Vec<f64> = (j..m).map(|i| r[[i, j]]).collect();
let norm_x = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
if norm_x < 1e-14 {
continue;
}
v[0] += if v[0] >= 0.0 { norm_x } else { -norm_x };
let v_sq: f64 = v.iter().map(|&x| x * x).sum();
if v_sq < 1e-28 {
continue;
}
for col in j..n {
let dot: f64 = (0..col_len).map(|i| v[i] * r[[j + i, col]]).sum();
let fac = 2.0 * dot / v_sq;
for i in 0..col_len {
r[[j + i, col]] -= fac * v[i];
}
}
for row in 0..m {
let dot: f64 = (0..col_len).map(|i| q[[row, j + i]] * v[i]).sum();
let fac = 2.0 * dot / v_sq;
for i in 0..col_len {
q[[row, j + i]] -= fac * v[i];
}
}
}
Ok((q, r))
}
type LuResult = Result<(Array2<f64>, Array2<f64>, Array2<f64>), String>;
fn lu_decompose(a: &Array2<f64>) -> LuResult {
let n = a.nrows();
if a.ncols() != n {
return Err("lu: matrix must be square".to_string());
}
let mut u = a.clone();
let mut l = Array2::<f64>::eye(n);
let mut perm: Vec<usize> = (0..n).collect();
for j in 0..n {
let pivot = (j..n)
.max_by(|&r1, &r2| {
u[[r1, j]]
.abs()
.partial_cmp(&u[[r2, j]].abs())
.unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap();
if pivot != j {
for col in 0..n {
let tmp = u[[j, col]];
u[[j, col]] = u[[pivot, col]];
u[[pivot, col]] = tmp;
}
for col in 0..j {
let tmp = l[[j, col]];
l[[j, col]] = l[[pivot, col]];
l[[pivot, col]] = tmp;
}
perm.swap(j, pivot);
}
if u[[j, j]].abs() < 1e-15 {
continue;
}
for i in (j + 1)..n {
l[[i, j]] = u[[i, j]] / u[[j, j]];
for k in j..n {
let val = l[[i, j]] * u[[j, k]];
u[[i, k]] -= val;
}
}
}
let mut p = Array2::<f64>::zeros((n, n));
for (i, &j) in perm.iter().enumerate() {
p[[i, j]] = 1.0;
}
Ok((l, u, p))
}
fn chol_decompose(a: &Array2<f64>) -> Result<Array2<f64>, String> {
let n = a.nrows();
if a.ncols() != n {
return Err("chol: matrix must be square".to_string());
}
let mut r = Array2::<f64>::zeros((n, n));
for j in 0..n {
let mut s = a[[j, j]];
for k in 0..j {
s -= r[[k, j]] * r[[k, j]];
}
if s <= 0.0 {
return Err("chol: matrix is not positive definite".to_string());
}
r[[j, j]] = s.sqrt();
for i in (j + 1)..n {
let mut t = a[[j, i]];
for k in 0..j {
t -= r[[k, j]] * r[[k, i]];
}
r[[j, i]] = t / r[[j, j]];
}
}
Ok(r)
}
type SvdResult = Result<(Array2<f64>, Vec<f64>, Array2<f64>), String>;
fn svd_compute(a: &Array2<f64>) -> SvdResult {
let m = a.nrows();
let n = a.ncols();
if m < n {
let (v, s, u) = svd_compute(&a.t().to_owned())?;
return Ok((u, s, v));
}
let k = n;
let mut b = a.clone();
let mut v = Array2::<f64>::eye(k);
const MAX_ITER: usize = 200;
const EPS: f64 = 1e-14;
'outer: for _ in 0..MAX_ITER {
let mut changed = false;
for p in 0..k {
for q in (p + 1)..k {
let alpha: f64 = (0..m).map(|i| b[[i, p]] * b[[i, p]]).sum();
let beta: f64 = (0..m).map(|i| b[[i, q]] * b[[i, q]]).sum();
let gamma: f64 = (0..m).map(|i| b[[i, p]] * b[[i, q]]).sum();
if gamma.abs() <= EPS * (alpha * beta).sqrt() {
continue;
}
changed = true;
let zeta = (beta - alpha) / (2.0 * gamma);
let t = zeta.signum() / (zeta.abs() + (1.0 + zeta * zeta).sqrt());
let c = 1.0 / (1.0 + t * t).sqrt();
let s = c * t;
for i in 0..m {
let bp = b[[i, p]];
let bq = b[[i, q]];
b[[i, p]] = c * bp - s * bq;
b[[i, q]] = s * bp + c * bq;
}
for i in 0..k {
let vp = v[[i, p]];
let vq = v[[i, q]];
v[[i, p]] = c * vp - s * vq;
v[[i, q]] = s * vp + c * vq;
}
}
}
if !changed {
break 'outer;
}
}
let mut sigma: Vec<f64> = (0..k)
.map(|j| (0..m).map(|i| b[[i, j]] * b[[i, j]]).sum::<f64>().sqrt())
.collect();
let mut u_mat = Array2::<f64>::zeros((m, k));
for j in 0..k {
if sigma[j] > EPS {
for i in 0..m {
u_mat[[i, j]] = b[[i, j]] / sigma[j];
}
}
}
let mut order: Vec<usize> = (0..k).collect();
order.sort_by(|&a, &b| {
sigma[b]
.partial_cmp(&sigma[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let sigma_s: Vec<f64> = order.iter().map(|&i| sigma[i]).collect();
let mut u_s = Array2::<f64>::zeros((m, k));
let mut v_s = Array2::<f64>::zeros((n, k));
for (ni, &oi) in order.iter().enumerate() {
for r in 0..m {
u_s[[r, ni]] = u_mat[[r, oi]];
}
for r in 0..k {
v_s[[r, ni]] = v[[r, oi]];
}
}
sigma = sigma_s;
Ok((u_s, sigma, v_s))
}
fn complete_orthonormal_basis(u: &Array2<f64>) -> Array2<f64> {
let m = u.nrows();
let k = u.ncols();
let mut basis: Vec<Vec<f64>> = (0..k).map(|j| u.column(j).to_vec()).collect();
let mut ei = 0usize;
while basis.len() < m && ei < m {
let mut v: Vec<f64> = vec![0.0; m];
v[ei] = 1.0;
ei += 1;
for b in &basis {
let dot: f64 = v.iter().zip(b.iter()).map(|(&a, &b)| a * b).sum();
for (vi, &bi) in v.iter_mut().zip(b.iter()) {
*vi -= dot * bi;
}
}
let norm = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
if norm > 1e-10 {
for vi in &mut v {
*vi /= norm;
}
basis.push(v);
}
}
let mut result = Array2::<f64>::zeros((m, m));
for (j, b) in basis.iter().enumerate() {
for (i, &val) in b.iter().enumerate() {
result[[i, j]] = val;
}
}
result
}
fn eig_compute(a: &Array2<f64>) -> Result<(Vec<Complex<f64>>, Array2<f64>), String> {
let n = a.nrows();
if a.ncols() != n {
return Err("eig: matrix must be square".to_string());
}
if n == 0 {
return Ok((vec![], Array2::zeros((0, 0))));
}
if n == 1 {
return Ok((vec![Complex::new(a[[0, 0]], 0.0)], Array2::eye(1)));
}
let mut ak = a.clone();
let mut evecs = Array2::<f64>::eye(n);
const MAX_ITER: usize = 2000;
const EPS: f64 = 1e-12;
for _ in 0..MAX_ITER {
let mu = {
let d = ak[[n - 1, n - 1]];
if n >= 2 {
let a = ak[[n - 2, n - 2]];
let b = ak[[n - 2, n - 1]];
let delta = (a - d) / 2.0;
if delta.abs() < 1e-30 {
d - b.abs()
} else {
d - b * b / (delta + delta.signum() * (delta * delta + b * b).sqrt())
}
} else {
d
}
};
for i in 0..n {
ak[[i, i]] -= mu;
}
let (q, r) = qr_decompose(&ak)?;
ak = r.dot(&q);
for i in 0..n {
ak[[i, i]] += mu;
}
evecs = evecs.dot(&q);
let max_sub = (0..(n - 1))
.map(|i| ak[[i + 1, i]].abs())
.fold(0.0_f64, f64::max);
if max_sub < EPS {
break;
}
}
const EPS_BLOCK: f64 = 1e-8;
let mut evals: Vec<Complex<f64>> = Vec::with_capacity(n);
let mut i = 0;
while i < n {
if i + 1 < n && ak[[i + 1, i]].abs() > EPS_BLOCK {
let (a_ii, b, c, d_ii) = (
ak[[i, i]],
ak[[i, i + 1]],
ak[[i + 1, i]],
ak[[i + 1, i + 1]],
);
let p = (a_ii + d_ii) / 2.0;
let disc = ((a_ii - d_ii) / 2.0).powi(2) + b * c;
if disc < 0.0 {
let q = (-disc).sqrt();
evals.push(Complex::new(p, q));
evals.push(Complex::new(p, -q));
} else {
let q = disc.sqrt();
evals.push(Complex::new(p + q, 0.0));
evals.push(Complex::new(p - q, 0.0));
}
i += 2;
} else {
evals.push(Complex::new(ak[[i, i]], 0.0));
i += 1;
}
}
Ok((evals, evecs))
}
fn env_with_end(env: &Env, dim_size: usize) -> Env {
let mut e = env.clone();
e.insert("end".to_string(), Value::Scalar(dim_size as f64));
e
}
pub(crate) fn contains_end(expr: &Expr) -> bool {
match expr {
Expr::Var(s) => s == "end",
Expr::Number(_)
| Expr::Colon
| Expr::StrLiteral(_)
| Expr::StringObjLiteral(_)
| Expr::NaT
| Expr::FuncHandle(_) => false,
Expr::UnaryMinus(e)
| Expr::UnaryNot(e)
| Expr::Transpose(e)
| Expr::PlainTranspose(e)
| Expr::FieldGet(e, _) => contains_end(e),
Expr::DynFieldGet(a, b) => contains_end(a) || contains_end(b),
Expr::BinOp(l, _, r) => contains_end(l) || contains_end(r),
Expr::Call(_, args) | Expr::DotCall(_, args) => args.iter().any(contains_end),
Expr::Matrix(rows) => rows.iter().flat_map(|r| r.iter()).any(contains_end),
Expr::Range(a, step, b) => {
contains_end(a) || step.as_deref().is_some_and(contains_end) || contains_end(b)
}
Expr::Lambda { body, .. } => contains_end(body),
Expr::CellLiteral(elems) => elems.iter().any(contains_end),
Expr::CellIndex(a, b) => contains_end(a) || contains_end(b),
}
}
fn make_containers_map(
args: &[Expr],
env: &Env,
io: Option<&mut IoContext>,
) -> Result<Value, String> {
if args.is_empty() {
return Ok(Value::Map(Box::new(IndexMap::new())));
}
if args.len() != 2 {
return Err(
"containers.Map: expected 0 or 2 arguments (keys cell, values cell)".to_string(),
);
}
let mut io_wrap = io;
let keys_val = eval_inner(&args[0], env, io_wrap.as_deref_mut())?;
let vals_val = eval_inner(&args[1], env, io_wrap)?;
let keys = match keys_val {
Value::Cell(v) => v,
_ => {
return Err(
"containers.Map: first argument must be a cell array of strings".to_string(),
);
}
};
let values = match vals_val {
Value::Cell(v) => v,
_ => {
return Err(
"containers.Map: second argument must be a cell array of values".to_string(),
);
}
};
if keys.len() != values.len() {
return Err(format!(
"containers.Map: key count ({}) does not match value count ({})",
keys.len(),
values.len()
));
}
let keys = *keys;
let values = *values;
let mut map = IndexMap::new();
for (k, v) in keys.into_iter().zip(values) {
let key = match k {
Value::Str(s) | Value::StringObj(s) => s,
_ => return Err("containers.Map: all keys must be strings".to_string()),
};
map.insert(key, v);
}
Ok(Value::Map(Box::new(map)))
}
fn eval_index(val: &Value, args: &[Expr], env: &Env) -> Result<Value, String> {
match args.len() {
0 => Err("Indexing requires at least one index".to_string()),
1 => {
match val {
Value::Void => Err("Cannot index into void".to_string()),
Value::Lambda(_) | Value::Function(_) | Value::Tuple(_) => {
Err("Cannot index into a function value".to_string())
}
Value::Cell(_) => Err("Use c{i} to index into a cell array, not c(i)".to_string()),
Value::Struct(_) => {
Err("Use s.field to access struct fields, not s(i)".to_string())
}
Value::Map(map) => {
let key_val = eval_inner(&args[0], env, None)?;
let key = match key_val {
Value::Str(s) | Value::StringObj(s) => s,
_ => return Err("Map key must be a string".to_string()),
};
map.get(&key)
.cloned()
.ok_or_else(|| format!("Map key '{key}' not found"))
}
Value::StructArray(arr) => {
let total = arr.len();
let _owned_env;
let env1: &Env = if contains_end(&args[0]) {
_owned_env = env_with_end(env, total);
&_owned_env
} else {
env
};
match resolve_dim(&args[0], total, env1)? {
DimIdx::All => {
Ok(Value::StructArray(arr.clone()))
}
DimIdx::Indices(idxs) => {
if idxs.len() == 1 {
let i = idxs[0];
if i >= total {
return Err(format!(
"Index {} out of range (1..{})",
i + 1,
total
));
}
Ok(Value::Struct(Box::new(arr[i].clone())))
} else {
let mut selected = Vec::with_capacity(idxs.len());
for &i in &idxs {
if i >= total {
return Err(format!(
"Index {} out of range (1..{})",
i + 1,
total
));
}
selected.push(arr[i].clone());
}
Ok(Value::StructArray(Box::new(selected)))
}
}
}
}
Value::Scalar(n) => {
let _owned_env;
let env1: &Env = if contains_end(&args[0]) {
_owned_env = env_with_end(env, 1);
&_owned_env
} else {
env
};
match resolve_dim(&args[0], 1, env1)? {
DimIdx::All | DimIdx::Indices(_) => Ok(Value::Scalar(*n)),
}
}
Value::Complex(re, im) => {
let _owned_env;
let env1: &Env = if contains_end(&args[0]) {
_owned_env = env_with_end(env, 1);
&_owned_env
} else {
env
};
match resolve_dim(&args[0], 1, env1)? {
DimIdx::All | DimIdx::Indices(_) => Ok(Value::Complex(*re, *im)),
}
}
Value::ComplexMatrix(m) => {
let total = m.nrows() * m.ncols();
let _owned_env;
let env1: &Env = if contains_end(&args[0]) {
_owned_env = env_with_end(env, total);
&_owned_env
} else {
env
};
match resolve_dim(&args[0], total, env1)? {
DimIdx::All => {
let mut flat: Vec<Complex<f64>> = Vec::with_capacity(total);
for col in 0..m.ncols() {
for row in 0..m.nrows() {
flat.push(m[[row, col]]);
}
}
Ok(Value::ComplexMatrix(Box::new(
Array2::from_shape_vec((total, 1), flat).unwrap(),
)))
}
DimIdx::Indices(idxs) => {
let nrows = m.nrows();
let ncols_m = m.ncols();
let vals: Result<Vec<Complex<f64>>, String> = idxs
.iter()
.map(|&i| {
let row = i % nrows;
let col = i / nrows;
if col >= ncols_m {
Err(format!("Index {} out of range (1..{})", i + 1, total))
} else {
Ok(m[[row, col]])
}
})
.collect();
let vals = vals?;
if vals.len() == 1 {
let c = vals[0];
Ok(make_complex(c.re, c.im))
} else {
let n = vals.len();
Ok(Value::ComplexMatrix(Box::new(
Array2::from_shape_vec((1, n), vals).unwrap(),
)))
}
}
}
}
Value::Matrix(m) => {
let total = m.nrows() * m.ncols();
let _owned_env;
let env1: &Env = if contains_end(&args[0]) {
_owned_env = env_with_end(env, total);
&_owned_env
} else {
env
};
match resolve_dim(&args[0], total, env1)? {
DimIdx::All => {
let mut flat = Vec::with_capacity(total);
for col in 0..m.ncols() {
for row in 0..m.nrows() {
flat.push(m[[row, col]]);
}
}
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((total, 1), flat).unwrap(),
)))
}
DimIdx::Indices(idxs) => {
let nrows = m.nrows();
let ncols_m = m.ncols();
let vals: Result<Vec<f64>, String> = idxs
.iter()
.map(|&i| {
let row = i % nrows;
let col = i / nrows;
if col >= ncols_m {
Err(format!("Index {} out of range (1..{})", i + 1, total))
} else {
Ok(m[[row, col]])
}
})
.collect();
let vals = vals?;
if vals.len() == 1 {
Ok(Value::Scalar(vals[0]))
} else {
let n = vals.len();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, n), vals).unwrap(),
)))
}
}
}
}
Value::Str(s) => {
let chars: Vec<char> = s.chars().collect();
let total = chars.len();
let _owned_env;
let env1: &Env = if contains_end(&args[0]) {
_owned_env = env_with_end(env, total);
&_owned_env
} else {
env
};
match resolve_dim(&args[0], total, env1)? {
DimIdx::All => {
let codes: Vec<f64> = chars.iter().map(|&c| c as u32 as f64).collect();
if codes.len() == 1 {
Ok(Value::Scalar(codes[0]))
} else {
let n = codes.len();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((1, n), codes).unwrap(),
)))
}
}
DimIdx::Indices(idxs) => {
let mut selected = String::new();
for &i in &idxs {
if i >= chars.len() {
return Err(format!("Index {} out of range", i + 1));
}
selected.push(chars[i]);
}
if selected.chars().count() == 1 {
Ok(Value::Scalar(selected.chars().next().unwrap() as u32 as f64))
} else {
Ok(Value::Str(selected))
}
}
}
}
Value::StringObj(s) => {
let _owned_env;
let env1: &Env = if contains_end(&args[0]) {
_owned_env = env_with_end(env, 1);
&_owned_env
} else {
env
};
match resolve_dim(&args[0], 1, env1)? {
DimIdx::All | DimIdx::Indices(_) => Ok(Value::StringObj(s.clone())),
}
}
Value::DateTimeArray(v) => {
let total = v.len();
let _owned_env;
let env1: &Env = if contains_end(&args[0]) {
_owned_env = env_with_end(env, total);
&_owned_env
} else {
env
};
match resolve_dim(&args[0], total, env1)? {
DimIdx::All => Ok(Value::DateTimeArray(v.clone())),
DimIdx::Indices(idxs) => {
if idxs.len() == 1 {
let i = idxs[0];
if i >= total {
return Err(format!(
"Index {} out of range (1..{})",
i + 1,
total
));
}
Ok(Value::DateTime(v[i]))
} else {
let mut sel = Vec::with_capacity(idxs.len());
for &i in &idxs {
if i >= total {
return Err(format!(
"Index {} out of range (1..{})",
i + 1,
total
));
}
sel.push(v[i]);
}
Ok(Value::DateTimeArray(sel))
}
}
}
}
Value::DurationArray(v) => {
let total = v.len();
let _owned_env;
let env1: &Env = if contains_end(&args[0]) {
_owned_env = env_with_end(env, total);
&_owned_env
} else {
env
};
match resolve_dim(&args[0], total, env1)? {
DimIdx::All => Ok(Value::DurationArray(v.clone())),
DimIdx::Indices(idxs) => {
if idxs.len() == 1 {
let i = idxs[0];
if i >= total {
return Err(format!(
"Index {} out of range (1..{})",
i + 1,
total
));
}
Ok(Value::Duration(v[i]))
} else {
let mut sel = Vec::with_capacity(idxs.len());
for &i in &idxs {
if i >= total {
return Err(format!(
"Index {} out of range (1..{})",
i + 1,
total
));
}
sel.push(v[i]);
}
Ok(Value::DurationArray(sel))
}
}
}
}
Value::DateTime(_) | Value::Duration(_) => {
let _owned_env;
let env1: &Env = if contains_end(&args[0]) {
_owned_env = env_with_end(env, 1);
&_owned_env
} else {
env
};
match resolve_dim(&args[0], 1, env1)? {
DimIdx::All | DimIdx::Indices(_) => Ok(val.clone()),
}
}
}
}
2 => {
if matches!(
val,
Value::Void
| Value::Str(_)
| Value::StringObj(_)
| Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
) {
return Err("2D indexing not supported for this type".to_string());
}
let (nrows, ncols) = match val {
Value::Scalar(_) | Value::Complex(_, _) => (1, 1),
Value::Matrix(m) => (m.nrows(), m.ncols()),
Value::ComplexMatrix(m) => (m.nrows(), m.ncols()),
_ => unreachable!(),
};
let _owned_r;
let env_r: &Env = if contains_end(&args[0]) {
_owned_r = env_with_end(env, nrows);
&_owned_r
} else {
env
};
let _owned_c;
let env_c: &Env = if contains_end(&args[1]) {
_owned_c = env_with_end(env, ncols);
&_owned_c
} else {
env
};
let row_idx = resolve_dim(&args[0], nrows, env_r)?;
let col_idx = resolve_dim(&args[1], ncols, env_c)?;
let rows: Vec<usize> = match row_idx {
DimIdx::All => (0..nrows).collect(),
DimIdx::Indices(v) => v,
};
let cols: Vec<usize> = match col_idx {
DimIdx::All => (0..ncols).collect(),
DimIdx::Indices(v) => v,
};
if rows.len() == 1 && cols.len() == 1 {
match val {
Value::Scalar(n) => Ok(Value::Scalar(*n)),
Value::Complex(re, im) => Ok(Value::Complex(*re, *im)),
Value::Matrix(m) => Ok(Value::Scalar(m[[rows[0], cols[0]]])),
Value::ComplexMatrix(m) => {
let c = m[[rows[0], cols[0]]];
Ok(make_complex(c.re, c.im))
}
_ => unreachable!(),
}
} else {
let out_r = rows.len();
let out_c = cols.len();
match val {
Value::ComplexMatrix(m) => {
let flat: Vec<Complex<f64>> = rows
.iter()
.flat_map(|&r| cols.iter().map(move |&c| m[[r, c]]))
.collect();
Ok(Value::ComplexMatrix(Box::new(
Array2::from_shape_vec((out_r, out_c), flat).unwrap(),
)))
}
_ => {
let flat: Vec<f64> = rows
.iter()
.flat_map(|&r| {
cols.iter().map(move |&c| match val {
Value::Scalar(n) => *n,
Value::Complex(re, _) => *re,
Value::Matrix(m) => m[[r, c]],
_ => unreachable!(),
})
})
.collect();
Ok(Value::Matrix(Box::new(
Array2::from_shape_vec((out_r, out_c), flat).unwrap(),
)))
}
}
}
}
n => Err(format!(
"Indexing with {n} indices is not supported (max 2)"
)),
}
}
enum DimIdx {
All,
Indices(Vec<usize>),
}
fn resolve_dim(expr: &Expr, dim_size: usize, env: &Env) -> Result<DimIdx, String> {
if matches!(expr, Expr::Colon) {
return Ok(DimIdx::All);
}
let val = eval(expr, env)?;
let floats: Vec<f64> = match val {
Value::Void => {
return Err("Index must be numeric, not void".to_string());
}
Value::Scalar(n) => vec![n],
Value::Complex(re, im) => {
if im != 0.0 {
return Err("Index must be real, not complex".to_string());
}
vec![re]
}
Value::Matrix(m) => {
let total = m.nrows() * m.ncols();
if m.nrows() > 1 && m.ncols() > 1 && total != dim_size {
return Err("Index must be a scalar or vector, not a matrix".to_string());
}
if m.nrows() > 1 && m.ncols() > 1 {
let mut v = Vec::with_capacity(total);
for col in 0..m.ncols() {
for row in 0..m.nrows() {
v.push(m[[row, col]]);
}
}
v
} else {
m.iter().copied().collect()
}
}
Value::Str(_) | Value::StringObj(_) => {
return Err("Index must be numeric, not a string".to_string());
}
Value::ComplexMatrix(_) => {
return Err("Index must be real, not a complex matrix".to_string());
}
Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::StructArray(_)
| Value::DateTime(_)
| Value::Duration(_)
| Value::DateTimeArray(_)
| Value::DurationArray(_)
| Value::Map(_) => {
return Err("Index must be numeric, not a function or datetime".to_string());
}
};
if dim_size > 0 && floats.len() == dim_size && floats.iter().all(|&f| f == 0.0 || f == 1.0) {
let idxs: Vec<usize> = floats
.iter()
.enumerate()
.filter(|&(_, &f)| f == 1.0)
.map(|(i, _)| i)
.collect();
return Ok(DimIdx::Indices(idxs));
}
let mut idxs = Vec::with_capacity(floats.len());
for n in floats {
let i = n.round() as i64;
if i < 1 || i as usize > dim_size {
return Err(format!("Index {i} out of range (1..{dim_size})"));
}
idxs.push(i as usize - 1);
}
Ok(DimIdx::Indices(idxs))
}
pub fn format_number(n: f64) -> String {
if n.fract() == 0.0 && n.abs() < 1e15 {
format!("{}", n as i64)
} else if n != 0.0 && (n.abs() >= 1e15 || n.abs() < 1e-9) {
trim_sci(&format!("{:.15e}", n))
} else {
let s = format!("{:.10}", n);
s.trim_end_matches('0').trim_end_matches('.').to_string()
}
}
pub fn format_scalar(n: f64, base: Base, mode: &FormatMode) -> String {
if matches!(mode, FormatMode::Hex) {
return format_decimal(n, mode);
}
match base {
Base::Dec => format_decimal(n, mode),
_ => format_non_dec(n, base),
}
}
pub fn format_complex(re: f64, im: f64, mode: &FormatMode) -> String {
if im == 0.0 {
return format_decimal(re, mode);
}
let im_abs = im.abs();
let im_str = if im_abs == 1.0 {
String::new()
} else {
format_decimal(im_abs, mode)
};
if re == 0.0 {
if im < 0.0 {
format!("-{}i", im_str)
} else {
format!("{}i", im_str)
}
} else {
let re_str = format_decimal(re, mode);
if im < 0.0 {
format!("{} - {}i", re_str, im_str)
} else {
format!("{} + {}i", re_str, im_str)
}
}
}
pub fn expr_to_string(e: &Expr) -> String {
match e {
Expr::Number(n) => {
if n.is_nan() {
"nan".to_string()
} else if n.is_infinite() {
if *n > 0.0 {
"inf".to_string()
} else {
"-inf".to_string()
}
} else {
format!("{n}")
}
}
Expr::Var(name) => name.clone(),
Expr::UnaryMinus(e) => format!("-{}", expr_to_string(e)),
Expr::UnaryNot(e) => format!("~{}", expr_to_string(e)),
Expr::BinOp(l, op, r) => {
let op_str = match op {
Op::Add => "+",
Op::Sub => "-",
Op::Mul => "*",
Op::Div => "/",
Op::Pow => "^",
Op::ElemMul => ".*",
Op::ElemDiv => "./",
Op::ElemPow => ".^",
Op::Eq => "==",
Op::NotEq => "~=",
Op::Lt => "<",
Op::Gt => ">",
Op::LtEq => "<=",
Op::GtEq => ">=",
Op::And => "&&",
Op::Or => "||",
Op::ElemAnd => "&",
Op::ElemOr => "|",
Op::LDiv => "\\",
};
format!("{} {op_str} {}", expr_to_string(l), expr_to_string(r))
}
Expr::Call(name, args) => {
let args_str = args
.iter()
.map(expr_to_string)
.collect::<Vec<_>>()
.join(", ");
format!("{name}({args_str})")
}
Expr::Transpose(e) => format!("{}'", expr_to_string(e)),
Expr::PlainTranspose(e) => format!("{}.'", expr_to_string(e)),
Expr::Range(start, step, stop) => {
if let Some(step) = step {
format!(
"{}:{}:{}",
expr_to_string(start),
expr_to_string(step),
expr_to_string(stop)
)
} else {
format!("{}:{}", expr_to_string(start), expr_to_string(stop))
}
}
Expr::StrLiteral(s) => format!("'{s}'"),
Expr::StringObjLiteral(s) => format!("\"{s}\""),
Expr::Lambda { params, body, .. } => {
format!("@({}) {}", params.join(", "), expr_to_string(body))
}
Expr::FuncHandle(name) => format!("@{name}"),
Expr::Matrix(_) => "[...]".to_string(),
Expr::CellLiteral(_) => "{...}".to_string(),
Expr::CellIndex(e, i) => format!("{}{{{}}}", expr_to_string(e), expr_to_string(i)),
Expr::Colon => ":".to_string(),
Expr::NaT => "NaT".to_string(),
Expr::FieldGet(base, field) => format!("{}.{field}", expr_to_string(base)),
Expr::DynFieldGet(base, field_expr) => {
format!("{}.({})", expr_to_string(base), expr_to_string(field_expr))
}
Expr::DotCall(segs, args) => {
let args_str = args
.iter()
.map(expr_to_string)
.collect::<Vec<_>>()
.join(", ");
format!("{}({args_str})", segs.join("."))
}
}
}
pub fn format_value(v: &Value, base: Base, mode: &FormatMode) -> String {
match v {
Value::Void => String::new(),
Value::Scalar(n) => format_scalar(*n, base, mode),
Value::Matrix(m) => format!("[{}x{} double]", m.nrows(), m.ncols()),
Value::ComplexMatrix(m) => format!("[{}×{} complex]", m.nrows(), m.ncols()),
Value::Complex(re, im) => format_complex(*re, *im, mode),
Value::Str(s) => s.clone(),
Value::StringObj(s) => s.clone(),
Value::Lambda(lf) => lf.1.clone(),
Value::Function(fd) => {
let params_str = fd.params.join(", ");
let out_str = match fd.outputs.len() {
0 => String::new(),
1 => format!("{} = ", fd.outputs[0]),
_ => format!("[{}] = ", fd.outputs.join(", ")),
};
format!("@function {out_str}f({params_str})")
}
Value::Tuple(vals) => {
let parts: Vec<String> = vals.iter().map(|v| format_value(v, base, mode)).collect();
format!("({})", parts.join(", "))
}
Value::Cell(v) => format!("{{1×{} cell}}", v.len()),
Value::Struct(_) => "[1×1 struct]".to_string(),
Value::StructArray(arr) => format!("[1×{} struct]", arr.len()),
Value::DateTime(ts) => crate::datetime::format_datetime(*ts),
Value::Duration(s) => crate::datetime::format_duration(*s),
Value::DateTimeArray(v) => format!("[{}×1 datetime]", v.len()),
Value::DurationArray(v) => format!("[{}×1 duration]", v.len()),
Value::Map(m) => format!("[Map with {} entries]", m.len()),
}
}
pub fn format_value_full(v: &Value, mode: &FormatMode) -> Option<String> {
match v {
Value::Void
| Value::Scalar(_)
| Value::Complex(_, _)
| Value::Str(_)
| Value::StringObj(_)
| Value::Lambda(_)
| Value::Function(_)
| Value::Tuple(_)
| Value::DateTime(_)
| Value::Duration(_) => None,
Value::Matrix(m) => Some(format_matrix(m, mode)),
Value::ComplexMatrix(m) => Some(format_complex_matrix(m, mode)),
Value::Cell(elems) => Some(format_cell(elems, mode)),
Value::Struct(map) => Some(format_struct(map, mode)),
Value::StructArray(arr) => Some(format_struct_array(arr, mode)),
Value::DateTimeArray(v) => Some(format_datetime_array(v)),
Value::DurationArray(v) => Some(format_duration_array(v)),
Value::Map(m) => Some(format_map(m, mode)),
}
}
fn format_cell(elems: &[Value], mode: &FormatMode) -> String {
if elems.is_empty() {
return " {}".to_string();
}
let mut lines = vec![" {".to_string()];
for (i, val) in elems.iter().enumerate() {
let label = format!(" [1,{}]", i + 1);
match val {
Value::Matrix(_) => {
lines.push(format!("{label}:"));
if let Some(full) = format_value_full(val, mode) {
for line in full.lines() {
lines.push(format!(" {line}"));
}
}
}
Value::Cell(_) => {
lines.push(format!("{label}: {}", format_value(val, Base::Dec, mode)));
}
_ => {
lines.push(format!("{label}: {}", format_value(val, Base::Dec, mode)));
}
}
}
lines.push(" }".to_string());
lines.join("\n")
}
fn format_struct(map: &IndexMap<String, Value>, mode: &FormatMode) -> String {
let mut lines = vec![
String::new(),
" scalar structure containing the fields:".to_string(),
String::new(),
];
for (key, val) in map {
let val_str = match val {
Value::Struct(_) => "[1×1 struct]".to_string(),
Value::StructArray(arr) => format!("[1×{} struct]", arr.len()),
Value::Matrix(m) => format!("[{}×{} double]", m.nrows(), m.ncols()),
Value::Cell(v) => format!("{{1×{} cell}}", v.len()),
_ => format_value(val, Base::Dec, mode),
};
lines.push(format!(" {key}: {val_str}"));
}
lines.join("\n")
}
fn format_struct_array(arr: &[IndexMap<String, Value>], mode: &FormatMode) -> String {
let n = arr.len();
let mut lines = vec![
String::new(),
format!(" 1×{n} struct array with fields:"),
String::new(),
];
if let Some(first) = arr.first() {
for key in first.keys() {
lines.push(format!(" {key}"));
}
}
if n == 1
&& let Some(first) = arr.first()
{
lines.clear();
lines.push(String::new());
lines.push(" scalar structure containing the fields:".to_string());
lines.push(String::new());
for (key, val) in first {
let val_str = match val {
Value::Struct(_) => "[1×1 struct]".to_string(),
Value::StructArray(a) => format!("[1×{} struct]", a.len()),
Value::Matrix(m) => format!("[{}×{} double]", m.nrows(), m.ncols()),
Value::Cell(v) => format!("{{1×{} cell}}", v.len()),
_ => format_value(val, Base::Dec, mode),
};
lines.push(format!(" {key}: {val_str}"));
}
}
lines.join("\n")
}
fn format_map(map: &IndexMap<String, Value>, mode: &FormatMode) -> String {
let n = map.len();
let mut lines = vec![
String::new(),
format!(" Map with {n} entries:"),
String::new(),
];
for (key, val) in map {
let val_str = match val {
Value::Struct(_) => "[1×1 struct]".to_string(),
Value::Matrix(m) => format!("[{}×{} double]", m.nrows(), m.ncols()),
Value::Cell(v) => format!("{{1×{} cell}}", v.len()),
_ => format_value(val, Base::Dec, mode),
};
lines.push(format!(" '{key}' → {val_str}"));
}
lines.join("\n")
}
fn format_datetime_array(v: &[f64]) -> String {
let mut lines = Vec::with_capacity(v.len());
for ts in v {
lines.push(format!(" {}", crate::datetime::format_datetime(*ts)));
}
lines.join("\n")
}
fn format_duration_array(v: &[f64]) -> String {
let mut lines = Vec::with_capacity(v.len());
for secs in v {
lines.push(format!(" {}", crate::datetime::format_duration(*secs)));
}
lines.join("\n")
}
fn format_complex_matrix(m: &Array2<Complex<f64>>, mode: &FormatMode) -> String {
if m.nrows() == 0 || m.ncols() == 0 {
return " []".to_string();
}
let ncols = m.ncols();
let parts: Vec<Vec<(String, &'static str, String)>> = m
.rows()
.into_iter()
.map(|row| {
row.iter()
.map(|c| {
let re_str = format_decimal(c.re, mode);
let im_abs = format_decimal(c.im.abs(), mode);
let sign = if c.im < 0.0 { " - " } else { " + " };
(re_str, sign, im_abs)
})
.collect()
})
.collect();
let re_widths: Vec<usize> = (0..ncols)
.map(|c| parts.iter().map(|row| row[c].0.len()).max().unwrap_or(0))
.collect();
let im_widths: Vec<usize> = (0..ncols)
.map(|c| parts.iter().map(|row| row[c].2.len()).max().unwrap_or(0))
.collect();
let mut lines = Vec::new();
for row in &parts {
let mut line = String::from(" ");
for (c, (re_str, sign, im_str)) in row.iter().enumerate() {
if c > 0 {
let prev_im_pad = im_widths[c - 1].saturating_sub(row[c - 1].2.len());
for _ in 0..prev_im_pad {
line.push(' ');
}
line.push_str(" ");
}
let re_pad = re_widths[c].saturating_sub(re_str.len());
for _ in 0..re_pad {
line.push(' ');
}
line.push_str(re_str);
line.push_str(sign);
line.push_str(im_str);
line.push('i');
}
lines.push(line);
}
lines.join("\n")
}
fn format_matrix(m: &Array2<f64>, mode: &FormatMode) -> String {
if m.nrows() == 0 || m.ncols() == 0 {
return " []".to_string();
}
if matches!(mode, FormatMode::Plus) {
let lines: Vec<String> = m
.rows()
.into_iter()
.map(|row| {
let chars: String = row
.iter()
.map(|&x| {
if x > 0.0 {
'+'
} else if x < 0.0 {
'-'
} else {
'0'
}
})
.collect();
format!(" {}", chars)
})
.collect();
return lines.join("\n");
}
let ncols = m.ncols();
let cells: Vec<Vec<String>> = m
.rows()
.into_iter()
.map(|row| row.iter().map(|&x| format_decimal(x, mode)).collect())
.collect();
let col_widths: Vec<usize> = (0..ncols)
.map(|c| cells.iter().map(|row| row[c].len()).max().unwrap_or(0))
.collect();
let mut lines = Vec::new();
for row in &cells {
let mut line = String::from(" ");
for (c, cell) in row.iter().enumerate() {
if c > 0 {
line.push_str(" ");
}
let pad = col_widths[c].saturating_sub(cell.len());
for _ in 0..pad {
line.push(' ');
}
line.push_str(cell);
}
lines.push(line);
}
lines.join("\n")
}
pub fn format_non_dec(n: f64, base: Base) -> String {
let i = n.round() as i64;
let u = i.unsigned_abs();
let sign = if i < 0 { "-" } else { "" };
match base {
Base::Hex => format!("{}0x{:X}", sign, u),
Base::Bin => format!("{}0b{:b}", sign, u),
Base::Oct => format!("{}0o{:o}", sign, u),
Base::Dec => format_decimal(n, &FormatMode::default()),
}
}
fn format_decimal(n: f64, mode: &FormatMode) -> String {
if n.is_nan() {
return "NaN".to_string();
}
if n.is_infinite() {
return if n > 0.0 { "Inf" } else { "-Inf" }.to_string();
}
match mode {
FormatMode::Short | FormatMode::ShortG => fmt_auto_sig(n, 5),
FormatMode::Long | FormatMode::LongG => fmt_auto_sig(n, 15),
FormatMode::ShortE => fmt_sci_dp(n, 4),
FormatMode::LongE => fmt_sci_dp(n, 14),
FormatMode::Bank => format!("{:.2}", n),
FormatMode::Rat => fmt_rat(n),
FormatMode::Hex => fmt_hex_ieee754(n),
FormatMode::Plus => fmt_plus_sign(n),
FormatMode::Custom(prec) => fmt_custom_prec(n, *prec),
}
}
#[inline]
fn is_exact_int(n: f64) -> bool {
n.fract() == 0.0 && n.abs() < 1e15
}
fn fmt_auto_sig(n: f64, sig: usize) -> String {
if is_exact_int(n) {
return format!("{}", n as i64);
}
let abs_n = n.abs();
let exp = if abs_n == 0.0 {
0i32
} else {
abs_n.log10().floor() as i32
};
if exp >= -3 && exp < sig as i32 {
let dp = (sig as i32 - 1 - exp) as usize;
let s = format!("{:.prec$}", n, prec = dp);
if s.contains('.') {
s.trim_end_matches('0').trim_end_matches('.').to_string()
} else {
s
}
} else {
let s = format!("{:.prec$e}", n, prec = sig - 1);
trim_sci(&s)
}
}
fn fmt_sci_dp(n: f64, dp: usize) -> String {
let s = format!("{:.prec$e}", n, prec = dp);
trim_sci(&s)
}
fn fmt_custom_prec(n: f64, prec: usize) -> String {
if is_exact_int(n) {
return format!("{}", n as i64);
}
if n.abs() >= 1e15 || (n != 0.0 && n.abs() < 1e-9) {
let s = format!("{:.prec$e}", n, prec = prec);
trim_sci(&s)
} else {
let s = format!("{:.prec$}", n, prec = prec);
s.trim_end_matches('0').trim_end_matches('.').to_string()
}
}
fn fmt_rat(n: f64) -> String {
if is_exact_int(n) {
return format!("{}", n as i64);
}
let sign = if n < 0.0 { -1i64 } else { 1i64 };
let x = n.abs();
let (mut h1, mut h2): (i64, i64) = (1, 0);
let (mut k1, mut k2): (i64, i64) = (0, 1);
let mut b = x;
for _ in 0..64 {
let a = b.floor() as i64;
let (nh, nk) = (a * h1 + h2, a * k1 + k2);
if nk > 10_000 {
break;
}
h2 = h1;
h1 = nh;
k2 = k1;
k1 = nk;
let frac = b - a as f64;
if frac < 1e-12 || (h1 as f64 / k1 as f64 - x).abs() < 1e-6 {
break;
}
b = 1.0 / frac;
}
let p = sign * h1;
if k1 == 1 {
format!("{}", p)
} else {
format!("{}/{}", p, k1)
}
}
fn fmt_hex_ieee754(n: f64) -> String {
format!("{:016X}", n.to_bits())
}
fn fmt_plus_sign(n: f64) -> String {
if n > 0.0 {
"+".to_string()
} else if n < 0.0 {
"-".to_string()
} else {
" ".to_string()
}
}
fn trim_sci(s: &str) -> String {
if let Some(e_pos) = s.find('e') {
let mantissa = s[..e_pos].trim_end_matches('0').trim_end_matches('.');
let exp_str = &s[e_pos + 1..];
let (sign, digits) = if let Some(d) = exp_str.strip_prefix('-') {
("-", d)
} else if let Some(d) = exp_str.strip_prefix('+') {
("+", d)
} else {
("+", exp_str)
};
let exp_num: i32 = digits.parse().unwrap_or(0);
format!("{}e{}{:02}", mantissa, sign, exp_num)
} else {
s.to_string()
}
}
pub fn load_mat_file(path: &str) -> Result<Value, String> {
load_mat_file_impl(path)
}
#[cfg(feature = "mat")]
fn load_mat_file_impl(path: &str) -> Result<Value, String> {
crate::mat::mat_load(path)
}
#[cfg(not(feature = "mat"))]
fn load_mat_file_impl(_path: &str) -> Result<Value, String> {
Err("load: .mat support not available — rebuild with --features mat".to_string())
}
#[cfg(feature = "regex")]
fn regexp_impl(
fname: &str,
s: &str,
pat: &str,
ignore_case: bool,
return_match: bool,
) -> Result<Value, String> {
use ndarray::Array2;
let full_pat = if ignore_case {
format!("(?i){pat}")
} else {
pat.to_string()
};
let re = regex::Regex::new(&full_pat).map_err(|e| format!("{fname}: invalid pattern: {e}"))?;
if return_match {
let matches: Vec<Value> = re
.find_iter(s)
.map(|m| Value::Str(m.as_str().to_string()))
.collect();
Ok(Value::Cell(Box::new(matches)))
} else {
match re.find(s) {
Some(m) => Ok(Value::Scalar((s[..m.start()].chars().count() + 1) as f64)),
None => Ok(Value::Matrix(Box::new(Array2::zeros((0, 0))))),
}
}
}
#[cfg(not(feature = "regex"))]
fn regexp_impl(
fname: &str,
_s: &str,
_pat: &str,
_ignore_case: bool,
_return_match: bool,
) -> Result<Value, String> {
Err(format!(
"{fname}: not available — rebuild with --features regex"
))
}
#[cfg(feature = "regex")]
fn regexprep_impl(s: &str, pat: &str, rep: &str) -> Result<Value, String> {
let re = regex::Regex::new(pat).map_err(|e| format!("regexprep: invalid pattern: {e}"))?;
let result = re.replace_all(s, regex::NoExpand(rep));
Ok(Value::Str(result.into_owned()))
}
#[cfg(not(feature = "regex"))]
fn regexprep_impl(_s: &str, _pat: &str, _rep: &str) -> Result<Value, String> {
Err("regexprep: not available — rebuild with --features regex".to_string())
}
#[cfg(feature = "fft")]
fn extract_real_vec(v: &Value, name: &str) -> Result<Vec<f64>, String> {
match v {
Value::Scalar(s) => Ok(vec![*s]),
Value::Matrix(m) if m.nrows() == 1 || m.ncols() == 1 => Ok(m.iter().copied().collect()),
Value::Matrix(m) => Err(format!(
"{name}: input must be a vector (got {}×{} matrix)",
m.nrows(),
m.ncols()
)),
_ => Err(format!("{name}: input must be a real numeric vector")),
}
}
#[cfg(feature = "fft")]
fn complex_pairs_to_complex_matrix(data: Vec<(f64, f64)>) -> Value {
let n = data.len();
if n == 0 {
return Value::ComplexMatrix(Box::new(Array2::zeros((1, 0))));
}
let elems: Vec<Complex<f64>> = data
.into_iter()
.map(|(re, im)| Complex::new(re, im))
.collect();
Value::ComplexMatrix(Box::new(Array2::from_shape_vec((1, n), elems).unwrap()))
}
#[cfg(feature = "fft")]
fn extract_complex_vec(v: &Value, name: &str) -> Result<Vec<(f64, f64)>, String> {
match v {
Value::Scalar(s) => Ok(vec![(*s, 0.0)]),
Value::Matrix(m) => Ok(m.iter().copied().map(|x| (x, 0.0)).collect()),
Value::ComplexMatrix(m) => Ok(m.iter().map(|c| (c.re, c.im)).collect()),
Value::Cell(elems) => elems
.iter()
.enumerate()
.map(|(i, e)| match e {
Value::Complex(re, im) => Ok((*re, *im)),
Value::Scalar(s) => Ok((*s, 0.0)),
_ => Err(format!(
"{name}: cell element {} must be a complex or real number",
i + 1
)),
})
.collect(),
_ => Err(format!(
"{name}: input must be a complex matrix, cell array, or numeric vector"
)),
}
}
#[cfg(feature = "fft")]
fn fft_call(v: &Value, n_opt: Option<usize>) -> Result<Value, String> {
let real = extract_real_vec(v, "fft")?;
let n = n_opt.unwrap_or(real.len());
if n == 0 {
return Err("fft: length must be positive".to_string());
}
let out = crate::fft::fft_forward(&real, n);
Ok(complex_pairs_to_complex_matrix(out))
}
#[cfg(not(feature = "fft"))]
fn fft_call(_v: &Value, _n_opt: Option<usize>) -> Result<Value, String> {
Err("fft: not available — rebuild with --features fft".to_string())
}
#[cfg(feature = "fft")]
fn ifft_call(v: &Value) -> Result<Value, String> {
let complex = extract_complex_vec(v, "ifft")?;
if complex.is_empty() {
return Ok(Value::Matrix(Box::new(ndarray::Array2::zeros((1, 0)))));
}
let out = crate::fft::fft_inverse(&complex);
if out.iter().all(|(_, im)| im.abs() < 1e-12) {
let real: Vec<f64> = out.iter().map(|(re, _)| *re).collect();
let n = real.len();
Ok(Value::Matrix(Box::new(
ndarray::Array2::from_shape_vec((1, n), real).unwrap(),
)))
} else {
Ok(complex_pairs_to_complex_matrix(out))
}
}
#[cfg(not(feature = "fft"))]
fn ifft_call(_v: &Value) -> Result<Value, String> {
Err("ifft: not available — rebuild with --features fft".to_string())
}
#[cfg(feature = "json")]
fn jsondecode_impl(arg: &Value) -> Result<Value, String> {
let s = match arg {
Value::Str(s) | Value::StringObj(s) => s.as_str(),
_ => return Err("jsondecode: argument must be a string".to_string()),
};
let jval: serde_json::Value =
serde_json::from_str(s).map_err(|e| format!("jsondecode: invalid JSON: {e}"))?;
Ok(crate::json::json_to_value(&jval))
}
#[cfg(not(feature = "json"))]
fn jsondecode_impl(_arg: &Value) -> Result<Value, String> {
Err("jsondecode: not available — rebuild with --features json".to_string())
}
#[cfg(feature = "json")]
fn jsonencode_impl(arg: &Value) -> Result<Value, String> {
let jval = crate::json::value_to_json(arg)?;
let s = serde_json::to_string(&jval)
.map_err(|e| format!("jsonencode: serialization error: {e}"))?;
Ok(Value::Str(s))
}
#[cfg(not(feature = "json"))]
fn jsonencode_impl(_arg: &Value) -> Result<Value, String> {
Err("jsonencode: not available — rebuild with --features json".to_string())
}
fn cpoly_eval(coeffs: &[f64], z: (f64, f64)) -> (f64, f64) {
let mut acc = (0.0_f64, 0.0_f64);
for &c in coeffs {
acc = (acc.0 * z.0 - acc.1 * z.1 + c, acc.0 * z.1 + acc.1 * z.0);
}
acc
}
fn horner(coeffs: &[f64], x: f64) -> f64 {
coeffs.iter().fold(0.0, |acc, &c| acc * x + c)
}
fn poly_coeffs(v: &Value, fname: &str) -> Result<Vec<f64>, String> {
match v {
Value::Scalar(s) => Ok(vec![*s]),
Value::Matrix(m) => {
if m.nrows() == 1 {
Ok(m.row(0).iter().copied().collect())
} else if m.ncols() == 1 {
Ok(m.column(0).iter().copied().collect())
} else {
Err(format!(
"{fname}: argument must be a vector, got {}×{}",
m.nrows(),
m.ncols()
))
}
}
_ => Err(format!("{fname}: argument must be a real numeric vector")),
}
}
fn poly_conv(a: &[f64], b: &[f64]) -> Vec<f64> {
if a.is_empty() || b.is_empty() {
return vec![];
}
let mut result = vec![0.0_f64; a.len() + b.len() - 1];
for (i, &ai) in a.iter().enumerate() {
for (j, &bj) in b.iter().enumerate() {
result[i + j] += ai * bj;
}
}
result
}
fn poly_deconv(c: &[f64], b: &[f64]) -> Result<(Vec<f64>, Vec<f64>), String> {
if b.is_empty() || b.iter().all(|&x| x == 0.0) {
return Err("deconv: divisor polynomial must not be zero".to_string());
}
let mc = c.len();
let mb = b.len();
if mb > mc {
return Ok((vec![0.0], c.to_vec()));
}
let q_len = mc - mb + 1;
let mut remainder = c.to_vec();
let mut q = vec![0.0_f64; q_len];
for i in 0..q_len {
let coeff = remainder[i] / b[0];
q[i] = coeff;
for j in 0..mb {
remainder[i + j] -= coeff * b[j];
}
}
let scale = c.iter().map(|v| v.abs()).fold(0.0_f64, f64::max).max(1.0);
for x in &mut remainder {
if x.abs() < 1e-10 * scale {
*x = 0.0;
}
}
Ok((q, remainder))
}
fn durand_kerner(coeffs: &[f64]) -> Result<Vec<(f64, f64)>, String> {
let n = coeffs.len() - 1; if n == 0 {
return Ok(vec![]);
}
let lc = coeffs[0];
if lc == 0.0 {
return Err("roots: leading coefficient must not be zero".to_string());
}
let monic: Vec<f64> = coeffs.iter().map(|&c| c / lc).collect();
let r = 1.0 + monic[1..].iter().map(|c| c.abs()).fold(0.0_f64, f64::max);
let mut z: Vec<(f64, f64)> = (0..n)
.map(|k| {
let angle = 2.0 * std::f64::consts::PI * (k as f64 + 0.25) / n as f64;
(r * angle.cos(), r * angle.sin())
})
.collect();
const MAX_ITER: usize = 2000;
const EPS: f64 = 1e-12;
for _ in 0..MAX_ITER {
let z_old = z.clone();
let mut max_corr = 0.0_f64;
for i in 0..n {
let (pre, pim) = cpoly_eval(&monic, z_old[i]);
let mut dre = 1.0_f64;
let mut dim = 0.0_f64;
for j in 0..n {
if j == i {
continue;
}
let (dr, di) = (z_old[i].0 - z_old[j].0, z_old[i].1 - z_old[j].1);
(dre, dim) = (dre * dr - dim * di, dre * di + dim * dr);
}
let d2 = dre * dre + dim * dim;
let (cre, cim) = if d2 > 0.0 {
((pre * dre + pim * dim) / d2, (pim * dre - pre * dim) / d2)
} else {
(pre, pim)
};
let corr_abs = (cre * cre + cim * cim).sqrt();
max_corr = max_corr.max(corr_abs);
z[i] = (z_old[i].0 - cre, z_old[i].1 - cim);
}
if max_corr < EPS {
break;
}
}
z.sort_by(|a, b| {
b.0.partial_cmp(&a.0)
.unwrap_or(std::cmp::Ordering::Equal)
.then(b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal))
});
Ok(z)
}
fn roots_to_value(roots: &[(f64, f64)]) -> Value {
const IMAG_TOL: f64 = 1e-9;
let all_real = roots.iter().all(|(_, im)| im.abs() < IMAG_TOL);
if all_real {
let data: Vec<f64> = roots.iter().map(|(re, _)| *re).collect();
let n = data.len();
Value::Matrix(Box::new(Array2::from_shape_vec((n, 1), data).unwrap()))
} else {
let vals: Vec<Value> = roots
.iter()
.map(|&(re, im)| {
if im.abs() < IMAG_TOL {
Value::Scalar(re)
} else {
Value::Complex(re, im)
}
})
.collect();
Value::Cell(Box::new(vals))
}
}
fn characteristic_poly(a: &Array2<f64>) -> Result<Vec<f64>, String> {
let n = a.nrows();
if a.ncols() != n {
return Err("poly: matrix must be square".to_string());
}
if n == 0 {
return Ok(vec![1.0]);
}
let mut coeffs = vec![0.0_f64; n + 1];
coeffs[0] = 1.0;
let mut nk = Array2::<f64>::eye(n); for (k, coeff) in coeffs.iter_mut().enumerate().skip(1) {
let ank = a.dot(&nk); let tr: f64 = (0..n).map(|i| ank[[i, i]]).sum();
let ak = -tr / k as f64;
*coeff = ak;
nk = ank; for i in 0..n {
nk[[i, i]] += ak;
}
}
Ok(coeffs)
}
fn poly_back_sub(r: &Array2<f64>, b: &[f64]) -> Result<Vec<f64>, String> {
let n = r.nrows();
let mut x = vec![0.0_f64; n];
for i in (0..n).rev() {
let mut s = b[i];
for j in (i + 1)..n {
s -= r[[i, j]] * x[j];
}
if r[[i, i]].abs() < 1e-14 {
return Err(
"polyfit: Vandermonde matrix is rank-deficient; reduce polynomial degree"
.to_string(),
);
}
x[i] = s / r[[i, i]];
}
Ok(x)
}
fn interp1_at(x: &[f64], y: &[f64], xi: f64, method: &str) -> f64 {
let n = x.len();
if xi < x[0] || xi > x[n - 1] {
return f64::NAN;
}
let lo = x.partition_point(|&xk| xk <= xi).saturating_sub(1);
let lo2 = lo.min(n - 2);
match method {
"nearest" => {
if lo == n - 1 {
return y[n - 1];
}
if (xi - x[lo2]) <= (x[lo2 + 1] - xi) {
y[lo2]
} else {
y[lo2 + 1]
}
}
"previous" => y[lo],
"next" => {
if lo == n - 1 || xi == x[lo] {
y[lo]
} else {
y[lo2 + 1]
}
}
_ => {
if lo == n - 1 {
return y[n - 1];
}
let t = (xi - x[lo2]) / (x[lo2 + 1] - x[lo2]);
y[lo2] + t * (y[lo2 + 1] - y[lo2])
}
}
}
#[cfg(test)]
#[path = "eval_tests.rs"]
mod tests;