Skip to main content

runmat_runtime/builtins/math/optim/
fsolve.rs

1//! MATLAB-compatible `fsolve` builtin for nonlinear systems.
2
3use nalgebra::{DMatrix, DVector};
4use runmat_builtins::{
5    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
6    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
7};
8use runmat_builtins::{StructValue, Value};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::common::spec::{
12    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13    ReductionNaN, ResidencyPolicy, ShapeRequirements,
14};
15use crate::builtins::math::optim::common::{
16    call_function, initial_guess, option_f64, option_string, option_usize, value_to_real_vector,
17    vector_to_value,
18};
19use crate::builtins::math::optim::type_resolvers::nonlinear_solve_type;
20use crate::{build_runtime_error, BuiltinResult, RuntimeError};
21
22const NAME: &str = "fsolve";
23const DEFAULT_TOL_X: f64 = 1.0e-6;
24const DEFAULT_TOL_FUN: f64 = 1.0e-6;
25const DEFAULT_MAX_ITER: usize = 400;
26
27const FSOLVE_OUTPUT_X: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
28    name: "x",
29    ty: BuiltinParamType::NumericArray,
30    arity: BuiltinParamArity::Required,
31    default: None,
32    description: "Approximate solution vector/scalar.",
33}];
34
35const FSOLVE_INPUTS_CORE: [BuiltinParamDescriptor; 2] = [
36    BuiltinParamDescriptor {
37        name: "fun",
38        ty: BuiltinParamType::Any,
39        arity: BuiltinParamArity::Required,
40        default: None,
41        description: "System residual callback.",
42    },
43    BuiltinParamDescriptor {
44        name: "x0",
45        ty: BuiltinParamType::Any,
46        arity: BuiltinParamArity::Required,
47        default: None,
48        description: "Initial guess scalar/vector.",
49    },
50];
51
52const FSOLVE_INPUTS_WITH_OPTIONS: [BuiltinParamDescriptor; 3] = [
53    BuiltinParamDescriptor {
54        name: "fun",
55        ty: BuiltinParamType::Any,
56        arity: BuiltinParamArity::Required,
57        default: None,
58        description: "System residual callback.",
59    },
60    BuiltinParamDescriptor {
61        name: "x0",
62        ty: BuiltinParamType::Any,
63        arity: BuiltinParamArity::Required,
64        default: None,
65        description: "Initial guess scalar/vector.",
66    },
67    BuiltinParamDescriptor {
68        name: "options",
69        ty: BuiltinParamType::Any,
70        arity: BuiltinParamArity::Optional,
71        default: None,
72        description: "Options struct from optimset.",
73    },
74];
75
76const FSOLVE_SIGNATURES: [BuiltinSignatureDescriptor; 2] = [
77    BuiltinSignatureDescriptor {
78        label: "x = fsolve(fun, x0)",
79        inputs: &FSOLVE_INPUTS_CORE,
80        outputs: &FSOLVE_OUTPUT_X,
81    },
82    BuiltinSignatureDescriptor {
83        label: "x = fsolve(fun, x0, options)",
84        inputs: &FSOLVE_INPUTS_WITH_OPTIONS,
85        outputs: &FSOLVE_OUTPUT_X,
86    },
87];
88
89const FSOLVE_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
90    code: "RM.FSOLVE.INVALID_ARGUMENT",
91    identifier: Some("RunMat:fsolve:InvalidArgument"),
92    when: "Argument grammar/options configuration is invalid.",
93    message: "fsolve: invalid argument",
94};
95
96const FSOLVE_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
97    code: "RM.FSOLVE.INVALID_INPUT",
98    identifier: Some("RunMat:fsolve:InvalidInput"),
99    when: "Initial guess/callback/iteration semantics are invalid.",
100    message: "fsolve: invalid input",
101};
102
103const FSOLVE_ERRORS: [BuiltinErrorDescriptor; 2] =
104    [FSOLVE_ERROR_INVALID_ARGUMENT, FSOLVE_ERROR_INVALID_INPUT];
105
106pub const FSOLVE_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
107    signatures: &FSOLVE_SIGNATURES,
108    output_mode: BuiltinOutputMode::Fixed,
109    completion_policy: BuiltinCompletionPolicy::Public,
110    errors: &FSOLVE_ERRORS,
111};
112
113fn fsolve_error_with_detail(
114    error: &'static BuiltinErrorDescriptor,
115    detail: impl AsRef<str>,
116) -> RuntimeError {
117    let detail = detail.as_ref();
118    let message = if detail.starts_with("fsolve:") {
119        detail.to_string()
120    } else {
121        format!("{}: {detail}", error.message)
122    };
123    let mut builder = build_runtime_error(message).with_builtin(NAME);
124    if let Some(identifier) = error.identifier {
125        builder = builder.with_identifier(identifier);
126    }
127    builder.build()
128}
129
130fn fsolve_map_error(err: RuntimeError, fallback: &'static BuiltinErrorDescriptor) -> RuntimeError {
131    if err.identifier().is_some() {
132        err
133    } else {
134        fsolve_error_with_detail(fallback, err.message())
135    }
136}
137
138#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::optim::fsolve")]
139pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
140    name: "fsolve",
141    op_kind: GpuOpKind::Custom("nonlinear-solve"),
142    supported_precisions: &[],
143    broadcast: BroadcastSemantics::None,
144    provider_hooks: &[],
145    constant_strategy: ConstantStrategy::InlineLiteral,
146    residency: ResidencyPolicy::GatherImmediately,
147    nan_mode: ReductionNaN::Include,
148    two_pass_threshold: None,
149    workgroup_size: None,
150    accepts_nan_mode: false,
151    notes: "Host finite-difference Levenberg-Marquardt solver. Callback computations may use GPU-aware builtins.",
152};
153
154#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::optim::fsolve")]
155pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
156    name: "fsolve",
157    shape: ShapeRequirements::Any,
158    constant_strategy: ConstantStrategy::InlineLiteral,
159    elementwise: None,
160    reduction: None,
161    emits_nan: false,
162    notes: "Nonlinear solving repeatedly invokes user code and terminates fusion planning.",
163};
164
165#[runtime_builtin(
166    name = "fsolve",
167    category = "math/optim",
168    summary = "Solve nonlinear equation systems.",
169    keywords = "fsolve,nonlinear solve,root finding,levenberg-marquardt,jacobian",
170    accel = "sink",
171    type_resolver(nonlinear_solve_type),
172    descriptor(crate::builtins::math::optim::fsolve::FSOLVE_DESCRIPTOR),
173    builtin_path = "crate::builtins::math::optim::fsolve"
174)]
175async fn fsolve_builtin(function: Value, x0: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
176    if rest.len() > 1 {
177        return Err(fsolve_error_with_detail(
178            &FSOLVE_ERROR_INVALID_ARGUMENT,
179            "too many input arguments",
180        ));
181    }
182    let options = parse_options(rest.first())
183        .map_err(|err| fsolve_map_error(err, &FSOLVE_ERROR_INVALID_ARGUMENT))?;
184    let opts = FsolveOptions::from_struct(options.as_ref())
185        .map_err(|err| fsolve_map_error(err, &FSOLVE_ERROR_INVALID_ARGUMENT))?;
186    let guess = initial_guess(NAME, x0)
187        .await
188        .map_err(|err| fsolve_map_error(err, &FSOLVE_ERROR_INVALID_INPUT))?;
189    let solution = solve(&function, guess.values, &guess.shape, guess.scalar, &opts)
190        .await
191        .map_err(|err| fsolve_map_error(err, &FSOLVE_ERROR_INVALID_INPUT))?;
192    vector_to_value(NAME, solution, &guess.shape, guess.scalar)
193        .map_err(|err| fsolve_map_error(err, &FSOLVE_ERROR_INVALID_INPUT))
194}
195
196fn parse_options(value: Option<&Value>) -> BuiltinResult<Option<StructValue>> {
197    match value {
198        None => Ok(None),
199        Some(Value::Struct(options)) => Ok(Some(options.clone())),
200        Some(other) => Err(fsolve_error_with_detail(
201            &FSOLVE_ERROR_INVALID_ARGUMENT,
202            format!("options must be a struct, got {other:?}"),
203        )),
204    }
205}
206
207#[derive(Clone, Copy)]
208struct FsolveOptions {
209    tol_x: f64,
210    tol_fun: f64,
211    max_iter: usize,
212    max_fun_evals: usize,
213}
214
215impl FsolveOptions {
216    fn from_struct(options: Option<&StructValue>) -> BuiltinResult<Self> {
217        let display = option_string(options, "Display", "off")?;
218        if !matches!(display.as_str(), "off" | "none" | "final" | "iter") {
219            return Err(fsolve_error_with_detail(
220                &FSOLVE_ERROR_INVALID_ARGUMENT,
221                "option Display must be 'off', 'none', 'final', or 'iter'",
222            ));
223        }
224        let tol_x = option_f64(NAME, options, "TolX", DEFAULT_TOL_X)?;
225        let tol_fun = option_f64(NAME, options, "TolFun", DEFAULT_TOL_FUN)?;
226        if tol_x <= 0.0 || tol_fun <= 0.0 {
227            return Err(fsolve_error_with_detail(
228                &FSOLVE_ERROR_INVALID_ARGUMENT,
229                "options TolX and TolFun must be positive",
230            ));
231        }
232        let max_iter = option_usize(NAME, options, "MaxIter", DEFAULT_MAX_ITER)?.max(1);
233        let max_fun_evals = option_usize(NAME, options, "MaxFunEvals", 100 * max_iter)?.max(1);
234        Ok(Self {
235            tol_x,
236            tol_fun,
237            max_iter,
238            max_fun_evals,
239        })
240    }
241}
242
243async fn solve(
244    function: &Value,
245    mut x: Vec<f64>,
246    shape: &[usize],
247    scalar: bool,
248    options: &FsolveOptions,
249) -> BuiltinResult<Vec<f64>> {
250    let n = x.len();
251    if n == 0 {
252        return Err(fsolve_error_with_detail(
253            &FSOLVE_ERROR_INVALID_INPUT,
254            "initial guess cannot be empty",
255        ));
256    }
257
258    let mut residual = eval_residual(function, &x, shape, scalar).await?;
259    let mut evals = 1usize;
260    let mut lambda = 1.0e-3;
261
262    if residual_norm_inf(&residual) <= options.tol_fun {
263        return Ok(x);
264    }
265
266    for _ in 0..options.max_iter {
267        if evals >= options.max_fun_evals {
268            return Err(fsolve_error_with_detail(
269                &FSOLVE_ERROR_INVALID_INPUT,
270                "exceeded maximum function evaluations",
271            ));
272        }
273        let jacobian =
274            finite_difference_jacobian(function, &x, shape, scalar, &residual, &mut evals, options)
275                .await?;
276        let j = DMatrix::from_row_slice(residual.len(), n, &jacobian);
277        let f = DVector::from_column_slice(&residual);
278        let gradient = j.transpose() * &f;
279        let mut accepted = false;
280
281        for _ in 0..8 {
282            let normal = j.transpose() * &j + DMatrix::<f64>::identity(n, n) * lambda;
283            let rhs = -&gradient;
284            let Some(delta) = normal.lu().solve(&rhs) else {
285                lambda *= 10.0;
286                continue;
287            };
288            let trial = x
289                .iter()
290                .zip(delta.iter())
291                .map(|(xi, di)| xi + di)
292                .collect::<Vec<_>>();
293            let trial_residual = eval_residual(function, &trial, shape, scalar).await?;
294            evals += 1;
295
296            if norm2(&trial_residual) < norm2(&residual) {
297                let step_norm = delta
298                    .iter()
299                    .fold(0.0_f64, |acc, value| acc.max(value.abs()));
300                let x_norm = x.iter().fold(0.0_f64, |acc, value| acc.max(value.abs()));
301                x = trial;
302                residual = trial_residual;
303                lambda = (lambda * 0.3).max(1.0e-12);
304                accepted = true;
305                if residual_norm_inf(&residual) <= options.tol_fun
306                    || step_norm <= options.tol_x * (1.0 + x_norm)
307                {
308                    return Ok(x);
309                }
310                break;
311            }
312
313            lambda *= 10.0;
314            if evals >= options.max_fun_evals {
315                return Err(fsolve_error_with_detail(
316                    &FSOLVE_ERROR_INVALID_INPUT,
317                    "exceeded maximum function evaluations",
318                ));
319            }
320        }
321
322        if !accepted {
323            return Err(fsolve_error_with_detail(
324                &FSOLVE_ERROR_INVALID_INPUT,
325                "iteration stalled before convergence",
326            ));
327        }
328    }
329
330    Err(fsolve_error_with_detail(
331        &FSOLVE_ERROR_INVALID_INPUT,
332        "exceeded maximum iterations",
333    ))
334}
335
336async fn eval_residual(
337    function: &Value,
338    x: &[f64],
339    shape: &[usize],
340    scalar: bool,
341) -> BuiltinResult<Vec<f64>> {
342    let arg = if scalar {
343        Value::Num(x[0])
344    } else {
345        Value::Tensor(
346            runmat_builtins::Tensor::new(x.to_vec(), shape.to_vec())
347                .map_err(|e| fsolve_error_with_detail(&FSOLVE_ERROR_INVALID_INPUT, e))?,
348        )
349    };
350    let value = call_function(function, vec![arg]).await?;
351    let residual = value_to_real_vector(NAME, value).await?;
352    if residual.is_empty() {
353        Err(fsolve_error_with_detail(
354            &FSOLVE_ERROR_INVALID_INPUT,
355            "function value must not be empty",
356        ))
357    } else {
358        Ok(residual)
359    }
360}
361
362async fn finite_difference_jacobian(
363    function: &Value,
364    x: &[f64],
365    shape: &[usize],
366    scalar: bool,
367    residual: &[f64],
368    evals: &mut usize,
369    options: &FsolveOptions,
370) -> BuiltinResult<Vec<f64>> {
371    let m = residual.len();
372    let n = x.len();
373    let mut jacobian = vec![0.0; m * n];
374
375    for col in 0..n {
376        if *evals >= options.max_fun_evals {
377            return Err(fsolve_error_with_detail(
378                &FSOLVE_ERROR_INVALID_INPUT,
379                "exceeded maximum function evaluations",
380            ));
381        }
382        let mut perturbed = x.to_vec();
383        let step = f64::EPSILON.sqrt() * (x[col].abs() + 1.0);
384        perturbed[col] += step;
385        let next = eval_residual(function, &perturbed, shape, scalar).await?;
386        *evals += 1;
387        if next.len() != m {
388            return Err(fsolve_error_with_detail(
389                &FSOLVE_ERROR_INVALID_INPUT,
390                "function output size changed during finite differencing",
391            ));
392        }
393        for row in 0..m {
394            jacobian[row * n + col] = (next[row] - residual[row]) / step;
395        }
396    }
397
398    Ok(jacobian)
399}
400
401fn norm2(values: &[f64]) -> f64 {
402    values.iter().map(|value| value * value).sum::<f64>().sqrt()
403}
404
405fn residual_norm_inf(values: &[f64]) -> f64 {
406    values
407        .iter()
408        .fold(0.0_f64, |acc, value| acc.max(value.abs()))
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414    use futures::executor::block_on;
415    use runmat_builtins::Tensor;
416    use std::sync::{Arc, Mutex};
417
418    #[test]
419    fn fsolve_scalar_builtin_handle() {
420        let root = block_on(fsolve_builtin(
421            Value::FunctionHandle("sin".into()),
422            Value::Num(3.0),
423            Vec::new(),
424        ))
425        .unwrap();
426        match root {
427            Value::Num(n) => assert!((n - std::f64::consts::PI).abs() < 1.0e-5),
428            other => panic!("unexpected value {other:?}"),
429        }
430    }
431
432    #[test]
433    fn fsolve_vector_system_via_semantic_resolver() {
434        let _resolver =
435            crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|_name| {
436                Some(0)
437            })));
438        let _invoker = crate::user_functions::install_semantic_function_invoker(Some(
439            std::sync::Arc::new(|_function, args, _requested_outputs| {
440                let x = match &args[0] {
441                    Value::Tensor(t) => t.data.clone(),
442                    _ => panic!("expected tensor input"),
443                };
444                Box::pin(async move {
445                    Ok(Value::Tensor(
446                        Tensor::new(
447                            vec![x[0] * x[0] + x[1] * x[1] - 4.0, x[0] * x[1] - 1.0],
448                            vec![2, 1],
449                        )
450                        .unwrap(),
451                    ))
452                })
453            }),
454        ));
455        let x0 = Tensor::new(vec![1.0, 1.0], vec![2, 1]).unwrap();
456        let root = block_on(fsolve_builtin(
457            Value::FunctionHandle("system".into()),
458            Value::Tensor(x0),
459            Vec::new(),
460        ))
461        .unwrap();
462        match root {
463            Value::Tensor(t) => {
464                assert!((t.data[0] * t.data[0] + t.data[1] * t.data[1] - 4.0).abs() < 1.0e-5);
465                assert!((t.data[0] * t.data[1] - 1.0).abs() < 1.0e-5);
466            }
467            other => panic!("unexpected value {other:?}"),
468        }
469    }
470
471    #[test]
472    fn fsolve_preserves_row_vector_shape_for_callback() {
473        let seen_shapes = Arc::new(Mutex::new(Vec::new()));
474        let seen_shapes_for_invoker = Arc::clone(&seen_shapes);
475        let _resolver =
476            crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|_name| {
477                Some(0)
478            })));
479        let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
480            move |_function, args, _requested_outputs| {
481                let (x, shape) = match &args[0] {
482                    Value::Tensor(t) => (t.data.clone(), t.shape.clone()),
483                    other => panic!("expected tensor input, got {other:?}"),
484                };
485                assert_eq!(shape, vec![1, 2]);
486                seen_shapes_for_invoker.lock().unwrap().push(shape.clone());
487                Box::pin(async move {
488                    Ok(Value::Tensor(
489                        Tensor::new(vec![x[0] - 3.0, x[1] - 4.0], shape).unwrap(),
490                    ))
491                })
492            },
493        )));
494        let x0 = Tensor::new(vec![0.0, 0.0], vec![1, 2]).unwrap();
495        let root = block_on(fsolve_builtin(
496            Value::FunctionHandle("row_system".into()),
497            Value::Tensor(x0),
498            Vec::new(),
499        ))
500        .unwrap();
501        match root {
502            Value::Tensor(t) => {
503                assert_eq!(t.shape, vec![1, 2]);
504                assert!((t.data[0] - 3.0).abs() < 1.0e-5);
505                assert!((t.data[1] - 4.0).abs() < 1.0e-5);
506            }
507            other => panic!("unexpected value {other:?}"),
508        }
509        assert!(!seen_shapes.lock().unwrap().is_empty());
510    }
511
512    #[test]
513    fn fsolve_preserves_matrix_shape_for_callback() {
514        let seen_shapes = Arc::new(Mutex::new(Vec::new()));
515        let seen_shapes_for_invoker = Arc::clone(&seen_shapes);
516        let _resolver =
517            crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|_name| {
518                Some(0)
519            })));
520        let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
521            move |_function, args, _requested_outputs| {
522                let (x, shape) = match &args[0] {
523                    Value::Tensor(t) => (t.data.clone(), t.shape.clone()),
524                    other => panic!("expected tensor input, got {other:?}"),
525                };
526                assert_eq!(shape, vec![2, 2]);
527                seen_shapes_for_invoker.lock().unwrap().push(shape.clone());
528                Box::pin(async move {
529                    Ok(Value::Tensor(
530                        Tensor::new(vec![x[0] - 1.0, x[1] - 2.0, x[2] - 3.0, x[3] - 4.0], shape)
531                            .unwrap(),
532                    ))
533                })
534            },
535        )));
536        let x0 = Tensor::new(vec![0.0, 0.0, 0.0, 0.0], vec![2, 2]).unwrap();
537        let root = block_on(fsolve_builtin(
538            Value::FunctionHandle("matrix_system".into()),
539            Value::Tensor(x0),
540            Vec::new(),
541        ))
542        .unwrap();
543        match root {
544            Value::Tensor(t) => {
545                assert_eq!(t.shape, vec![2, 2]);
546                assert!((t.data[0] - 1.0).abs() < 1.0e-5);
547                assert!((t.data[1] - 2.0).abs() < 1.0e-5);
548                assert!((t.data[2] - 3.0).abs() < 1.0e-5);
549                assert!((t.data[3] - 4.0).abs() < 1.0e-5);
550            }
551            other => panic!("unexpected value {other:?}"),
552        }
553        assert!(!seen_shapes.lock().unwrap().is_empty());
554    }
555
556    #[test]
557    fn fsolve_accepts_semantic_function_handle_callback() {
558        let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
559            |function, args, requested_outputs| {
560                assert_eq!(function, 43);
561                assert_eq!(requested_outputs, 1);
562                let x = match &args[0] {
563                    Value::Num(value) => *value,
564                    other => panic!("expected scalar numeric argument, got {other:?}"),
565                };
566                Box::pin(async move { Ok(Value::Num(x - 3.0)) })
567            },
568        )));
569        let root = block_on(fsolve_builtin(
570            Value::BoundFunctionHandle {
571                name: "system_function".to_string(),
572                function: 43,
573            },
574            Value::Num(1.0),
575            Vec::new(),
576        ))
577        .unwrap();
578        match root {
579            Value::Num(n) => assert!((n - 3.0).abs() < 1.0e-5),
580            other => panic!("unexpected value {other:?}"),
581        }
582    }
583
584    #[test]
585    fn fsolve_descriptor_signatures_cover_core_forms() {
586        let labels: Vec<&str> = FSOLVE_DESCRIPTOR
587            .signatures
588            .iter()
589            .map(|signature| signature.label)
590            .collect();
591        assert_eq!(
592            labels,
593            vec!["x = fsolve(fun, x0)", "x = fsolve(fun, x0, options)"]
594        );
595
596        let codes: Vec<&str> = FSOLVE_DESCRIPTOR
597            .errors
598            .iter()
599            .map(|error| error.code)
600            .collect();
601        assert_eq!(
602            codes,
603            vec!["RM.FSOLVE.INVALID_ARGUMENT", "RM.FSOLVE.INVALID_INPUT"]
604        );
605    }
606
607    #[test]
608    fn fsolve_too_many_args_uses_stable_identifier() {
609        let err = block_on(fsolve_builtin(
610            Value::FunctionHandle("sin".into()),
611            Value::Num(1.0),
612            vec![
613                Value::Struct(StructValue::new()),
614                Value::Struct(StructValue::new()),
615            ],
616        ))
617        .unwrap_err();
618        assert_eq!(err.identifier(), Some("RunMat:fsolve:InvalidArgument"));
619    }
620}