Skip to main content

runmat_runtime/builtins/math/optim/
fzero.rs

1//! MATLAB-compatible `fzero` builtin for scalar nonlinear root finding.
2
3use runmat_builtins::{StructValue, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::spec::{
7    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8    ReductionNaN, ResidencyPolicy, ShapeRequirements,
9};
10use crate::builtins::math::optim::common::{
11    call_scalar_function, optim_error, option_f64, option_string, option_usize,
12};
13use crate::builtins::math::optim::type_resolvers::scalar_root_type;
14use crate::BuiltinResult;
15
16const NAME: &str = "fzero";
17const DEFAULT_TOL_X: f64 = 1.0e-6;
18const DEFAULT_MAX_ITER: usize = 400;
19const DEFAULT_MAX_FUN_EVALS: usize = 500;
20
21#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::optim::fzero")]
22pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
23    name: "fzero",
24    op_kind: GpuOpKind::Custom("scalar-root-find"),
25    supported_precisions: &[],
26    broadcast: BroadcastSemantics::None,
27    provider_hooks: &[],
28    constant_strategy: ConstantStrategy::InlineLiteral,
29    residency: ResidencyPolicy::GatherImmediately,
30    nan_mode: ReductionNaN::Include,
31    two_pass_threshold: None,
32    workgroup_size: None,
33    accepts_nan_mode: false,
34    notes: "Host iterative solver. Callback values may use GPU-aware builtins, but the root search runs on the CPU.",
35};
36
37#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::optim::fzero")]
38pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
39    name: "fzero",
40    shape: ShapeRequirements::Any,
41    constant_strategy: ConstantStrategy::InlineLiteral,
42    elementwise: None,
43    reduction: None,
44    emits_nan: false,
45    notes: "Root finding repeatedly invokes user code and terminates fusion planning.",
46};
47
48#[runtime_builtin(
49    name = "fzero",
50    category = "math/optim",
51    summary = "Find a zero of a scalar nonlinear function using bracket expansion and Brent's method.",
52    keywords = "fzero,root finding,zero,brent,optimization",
53    accel = "sink",
54    type_resolver(scalar_root_type),
55    builtin_path = "crate::builtins::math::optim::fzero"
56)]
57async fn fzero_builtin(function: Value, x: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
58    if rest.len() > 1 {
59        return Err(optim_error(NAME, "fzero: too many input arguments"));
60    }
61    let options = parse_options(rest.first())?;
62    let opts = FzeroOptions::from_struct(options.as_ref())?;
63    let bracket = initial_bracket(&function, x, &opts).await?;
64    let root = brent(&function, bracket, &opts).await?;
65    Ok(Value::Num(root))
66}
67
68fn parse_options(value: Option<&Value>) -> BuiltinResult<Option<StructValue>> {
69    match value {
70        None => Ok(None),
71        Some(Value::Struct(options)) => Ok(Some(options.clone())),
72        Some(other) => Err(optim_error(
73            NAME,
74            format!("fzero: options must be a struct, got {other:?}"),
75        )),
76    }
77}
78
79#[derive(Clone, Copy)]
80struct FzeroOptions {
81    tol_x: f64,
82    max_iter: usize,
83    max_fun_evals: usize,
84}
85
86impl FzeroOptions {
87    fn from_struct(options: Option<&StructValue>) -> BuiltinResult<Self> {
88        let display = option_string(options, "Display", "off")?;
89        if !matches!(display.as_str(), "off" | "none" | "final" | "iter") {
90            return Err(optim_error(
91                NAME,
92                "fzero: option Display must be 'off', 'none', 'final', or 'iter'",
93            ));
94        }
95        let tol_x = option_f64(NAME, options, "TolX", DEFAULT_TOL_X)?;
96        if tol_x <= 0.0 {
97            return Err(optim_error(NAME, "fzero: option TolX must be positive"));
98        }
99        let max_iter = option_usize(NAME, options, "MaxIter", DEFAULT_MAX_ITER)?;
100        let max_fun_evals = option_usize(NAME, options, "MaxFunEvals", DEFAULT_MAX_FUN_EVALS)?;
101        Ok(Self {
102            tol_x,
103            max_iter: max_iter.max(1),
104            max_fun_evals: max_fun_evals.max(1),
105        })
106    }
107}
108
109#[derive(Clone, Copy)]
110struct Bracket {
111    a: f64,
112    b: f64,
113    fa: f64,
114    fb: f64,
115    evals: usize,
116}
117
118async fn initial_bracket(
119    function: &Value,
120    x: Value,
121    options: &FzeroOptions,
122) -> BuiltinResult<Bracket> {
123    let x = crate::dispatcher::gather_if_needed_async(&x).await?;
124    match x {
125        Value::Tensor(tensor) if tensor.data.len() == 2 => {
126            let a = tensor.data[0];
127            let b = tensor.data[1];
128            bracket_from_endpoints(function, a, b).await
129        }
130        Value::Tensor(tensor) if tensor.data.len() == 1 => {
131            expand_bracket(function, tensor.data[0], options).await
132        }
133        Value::Num(n) => expand_bracket(function, n, options).await,
134        Value::Int(i) => expand_bracket(function, i.to_f64(), options).await,
135        Value::Bool(b) => expand_bracket(function, if b { 1.0 } else { 0.0 }, options).await,
136        other => Err(optim_error(
137            NAME,
138            format!("fzero: initial point must be a scalar or two-element bracket, got {other:?}"),
139        )),
140    }
141}
142
143async fn bracket_from_endpoints(function: &Value, a: f64, b: f64) -> BuiltinResult<Bracket> {
144    if !a.is_finite() || !b.is_finite() || a == b {
145        return Err(optim_error(
146            NAME,
147            "fzero: bracket endpoints must be finite and distinct",
148        ));
149    }
150    let fa = call_scalar_function(NAME, function, a).await?;
151    if fa == 0.0 {
152        return Ok(Bracket {
153            a,
154            b: a,
155            fa,
156            fb: fa,
157            evals: 1,
158        });
159    }
160    let fb = call_scalar_function(NAME, function, b).await?;
161    if fb == 0.0 || fa.signum() != fb.signum() {
162        Ok(Bracket {
163            a,
164            b,
165            fa,
166            fb,
167            evals: 2,
168        })
169    } else {
170        Err(optim_error(
171            NAME,
172            "fzero: function values at bracket endpoints must differ in sign",
173        ))
174    }
175}
176
177async fn expand_bracket(
178    function: &Value,
179    x0: f64,
180    options: &FzeroOptions,
181) -> BuiltinResult<Bracket> {
182    if !x0.is_finite() {
183        return Err(optim_error(NAME, "fzero: initial point must be finite"));
184    }
185    let f0 = call_scalar_function(NAME, function, x0).await?;
186    if f0 == 0.0 {
187        return Ok(Bracket {
188            a: x0,
189            b: x0,
190            fa: f0,
191            fb: f0,
192            evals: 1,
193        });
194    }
195
196    let mut evals = 1usize;
197    let mut step = (x0.abs() * 0.01).max(0.01);
198    while evals + 2 <= options.max_fun_evals {
199        let a = x0 - step;
200        let b = x0 + step;
201        let fa = call_scalar_function(NAME, function, a).await?;
202        let fb = call_scalar_function(NAME, function, b).await?;
203        evals += 2;
204        if fa == 0.0 {
205            return Ok(Bracket {
206                a,
207                b: a,
208                fa,
209                fb: fa,
210                evals,
211            });
212        }
213        if fa.signum() != f0.signum() {
214            return Ok(Bracket {
215                a,
216                b: x0,
217                fa,
218                fb: f0,
219                evals,
220            });
221        }
222        if fb.signum() != f0.signum() {
223            return Ok(Bracket {
224                a: x0,
225                b,
226                fa: f0,
227                fb,
228                evals,
229            });
230        }
231        if fb == 0.0 || fa.signum() != fb.signum() {
232            return Ok(Bracket {
233                a,
234                b,
235                fa,
236                fb,
237                evals,
238            });
239        }
240        step *= 1.6;
241    }
242
243    Err(optim_error(
244        NAME,
245        "fzero: could not find a sign-changing bracket around the initial point",
246    ))
247}
248
249async fn brent(function: &Value, bracket: Bracket, options: &FzeroOptions) -> BuiltinResult<f64> {
250    if bracket.fa == 0.0 || bracket.a == bracket.b {
251        return Ok(bracket.a);
252    }
253    if bracket.fb == 0.0 {
254        return Ok(bracket.b);
255    }
256
257    let mut a = bracket.a;
258    let mut b = bracket.b;
259    let mut c = a;
260    let mut fa = bracket.fa;
261    let mut fb = bracket.fb;
262    let mut fc = fa;
263    let mut d = b - a;
264    let mut e = d;
265    let mut evals = bracket.evals;
266
267    for _ in 0..options.max_iter {
268        if fb.signum() == fc.signum() {
269            c = a;
270            fc = fa;
271            d = b - a;
272            e = d;
273        }
274        if fc.abs() < fb.abs() {
275            let old_b = b;
276            let old_fb = fb;
277            a = b;
278            fa = fb;
279            b = c;
280            fb = fc;
281            c = old_b;
282            fc = old_fb;
283        }
284
285        let tol = 2.0 * f64::EPSILON * b.abs() + 0.5 * options.tol_x;
286        let midpoint = 0.5 * (c - b);
287        if midpoint.abs() <= tol || fb == 0.0 {
288            return Ok(b);
289        }
290        if evals >= options.max_fun_evals {
291            return Err(optim_error(
292                NAME,
293                "fzero: exceeded maximum function evaluations",
294            ));
295        }
296
297        if e.abs() >= tol && fa.abs() > fb.abs() {
298            let s = fb / fa;
299            let (mut p, mut q) = if a == c {
300                (2.0 * midpoint * s, 1.0 - s)
301            } else {
302                let q = fa / fc;
303                let r = fb / fc;
304                (
305                    s * (2.0 * midpoint * q * (q - r) - (b - a) * (r - 1.0)),
306                    (q - 1.0) * (r - 1.0) * (s - 1.0),
307                )
308            };
309            if p > 0.0 {
310                q = -q;
311            }
312            p = p.abs();
313            if interpolation_step_accepted(p, q, midpoint, tol, e) {
314                e = d;
315                d = p / q;
316            } else {
317                d = midpoint;
318                e = d;
319            }
320        } else {
321            d = midpoint;
322            e = d;
323        }
324
325        a = b;
326        fa = fb;
327        b += if d.abs() > tol {
328            d
329        } else if midpoint >= 0.0 {
330            tol
331        } else {
332            -tol
333        };
334        fb = call_scalar_function(NAME, function, b).await?;
335        evals += 1;
336    }
337
338    Err(optim_error(NAME, "fzero: exceeded maximum iterations"))
339}
340
341fn interpolation_step_accepted(p: f64, q: f64, midpoint: f64, tol: f64, e: f64) -> bool {
342    let min_a = 3.0 * midpoint * q - (tol * q).abs();
343    let min_b = (e * q).abs();
344    2.0 * p < min_a.min(min_b)
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use futures::executor::block_on;
351    use runmat_builtins::Tensor;
352
353    #[test]
354    fn fzero_bracketed_builtin_handle() {
355        let bracket = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
356        let root = block_on(fzero_builtin(
357            Value::FunctionHandle("sin".into()),
358            Value::Tensor(bracket),
359            Vec::new(),
360        ))
361        .unwrap();
362        match root {
363            Value::Num(n) => assert!((n - std::f64::consts::PI).abs() < 1.0e-6),
364            other => panic!("unexpected value {other:?}"),
365        }
366    }
367
368    #[test]
369    fn fzero_scalar_initial_guess_expands_bracket() {
370        let root = block_on(fzero_builtin(
371            Value::FunctionHandle("cos".into()),
372            Value::Num(1.0),
373            Vec::new(),
374        ))
375        .unwrap();
376        match root {
377            Value::Num(n) => assert!((n - std::f64::consts::FRAC_PI_2).abs() < 1.0e-6),
378            other => panic!("unexpected value {other:?}"),
379        }
380    }
381
382    #[test]
383    fn fzero_scalar_initial_guess_uses_center_sign_for_bracket() {
384        let root = block_on(fzero_builtin(
385            Value::FunctionHandle("sin".into()),
386            Value::Num(std::f64::consts::FRAC_PI_2),
387            Vec::new(),
388        ))
389        .unwrap();
390        match root {
391            Value::Num(n) => assert!(n.abs() < 1.0e-6),
392            other => panic!("unexpected value {other:?}"),
393        }
394    }
395
396    #[test]
397    fn brent_interpolation_acceptance_uses_signed_q() {
398        assert!(!interpolation_step_accepted(1.0, -2.0, 1.0, 0.1, 10.0));
399        assert!(interpolation_step_accepted(1.0, -2.0, -1.0, 0.1, 10.0));
400    }
401}