Skip to main content

runmat_runtime/builtins/math/optim/
fsolve.rs

1//! MATLAB-compatible `fsolve` builtin for nonlinear systems.
2
3use runmat_builtins::{
4    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
6};
7use runmat_builtins::{StructValue, Value};
8use runmat_macros::runtime_builtin;
9
10use crate::builtins::common::spec::{
11    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
12    ReductionNaN, ResidencyPolicy, ShapeRequirements,
13};
14use crate::builtins::math::optim::common::{
15    call_function, initial_guess, option_f64, option_string, option_usize, value_to_real_vector,
16    vector_to_value,
17};
18use crate::builtins::math::optim::least_squares::{
19    solve_least_squares, LeastSquaresBounds, LeastSquaresEvaluator, LeastSquaresOptions,
20    ResidualFuture,
21};
22use crate::builtins::math::optim::type_resolvers::nonlinear_solve_type;
23use crate::{build_runtime_error, BuiltinResult, RuntimeError};
24
25const NAME: &str = "fsolve";
26const DEFAULT_TOL_X: f64 = 1.0e-6;
27const DEFAULT_TOL_FUN: f64 = 1.0e-6;
28const DEFAULT_MAX_ITER: usize = 400;
29
30const FSOLVE_OUTPUT_X: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
31    name: "x",
32    ty: BuiltinParamType::NumericArray,
33    arity: BuiltinParamArity::Required,
34    default: None,
35    description: "Approximate solution vector/scalar.",
36}];
37
38const FSOLVE_INPUTS_CORE: [BuiltinParamDescriptor; 2] = [
39    BuiltinParamDescriptor {
40        name: "fun",
41        ty: BuiltinParamType::Any,
42        arity: BuiltinParamArity::Required,
43        default: None,
44        description: "System residual callback.",
45    },
46    BuiltinParamDescriptor {
47        name: "x0",
48        ty: BuiltinParamType::Any,
49        arity: BuiltinParamArity::Required,
50        default: None,
51        description: "Initial guess scalar/vector.",
52    },
53];
54
55const FSOLVE_INPUTS_WITH_OPTIONS: [BuiltinParamDescriptor; 3] = [
56    BuiltinParamDescriptor {
57        name: "fun",
58        ty: BuiltinParamType::Any,
59        arity: BuiltinParamArity::Required,
60        default: None,
61        description: "System residual callback.",
62    },
63    BuiltinParamDescriptor {
64        name: "x0",
65        ty: BuiltinParamType::Any,
66        arity: BuiltinParamArity::Required,
67        default: None,
68        description: "Initial guess scalar/vector.",
69    },
70    BuiltinParamDescriptor {
71        name: "options",
72        ty: BuiltinParamType::Any,
73        arity: BuiltinParamArity::Optional,
74        default: None,
75        description: "Options struct from optimset.",
76    },
77];
78
79const FSOLVE_SIGNATURES: [BuiltinSignatureDescriptor; 2] = [
80    BuiltinSignatureDescriptor {
81        label: "x = fsolve(fun, x0)",
82        inputs: &FSOLVE_INPUTS_CORE,
83        outputs: &FSOLVE_OUTPUT_X,
84    },
85    BuiltinSignatureDescriptor {
86        label: "x = fsolve(fun, x0, options)",
87        inputs: &FSOLVE_INPUTS_WITH_OPTIONS,
88        outputs: &FSOLVE_OUTPUT_X,
89    },
90];
91
92const FSOLVE_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
93    code: "RM.FSOLVE.INVALID_ARGUMENT",
94    identifier: Some("RunMat:fsolve:InvalidArgument"),
95    when: "Argument grammar/options configuration is invalid.",
96    message: "fsolve: invalid argument",
97};
98
99const FSOLVE_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
100    code: "RM.FSOLVE.INVALID_INPUT",
101    identifier: Some("RunMat:fsolve:InvalidInput"),
102    when: "Initial guess/callback/iteration semantics are invalid.",
103    message: "fsolve: invalid input",
104};
105
106const FSOLVE_ERRORS: [BuiltinErrorDescriptor; 2] =
107    [FSOLVE_ERROR_INVALID_ARGUMENT, FSOLVE_ERROR_INVALID_INPUT];
108
109pub const FSOLVE_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
110    signatures: &FSOLVE_SIGNATURES,
111    output_mode: BuiltinOutputMode::Fixed,
112    completion_policy: BuiltinCompletionPolicy::Public,
113    errors: &FSOLVE_ERRORS,
114};
115
116fn fsolve_error_with_detail(
117    error: &'static BuiltinErrorDescriptor,
118    detail: impl AsRef<str>,
119) -> RuntimeError {
120    let detail = detail.as_ref();
121    let message = if detail.starts_with("fsolve:") {
122        detail.to_string()
123    } else {
124        format!("{}: {detail}", error.message)
125    };
126    let mut builder = build_runtime_error(message).with_builtin(NAME);
127    if let Some(identifier) = error.identifier {
128        builder = builder.with_identifier(identifier);
129    }
130    builder.build()
131}
132
133fn fsolve_map_error(err: RuntimeError, fallback: &'static BuiltinErrorDescriptor) -> RuntimeError {
134    if err.identifier().is_some() {
135        err
136    } else {
137        fsolve_error_with_detail(fallback, err.message())
138    }
139}
140
141#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::optim::fsolve")]
142pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
143    name: "fsolve",
144    op_kind: GpuOpKind::Custom("nonlinear-solve"),
145    supported_precisions: &[],
146    broadcast: BroadcastSemantics::None,
147    provider_hooks: &[],
148    constant_strategy: ConstantStrategy::InlineLiteral,
149    residency: ResidencyPolicy::GatherImmediately,
150    nan_mode: ReductionNaN::Include,
151    two_pass_threshold: None,
152    workgroup_size: None,
153    accepts_nan_mode: false,
154    notes: "Host finite-difference Levenberg-Marquardt solver. Callback computations may use GPU-aware builtins.",
155};
156
157#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::optim::fsolve")]
158pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
159    name: "fsolve",
160    shape: ShapeRequirements::Any,
161    constant_strategy: ConstantStrategy::InlineLiteral,
162    elementwise: None,
163    reduction: None,
164    emits_nan: false,
165    notes: "Nonlinear solving repeatedly invokes user code and terminates fusion planning.",
166};
167
168#[runtime_builtin(
169    name = "fsolve",
170    category = "math/optim",
171    summary = "Solve nonlinear equation systems.",
172    keywords = "fsolve,nonlinear solve,root finding,levenberg-marquardt,jacobian",
173    accel = "sink",
174    type_resolver(nonlinear_solve_type),
175    descriptor(crate::builtins::math::optim::fsolve::FSOLVE_DESCRIPTOR),
176    builtin_path = "crate::builtins::math::optim::fsolve"
177)]
178async fn fsolve_builtin(function: Value, x0: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
179    if rest.len() > 1 {
180        return Err(fsolve_error_with_detail(
181            &FSOLVE_ERROR_INVALID_ARGUMENT,
182            "too many input arguments",
183        ));
184    }
185    let options = parse_options(rest.first())
186        .map_err(|err| fsolve_map_error(err, &FSOLVE_ERROR_INVALID_ARGUMENT))?;
187    let opts = FsolveOptions::from_struct(options.as_ref())
188        .map_err(|err| fsolve_map_error(err, &FSOLVE_ERROR_INVALID_ARGUMENT))?;
189    let guess = initial_guess(NAME, x0)
190        .await
191        .map_err(|err| fsolve_map_error(err, &FSOLVE_ERROR_INVALID_INPUT))?;
192    let solution = solve(&function, guess.values, &guess.shape, guess.scalar, &opts)
193        .await
194        .map_err(|err| fsolve_map_error(err, &FSOLVE_ERROR_INVALID_INPUT))?;
195    vector_to_value(NAME, solution, &guess.shape, guess.scalar)
196        .map_err(|err| fsolve_map_error(err, &FSOLVE_ERROR_INVALID_INPUT))
197}
198
199fn parse_options(value: Option<&Value>) -> BuiltinResult<Option<StructValue>> {
200    match value {
201        None => Ok(None),
202        Some(Value::Struct(options)) => Ok(Some(options.clone())),
203        Some(other) => Err(fsolve_error_with_detail(
204            &FSOLVE_ERROR_INVALID_ARGUMENT,
205            format!("options must be a struct, got {other:?}"),
206        )),
207    }
208}
209
210#[derive(Clone, Copy)]
211struct FsolveOptions {
212    tol_x: f64,
213    tol_fun: f64,
214    max_iter: usize,
215    max_fun_evals: usize,
216}
217
218impl FsolveOptions {
219    fn from_struct(options: Option<&StructValue>) -> BuiltinResult<Self> {
220        let display = option_string(options, "Display", "off")?;
221        if !matches!(display.as_str(), "off" | "none" | "final" | "iter") {
222            return Err(fsolve_error_with_detail(
223                &FSOLVE_ERROR_INVALID_ARGUMENT,
224                "option Display must be 'off', 'none', 'final', or 'iter'",
225            ));
226        }
227        let tol_x = option_f64(NAME, options, "TolX", DEFAULT_TOL_X)?;
228        let tol_fun = option_f64(NAME, options, "TolFun", DEFAULT_TOL_FUN)?;
229        if tol_x <= 0.0 || tol_fun <= 0.0 {
230            return Err(fsolve_error_with_detail(
231                &FSOLVE_ERROR_INVALID_ARGUMENT,
232                "options TolX and TolFun must be positive",
233            ));
234        }
235        let max_iter = option_usize(NAME, options, "MaxIter", DEFAULT_MAX_ITER)?.max(1);
236        let max_fun_evals = option_usize(NAME, options, "MaxFunEvals", 100 * max_iter)?.max(1);
237        Ok(Self {
238            tol_x,
239            tol_fun,
240            max_iter,
241            max_fun_evals,
242        })
243    }
244}
245
246async fn solve(
247    function: &Value,
248    x: Vec<f64>,
249    shape: &[usize],
250    scalar: bool,
251    options: &FsolveOptions,
252) -> BuiltinResult<Vec<f64>> {
253    let mut evaluator = FsolveEvaluator {
254        function,
255        shape: shape.to_vec(),
256        scalar,
257    };
258    let variable_len = x.len();
259    let result = solve_least_squares(
260        NAME,
261        &mut evaluator,
262        x,
263        &LeastSquaresBounds::unbounded(variable_len),
264        &LeastSquaresOptions {
265            tol_x: options.tol_x,
266            tol_fun: options.tol_fun,
267            max_iter: options.max_iter,
268            max_fun_evals: options.max_fun_evals,
269            final_jacobian: false,
270        },
271    )
272    .await?;
273    if result.exitflag > 0
274        && result
275            .residual
276            .iter()
277            .fold(0.0_f64, |acc, value| acc.max(value.abs()))
278            <= options.tol_fun
279    {
280        Ok(result.x)
281    } else {
282        Err(fsolve_error_with_detail(
283            &FSOLVE_ERROR_INVALID_INPUT,
284            format!(
285                "{} Final residual norm is above function tolerance.",
286                result.message
287            ),
288        ))
289    }
290}
291
292struct FsolveEvaluator<'a> {
293    function: &'a Value,
294    shape: Vec<usize>,
295    scalar: bool,
296}
297
298impl LeastSquaresEvaluator for FsolveEvaluator<'_> {
299    fn residual<'a>(&'a mut self, x: &'a [f64]) -> ResidualFuture<'a> {
300        Box::pin(async move {
301            let arg = if self.scalar {
302                Value::Num(x[0])
303            } else {
304                Value::Tensor(
305                    runmat_builtins::Tensor::new(x.to_vec(), self.shape.clone())
306                        .map_err(|e| fsolve_error_with_detail(&FSOLVE_ERROR_INVALID_INPUT, e))?,
307                )
308            };
309            let value = call_function(self.function, vec![arg]).await?;
310            let residual = value_to_real_vector(NAME, value).await?;
311            if residual.is_empty() {
312                Err(fsolve_error_with_detail(
313                    &FSOLVE_ERROR_INVALID_INPUT,
314                    "function value must not be empty",
315                ))
316            } else {
317                Ok(residual)
318            }
319        })
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use futures::executor::block_on;
327    use runmat_builtins::Tensor;
328    use std::sync::{Arc, Mutex};
329
330    #[test]
331    fn fsolve_scalar_builtin_handle() {
332        let root = block_on(fsolve_builtin(
333            Value::FunctionHandle("sin".into()),
334            Value::Num(3.0),
335            Vec::new(),
336        ))
337        .unwrap();
338        match root {
339            Value::Num(n) => assert!((n - std::f64::consts::PI).abs() < 1.0e-5),
340            other => panic!("unexpected value {other:?}"),
341        }
342    }
343
344    #[test]
345    fn fsolve_rejects_stationary_non_root_residual() {
346        let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
347            |_function, args, _requested_outputs| {
348                let x = match &args[0] {
349                    Value::Num(value) => *value,
350                    other => panic!("expected scalar numeric argument, got {other:?}"),
351                };
352                Box::pin(async move { Ok(Value::Num(x * x + 1.0)) })
353            },
354        )));
355        let err = block_on(fsolve_builtin(
356            Value::BoundFunctionHandle {
357                name: "no_real_root".to_string(),
358                function: 44,
359            },
360            Value::Num(0.0),
361            Vec::new(),
362        ))
363        .unwrap_err();
364        assert_eq!(err.identifier(), Some("RunMat:fsolve:InvalidInput"));
365        assert!(err.message().contains("Final residual norm"));
366    }
367
368    #[test]
369    fn fsolve_vector_system_via_semantic_resolver() {
370        let _resolver =
371            crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|_name| {
372                Some(0)
373            })));
374        let _invoker = crate::user_functions::install_semantic_function_invoker(Some(
375            std::sync::Arc::new(|_function, args, _requested_outputs| {
376                let x = match &args[0] {
377                    Value::Tensor(t) => t.data.clone(),
378                    _ => panic!("expected tensor input"),
379                };
380                Box::pin(async move {
381                    Ok(Value::Tensor(
382                        Tensor::new(
383                            vec![x[0] * x[0] + x[1] * x[1] - 4.0, x[0] * x[1] - 1.0],
384                            vec![2, 1],
385                        )
386                        .unwrap(),
387                    ))
388                })
389            }),
390        ));
391        let x0 = Tensor::new(vec![1.0, 1.0], vec![2, 1]).unwrap();
392        let root = block_on(fsolve_builtin(
393            Value::FunctionHandle("system".into()),
394            Value::Tensor(x0),
395            Vec::new(),
396        ))
397        .unwrap();
398        match root {
399            Value::Tensor(t) => {
400                assert!((t.data[0] * t.data[0] + t.data[1] * t.data[1] - 4.0).abs() < 1.0e-5);
401                assert!((t.data[0] * t.data[1] - 1.0).abs() < 1.0e-5);
402            }
403            other => panic!("unexpected value {other:?}"),
404        }
405    }
406
407    #[test]
408    fn fsolve_preserves_row_vector_shape_for_callback() {
409        let seen_shapes = Arc::new(Mutex::new(Vec::new()));
410        let seen_shapes_for_invoker = Arc::clone(&seen_shapes);
411        let _resolver =
412            crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|_name| {
413                Some(0)
414            })));
415        let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
416            move |_function, args, _requested_outputs| {
417                let (x, shape) = match &args[0] {
418                    Value::Tensor(t) => (t.data.clone(), t.shape.clone()),
419                    other => panic!("expected tensor input, got {other:?}"),
420                };
421                assert_eq!(shape, vec![1, 2]);
422                seen_shapes_for_invoker.lock().unwrap().push(shape.clone());
423                Box::pin(async move {
424                    Ok(Value::Tensor(
425                        Tensor::new(vec![x[0] - 3.0, x[1] - 4.0], shape).unwrap(),
426                    ))
427                })
428            },
429        )));
430        let x0 = Tensor::new(vec![0.0, 0.0], vec![1, 2]).unwrap();
431        let root = block_on(fsolve_builtin(
432            Value::FunctionHandle("row_system".into()),
433            Value::Tensor(x0),
434            Vec::new(),
435        ))
436        .unwrap();
437        match root {
438            Value::Tensor(t) => {
439                assert_eq!(t.shape, vec![1, 2]);
440                assert!((t.data[0] - 3.0).abs() < 1.0e-5);
441                assert!((t.data[1] - 4.0).abs() < 1.0e-5);
442            }
443            other => panic!("unexpected value {other:?}"),
444        }
445        assert!(!seen_shapes.lock().unwrap().is_empty());
446    }
447
448    #[test]
449    fn fsolve_preserves_matrix_shape_for_callback() {
450        let seen_shapes = Arc::new(Mutex::new(Vec::new()));
451        let seen_shapes_for_invoker = Arc::clone(&seen_shapes);
452        let _resolver =
453            crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|_name| {
454                Some(0)
455            })));
456        let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
457            move |_function, args, _requested_outputs| {
458                let (x, shape) = match &args[0] {
459                    Value::Tensor(t) => (t.data.clone(), t.shape.clone()),
460                    other => panic!("expected tensor input, got {other:?}"),
461                };
462                assert_eq!(shape, vec![2, 2]);
463                seen_shapes_for_invoker.lock().unwrap().push(shape.clone());
464                Box::pin(async move {
465                    Ok(Value::Tensor(
466                        Tensor::new(vec![x[0] - 1.0, x[1] - 2.0, x[2] - 3.0, x[3] - 4.0], shape)
467                            .unwrap(),
468                    ))
469                })
470            },
471        )));
472        let x0 = Tensor::new(vec![0.0, 0.0, 0.0, 0.0], vec![2, 2]).unwrap();
473        let root = block_on(fsolve_builtin(
474            Value::FunctionHandle("matrix_system".into()),
475            Value::Tensor(x0),
476            Vec::new(),
477        ))
478        .unwrap();
479        match root {
480            Value::Tensor(t) => {
481                assert_eq!(t.shape, vec![2, 2]);
482                assert!((t.data[0] - 1.0).abs() < 1.0e-5);
483                assert!((t.data[1] - 2.0).abs() < 1.0e-5);
484                assert!((t.data[2] - 3.0).abs() < 1.0e-5);
485                assert!((t.data[3] - 4.0).abs() < 1.0e-5);
486            }
487            other => panic!("unexpected value {other:?}"),
488        }
489        assert!(!seen_shapes.lock().unwrap().is_empty());
490    }
491
492    #[test]
493    fn fsolve_accepts_semantic_function_handle_callback() {
494        let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
495            |function, args, requested_outputs| {
496                assert_eq!(function, 43);
497                assert_eq!(requested_outputs, 1);
498                let x = match &args[0] {
499                    Value::Num(value) => *value,
500                    other => panic!("expected scalar numeric argument, got {other:?}"),
501                };
502                Box::pin(async move { Ok(Value::Num(x - 3.0)) })
503            },
504        )));
505        let root = block_on(fsolve_builtin(
506            Value::BoundFunctionHandle {
507                name: "system_function".to_string(),
508                function: 43,
509            },
510            Value::Num(1.0),
511            Vec::new(),
512        ))
513        .unwrap();
514        match root {
515            Value::Num(n) => assert!((n - 3.0).abs() < 1.0e-5),
516            other => panic!("unexpected value {other:?}"),
517        }
518    }
519
520    #[test]
521    fn fsolve_descriptor_signatures_cover_core_forms() {
522        let labels: Vec<&str> = FSOLVE_DESCRIPTOR
523            .signatures
524            .iter()
525            .map(|signature| signature.label)
526            .collect();
527        assert_eq!(
528            labels,
529            vec!["x = fsolve(fun, x0)", "x = fsolve(fun, x0, options)"]
530        );
531
532        let codes: Vec<&str> = FSOLVE_DESCRIPTOR
533            .errors
534            .iter()
535            .map(|error| error.code)
536            .collect();
537        assert_eq!(
538            codes,
539            vec!["RM.FSOLVE.INVALID_ARGUMENT", "RM.FSOLVE.INVALID_INPUT"]
540        );
541    }
542
543    #[test]
544    fn fsolve_too_many_args_uses_stable_identifier() {
545        let err = block_on(fsolve_builtin(
546            Value::FunctionHandle("sin".into()),
547            Value::Num(1.0),
548            vec![
549                Value::Struct(StructValue::new()),
550                Value::Struct(StructValue::new()),
551            ],
552        ))
553        .unwrap_err();
554        assert_eq!(err.identifier(), Some("RunMat:fsolve:InvalidArgument"));
555    }
556}