Skip to main content

runmat_runtime/builtins/math/optim/
fzero.rs

1//! MATLAB-compatible `fzero` builtin for scalar nonlinear root finding.
2//!
3//! `fzero` searches for a scalar zero from either a two-point bracket or a
4//! scalar initial guess.  It supports MATLAB's four output arities:
5//!
6//! * `x = fzero(fun, x0)`
7//! * `x = fzero(fun, x0, options)`
8//! * `[x, fval] = fzero(...)`
9//! * `[x, fval, exitflag] = fzero(...)`
10//! * `[x, fval, exitflag, output] = fzero(...)`
11
12use runmat_builtins::{
13    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
14    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
15};
16use runmat_builtins::{StructValue, Value};
17use runmat_macros::runtime_builtin;
18
19use crate::builtins::common::spec::{
20    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
21    ReductionNaN, ResidencyPolicy, ShapeRequirements,
22};
23use crate::builtins::math::optim::brent::{
24    brent_zero, BrentParams, BrentZeroBracket, BrentZeroObserver, BrentZeroResult,
25    BrentZeroStepKind,
26};
27use crate::builtins::math::optim::common::{
28    call_scalar_function, option_f64, option_string, option_usize,
29};
30use crate::builtins::math::optim::type_resolvers::scalar_root_type;
31use crate::{build_runtime_error, BuiltinResult, RuntimeError};
32
33const NAME: &str = "fzero";
34const ALGORITHM: &str = "bisection, interpolation";
35const DEFAULT_TOL_X: f64 = 1.0e-6;
36const DEFAULT_MAX_ITER: usize = 400;
37const DEFAULT_MAX_FUN_EVALS: usize = 500;
38
39const FZERO_OUTPUT_X: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
40    name: "x",
41    ty: BuiltinParamType::NumericScalar,
42    arity: BuiltinParamArity::Required,
43    default: None,
44    description: "Estimated root location.",
45}];
46
47const FZERO_OUTPUT_X_FVAL: [BuiltinParamDescriptor; 2] = [
48    BuiltinParamDescriptor {
49        name: "x",
50        ty: BuiltinParamType::NumericScalar,
51        arity: BuiltinParamArity::Required,
52        default: None,
53        description: "Estimated root location.",
54    },
55    BuiltinParamDescriptor {
56        name: "fval",
57        ty: BuiltinParamType::NumericScalar,
58        arity: BuiltinParamArity::Required,
59        default: None,
60        description: "Function value at x.",
61    },
62];
63
64const FZERO_OUTPUT_X_FVAL_EXITFLAG: [BuiltinParamDescriptor; 3] = [
65    BuiltinParamDescriptor {
66        name: "x",
67        ty: BuiltinParamType::NumericScalar,
68        arity: BuiltinParamArity::Required,
69        default: None,
70        description: "Estimated root location.",
71    },
72    BuiltinParamDescriptor {
73        name: "fval",
74        ty: BuiltinParamType::NumericScalar,
75        arity: BuiltinParamArity::Required,
76        default: None,
77        description: "Function value at x.",
78    },
79    BuiltinParamDescriptor {
80        name: "exitflag",
81        ty: BuiltinParamType::NumericScalar,
82        arity: BuiltinParamArity::Required,
83        default: None,
84        description: "Convergence status code.",
85    },
86];
87
88const FZERO_OUTPUT_ALL: [BuiltinParamDescriptor; 4] = [
89    BuiltinParamDescriptor {
90        name: "x",
91        ty: BuiltinParamType::NumericScalar,
92        arity: BuiltinParamArity::Required,
93        default: None,
94        description: "Estimated root location.",
95    },
96    BuiltinParamDescriptor {
97        name: "fval",
98        ty: BuiltinParamType::NumericScalar,
99        arity: BuiltinParamArity::Required,
100        default: None,
101        description: "Function value at x.",
102    },
103    BuiltinParamDescriptor {
104        name: "exitflag",
105        ty: BuiltinParamType::NumericScalar,
106        arity: BuiltinParamArity::Required,
107        default: None,
108        description: "Convergence status code.",
109    },
110    BuiltinParamDescriptor {
111        name: "output",
112        ty: BuiltinParamType::Any,
113        arity: BuiltinParamArity::Required,
114        default: None,
115        description: "Iteration/function-count metadata struct.",
116    },
117];
118
119const FZERO_INPUTS_CORE: [BuiltinParamDescriptor; 2] = [
120    BuiltinParamDescriptor {
121        name: "fun",
122        ty: BuiltinParamType::Any,
123        arity: BuiltinParamArity::Required,
124        default: None,
125        description: "Scalar-valued callback.",
126    },
127    BuiltinParamDescriptor {
128        name: "x0",
129        ty: BuiltinParamType::Any,
130        arity: BuiltinParamArity::Required,
131        default: None,
132        description: "Initial point or two-element bracket.",
133    },
134];
135
136const FZERO_INPUTS_WITH_OPTIONS: [BuiltinParamDescriptor; 3] = [
137    BuiltinParamDescriptor {
138        name: "fun",
139        ty: BuiltinParamType::Any,
140        arity: BuiltinParamArity::Required,
141        default: None,
142        description: "Scalar-valued callback.",
143    },
144    BuiltinParamDescriptor {
145        name: "x0",
146        ty: BuiltinParamType::Any,
147        arity: BuiltinParamArity::Required,
148        default: None,
149        description: "Initial point or two-element bracket.",
150    },
151    BuiltinParamDescriptor {
152        name: "options",
153        ty: BuiltinParamType::Any,
154        arity: BuiltinParamArity::Optional,
155        default: None,
156        description: "Options struct from optimset.",
157    },
158];
159
160const FZERO_SIGNATURES: [BuiltinSignatureDescriptor; 8] = [
161    BuiltinSignatureDescriptor {
162        label: "x = fzero(fun, x0)",
163        inputs: &FZERO_INPUTS_CORE,
164        outputs: &FZERO_OUTPUT_X,
165    },
166    BuiltinSignatureDescriptor {
167        label: "x = fzero(fun, x0, options)",
168        inputs: &FZERO_INPUTS_WITH_OPTIONS,
169        outputs: &FZERO_OUTPUT_X,
170    },
171    BuiltinSignatureDescriptor {
172        label: "[x, fval] = fzero(fun, x0)",
173        inputs: &FZERO_INPUTS_CORE,
174        outputs: &FZERO_OUTPUT_X_FVAL,
175    },
176    BuiltinSignatureDescriptor {
177        label: "[x, fval] = fzero(fun, x0, options)",
178        inputs: &FZERO_INPUTS_WITH_OPTIONS,
179        outputs: &FZERO_OUTPUT_X_FVAL,
180    },
181    BuiltinSignatureDescriptor {
182        label: "[x, fval, exitflag] = fzero(fun, x0)",
183        inputs: &FZERO_INPUTS_CORE,
184        outputs: &FZERO_OUTPUT_X_FVAL_EXITFLAG,
185    },
186    BuiltinSignatureDescriptor {
187        label: "[x, fval, exitflag] = fzero(fun, x0, options)",
188        inputs: &FZERO_INPUTS_WITH_OPTIONS,
189        outputs: &FZERO_OUTPUT_X_FVAL_EXITFLAG,
190    },
191    BuiltinSignatureDescriptor {
192        label: "[x, fval, exitflag, output] = fzero(fun, x0)",
193        inputs: &FZERO_INPUTS_CORE,
194        outputs: &FZERO_OUTPUT_ALL,
195    },
196    BuiltinSignatureDescriptor {
197        label: "[x, fval, exitflag, output] = fzero(fun, x0, options)",
198        inputs: &FZERO_INPUTS_WITH_OPTIONS,
199        outputs: &FZERO_OUTPUT_ALL,
200    },
201];
202
203const FZERO_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
204    code: "RM.FZERO.INVALID_ARGUMENT",
205    identifier: Some("RunMat:fzero:InvalidArgument"),
206    when: "Argument grammar/options struct are invalid.",
207    message: "fzero: invalid argument",
208};
209
210const FZERO_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
211    code: "RM.FZERO.INVALID_INPUT",
212    identifier: Some("RunMat:fzero:InvalidInput"),
213    when: "Callback/bracket/initial-point semantics are invalid.",
214    message: "fzero: invalid input",
215};
216
217const FZERO_ERROR_TOO_MANY_OUTPUTS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
218    code: "RM.FZERO.TOO_MANY_OUTPUTS",
219    identifier: Some("RunMat:fzero:TooManyOutputs"),
220    when: "`fzero` is called with more than four requested output arguments.",
221    message: "fzero: too many output arguments",
222};
223
224const FZERO_ERRORS: [BuiltinErrorDescriptor; 3] = [
225    FZERO_ERROR_INVALID_ARGUMENT,
226    FZERO_ERROR_INVALID_INPUT,
227    FZERO_ERROR_TOO_MANY_OUTPUTS,
228];
229
230pub const FZERO_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
231    signatures: &FZERO_SIGNATURES,
232    output_mode: BuiltinOutputMode::ByRequestedOutputCount,
233    completion_policy: BuiltinCompletionPolicy::Public,
234    errors: &FZERO_ERRORS,
235};
236
237fn fzero_error_with_detail(
238    error: &'static BuiltinErrorDescriptor,
239    detail: impl AsRef<str>,
240) -> RuntimeError {
241    let detail = detail.as_ref();
242    let message = if detail.starts_with("fzero:") {
243        detail.to_string()
244    } else {
245        format!("{}: {detail}", error.message)
246    };
247    let mut builder = build_runtime_error(message).with_builtin(NAME);
248    if let Some(identifier) = error.identifier {
249        builder = builder.with_identifier(identifier);
250    }
251    builder.build()
252}
253
254fn fzero_map_error(err: RuntimeError, fallback: &'static BuiltinErrorDescriptor) -> RuntimeError {
255    if err.identifier().is_some() {
256        err
257    } else {
258        fzero_error_with_detail(fallback, err.message())
259    }
260}
261
262fn validate_requested_outputs() -> BuiltinResult<()> {
263    if matches!(crate::output_count::current_output_count(), Some(n) if n > 4) {
264        return Err(fzero_too_many_outputs_error());
265    }
266    Ok(())
267}
268
269fn fzero_too_many_outputs_error() -> RuntimeError {
270    fzero_error_with_detail(
271        &FZERO_ERROR_TOO_MANY_OUTPUTS,
272        "fzero: too many output arguments; maximum is 4",
273    )
274}
275
276#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::optim::fzero")]
277pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
278    name: "fzero",
279    op_kind: GpuOpKind::Custom("scalar-root-find"),
280    supported_precisions: &[],
281    broadcast: BroadcastSemantics::None,
282    provider_hooks: &[],
283    constant_strategy: ConstantStrategy::InlineLiteral,
284    residency: ResidencyPolicy::GatherImmediately,
285    nan_mode: ReductionNaN::Include,
286    two_pass_threshold: None,
287    workgroup_size: None,
288    accepts_nan_mode: false,
289    notes: "Host iterative solver. Callback values may use GPU-aware builtins, but the root search runs on the CPU.",
290};
291
292#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::optim::fzero")]
293pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
294    name: "fzero",
295    shape: ShapeRequirements::Any,
296    constant_strategy: ConstantStrategy::InlineLiteral,
297    elementwise: None,
298    reduction: None,
299    emits_nan: false,
300    notes: "Root finding repeatedly invokes user code and terminates fusion planning.",
301};
302
303#[runtime_builtin(
304    name = "fzero",
305    category = "math/optim",
306    summary = "Find scalar function zeros with bracketed root-finding.",
307    keywords = "fzero,root finding,zero,brent,optimization",
308    accel = "sink",
309    type_resolver(scalar_root_type),
310    descriptor(crate::builtins::math::optim::fzero::FZERO_DESCRIPTOR),
311    builtin_path = "crate::builtins::math::optim::fzero"
312)]
313async fn fzero_builtin(function: Value, x: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
314    if rest.len() > 1 {
315        return Err(fzero_error_with_detail(
316            &FZERO_ERROR_INVALID_ARGUMENT,
317            "too many input arguments",
318        ));
319    }
320    validate_requested_outputs()?;
321    let options = parse_options(rest.first())
322        .map_err(|err| fzero_map_error(err, &FZERO_ERROR_INVALID_ARGUMENT))?;
323    let opts = FzeroOptions::from_struct(options.as_ref())
324        .map_err(|err| fzero_map_error(err, &FZERO_ERROR_INVALID_ARGUMENT))?;
325    let bracket = initial_bracket(&function, x, &opts)
326        .await
327        .map_err(|err| fzero_map_error(err, &FZERO_ERROR_INVALID_INPUT))?;
328    let mut iter_log = IterDisplay::new(opts.display);
329    let observer: Option<&mut dyn BrentZeroObserver> = if matches!(opts.display, DisplayMode::Iter)
330    {
331        Some(&mut iter_log)
332    } else {
333        None
334    };
335    let result = brent_zero(
336        NAME,
337        &function,
338        BrentZeroBracket {
339            a: bracket.a,
340            b: bracket.b,
341            fa: bracket.fa,
342            fb: bracket.fb,
343            evals: bracket.evals,
344        },
345        BrentParams {
346            tol_x: opts.tol_x,
347            max_iter: opts.max_iter,
348            max_fun_evals: opts.max_fun_evals,
349        },
350        observer,
351    )
352    .await
353    .map_err(|err| fzero_map_error(err, &FZERO_ERROR_INVALID_INPUT))?;
354    finalize(result, &opts)
355}
356
357fn parse_options(value: Option<&Value>) -> BuiltinResult<Option<StructValue>> {
358    match value {
359        None => Ok(None),
360        Some(Value::Struct(options)) => Ok(Some(options.clone())),
361        Some(other) => Err(fzero_error_with_detail(
362            &FZERO_ERROR_INVALID_ARGUMENT,
363            format!("options must be a struct, got {other:?}"),
364        )),
365    }
366}
367
368#[derive(Clone, Copy)]
369struct FzeroOptions {
370    tol_x: f64,
371    max_iter: usize,
372    max_fun_evals: usize,
373    display: DisplayMode,
374}
375
376impl FzeroOptions {
377    fn from_struct(options: Option<&StructValue>) -> BuiltinResult<Self> {
378        let display = DisplayMode::parse(&option_string(options, "Display", "off")?)?;
379        let tol_x = option_f64(NAME, options, "TolX", DEFAULT_TOL_X)?;
380        if tol_x <= 0.0 {
381            return Err(fzero_error_with_detail(
382                &FZERO_ERROR_INVALID_ARGUMENT,
383                "option TolX must be positive",
384            ));
385        }
386        let max_iter = option_usize(NAME, options, "MaxIter", DEFAULT_MAX_ITER)?;
387        let max_fun_evals = option_usize(NAME, options, "MaxFunEvals", DEFAULT_MAX_FUN_EVALS)?;
388        Ok(Self {
389            tol_x,
390            max_iter: max_iter.max(1),
391            max_fun_evals: max_fun_evals.max(1),
392            display,
393        })
394    }
395}
396
397#[derive(Debug, Clone, Copy, PartialEq, Eq)]
398enum DisplayMode {
399    Off,
400    Iter,
401    Final,
402}
403
404impl DisplayMode {
405    fn parse(text: &str) -> BuiltinResult<Self> {
406        match text.to_ascii_lowercase().as_str() {
407            "off" | "none" => Ok(Self::Off),
408            "iter" => Ok(Self::Iter),
409            "final" => Ok(Self::Final),
410            other => Err(fzero_error_with_detail(
411                &FZERO_ERROR_INVALID_ARGUMENT,
412                format!("option Display must be 'off', 'none', 'final', or 'iter', got '{other}'"),
413            )),
414        }
415    }
416}
417
418#[derive(Clone, Copy)]
419struct Bracket {
420    a: f64,
421    b: f64,
422    fa: f64,
423    fb: f64,
424    evals: usize,
425}
426
427async fn initial_bracket(
428    function: &Value,
429    x: Value,
430    options: &FzeroOptions,
431) -> BuiltinResult<Bracket> {
432    let x = crate::dispatcher::gather_if_needed_async(&x).await?;
433    match x {
434        Value::Tensor(tensor) if tensor.data.len() == 2 => {
435            let a = tensor.data[0];
436            let b = tensor.data[1];
437            bracket_from_endpoints(function, a, b).await
438        }
439        Value::Tensor(tensor) if tensor.data.len() == 1 => {
440            expand_bracket(function, tensor.data[0], options).await
441        }
442        Value::Num(n) => expand_bracket(function, n, options).await,
443        Value::Int(i) => expand_bracket(function, i.to_f64(), options).await,
444        Value::Bool(b) => expand_bracket(function, if b { 1.0 } else { 0.0 }, options).await,
445        other => Err(fzero_error_with_detail(
446            &FZERO_ERROR_INVALID_INPUT,
447            format!("initial point must be a scalar or two-element bracket, got {other:?}"),
448        )),
449    }
450}
451
452async fn bracket_from_endpoints(function: &Value, a: f64, b: f64) -> BuiltinResult<Bracket> {
453    if !a.is_finite() || !b.is_finite() || a == b {
454        return Err(fzero_error_with_detail(
455            &FZERO_ERROR_INVALID_INPUT,
456            "bracket endpoints must be finite and distinct",
457        ));
458    }
459    let fa = call_scalar_function(NAME, function, a).await?;
460    if fa == 0.0 {
461        return Ok(Bracket {
462            a,
463            b: a,
464            fa,
465            fb: fa,
466            evals: 1,
467        });
468    }
469    let fb = call_scalar_function(NAME, function, b).await?;
470    if fb == 0.0 || fa.signum() != fb.signum() {
471        Ok(Bracket {
472            a,
473            b,
474            fa,
475            fb,
476            evals: 2,
477        })
478    } else {
479        Err(fzero_error_with_detail(
480            &FZERO_ERROR_INVALID_INPUT,
481            "function values at bracket endpoints must differ in sign",
482        ))
483    }
484}
485
486async fn expand_bracket(
487    function: &Value,
488    x0: f64,
489    options: &FzeroOptions,
490) -> BuiltinResult<Bracket> {
491    if !x0.is_finite() {
492        return Err(fzero_error_with_detail(
493            &FZERO_ERROR_INVALID_INPUT,
494            "initial point must be finite",
495        ));
496    }
497    let f0 = call_scalar_function(NAME, function, x0).await?;
498    if f0 == 0.0 {
499        return Ok(Bracket {
500            a: x0,
501            b: x0,
502            fa: f0,
503            fb: f0,
504            evals: 1,
505        });
506    }
507
508    let mut evals = 1usize;
509    let mut step = (x0.abs() * 0.01).max(0.01);
510    while evals + 2 <= options.max_fun_evals {
511        let a = x0 - step;
512        let b = x0 + step;
513        let fa = call_scalar_function(NAME, function, a).await?;
514        let fb = call_scalar_function(NAME, function, b).await?;
515        evals += 2;
516        if fa == 0.0 {
517            return Ok(Bracket {
518                a,
519                b: a,
520                fa,
521                fb: fa,
522                evals,
523            });
524        }
525        if fa.signum() != f0.signum() {
526            return Ok(Bracket {
527                a,
528                b: x0,
529                fa,
530                fb: f0,
531                evals,
532            });
533        }
534        if fb.signum() != f0.signum() {
535            return Ok(Bracket {
536                a: x0,
537                b,
538                fa: f0,
539                fb,
540                evals,
541            });
542        }
543        if fb == 0.0 || fa.signum() != fb.signum() {
544            return Ok(Bracket {
545                a,
546                b,
547                fa,
548                fb,
549                evals,
550            });
551        }
552        step *= 1.6;
553    }
554
555    Err(fzero_error_with_detail(
556        &FZERO_ERROR_INVALID_INPUT,
557        "could not find a sign-changing bracket around the initial point",
558    ))
559}
560
561fn finalize(result: BrentZeroResult, options: &FzeroOptions) -> BuiltinResult<Value> {
562    let exit_flag = if result.converged { 1 } else { 0 };
563    let message = build_message(&result);
564    emit_summary(&result, exit_flag, &message, options);
565
566    let x = Value::Num(result.x);
567    let fval = Value::Num(result.fval);
568    let exitflag = Value::Num(exit_flag as f64);
569    let output_struct = Value::Struct(build_output_struct(&result, &message));
570
571    match crate::output_count::current_output_count() {
572        None => Ok(x),
573        Some(0) => Ok(Value::OutputList(Vec::new())),
574        Some(1) => Ok(crate::output_count::output_list_with_padding(1, vec![x])),
575        Some(2) => Ok(crate::output_count::output_list_with_padding(
576            2,
577            vec![x, fval],
578        )),
579        Some(3) => Ok(crate::output_count::output_list_with_padding(
580            3,
581            vec![x, fval, exitflag],
582        )),
583        Some(4) => Ok(crate::output_count::output_list_with_padding(
584            4,
585            vec![x, fval, exitflag, output_struct],
586        )),
587        Some(_) => Err(fzero_too_many_outputs_error()),
588    }
589}
590
591fn build_output_struct(result: &BrentZeroResult, message: &str) -> StructValue {
592    let mut fields = StructValue::new();
593    fields.insert("iterations", Value::Num(result.iterations as f64));
594    fields.insert("funcCount", Value::Num(result.func_count as f64));
595    fields.insert("algorithm", Value::from(ALGORITHM));
596    fields.insert("message", Value::from(message.to_string()));
597    fields
598}
599
600fn build_message(result: &BrentZeroResult) -> String {
601    if result.converged {
602        format!(
603            "Zero found within OPTIONS.TolX. Iterations: {}, FuncCount: {}.",
604            result.iterations, result.func_count
605        )
606    } else {
607        format!(
608            "Exiting: Maximum number of function evaluations or iterations has been exceeded - increase MaxFunEvals or MaxIter. Iterations: {}, FuncCount: {}.",
609            result.iterations, result.func_count
610        )
611    }
612}
613
614fn emit_summary(result: &BrentZeroResult, exit_flag: i32, message: &str, options: &FzeroOptions) {
615    if !matches!(options.display, DisplayMode::Final | DisplayMode::Iter) {
616        return;
617    }
618    crate::console::record_console_line(
619        crate::console::ConsoleStream::Stdout,
620        format!(
621            "fzero: x = {x:.6}, fval = {fval:.6}, exitflag = {exit_flag}. {message}",
622            x = result.x,
623            fval = result.fval,
624        ),
625    );
626}
627
628struct IterDisplay {
629    mode: DisplayMode,
630    printed_header: bool,
631}
632
633impl IterDisplay {
634    fn new(mode: DisplayMode) -> Self {
635        Self {
636            mode,
637            printed_header: false,
638        }
639    }
640}
641
642impl BrentZeroObserver for IterDisplay {
643    fn on_iteration(
644        &mut self,
645        iter: usize,
646        func_count: usize,
647        x: f64,
648        fx: f64,
649        step_kind: BrentZeroStepKind,
650    ) {
651        if !matches!(self.mode, DisplayMode::Iter) {
652            return;
653        }
654        if !self.printed_header {
655            crate::console::record_console_line(
656                crate::console::ConsoleStream::Stdout,
657                " Func-count        x          f(x)          Procedure",
658            );
659            self.printed_header = true;
660        }
661        let procedure = match step_kind {
662            BrentZeroStepKind::Initial => "initial",
663            BrentZeroStepKind::Bisection => "bisection",
664            BrentZeroStepKind::Interpolation => "interpolation",
665        };
666        let line =
667            format!("    {func_count:>5}    {x:13.6e} {fx:13.6e}    {procedure}    (iter {iter})");
668        crate::console::record_console_line(crate::console::ConsoleStream::Stdout, line);
669    }
670}
671
672#[cfg(test)]
673mod tests {
674    use super::*;
675    use crate::builtins::math::optim::brent::interpolation_step_accepted;
676    use futures::executor::block_on;
677    use runmat_builtins::Tensor;
678    use std::sync::Arc;
679
680    #[test]
681    fn fzero_bracketed_builtin_handle() {
682        let bracket = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
683        let root = block_on(fzero_builtin(
684            Value::FunctionHandle("sin".into()),
685            Value::Tensor(bracket),
686            Vec::new(),
687        ))
688        .unwrap();
689        match root {
690            Value::Num(n) => assert!((n - std::f64::consts::PI).abs() < 1.0e-6),
691            other => panic!("unexpected value {other:?}"),
692        }
693    }
694
695    #[test]
696    fn fzero_scalar_initial_guess_expands_bracket() {
697        let root = block_on(fzero_builtin(
698            Value::FunctionHandle("cos".into()),
699            Value::Num(1.0),
700            Vec::new(),
701        ))
702        .unwrap();
703        match root {
704            Value::Num(n) => assert!((n - std::f64::consts::FRAC_PI_2).abs() < 1.0e-6),
705            other => panic!("unexpected value {other:?}"),
706        }
707    }
708
709    #[test]
710    fn fzero_scalar_initial_guess_uses_center_sign_for_bracket() {
711        let root = block_on(fzero_builtin(
712            Value::FunctionHandle("sin".into()),
713            Value::Num(std::f64::consts::FRAC_PI_2),
714            Vec::new(),
715        ))
716        .unwrap();
717        match root {
718            Value::Num(n) => assert!(n.abs() < 1.0e-6),
719            other => panic!("unexpected value {other:?}"),
720        }
721    }
722    #[test]
723    fn fzero_accepts_semantic_function_handle_callback() {
724        let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
725            |function, args, requested_outputs| {
726                assert_eq!(function, 42);
727                assert_eq!(requested_outputs, 1);
728                let x = match &args[0] {
729                    Value::Num(value) => *value,
730                    other => panic!("expected scalar numeric argument, got {other:?}"),
731                };
732                Box::pin(async move { Ok(Value::Num(x - 2.0)) })
733            },
734        )));
735
736        let root = block_on(fzero_builtin(
737            Value::BoundFunctionHandle {
738                name: "root_function".to_string(),
739                function: 42,
740            },
741            Value::Num(0.0),
742            Vec::new(),
743        ))
744        .unwrap();
745        match root {
746            Value::Num(n) => assert!((n - 2.0).abs() < 1.0e-6),
747            other => panic!("unexpected value {other:?}"),
748        }
749    }
750
751    #[test]
752    fn fzero_multi_output_two_returns_root_and_fval() {
753        let _guard = crate::output_count::push_output_count(Some(2));
754        let bracket = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
755        let result = block_on(fzero_builtin(
756            Value::FunctionHandle("sin".into()),
757            Value::Tensor(bracket),
758            Vec::new(),
759        ))
760        .expect("fzero");
761        match result {
762            Value::OutputList(outputs) => {
763                assert_eq!(outputs.len(), 2);
764                match (&outputs[0], &outputs[1]) {
765                    (Value::Num(x), Value::Num(fval)) => {
766                        assert!((x - std::f64::consts::PI).abs() < 1.0e-6);
767                        assert!(fval.abs() < 1.0e-6);
768                    }
769                    other => panic!("unexpected outputs {other:?}"),
770                }
771            }
772            other => panic!("unexpected value {other:?}"),
773        }
774    }
775
776    #[test]
777    fn fzero_multi_output_four_includes_output_struct() {
778        let _guard = crate::output_count::push_output_count(Some(4));
779        let bracket = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
780        let result = block_on(fzero_builtin(
781            Value::FunctionHandle("sin".into()),
782            Value::Tensor(bracket),
783            Vec::new(),
784        ))
785        .expect("fzero");
786        match result {
787            Value::OutputList(outputs) => {
788                assert_eq!(outputs.len(), 4);
789                assert!(matches!(&outputs[2], Value::Num(flag) if *flag == 1.0));
790                match &outputs[3] {
791                    Value::Struct(output) => {
792                        assert!(matches!(
793                            output.fields.get("iterations"),
794                            Some(Value::Num(_))
795                        ));
796                        assert!(matches!(
797                            output.fields.get("funcCount"),
798                            Some(Value::Num(_))
799                        ));
800                        match output.fields.get("algorithm") {
801                            Some(Value::String(text)) => assert!(text.contains("bisection")),
802                            other => panic!("unexpected algorithm field {other:?}"),
803                        }
804                        match output.fields.get("message") {
805                            Some(Value::String(text)) => assert!(text.contains("Zero found")),
806                            other => panic!("unexpected message field {other:?}"),
807                        }
808                    }
809                    other => panic!("unexpected output struct {other:?}"),
810                }
811            }
812            other => panic!("unexpected value {other:?}"),
813        }
814    }
815
816    #[test]
817    fn fzero_reports_zero_exitflag_when_iteration_budget_exhausted() {
818        let mut opts = StructValue::new();
819        opts.insert("MaxIter", Value::Num(1.0));
820        opts.insert("MaxFunEvals", Value::Num(2.0));
821        opts.insert("Display", Value::from("off"));
822        let _guard = crate::output_count::push_output_count(Some(3));
823        let bracket = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
824        let result = block_on(fzero_builtin(
825            Value::FunctionHandle("sin".into()),
826            Value::Tensor(bracket),
827            vec![Value::Struct(opts)],
828        ))
829        .expect("fzero");
830        match result {
831            Value::OutputList(outputs) => match &outputs[2] {
832                Value::Num(flag) => assert_eq!(*flag, 0.0),
833                other => panic!("unexpected exitflag {other:?}"),
834            },
835            other => panic!("unexpected value {other:?}"),
836        }
837    }
838
839    #[test]
840    fn fzero_reports_convergence_when_final_step_hits_root() {
841        let mut opts = StructValue::new();
842        opts.insert("MaxIter", Value::Num(1.0));
843        opts.insert("Display", Value::from("off"));
844        let _guard = crate::output_count::push_output_count(Some(3));
845        let bracket = Tensor::new(vec![-1.0, 1.0], vec![1, 2]).unwrap();
846        let result = block_on(fzero_builtin(
847            Value::FunctionHandle("sin".into()),
848            Value::Tensor(bracket),
849            vec![Value::Struct(opts)],
850        ))
851        .expect("fzero");
852        match result {
853            Value::OutputList(outputs) => {
854                assert!(matches!(&outputs[0], Value::Num(x) if x.abs() < 1.0e-12));
855                assert!(matches!(&outputs[1], Value::Num(fval) if fval.abs() < 1.0e-12));
856                assert!(matches!(&outputs[2], Value::Num(flag) if *flag == 1.0));
857            }
858            other => panic!("unexpected value {other:?}"),
859        }
860    }
861
862    #[test]
863    fn fzero_rejects_more_than_four_outputs() {
864        let _guard = crate::output_count::push_output_count(Some(5));
865        let bracket = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
866        let err = block_on(fzero_builtin(
867            Value::FunctionHandle("sin".into()),
868            Value::Tensor(bracket),
869            Vec::new(),
870        ))
871        .expect_err("too many outputs should fail");
872        assert_eq!(err.identifier(), Some("RunMat:fzero:TooManyOutputs"));
873        assert!(err.message().contains("maximum is 4"));
874    }
875
876    #[test]
877    fn fzero_iter_display_records_iteration_rows() {
878        crate::console::reset_thread_buffer();
879        let mut opts = StructValue::new();
880        opts.insert("Display", Value::from("iter"));
881        let bracket = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
882        let result = block_on(fzero_builtin(
883            Value::FunctionHandle("sin".into()),
884            Value::Tensor(bracket),
885            vec![Value::Struct(opts)],
886        ))
887        .expect("fzero");
888        assert!(matches!(result, Value::Num(_)));
889
890        let joined = crate::console::take_thread_buffer()
891            .into_iter()
892            .map(|entry| entry.text)
893            .collect::<String>();
894        assert!(joined.contains("Func-count"), "{joined}");
895        assert!(joined.contains("initial"), "{joined}");
896        assert!(
897            joined.contains("interpolation") || joined.contains("bisection"),
898            "{joined}"
899        );
900        assert!(joined.contains("exitflag = 1"), "{joined}");
901    }
902
903    #[test]
904    fn brent_interpolation_acceptance_uses_signed_q() {
905        assert!(!interpolation_step_accepted(1.0, -2.0, 1.0, 0.1, 10.0));
906        assert!(interpolation_step_accepted(1.0, -2.0, -1.0, 0.1, 10.0));
907    }
908
909    #[test]
910    fn fzero_descriptor_signatures_cover_core_forms() {
911        let labels: Vec<&str> = FZERO_DESCRIPTOR
912            .signatures
913            .iter()
914            .map(|signature| signature.label)
915            .collect();
916        assert_eq!(
917            labels,
918            vec![
919                "x = fzero(fun, x0)",
920                "x = fzero(fun, x0, options)",
921                "[x, fval] = fzero(fun, x0)",
922                "[x, fval] = fzero(fun, x0, options)",
923                "[x, fval, exitflag] = fzero(fun, x0)",
924                "[x, fval, exitflag] = fzero(fun, x0, options)",
925                "[x, fval, exitflag, output] = fzero(fun, x0)",
926                "[x, fval, exitflag, output] = fzero(fun, x0, options)",
927            ]
928        );
929
930        let codes: Vec<&str> = FZERO_DESCRIPTOR
931            .errors
932            .iter()
933            .map(|error| error.code)
934            .collect();
935        assert_eq!(
936            codes,
937            vec![
938                "RM.FZERO.INVALID_ARGUMENT",
939                "RM.FZERO.INVALID_INPUT",
940                "RM.FZERO.TOO_MANY_OUTPUTS",
941            ]
942        );
943    }
944
945    #[test]
946    fn fzero_too_many_args_uses_stable_identifier() {
947        let err = block_on(fzero_builtin(
948            Value::FunctionHandle("sin".into()),
949            Value::Num(0.0),
950            vec![
951                Value::Struct(StructValue::new()),
952                Value::Struct(StructValue::new()),
953            ],
954        ))
955        .unwrap_err();
956        assert_eq!(err.identifier(), Some("RunMat:fzero:InvalidArgument"));
957    }
958}