Skip to main content

runmat_runtime/builtins/math/optim/
fzero.rs

1//! MATLAB-compatible `fzero` builtin for scalar nonlinear root finding.
2
3use runmat_builtins::{StructValue, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::spec::{
7    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8    ReductionNaN, ResidencyPolicy, ShapeRequirements,
9};
10use crate::builtins::math::optim::brent::{brent_zero, BrentParams, BrentZeroBracket};
11use crate::builtins::math::optim::common::{
12    call_scalar_function, optim_error, option_f64, option_string, option_usize,
13};
14use crate::builtins::math::optim::type_resolvers::scalar_root_type;
15use crate::BuiltinResult;
16
17const NAME: &str = "fzero";
18const DEFAULT_TOL_X: f64 = 1.0e-6;
19const DEFAULT_MAX_ITER: usize = 400;
20const DEFAULT_MAX_FUN_EVALS: usize = 500;
21
22#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::optim::fzero")]
23pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
24    name: "fzero",
25    op_kind: GpuOpKind::Custom("scalar-root-find"),
26    supported_precisions: &[],
27    broadcast: BroadcastSemantics::None,
28    provider_hooks: &[],
29    constant_strategy: ConstantStrategy::InlineLiteral,
30    residency: ResidencyPolicy::GatherImmediately,
31    nan_mode: ReductionNaN::Include,
32    two_pass_threshold: None,
33    workgroup_size: None,
34    accepts_nan_mode: false,
35    notes: "Host iterative solver. Callback values may use GPU-aware builtins, but the root search runs on the CPU.",
36};
37
38#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::optim::fzero")]
39pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
40    name: "fzero",
41    shape: ShapeRequirements::Any,
42    constant_strategy: ConstantStrategy::InlineLiteral,
43    elementwise: None,
44    reduction: None,
45    emits_nan: false,
46    notes: "Root finding repeatedly invokes user code and terminates fusion planning.",
47};
48
49#[runtime_builtin(
50    name = "fzero",
51    category = "math/optim",
52    summary = "Find a zero of a scalar nonlinear function using bracket expansion and Brent's method.",
53    keywords = "fzero,root finding,zero,brent,optimization",
54    accel = "sink",
55    type_resolver(scalar_root_type),
56    builtin_path = "crate::builtins::math::optim::fzero"
57)]
58async fn fzero_builtin(function: Value, x: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
59    if rest.len() > 1 {
60        return Err(optim_error(NAME, "fzero: too many input arguments"));
61    }
62    let options = parse_options(rest.first())?;
63    let opts = FzeroOptions::from_struct(options.as_ref())?;
64    let bracket = initial_bracket(&function, x, &opts).await?;
65    let root = brent_zero(
66        NAME,
67        &function,
68        BrentZeroBracket {
69            a: bracket.a,
70            b: bracket.b,
71            fa: bracket.fa,
72            fb: bracket.fb,
73            evals: bracket.evals,
74        },
75        BrentParams {
76            tol_x: opts.tol_x,
77            max_iter: opts.max_iter,
78            max_fun_evals: opts.max_fun_evals,
79        },
80    )
81    .await?;
82    Ok(Value::Num(root))
83}
84
85fn parse_options(value: Option<&Value>) -> BuiltinResult<Option<StructValue>> {
86    match value {
87        None => Ok(None),
88        Some(Value::Struct(options)) => Ok(Some(options.clone())),
89        Some(other) => Err(optim_error(
90            NAME,
91            format!("fzero: options must be a struct, got {other:?}"),
92        )),
93    }
94}
95
96#[derive(Clone, Copy)]
97struct FzeroOptions {
98    tol_x: f64,
99    max_iter: usize,
100    max_fun_evals: usize,
101}
102
103impl FzeroOptions {
104    fn from_struct(options: Option<&StructValue>) -> BuiltinResult<Self> {
105        let display = option_string(options, "Display", "off")?;
106        if !matches!(display.as_str(), "off" | "none" | "final" | "iter") {
107            return Err(optim_error(
108                NAME,
109                "fzero: option Display must be 'off', 'none', 'final', or 'iter'",
110            ));
111        }
112        let tol_x = option_f64(NAME, options, "TolX", DEFAULT_TOL_X)?;
113        if tol_x <= 0.0 {
114            return Err(optim_error(NAME, "fzero: option TolX must be positive"));
115        }
116        let max_iter = option_usize(NAME, options, "MaxIter", DEFAULT_MAX_ITER)?;
117        let max_fun_evals = option_usize(NAME, options, "MaxFunEvals", DEFAULT_MAX_FUN_EVALS)?;
118        Ok(Self {
119            tol_x,
120            max_iter: max_iter.max(1),
121            max_fun_evals: max_fun_evals.max(1),
122        })
123    }
124}
125
126#[derive(Clone, Copy)]
127struct Bracket {
128    a: f64,
129    b: f64,
130    fa: f64,
131    fb: f64,
132    evals: usize,
133}
134
135async fn initial_bracket(
136    function: &Value,
137    x: Value,
138    options: &FzeroOptions,
139) -> BuiltinResult<Bracket> {
140    let x = crate::dispatcher::gather_if_needed_async(&x).await?;
141    match x {
142        Value::Tensor(tensor) if tensor.data.len() == 2 => {
143            let a = tensor.data[0];
144            let b = tensor.data[1];
145            bracket_from_endpoints(function, a, b).await
146        }
147        Value::Tensor(tensor) if tensor.data.len() == 1 => {
148            expand_bracket(function, tensor.data[0], options).await
149        }
150        Value::Num(n) => expand_bracket(function, n, options).await,
151        Value::Int(i) => expand_bracket(function, i.to_f64(), options).await,
152        Value::Bool(b) => expand_bracket(function, if b { 1.0 } else { 0.0 }, options).await,
153        other => Err(optim_error(
154            NAME,
155            format!("fzero: initial point must be a scalar or two-element bracket, got {other:?}"),
156        )),
157    }
158}
159
160async fn bracket_from_endpoints(function: &Value, a: f64, b: f64) -> BuiltinResult<Bracket> {
161    if !a.is_finite() || !b.is_finite() || a == b {
162        return Err(optim_error(
163            NAME,
164            "fzero: bracket endpoints must be finite and distinct",
165        ));
166    }
167    let fa = call_scalar_function(NAME, function, a).await?;
168    if fa == 0.0 {
169        return Ok(Bracket {
170            a,
171            b: a,
172            fa,
173            fb: fa,
174            evals: 1,
175        });
176    }
177    let fb = call_scalar_function(NAME, function, b).await?;
178    if fb == 0.0 || fa.signum() != fb.signum() {
179        Ok(Bracket {
180            a,
181            b,
182            fa,
183            fb,
184            evals: 2,
185        })
186    } else {
187        Err(optim_error(
188            NAME,
189            "fzero: function values at bracket endpoints must differ in sign",
190        ))
191    }
192}
193
194async fn expand_bracket(
195    function: &Value,
196    x0: f64,
197    options: &FzeroOptions,
198) -> BuiltinResult<Bracket> {
199    if !x0.is_finite() {
200        return Err(optim_error(NAME, "fzero: initial point must be finite"));
201    }
202    let f0 = call_scalar_function(NAME, function, x0).await?;
203    if f0 == 0.0 {
204        return Ok(Bracket {
205            a: x0,
206            b: x0,
207            fa: f0,
208            fb: f0,
209            evals: 1,
210        });
211    }
212
213    let mut evals = 1usize;
214    let mut step = (x0.abs() * 0.01).max(0.01);
215    while evals + 2 <= options.max_fun_evals {
216        let a = x0 - step;
217        let b = x0 + step;
218        let fa = call_scalar_function(NAME, function, a).await?;
219        let fb = call_scalar_function(NAME, function, b).await?;
220        evals += 2;
221        if fa == 0.0 {
222            return Ok(Bracket {
223                a,
224                b: a,
225                fa,
226                fb: fa,
227                evals,
228            });
229        }
230        if fa.signum() != f0.signum() {
231            return Ok(Bracket {
232                a,
233                b: x0,
234                fa,
235                fb: f0,
236                evals,
237            });
238        }
239        if fb.signum() != f0.signum() {
240            return Ok(Bracket {
241                a: x0,
242                b,
243                fa: f0,
244                fb,
245                evals,
246            });
247        }
248        if fb == 0.0 || fa.signum() != fb.signum() {
249            return Ok(Bracket {
250                a,
251                b,
252                fa,
253                fb,
254                evals,
255            });
256        }
257        step *= 1.6;
258    }
259
260    Err(optim_error(
261        NAME,
262        "fzero: could not find a sign-changing bracket around the initial point",
263    ))
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use futures::executor::block_on;
270    use runmat_builtins::Tensor;
271
272    #[test]
273    fn fzero_bracketed_builtin_handle() {
274        let bracket = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
275        let root = block_on(fzero_builtin(
276            Value::FunctionHandle("sin".into()),
277            Value::Tensor(bracket),
278            Vec::new(),
279        ))
280        .unwrap();
281        match root {
282            Value::Num(n) => assert!((n - std::f64::consts::PI).abs() < 1.0e-6),
283            other => panic!("unexpected value {other:?}"),
284        }
285    }
286
287    #[test]
288    fn fzero_scalar_initial_guess_expands_bracket() {
289        let root = block_on(fzero_builtin(
290            Value::FunctionHandle("cos".into()),
291            Value::Num(1.0),
292            Vec::new(),
293        ))
294        .unwrap();
295        match root {
296            Value::Num(n) => assert!((n - std::f64::consts::FRAC_PI_2).abs() < 1.0e-6),
297            other => panic!("unexpected value {other:?}"),
298        }
299    }
300
301    #[test]
302    fn fzero_scalar_initial_guess_uses_center_sign_for_bracket() {
303        let root = block_on(fzero_builtin(
304            Value::FunctionHandle("sin".into()),
305            Value::Num(std::f64::consts::FRAC_PI_2),
306            Vec::new(),
307        ))
308        .unwrap();
309        match root {
310            Value::Num(n) => assert!(n.abs() < 1.0e-6),
311            other => panic!("unexpected value {other:?}"),
312        }
313    }
314}