1use runmat_builtins::{
13 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
14 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
15};
16use runmat_builtins::{StructValue, Value};
17use runmat_macros::runtime_builtin;
18
19use crate::builtins::common::spec::{
20 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
21 ReductionNaN, ResidencyPolicy, ShapeRequirements,
22};
23use crate::builtins::math::optim::brent::{
24 brent_zero, BrentParams, BrentZeroBracket, BrentZeroObserver, BrentZeroResult,
25 BrentZeroStepKind,
26};
27use crate::builtins::math::optim::common::{
28 call_scalar_function, option_f64, option_string, option_usize,
29};
30use crate::builtins::math::optim::type_resolvers::scalar_root_type;
31use crate::{build_runtime_error, BuiltinResult, RuntimeError};
32
33const NAME: &str = "fzero";
34const ALGORITHM: &str = "bisection, interpolation";
35const DEFAULT_TOL_X: f64 = 1.0e-6;
36const DEFAULT_MAX_ITER: usize = 400;
37const DEFAULT_MAX_FUN_EVALS: usize = 500;
38
39const FZERO_OUTPUT_X: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
40 name: "x",
41 ty: BuiltinParamType::NumericScalar,
42 arity: BuiltinParamArity::Required,
43 default: None,
44 description: "Estimated root location.",
45}];
46
47const FZERO_OUTPUT_X_FVAL: [BuiltinParamDescriptor; 2] = [
48 BuiltinParamDescriptor {
49 name: "x",
50 ty: BuiltinParamType::NumericScalar,
51 arity: BuiltinParamArity::Required,
52 default: None,
53 description: "Estimated root location.",
54 },
55 BuiltinParamDescriptor {
56 name: "fval",
57 ty: BuiltinParamType::NumericScalar,
58 arity: BuiltinParamArity::Required,
59 default: None,
60 description: "Function value at x.",
61 },
62];
63
64const FZERO_OUTPUT_X_FVAL_EXITFLAG: [BuiltinParamDescriptor; 3] = [
65 BuiltinParamDescriptor {
66 name: "x",
67 ty: BuiltinParamType::NumericScalar,
68 arity: BuiltinParamArity::Required,
69 default: None,
70 description: "Estimated root location.",
71 },
72 BuiltinParamDescriptor {
73 name: "fval",
74 ty: BuiltinParamType::NumericScalar,
75 arity: BuiltinParamArity::Required,
76 default: None,
77 description: "Function value at x.",
78 },
79 BuiltinParamDescriptor {
80 name: "exitflag",
81 ty: BuiltinParamType::NumericScalar,
82 arity: BuiltinParamArity::Required,
83 default: None,
84 description: "Convergence status code.",
85 },
86];
87
88const FZERO_OUTPUT_ALL: [BuiltinParamDescriptor; 4] = [
89 BuiltinParamDescriptor {
90 name: "x",
91 ty: BuiltinParamType::NumericScalar,
92 arity: BuiltinParamArity::Required,
93 default: None,
94 description: "Estimated root location.",
95 },
96 BuiltinParamDescriptor {
97 name: "fval",
98 ty: BuiltinParamType::NumericScalar,
99 arity: BuiltinParamArity::Required,
100 default: None,
101 description: "Function value at x.",
102 },
103 BuiltinParamDescriptor {
104 name: "exitflag",
105 ty: BuiltinParamType::NumericScalar,
106 arity: BuiltinParamArity::Required,
107 default: None,
108 description: "Convergence status code.",
109 },
110 BuiltinParamDescriptor {
111 name: "output",
112 ty: BuiltinParamType::Any,
113 arity: BuiltinParamArity::Required,
114 default: None,
115 description: "Iteration/function-count metadata struct.",
116 },
117];
118
119const FZERO_INPUTS_CORE: [BuiltinParamDescriptor; 2] = [
120 BuiltinParamDescriptor {
121 name: "fun",
122 ty: BuiltinParamType::Any,
123 arity: BuiltinParamArity::Required,
124 default: None,
125 description: "Scalar-valued callback.",
126 },
127 BuiltinParamDescriptor {
128 name: "x0",
129 ty: BuiltinParamType::Any,
130 arity: BuiltinParamArity::Required,
131 default: None,
132 description: "Initial point or two-element bracket.",
133 },
134];
135
136const FZERO_INPUTS_WITH_OPTIONS: [BuiltinParamDescriptor; 3] = [
137 BuiltinParamDescriptor {
138 name: "fun",
139 ty: BuiltinParamType::Any,
140 arity: BuiltinParamArity::Required,
141 default: None,
142 description: "Scalar-valued callback.",
143 },
144 BuiltinParamDescriptor {
145 name: "x0",
146 ty: BuiltinParamType::Any,
147 arity: BuiltinParamArity::Required,
148 default: None,
149 description: "Initial point or two-element bracket.",
150 },
151 BuiltinParamDescriptor {
152 name: "options",
153 ty: BuiltinParamType::Any,
154 arity: BuiltinParamArity::Optional,
155 default: None,
156 description: "Options struct from optimset.",
157 },
158];
159
160const FZERO_SIGNATURES: [BuiltinSignatureDescriptor; 8] = [
161 BuiltinSignatureDescriptor {
162 label: "x = fzero(fun, x0)",
163 inputs: &FZERO_INPUTS_CORE,
164 outputs: &FZERO_OUTPUT_X,
165 },
166 BuiltinSignatureDescriptor {
167 label: "x = fzero(fun, x0, options)",
168 inputs: &FZERO_INPUTS_WITH_OPTIONS,
169 outputs: &FZERO_OUTPUT_X,
170 },
171 BuiltinSignatureDescriptor {
172 label: "[x, fval] = fzero(fun, x0)",
173 inputs: &FZERO_INPUTS_CORE,
174 outputs: &FZERO_OUTPUT_X_FVAL,
175 },
176 BuiltinSignatureDescriptor {
177 label: "[x, fval] = fzero(fun, x0, options)",
178 inputs: &FZERO_INPUTS_WITH_OPTIONS,
179 outputs: &FZERO_OUTPUT_X_FVAL,
180 },
181 BuiltinSignatureDescriptor {
182 label: "[x, fval, exitflag] = fzero(fun, x0)",
183 inputs: &FZERO_INPUTS_CORE,
184 outputs: &FZERO_OUTPUT_X_FVAL_EXITFLAG,
185 },
186 BuiltinSignatureDescriptor {
187 label: "[x, fval, exitflag] = fzero(fun, x0, options)",
188 inputs: &FZERO_INPUTS_WITH_OPTIONS,
189 outputs: &FZERO_OUTPUT_X_FVAL_EXITFLAG,
190 },
191 BuiltinSignatureDescriptor {
192 label: "[x, fval, exitflag, output] = fzero(fun, x0)",
193 inputs: &FZERO_INPUTS_CORE,
194 outputs: &FZERO_OUTPUT_ALL,
195 },
196 BuiltinSignatureDescriptor {
197 label: "[x, fval, exitflag, output] = fzero(fun, x0, options)",
198 inputs: &FZERO_INPUTS_WITH_OPTIONS,
199 outputs: &FZERO_OUTPUT_ALL,
200 },
201];
202
203const FZERO_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
204 code: "RM.FZERO.INVALID_ARGUMENT",
205 identifier: Some("RunMat:fzero:InvalidArgument"),
206 when: "Argument grammar/options struct are invalid.",
207 message: "fzero: invalid argument",
208};
209
210const FZERO_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
211 code: "RM.FZERO.INVALID_INPUT",
212 identifier: Some("RunMat:fzero:InvalidInput"),
213 when: "Callback/bracket/initial-point semantics are invalid.",
214 message: "fzero: invalid input",
215};
216
217const FZERO_ERROR_TOO_MANY_OUTPUTS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
218 code: "RM.FZERO.TOO_MANY_OUTPUTS",
219 identifier: Some("RunMat:fzero:TooManyOutputs"),
220 when: "`fzero` is called with more than four requested output arguments.",
221 message: "fzero: too many output arguments",
222};
223
224const FZERO_ERRORS: [BuiltinErrorDescriptor; 3] = [
225 FZERO_ERROR_INVALID_ARGUMENT,
226 FZERO_ERROR_INVALID_INPUT,
227 FZERO_ERROR_TOO_MANY_OUTPUTS,
228];
229
230pub const FZERO_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
231 signatures: &FZERO_SIGNATURES,
232 output_mode: BuiltinOutputMode::ByRequestedOutputCount,
233 completion_policy: BuiltinCompletionPolicy::Public,
234 errors: &FZERO_ERRORS,
235};
236
237fn fzero_error_with_detail(
238 error: &'static BuiltinErrorDescriptor,
239 detail: impl AsRef<str>,
240) -> RuntimeError {
241 let detail = detail.as_ref();
242 let message = if detail.starts_with("fzero:") {
243 detail.to_string()
244 } else {
245 format!("{}: {detail}", error.message)
246 };
247 let mut builder = build_runtime_error(message).with_builtin(NAME);
248 if let Some(identifier) = error.identifier {
249 builder = builder.with_identifier(identifier);
250 }
251 builder.build()
252}
253
254fn fzero_map_error(err: RuntimeError, fallback: &'static BuiltinErrorDescriptor) -> RuntimeError {
255 if err.identifier().is_some() {
256 err
257 } else {
258 fzero_error_with_detail(fallback, err.message())
259 }
260}
261
262fn validate_requested_outputs() -> BuiltinResult<()> {
263 if matches!(crate::output_count::current_output_count(), Some(n) if n > 4) {
264 return Err(fzero_too_many_outputs_error());
265 }
266 Ok(())
267}
268
269fn fzero_too_many_outputs_error() -> RuntimeError {
270 fzero_error_with_detail(
271 &FZERO_ERROR_TOO_MANY_OUTPUTS,
272 "fzero: too many output arguments; maximum is 4",
273 )
274}
275
276#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::optim::fzero")]
277pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
278 name: "fzero",
279 op_kind: GpuOpKind::Custom("scalar-root-find"),
280 supported_precisions: &[],
281 broadcast: BroadcastSemantics::None,
282 provider_hooks: &[],
283 constant_strategy: ConstantStrategy::InlineLiteral,
284 residency: ResidencyPolicy::GatherImmediately,
285 nan_mode: ReductionNaN::Include,
286 two_pass_threshold: None,
287 workgroup_size: None,
288 accepts_nan_mode: false,
289 notes: "Host iterative solver. Callback values may use GPU-aware builtins, but the root search runs on the CPU.",
290};
291
292#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::optim::fzero")]
293pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
294 name: "fzero",
295 shape: ShapeRequirements::Any,
296 constant_strategy: ConstantStrategy::InlineLiteral,
297 elementwise: None,
298 reduction: None,
299 emits_nan: false,
300 notes: "Root finding repeatedly invokes user code and terminates fusion planning.",
301};
302
303#[runtime_builtin(
304 name = "fzero",
305 category = "math/optim",
306 summary = "Find scalar function zeros with bracketed root-finding.",
307 keywords = "fzero,root finding,zero,brent,optimization",
308 accel = "sink",
309 type_resolver(scalar_root_type),
310 descriptor(crate::builtins::math::optim::fzero::FZERO_DESCRIPTOR),
311 builtin_path = "crate::builtins::math::optim::fzero"
312)]
313async fn fzero_builtin(function: Value, x: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
314 if rest.len() > 1 {
315 return Err(fzero_error_with_detail(
316 &FZERO_ERROR_INVALID_ARGUMENT,
317 "too many input arguments",
318 ));
319 }
320 validate_requested_outputs()?;
321 let options = parse_options(rest.first())
322 .map_err(|err| fzero_map_error(err, &FZERO_ERROR_INVALID_ARGUMENT))?;
323 let opts = FzeroOptions::from_struct(options.as_ref())
324 .map_err(|err| fzero_map_error(err, &FZERO_ERROR_INVALID_ARGUMENT))?;
325 let bracket = initial_bracket(&function, x, &opts)
326 .await
327 .map_err(|err| fzero_map_error(err, &FZERO_ERROR_INVALID_INPUT))?;
328 let mut iter_log = IterDisplay::new(opts.display);
329 let observer: Option<&mut dyn BrentZeroObserver> = if matches!(opts.display, DisplayMode::Iter)
330 {
331 Some(&mut iter_log)
332 } else {
333 None
334 };
335 let result = brent_zero(
336 NAME,
337 &function,
338 BrentZeroBracket {
339 a: bracket.a,
340 b: bracket.b,
341 fa: bracket.fa,
342 fb: bracket.fb,
343 evals: bracket.evals,
344 },
345 BrentParams {
346 tol_x: opts.tol_x,
347 max_iter: opts.max_iter,
348 max_fun_evals: opts.max_fun_evals,
349 },
350 observer,
351 )
352 .await
353 .map_err(|err| fzero_map_error(err, &FZERO_ERROR_INVALID_INPUT))?;
354 finalize(result, &opts)
355}
356
357fn parse_options(value: Option<&Value>) -> BuiltinResult<Option<StructValue>> {
358 match value {
359 None => Ok(None),
360 Some(Value::Struct(options)) => Ok(Some(options.clone())),
361 Some(other) => Err(fzero_error_with_detail(
362 &FZERO_ERROR_INVALID_ARGUMENT,
363 format!("options must be a struct, got {other:?}"),
364 )),
365 }
366}
367
368#[derive(Clone, Copy)]
369struct FzeroOptions {
370 tol_x: f64,
371 max_iter: usize,
372 max_fun_evals: usize,
373 display: DisplayMode,
374}
375
376impl FzeroOptions {
377 fn from_struct(options: Option<&StructValue>) -> BuiltinResult<Self> {
378 let display = DisplayMode::parse(&option_string(options, "Display", "off")?)?;
379 let tol_x = option_f64(NAME, options, "TolX", DEFAULT_TOL_X)?;
380 if tol_x <= 0.0 {
381 return Err(fzero_error_with_detail(
382 &FZERO_ERROR_INVALID_ARGUMENT,
383 "option TolX must be positive",
384 ));
385 }
386 let max_iter = option_usize(NAME, options, "MaxIter", DEFAULT_MAX_ITER)?;
387 let max_fun_evals = option_usize(NAME, options, "MaxFunEvals", DEFAULT_MAX_FUN_EVALS)?;
388 Ok(Self {
389 tol_x,
390 max_iter: max_iter.max(1),
391 max_fun_evals: max_fun_evals.max(1),
392 display,
393 })
394 }
395}
396
397#[derive(Debug, Clone, Copy, PartialEq, Eq)]
398enum DisplayMode {
399 Off,
400 Iter,
401 Final,
402}
403
404impl DisplayMode {
405 fn parse(text: &str) -> BuiltinResult<Self> {
406 match text.to_ascii_lowercase().as_str() {
407 "off" | "none" => Ok(Self::Off),
408 "iter" => Ok(Self::Iter),
409 "final" => Ok(Self::Final),
410 other => Err(fzero_error_with_detail(
411 &FZERO_ERROR_INVALID_ARGUMENT,
412 format!("option Display must be 'off', 'none', 'final', or 'iter', got '{other}'"),
413 )),
414 }
415 }
416}
417
418#[derive(Clone, Copy)]
419struct Bracket {
420 a: f64,
421 b: f64,
422 fa: f64,
423 fb: f64,
424 evals: usize,
425}
426
427async fn initial_bracket(
428 function: &Value,
429 x: Value,
430 options: &FzeroOptions,
431) -> BuiltinResult<Bracket> {
432 let x = crate::dispatcher::gather_if_needed_async(&x).await?;
433 match x {
434 Value::Tensor(tensor) if tensor.data.len() == 2 => {
435 let a = tensor.data[0];
436 let b = tensor.data[1];
437 bracket_from_endpoints(function, a, b).await
438 }
439 Value::Tensor(tensor) if tensor.data.len() == 1 => {
440 expand_bracket(function, tensor.data[0], options).await
441 }
442 Value::Num(n) => expand_bracket(function, n, options).await,
443 Value::Int(i) => expand_bracket(function, i.to_f64(), options).await,
444 Value::Bool(b) => expand_bracket(function, if b { 1.0 } else { 0.0 }, options).await,
445 other => Err(fzero_error_with_detail(
446 &FZERO_ERROR_INVALID_INPUT,
447 format!("initial point must be a scalar or two-element bracket, got {other:?}"),
448 )),
449 }
450}
451
452async fn bracket_from_endpoints(function: &Value, a: f64, b: f64) -> BuiltinResult<Bracket> {
453 if !a.is_finite() || !b.is_finite() || a == b {
454 return Err(fzero_error_with_detail(
455 &FZERO_ERROR_INVALID_INPUT,
456 "bracket endpoints must be finite and distinct",
457 ));
458 }
459 let fa = call_scalar_function(NAME, function, a).await?;
460 if fa == 0.0 {
461 return Ok(Bracket {
462 a,
463 b: a,
464 fa,
465 fb: fa,
466 evals: 1,
467 });
468 }
469 let fb = call_scalar_function(NAME, function, b).await?;
470 if fb == 0.0 || fa.signum() != fb.signum() {
471 Ok(Bracket {
472 a,
473 b,
474 fa,
475 fb,
476 evals: 2,
477 })
478 } else {
479 Err(fzero_error_with_detail(
480 &FZERO_ERROR_INVALID_INPUT,
481 "function values at bracket endpoints must differ in sign",
482 ))
483 }
484}
485
486async fn expand_bracket(
487 function: &Value,
488 x0: f64,
489 options: &FzeroOptions,
490) -> BuiltinResult<Bracket> {
491 if !x0.is_finite() {
492 return Err(fzero_error_with_detail(
493 &FZERO_ERROR_INVALID_INPUT,
494 "initial point must be finite",
495 ));
496 }
497 let f0 = call_scalar_function(NAME, function, x0).await?;
498 if f0 == 0.0 {
499 return Ok(Bracket {
500 a: x0,
501 b: x0,
502 fa: f0,
503 fb: f0,
504 evals: 1,
505 });
506 }
507
508 let mut evals = 1usize;
509 let mut step = (x0.abs() * 0.01).max(0.01);
510 while evals + 2 <= options.max_fun_evals {
511 let a = x0 - step;
512 let b = x0 + step;
513 let fa = call_scalar_function(NAME, function, a).await?;
514 let fb = call_scalar_function(NAME, function, b).await?;
515 evals += 2;
516 if fa == 0.0 {
517 return Ok(Bracket {
518 a,
519 b: a,
520 fa,
521 fb: fa,
522 evals,
523 });
524 }
525 if fa.signum() != f0.signum() {
526 return Ok(Bracket {
527 a,
528 b: x0,
529 fa,
530 fb: f0,
531 evals,
532 });
533 }
534 if fb.signum() != f0.signum() {
535 return Ok(Bracket {
536 a: x0,
537 b,
538 fa: f0,
539 fb,
540 evals,
541 });
542 }
543 if fb == 0.0 || fa.signum() != fb.signum() {
544 return Ok(Bracket {
545 a,
546 b,
547 fa,
548 fb,
549 evals,
550 });
551 }
552 step *= 1.6;
553 }
554
555 Err(fzero_error_with_detail(
556 &FZERO_ERROR_INVALID_INPUT,
557 "could not find a sign-changing bracket around the initial point",
558 ))
559}
560
561fn finalize(result: BrentZeroResult, options: &FzeroOptions) -> BuiltinResult<Value> {
562 let exit_flag = if result.converged { 1 } else { 0 };
563 let message = build_message(&result);
564 emit_summary(&result, exit_flag, &message, options);
565
566 let x = Value::Num(result.x);
567 let fval = Value::Num(result.fval);
568 let exitflag = Value::Num(exit_flag as f64);
569 let output_struct = Value::Struct(build_output_struct(&result, &message));
570
571 match crate::output_count::current_output_count() {
572 None => Ok(x),
573 Some(0) => Ok(Value::OutputList(Vec::new())),
574 Some(1) => Ok(crate::output_count::output_list_with_padding(1, vec![x])),
575 Some(2) => Ok(crate::output_count::output_list_with_padding(
576 2,
577 vec![x, fval],
578 )),
579 Some(3) => Ok(crate::output_count::output_list_with_padding(
580 3,
581 vec![x, fval, exitflag],
582 )),
583 Some(4) => Ok(crate::output_count::output_list_with_padding(
584 4,
585 vec![x, fval, exitflag, output_struct],
586 )),
587 Some(_) => Err(fzero_too_many_outputs_error()),
588 }
589}
590
591fn build_output_struct(result: &BrentZeroResult, message: &str) -> StructValue {
592 let mut fields = StructValue::new();
593 fields.insert("iterations", Value::Num(result.iterations as f64));
594 fields.insert("funcCount", Value::Num(result.func_count as f64));
595 fields.insert("algorithm", Value::from(ALGORITHM));
596 fields.insert("message", Value::from(message.to_string()));
597 fields
598}
599
600fn build_message(result: &BrentZeroResult) -> String {
601 if result.converged {
602 format!(
603 "Zero found within OPTIONS.TolX. Iterations: {}, FuncCount: {}.",
604 result.iterations, result.func_count
605 )
606 } else {
607 format!(
608 "Exiting: Maximum number of function evaluations or iterations has been exceeded - increase MaxFunEvals or MaxIter. Iterations: {}, FuncCount: {}.",
609 result.iterations, result.func_count
610 )
611 }
612}
613
614fn emit_summary(result: &BrentZeroResult, exit_flag: i32, message: &str, options: &FzeroOptions) {
615 if !matches!(options.display, DisplayMode::Final | DisplayMode::Iter) {
616 return;
617 }
618 crate::console::record_console_line(
619 crate::console::ConsoleStream::Stdout,
620 format!(
621 "fzero: x = {x:.6}, fval = {fval:.6}, exitflag = {exit_flag}. {message}",
622 x = result.x,
623 fval = result.fval,
624 ),
625 );
626}
627
628struct IterDisplay {
629 mode: DisplayMode,
630 printed_header: bool,
631}
632
633impl IterDisplay {
634 fn new(mode: DisplayMode) -> Self {
635 Self {
636 mode,
637 printed_header: false,
638 }
639 }
640}
641
642impl BrentZeroObserver for IterDisplay {
643 fn on_iteration(
644 &mut self,
645 iter: usize,
646 func_count: usize,
647 x: f64,
648 fx: f64,
649 step_kind: BrentZeroStepKind,
650 ) {
651 if !matches!(self.mode, DisplayMode::Iter) {
652 return;
653 }
654 if !self.printed_header {
655 crate::console::record_console_line(
656 crate::console::ConsoleStream::Stdout,
657 " Func-count x f(x) Procedure",
658 );
659 self.printed_header = true;
660 }
661 let procedure = match step_kind {
662 BrentZeroStepKind::Initial => "initial",
663 BrentZeroStepKind::Bisection => "bisection",
664 BrentZeroStepKind::Interpolation => "interpolation",
665 };
666 let line =
667 format!(" {func_count:>5} {x:13.6e} {fx:13.6e} {procedure} (iter {iter})");
668 crate::console::record_console_line(crate::console::ConsoleStream::Stdout, line);
669 }
670}
671
672#[cfg(test)]
673mod tests {
674 use super::*;
675 use crate::builtins::math::optim::brent::interpolation_step_accepted;
676 use futures::executor::block_on;
677 use runmat_builtins::Tensor;
678 use std::sync::Arc;
679
680 #[test]
681 fn fzero_bracketed_builtin_handle() {
682 let bracket = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
683 let root = block_on(fzero_builtin(
684 Value::FunctionHandle("sin".into()),
685 Value::Tensor(bracket),
686 Vec::new(),
687 ))
688 .unwrap();
689 match root {
690 Value::Num(n) => assert!((n - std::f64::consts::PI).abs() < 1.0e-6),
691 other => panic!("unexpected value {other:?}"),
692 }
693 }
694
695 #[test]
696 fn fzero_scalar_initial_guess_expands_bracket() {
697 let root = block_on(fzero_builtin(
698 Value::FunctionHandle("cos".into()),
699 Value::Num(1.0),
700 Vec::new(),
701 ))
702 .unwrap();
703 match root {
704 Value::Num(n) => assert!((n - std::f64::consts::FRAC_PI_2).abs() < 1.0e-6),
705 other => panic!("unexpected value {other:?}"),
706 }
707 }
708
709 #[test]
710 fn fzero_scalar_initial_guess_uses_center_sign_for_bracket() {
711 let root = block_on(fzero_builtin(
712 Value::FunctionHandle("sin".into()),
713 Value::Num(std::f64::consts::FRAC_PI_2),
714 Vec::new(),
715 ))
716 .unwrap();
717 match root {
718 Value::Num(n) => assert!(n.abs() < 1.0e-6),
719 other => panic!("unexpected value {other:?}"),
720 }
721 }
722 #[test]
723 fn fzero_accepts_semantic_function_handle_callback() {
724 let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
725 |function, args, requested_outputs| {
726 assert_eq!(function, 42);
727 assert_eq!(requested_outputs, 1);
728 let x = match &args[0] {
729 Value::Num(value) => *value,
730 other => panic!("expected scalar numeric argument, got {other:?}"),
731 };
732 Box::pin(async move { Ok(Value::Num(x - 2.0)) })
733 },
734 )));
735
736 let root = block_on(fzero_builtin(
737 Value::BoundFunctionHandle {
738 name: "root_function".to_string(),
739 function: 42,
740 },
741 Value::Num(0.0),
742 Vec::new(),
743 ))
744 .unwrap();
745 match root {
746 Value::Num(n) => assert!((n - 2.0).abs() < 1.0e-6),
747 other => panic!("unexpected value {other:?}"),
748 }
749 }
750
751 #[test]
752 fn fzero_multi_output_two_returns_root_and_fval() {
753 let _guard = crate::output_count::push_output_count(Some(2));
754 let bracket = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
755 let result = block_on(fzero_builtin(
756 Value::FunctionHandle("sin".into()),
757 Value::Tensor(bracket),
758 Vec::new(),
759 ))
760 .expect("fzero");
761 match result {
762 Value::OutputList(outputs) => {
763 assert_eq!(outputs.len(), 2);
764 match (&outputs[0], &outputs[1]) {
765 (Value::Num(x), Value::Num(fval)) => {
766 assert!((x - std::f64::consts::PI).abs() < 1.0e-6);
767 assert!(fval.abs() < 1.0e-6);
768 }
769 other => panic!("unexpected outputs {other:?}"),
770 }
771 }
772 other => panic!("unexpected value {other:?}"),
773 }
774 }
775
776 #[test]
777 fn fzero_multi_output_four_includes_output_struct() {
778 let _guard = crate::output_count::push_output_count(Some(4));
779 let bracket = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
780 let result = block_on(fzero_builtin(
781 Value::FunctionHandle("sin".into()),
782 Value::Tensor(bracket),
783 Vec::new(),
784 ))
785 .expect("fzero");
786 match result {
787 Value::OutputList(outputs) => {
788 assert_eq!(outputs.len(), 4);
789 assert!(matches!(&outputs[2], Value::Num(flag) if *flag == 1.0));
790 match &outputs[3] {
791 Value::Struct(output) => {
792 assert!(matches!(
793 output.fields.get("iterations"),
794 Some(Value::Num(_))
795 ));
796 assert!(matches!(
797 output.fields.get("funcCount"),
798 Some(Value::Num(_))
799 ));
800 match output.fields.get("algorithm") {
801 Some(Value::String(text)) => assert!(text.contains("bisection")),
802 other => panic!("unexpected algorithm field {other:?}"),
803 }
804 match output.fields.get("message") {
805 Some(Value::String(text)) => assert!(text.contains("Zero found")),
806 other => panic!("unexpected message field {other:?}"),
807 }
808 }
809 other => panic!("unexpected output struct {other:?}"),
810 }
811 }
812 other => panic!("unexpected value {other:?}"),
813 }
814 }
815
816 #[test]
817 fn fzero_reports_zero_exitflag_when_iteration_budget_exhausted() {
818 let mut opts = StructValue::new();
819 opts.insert("MaxIter", Value::Num(1.0));
820 opts.insert("MaxFunEvals", Value::Num(2.0));
821 opts.insert("Display", Value::from("off"));
822 let _guard = crate::output_count::push_output_count(Some(3));
823 let bracket = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
824 let result = block_on(fzero_builtin(
825 Value::FunctionHandle("sin".into()),
826 Value::Tensor(bracket),
827 vec![Value::Struct(opts)],
828 ))
829 .expect("fzero");
830 match result {
831 Value::OutputList(outputs) => match &outputs[2] {
832 Value::Num(flag) => assert_eq!(*flag, 0.0),
833 other => panic!("unexpected exitflag {other:?}"),
834 },
835 other => panic!("unexpected value {other:?}"),
836 }
837 }
838
839 #[test]
840 fn fzero_reports_convergence_when_final_step_hits_root() {
841 let mut opts = StructValue::new();
842 opts.insert("MaxIter", Value::Num(1.0));
843 opts.insert("Display", Value::from("off"));
844 let _guard = crate::output_count::push_output_count(Some(3));
845 let bracket = Tensor::new(vec![-1.0, 1.0], vec![1, 2]).unwrap();
846 let result = block_on(fzero_builtin(
847 Value::FunctionHandle("sin".into()),
848 Value::Tensor(bracket),
849 vec![Value::Struct(opts)],
850 ))
851 .expect("fzero");
852 match result {
853 Value::OutputList(outputs) => {
854 assert!(matches!(&outputs[0], Value::Num(x) if x.abs() < 1.0e-12));
855 assert!(matches!(&outputs[1], Value::Num(fval) if fval.abs() < 1.0e-12));
856 assert!(matches!(&outputs[2], Value::Num(flag) if *flag == 1.0));
857 }
858 other => panic!("unexpected value {other:?}"),
859 }
860 }
861
862 #[test]
863 fn fzero_rejects_more_than_four_outputs() {
864 let _guard = crate::output_count::push_output_count(Some(5));
865 let bracket = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
866 let err = block_on(fzero_builtin(
867 Value::FunctionHandle("sin".into()),
868 Value::Tensor(bracket),
869 Vec::new(),
870 ))
871 .expect_err("too many outputs should fail");
872 assert_eq!(err.identifier(), Some("RunMat:fzero:TooManyOutputs"));
873 assert!(err.message().contains("maximum is 4"));
874 }
875
876 #[test]
877 fn fzero_iter_display_records_iteration_rows() {
878 crate::console::reset_thread_buffer();
879 let mut opts = StructValue::new();
880 opts.insert("Display", Value::from("iter"));
881 let bracket = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
882 let result = block_on(fzero_builtin(
883 Value::FunctionHandle("sin".into()),
884 Value::Tensor(bracket),
885 vec![Value::Struct(opts)],
886 ))
887 .expect("fzero");
888 assert!(matches!(result, Value::Num(_)));
889
890 let joined = crate::console::take_thread_buffer()
891 .into_iter()
892 .map(|entry| entry.text)
893 .collect::<String>();
894 assert!(joined.contains("Func-count"), "{joined}");
895 assert!(joined.contains("initial"), "{joined}");
896 assert!(
897 joined.contains("interpolation") || joined.contains("bisection"),
898 "{joined}"
899 );
900 assert!(joined.contains("exitflag = 1"), "{joined}");
901 }
902
903 #[test]
904 fn brent_interpolation_acceptance_uses_signed_q() {
905 assert!(!interpolation_step_accepted(1.0, -2.0, 1.0, 0.1, 10.0));
906 assert!(interpolation_step_accepted(1.0, -2.0, -1.0, 0.1, 10.0));
907 }
908
909 #[test]
910 fn fzero_descriptor_signatures_cover_core_forms() {
911 let labels: Vec<&str> = FZERO_DESCRIPTOR
912 .signatures
913 .iter()
914 .map(|signature| signature.label)
915 .collect();
916 assert_eq!(
917 labels,
918 vec![
919 "x = fzero(fun, x0)",
920 "x = fzero(fun, x0, options)",
921 "[x, fval] = fzero(fun, x0)",
922 "[x, fval] = fzero(fun, x0, options)",
923 "[x, fval, exitflag] = fzero(fun, x0)",
924 "[x, fval, exitflag] = fzero(fun, x0, options)",
925 "[x, fval, exitflag, output] = fzero(fun, x0)",
926 "[x, fval, exitflag, output] = fzero(fun, x0, options)",
927 ]
928 );
929
930 let codes: Vec<&str> = FZERO_DESCRIPTOR
931 .errors
932 .iter()
933 .map(|error| error.code)
934 .collect();
935 assert_eq!(
936 codes,
937 vec![
938 "RM.FZERO.INVALID_ARGUMENT",
939 "RM.FZERO.INVALID_INPUT",
940 "RM.FZERO.TOO_MANY_OUTPUTS",
941 ]
942 );
943 }
944
945 #[test]
946 fn fzero_too_many_args_uses_stable_identifier() {
947 let err = block_on(fzero_builtin(
948 Value::FunctionHandle("sin".into()),
949 Value::Num(0.0),
950 vec![
951 Value::Struct(StructValue::new()),
952 Value::Struct(StructValue::new()),
953 ],
954 ))
955 .unwrap_err();
956 assert_eq!(err.identifier(), Some("RunMat:fzero:InvalidArgument"));
957 }
958}