Skip to main content

runmat_runtime/builtins/math/optim/
fzero.rs

1//! MATLAB-compatible `fzero` builtin for scalar nonlinear root finding.
2
3use runmat_builtins::{
4    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
6};
7use runmat_builtins::{StructValue, Value};
8use runmat_macros::runtime_builtin;
9
10use crate::builtins::common::spec::{
11    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
12    ReductionNaN, ResidencyPolicy, ShapeRequirements,
13};
14use crate::builtins::math::optim::brent::{brent_zero, BrentParams, BrentZeroBracket};
15use crate::builtins::math::optim::common::{
16    call_scalar_function, option_f64, option_string, option_usize,
17};
18use crate::builtins::math::optim::type_resolvers::scalar_root_type;
19use crate::{build_runtime_error, BuiltinResult, RuntimeError};
20
21const NAME: &str = "fzero";
22const DEFAULT_TOL_X: f64 = 1.0e-6;
23const DEFAULT_MAX_ITER: usize = 400;
24const DEFAULT_MAX_FUN_EVALS: usize = 500;
25
26const FZERO_OUTPUT_ROOT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
27    name: "x",
28    ty: BuiltinParamType::NumericScalar,
29    arity: BuiltinParamArity::Required,
30    default: None,
31    description: "Estimated root location.",
32}];
33
34const FZERO_INPUTS_CORE: [BuiltinParamDescriptor; 2] = [
35    BuiltinParamDescriptor {
36        name: "fun",
37        ty: BuiltinParamType::Any,
38        arity: BuiltinParamArity::Required,
39        default: None,
40        description: "Scalar-valued callback.",
41    },
42    BuiltinParamDescriptor {
43        name: "x0",
44        ty: BuiltinParamType::Any,
45        arity: BuiltinParamArity::Required,
46        default: None,
47        description: "Initial point or two-element bracket.",
48    },
49];
50
51const FZERO_INPUTS_WITH_OPTIONS: [BuiltinParamDescriptor; 3] = [
52    BuiltinParamDescriptor {
53        name: "fun",
54        ty: BuiltinParamType::Any,
55        arity: BuiltinParamArity::Required,
56        default: None,
57        description: "Scalar-valued callback.",
58    },
59    BuiltinParamDescriptor {
60        name: "x0",
61        ty: BuiltinParamType::Any,
62        arity: BuiltinParamArity::Required,
63        default: None,
64        description: "Initial point or two-element bracket.",
65    },
66    BuiltinParamDescriptor {
67        name: "options",
68        ty: BuiltinParamType::Any,
69        arity: BuiltinParamArity::Optional,
70        default: None,
71        description: "Options struct from optimset.",
72    },
73];
74
75const FZERO_SIGNATURES: [BuiltinSignatureDescriptor; 2] = [
76    BuiltinSignatureDescriptor {
77        label: "x = fzero(fun, x0)",
78        inputs: &FZERO_INPUTS_CORE,
79        outputs: &FZERO_OUTPUT_ROOT,
80    },
81    BuiltinSignatureDescriptor {
82        label: "x = fzero(fun, x0, options)",
83        inputs: &FZERO_INPUTS_WITH_OPTIONS,
84        outputs: &FZERO_OUTPUT_ROOT,
85    },
86];
87
88const FZERO_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
89    code: "RM.FZERO.INVALID_ARGUMENT",
90    identifier: Some("RunMat:fzero:InvalidArgument"),
91    when: "Argument grammar/options struct are invalid.",
92    message: "fzero: invalid argument",
93};
94
95const FZERO_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
96    code: "RM.FZERO.INVALID_INPUT",
97    identifier: Some("RunMat:fzero:InvalidInput"),
98    when: "Callback/bracket/initial-point semantics are invalid.",
99    message: "fzero: invalid input",
100};
101
102const FZERO_ERRORS: [BuiltinErrorDescriptor; 2] =
103    [FZERO_ERROR_INVALID_ARGUMENT, FZERO_ERROR_INVALID_INPUT];
104
105pub const FZERO_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
106    signatures: &FZERO_SIGNATURES,
107    output_mode: BuiltinOutputMode::Fixed,
108    completion_policy: BuiltinCompletionPolicy::Public,
109    errors: &FZERO_ERRORS,
110};
111
112fn fzero_error_with_detail(
113    error: &'static BuiltinErrorDescriptor,
114    detail: impl AsRef<str>,
115) -> RuntimeError {
116    let detail = detail.as_ref();
117    let message = if detail.starts_with("fzero:") {
118        detail.to_string()
119    } else {
120        format!("{}: {detail}", error.message)
121    };
122    let mut builder = build_runtime_error(message).with_builtin(NAME);
123    if let Some(identifier) = error.identifier {
124        builder = builder.with_identifier(identifier);
125    }
126    builder.build()
127}
128
129fn fzero_map_error(err: RuntimeError, fallback: &'static BuiltinErrorDescriptor) -> RuntimeError {
130    if err.identifier().is_some() {
131        err
132    } else {
133        fzero_error_with_detail(fallback, err.message())
134    }
135}
136
137#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::optim::fzero")]
138pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
139    name: "fzero",
140    op_kind: GpuOpKind::Custom("scalar-root-find"),
141    supported_precisions: &[],
142    broadcast: BroadcastSemantics::None,
143    provider_hooks: &[],
144    constant_strategy: ConstantStrategy::InlineLiteral,
145    residency: ResidencyPolicy::GatherImmediately,
146    nan_mode: ReductionNaN::Include,
147    two_pass_threshold: None,
148    workgroup_size: None,
149    accepts_nan_mode: false,
150    notes: "Host iterative solver. Callback values may use GPU-aware builtins, but the root search runs on the CPU.",
151};
152
153#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::optim::fzero")]
154pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
155    name: "fzero",
156    shape: ShapeRequirements::Any,
157    constant_strategy: ConstantStrategy::InlineLiteral,
158    elementwise: None,
159    reduction: None,
160    emits_nan: false,
161    notes: "Root finding repeatedly invokes user code and terminates fusion planning.",
162};
163
164#[runtime_builtin(
165    name = "fzero",
166    category = "math/optim",
167    summary = "Find scalar function zeros with bracketed root-finding.",
168    keywords = "fzero,root finding,zero,brent,optimization",
169    accel = "sink",
170    type_resolver(scalar_root_type),
171    descriptor(crate::builtins::math::optim::fzero::FZERO_DESCRIPTOR),
172    builtin_path = "crate::builtins::math::optim::fzero"
173)]
174async fn fzero_builtin(function: Value, x: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
175    if rest.len() > 1 {
176        return Err(fzero_error_with_detail(
177            &FZERO_ERROR_INVALID_ARGUMENT,
178            "too many input arguments",
179        ));
180    }
181    let options = parse_options(rest.first())
182        .map_err(|err| fzero_map_error(err, &FZERO_ERROR_INVALID_ARGUMENT))?;
183    let opts = FzeroOptions::from_struct(options.as_ref())
184        .map_err(|err| fzero_map_error(err, &FZERO_ERROR_INVALID_ARGUMENT))?;
185    let bracket = initial_bracket(&function, x, &opts)
186        .await
187        .map_err(|err| fzero_map_error(err, &FZERO_ERROR_INVALID_INPUT))?;
188    let root = brent_zero(
189        NAME,
190        &function,
191        BrentZeroBracket {
192            a: bracket.a,
193            b: bracket.b,
194            fa: bracket.fa,
195            fb: bracket.fb,
196            evals: bracket.evals,
197        },
198        BrentParams {
199            tol_x: opts.tol_x,
200            max_iter: opts.max_iter,
201            max_fun_evals: opts.max_fun_evals,
202        },
203    )
204    .await
205    .map_err(|err| fzero_map_error(err, &FZERO_ERROR_INVALID_INPUT))?;
206    Ok(Value::Num(root))
207}
208
209fn parse_options(value: Option<&Value>) -> BuiltinResult<Option<StructValue>> {
210    match value {
211        None => Ok(None),
212        Some(Value::Struct(options)) => Ok(Some(options.clone())),
213        Some(other) => Err(fzero_error_with_detail(
214            &FZERO_ERROR_INVALID_ARGUMENT,
215            format!("options must be a struct, got {other:?}"),
216        )),
217    }
218}
219
220#[derive(Clone, Copy)]
221struct FzeroOptions {
222    tol_x: f64,
223    max_iter: usize,
224    max_fun_evals: usize,
225}
226
227impl FzeroOptions {
228    fn from_struct(options: Option<&StructValue>) -> BuiltinResult<Self> {
229        let display = option_string(options, "Display", "off")?;
230        if !matches!(display.as_str(), "off" | "none" | "final" | "iter") {
231            return Err(fzero_error_with_detail(
232                &FZERO_ERROR_INVALID_ARGUMENT,
233                "option Display must be 'off', 'none', 'final', or 'iter'",
234            ));
235        }
236        let tol_x = option_f64(NAME, options, "TolX", DEFAULT_TOL_X)?;
237        if tol_x <= 0.0 {
238            return Err(fzero_error_with_detail(
239                &FZERO_ERROR_INVALID_ARGUMENT,
240                "option TolX must be positive",
241            ));
242        }
243        let max_iter = option_usize(NAME, options, "MaxIter", DEFAULT_MAX_ITER)?;
244        let max_fun_evals = option_usize(NAME, options, "MaxFunEvals", DEFAULT_MAX_FUN_EVALS)?;
245        Ok(Self {
246            tol_x,
247            max_iter: max_iter.max(1),
248            max_fun_evals: max_fun_evals.max(1),
249        })
250    }
251}
252
253#[derive(Clone, Copy)]
254struct Bracket {
255    a: f64,
256    b: f64,
257    fa: f64,
258    fb: f64,
259    evals: usize,
260}
261
262async fn initial_bracket(
263    function: &Value,
264    x: Value,
265    options: &FzeroOptions,
266) -> BuiltinResult<Bracket> {
267    let x = crate::dispatcher::gather_if_needed_async(&x).await?;
268    match x {
269        Value::Tensor(tensor) if tensor.data.len() == 2 => {
270            let a = tensor.data[0];
271            let b = tensor.data[1];
272            bracket_from_endpoints(function, a, b).await
273        }
274        Value::Tensor(tensor) if tensor.data.len() == 1 => {
275            expand_bracket(function, tensor.data[0], options).await
276        }
277        Value::Num(n) => expand_bracket(function, n, options).await,
278        Value::Int(i) => expand_bracket(function, i.to_f64(), options).await,
279        Value::Bool(b) => expand_bracket(function, if b { 1.0 } else { 0.0 }, options).await,
280        other => Err(fzero_error_with_detail(
281            &FZERO_ERROR_INVALID_INPUT,
282            format!("initial point must be a scalar or two-element bracket, got {other:?}"),
283        )),
284    }
285}
286
287async fn bracket_from_endpoints(function: &Value, a: f64, b: f64) -> BuiltinResult<Bracket> {
288    if !a.is_finite() || !b.is_finite() || a == b {
289        return Err(fzero_error_with_detail(
290            &FZERO_ERROR_INVALID_INPUT,
291            "bracket endpoints must be finite and distinct",
292        ));
293    }
294    let fa = call_scalar_function(NAME, function, a).await?;
295    if fa == 0.0 {
296        return Ok(Bracket {
297            a,
298            b: a,
299            fa,
300            fb: fa,
301            evals: 1,
302        });
303    }
304    let fb = call_scalar_function(NAME, function, b).await?;
305    if fb == 0.0 || fa.signum() != fb.signum() {
306        Ok(Bracket {
307            a,
308            b,
309            fa,
310            fb,
311            evals: 2,
312        })
313    } else {
314        Err(fzero_error_with_detail(
315            &FZERO_ERROR_INVALID_INPUT,
316            "function values at bracket endpoints must differ in sign",
317        ))
318    }
319}
320
321async fn expand_bracket(
322    function: &Value,
323    x0: f64,
324    options: &FzeroOptions,
325) -> BuiltinResult<Bracket> {
326    if !x0.is_finite() {
327        return Err(fzero_error_with_detail(
328            &FZERO_ERROR_INVALID_INPUT,
329            "initial point must be finite",
330        ));
331    }
332    let f0 = call_scalar_function(NAME, function, x0).await?;
333    if f0 == 0.0 {
334        return Ok(Bracket {
335            a: x0,
336            b: x0,
337            fa: f0,
338            fb: f0,
339            evals: 1,
340        });
341    }
342
343    let mut evals = 1usize;
344    let mut step = (x0.abs() * 0.01).max(0.01);
345    while evals + 2 <= options.max_fun_evals {
346        let a = x0 - step;
347        let b = x0 + step;
348        let fa = call_scalar_function(NAME, function, a).await?;
349        let fb = call_scalar_function(NAME, function, b).await?;
350        evals += 2;
351        if fa == 0.0 {
352            return Ok(Bracket {
353                a,
354                b: a,
355                fa,
356                fb: fa,
357                evals,
358            });
359        }
360        if fa.signum() != f0.signum() {
361            return Ok(Bracket {
362                a,
363                b: x0,
364                fa,
365                fb: f0,
366                evals,
367            });
368        }
369        if fb.signum() != f0.signum() {
370            return Ok(Bracket {
371                a: x0,
372                b,
373                fa: f0,
374                fb,
375                evals,
376            });
377        }
378        if fb == 0.0 || fa.signum() != fb.signum() {
379            return Ok(Bracket {
380                a,
381                b,
382                fa,
383                fb,
384                evals,
385            });
386        }
387        step *= 1.6;
388    }
389
390    Err(fzero_error_with_detail(
391        &FZERO_ERROR_INVALID_INPUT,
392        "could not find a sign-changing bracket around the initial point",
393    ))
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399    use crate::builtins::math::optim::brent::interpolation_step_accepted;
400    use futures::executor::block_on;
401    use runmat_builtins::Tensor;
402    use std::sync::Arc;
403
404    #[test]
405    fn fzero_bracketed_builtin_handle() {
406        let bracket = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
407        let root = block_on(fzero_builtin(
408            Value::FunctionHandle("sin".into()),
409            Value::Tensor(bracket),
410            Vec::new(),
411        ))
412        .unwrap();
413        match root {
414            Value::Num(n) => assert!((n - std::f64::consts::PI).abs() < 1.0e-6),
415            other => panic!("unexpected value {other:?}"),
416        }
417    }
418
419    #[test]
420    fn fzero_scalar_initial_guess_expands_bracket() {
421        let root = block_on(fzero_builtin(
422            Value::FunctionHandle("cos".into()),
423            Value::Num(1.0),
424            Vec::new(),
425        ))
426        .unwrap();
427        match root {
428            Value::Num(n) => assert!((n - std::f64::consts::FRAC_PI_2).abs() < 1.0e-6),
429            other => panic!("unexpected value {other:?}"),
430        }
431    }
432
433    #[test]
434    fn fzero_scalar_initial_guess_uses_center_sign_for_bracket() {
435        let root = block_on(fzero_builtin(
436            Value::FunctionHandle("sin".into()),
437            Value::Num(std::f64::consts::FRAC_PI_2),
438            Vec::new(),
439        ))
440        .unwrap();
441        match root {
442            Value::Num(n) => assert!(n.abs() < 1.0e-6),
443            other => panic!("unexpected value {other:?}"),
444        }
445    }
446    #[test]
447    fn fzero_accepts_semantic_function_handle_callback() {
448        let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
449            |function, args, requested_outputs| {
450                assert_eq!(function, 42);
451                assert_eq!(requested_outputs, 1);
452                let x = match &args[0] {
453                    Value::Num(value) => *value,
454                    other => panic!("expected scalar numeric argument, got {other:?}"),
455                };
456                Box::pin(async move { Ok(Value::Num(x - 2.0)) })
457            },
458        )));
459
460        let root = block_on(fzero_builtin(
461            Value::BoundFunctionHandle {
462                name: "root_function".to_string(),
463                function: 42,
464            },
465            Value::Num(0.0),
466            Vec::new(),
467        ))
468        .unwrap();
469        match root {
470            Value::Num(n) => assert!((n - 2.0).abs() < 1.0e-6),
471            other => panic!("unexpected value {other:?}"),
472        }
473    }
474
475    #[test]
476    fn brent_interpolation_acceptance_uses_signed_q() {
477        assert!(!interpolation_step_accepted(1.0, -2.0, 1.0, 0.1, 10.0));
478        assert!(interpolation_step_accepted(1.0, -2.0, -1.0, 0.1, 10.0));
479    }
480
481    #[test]
482    fn fzero_descriptor_signatures_cover_core_forms() {
483        let labels: Vec<&str> = FZERO_DESCRIPTOR
484            .signatures
485            .iter()
486            .map(|signature| signature.label)
487            .collect();
488        assert_eq!(
489            labels,
490            vec!["x = fzero(fun, x0)", "x = fzero(fun, x0, options)"]
491        );
492
493        let codes: Vec<&str> = FZERO_DESCRIPTOR
494            .errors
495            .iter()
496            .map(|error| error.code)
497            .collect();
498        assert_eq!(
499            codes,
500            vec!["RM.FZERO.INVALID_ARGUMENT", "RM.FZERO.INVALID_INPUT"]
501        );
502    }
503
504    #[test]
505    fn fzero_too_many_args_uses_stable_identifier() {
506        let err = block_on(fzero_builtin(
507            Value::FunctionHandle("sin".into()),
508            Value::Num(0.0),
509            vec![
510                Value::Struct(StructValue::new()),
511                Value::Struct(StructValue::new()),
512            ],
513        ))
514        .unwrap_err();
515        assert_eq!(err.identifier(), Some("RunMat:fzero:InvalidArgument"));
516    }
517}