1use runmat_builtins::{
4 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
6};
7use runmat_builtins::{StructValue, Value};
8use runmat_macros::runtime_builtin;
9
10use crate::builtins::common::spec::{
11 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
12 ReductionNaN, ResidencyPolicy, ShapeRequirements,
13};
14use crate::builtins::math::optim::brent::{brent_zero, BrentParams, BrentZeroBracket};
15use crate::builtins::math::optim::common::{
16 call_scalar_function, option_f64, option_string, option_usize,
17};
18use crate::builtins::math::optim::type_resolvers::scalar_root_type;
19use crate::{build_runtime_error, BuiltinResult, RuntimeError};
20
21const NAME: &str = "fzero";
22const DEFAULT_TOL_X: f64 = 1.0e-6;
23const DEFAULT_MAX_ITER: usize = 400;
24const DEFAULT_MAX_FUN_EVALS: usize = 500;
25
26const FZERO_OUTPUT_ROOT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
27 name: "x",
28 ty: BuiltinParamType::NumericScalar,
29 arity: BuiltinParamArity::Required,
30 default: None,
31 description: "Estimated root location.",
32}];
33
34const FZERO_INPUTS_CORE: [BuiltinParamDescriptor; 2] = [
35 BuiltinParamDescriptor {
36 name: "fun",
37 ty: BuiltinParamType::Any,
38 arity: BuiltinParamArity::Required,
39 default: None,
40 description: "Scalar-valued callback.",
41 },
42 BuiltinParamDescriptor {
43 name: "x0",
44 ty: BuiltinParamType::Any,
45 arity: BuiltinParamArity::Required,
46 default: None,
47 description: "Initial point or two-element bracket.",
48 },
49];
50
51const FZERO_INPUTS_WITH_OPTIONS: [BuiltinParamDescriptor; 3] = [
52 BuiltinParamDescriptor {
53 name: "fun",
54 ty: BuiltinParamType::Any,
55 arity: BuiltinParamArity::Required,
56 default: None,
57 description: "Scalar-valued callback.",
58 },
59 BuiltinParamDescriptor {
60 name: "x0",
61 ty: BuiltinParamType::Any,
62 arity: BuiltinParamArity::Required,
63 default: None,
64 description: "Initial point or two-element bracket.",
65 },
66 BuiltinParamDescriptor {
67 name: "options",
68 ty: BuiltinParamType::Any,
69 arity: BuiltinParamArity::Optional,
70 default: None,
71 description: "Options struct from optimset.",
72 },
73];
74
75const FZERO_SIGNATURES: [BuiltinSignatureDescriptor; 2] = [
76 BuiltinSignatureDescriptor {
77 label: "x = fzero(fun, x0)",
78 inputs: &FZERO_INPUTS_CORE,
79 outputs: &FZERO_OUTPUT_ROOT,
80 },
81 BuiltinSignatureDescriptor {
82 label: "x = fzero(fun, x0, options)",
83 inputs: &FZERO_INPUTS_WITH_OPTIONS,
84 outputs: &FZERO_OUTPUT_ROOT,
85 },
86];
87
88const FZERO_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
89 code: "RM.FZERO.INVALID_ARGUMENT",
90 identifier: Some("RunMat:fzero:InvalidArgument"),
91 when: "Argument grammar/options struct are invalid.",
92 message: "fzero: invalid argument",
93};
94
95const FZERO_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
96 code: "RM.FZERO.INVALID_INPUT",
97 identifier: Some("RunMat:fzero:InvalidInput"),
98 when: "Callback/bracket/initial-point semantics are invalid.",
99 message: "fzero: invalid input",
100};
101
102const FZERO_ERRORS: [BuiltinErrorDescriptor; 2] =
103 [FZERO_ERROR_INVALID_ARGUMENT, FZERO_ERROR_INVALID_INPUT];
104
105pub const FZERO_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
106 signatures: &FZERO_SIGNATURES,
107 output_mode: BuiltinOutputMode::Fixed,
108 completion_policy: BuiltinCompletionPolicy::Public,
109 errors: &FZERO_ERRORS,
110};
111
112fn fzero_error_with_detail(
113 error: &'static BuiltinErrorDescriptor,
114 detail: impl AsRef<str>,
115) -> RuntimeError {
116 let detail = detail.as_ref();
117 let message = if detail.starts_with("fzero:") {
118 detail.to_string()
119 } else {
120 format!("{}: {detail}", error.message)
121 };
122 let mut builder = build_runtime_error(message).with_builtin(NAME);
123 if let Some(identifier) = error.identifier {
124 builder = builder.with_identifier(identifier);
125 }
126 builder.build()
127}
128
129fn fzero_map_error(err: RuntimeError, fallback: &'static BuiltinErrorDescriptor) -> RuntimeError {
130 if err.identifier().is_some() {
131 err
132 } else {
133 fzero_error_with_detail(fallback, err.message())
134 }
135}
136
137#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::optim::fzero")]
138pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
139 name: "fzero",
140 op_kind: GpuOpKind::Custom("scalar-root-find"),
141 supported_precisions: &[],
142 broadcast: BroadcastSemantics::None,
143 provider_hooks: &[],
144 constant_strategy: ConstantStrategy::InlineLiteral,
145 residency: ResidencyPolicy::GatherImmediately,
146 nan_mode: ReductionNaN::Include,
147 two_pass_threshold: None,
148 workgroup_size: None,
149 accepts_nan_mode: false,
150 notes: "Host iterative solver. Callback values may use GPU-aware builtins, but the root search runs on the CPU.",
151};
152
153#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::optim::fzero")]
154pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
155 name: "fzero",
156 shape: ShapeRequirements::Any,
157 constant_strategy: ConstantStrategy::InlineLiteral,
158 elementwise: None,
159 reduction: None,
160 emits_nan: false,
161 notes: "Root finding repeatedly invokes user code and terminates fusion planning.",
162};
163
164#[runtime_builtin(
165 name = "fzero",
166 category = "math/optim",
167 summary = "Find scalar function zeros with bracketed root-finding.",
168 keywords = "fzero,root finding,zero,brent,optimization",
169 accel = "sink",
170 type_resolver(scalar_root_type),
171 descriptor(crate::builtins::math::optim::fzero::FZERO_DESCRIPTOR),
172 builtin_path = "crate::builtins::math::optim::fzero"
173)]
174async fn fzero_builtin(function: Value, x: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
175 if rest.len() > 1 {
176 return Err(fzero_error_with_detail(
177 &FZERO_ERROR_INVALID_ARGUMENT,
178 "too many input arguments",
179 ));
180 }
181 let options = parse_options(rest.first())
182 .map_err(|err| fzero_map_error(err, &FZERO_ERROR_INVALID_ARGUMENT))?;
183 let opts = FzeroOptions::from_struct(options.as_ref())
184 .map_err(|err| fzero_map_error(err, &FZERO_ERROR_INVALID_ARGUMENT))?;
185 let bracket = initial_bracket(&function, x, &opts)
186 .await
187 .map_err(|err| fzero_map_error(err, &FZERO_ERROR_INVALID_INPUT))?;
188 let root = brent_zero(
189 NAME,
190 &function,
191 BrentZeroBracket {
192 a: bracket.a,
193 b: bracket.b,
194 fa: bracket.fa,
195 fb: bracket.fb,
196 evals: bracket.evals,
197 },
198 BrentParams {
199 tol_x: opts.tol_x,
200 max_iter: opts.max_iter,
201 max_fun_evals: opts.max_fun_evals,
202 },
203 )
204 .await
205 .map_err(|err| fzero_map_error(err, &FZERO_ERROR_INVALID_INPUT))?;
206 Ok(Value::Num(root))
207}
208
209fn parse_options(value: Option<&Value>) -> BuiltinResult<Option<StructValue>> {
210 match value {
211 None => Ok(None),
212 Some(Value::Struct(options)) => Ok(Some(options.clone())),
213 Some(other) => Err(fzero_error_with_detail(
214 &FZERO_ERROR_INVALID_ARGUMENT,
215 format!("options must be a struct, got {other:?}"),
216 )),
217 }
218}
219
220#[derive(Clone, Copy)]
221struct FzeroOptions {
222 tol_x: f64,
223 max_iter: usize,
224 max_fun_evals: usize,
225}
226
227impl FzeroOptions {
228 fn from_struct(options: Option<&StructValue>) -> BuiltinResult<Self> {
229 let display = option_string(options, "Display", "off")?;
230 if !matches!(display.as_str(), "off" | "none" | "final" | "iter") {
231 return Err(fzero_error_with_detail(
232 &FZERO_ERROR_INVALID_ARGUMENT,
233 "option Display must be 'off', 'none', 'final', or 'iter'",
234 ));
235 }
236 let tol_x = option_f64(NAME, options, "TolX", DEFAULT_TOL_X)?;
237 if tol_x <= 0.0 {
238 return Err(fzero_error_with_detail(
239 &FZERO_ERROR_INVALID_ARGUMENT,
240 "option TolX must be positive",
241 ));
242 }
243 let max_iter = option_usize(NAME, options, "MaxIter", DEFAULT_MAX_ITER)?;
244 let max_fun_evals = option_usize(NAME, options, "MaxFunEvals", DEFAULT_MAX_FUN_EVALS)?;
245 Ok(Self {
246 tol_x,
247 max_iter: max_iter.max(1),
248 max_fun_evals: max_fun_evals.max(1),
249 })
250 }
251}
252
253#[derive(Clone, Copy)]
254struct Bracket {
255 a: f64,
256 b: f64,
257 fa: f64,
258 fb: f64,
259 evals: usize,
260}
261
262async fn initial_bracket(
263 function: &Value,
264 x: Value,
265 options: &FzeroOptions,
266) -> BuiltinResult<Bracket> {
267 let x = crate::dispatcher::gather_if_needed_async(&x).await?;
268 match x {
269 Value::Tensor(tensor) if tensor.data.len() == 2 => {
270 let a = tensor.data[0];
271 let b = tensor.data[1];
272 bracket_from_endpoints(function, a, b).await
273 }
274 Value::Tensor(tensor) if tensor.data.len() == 1 => {
275 expand_bracket(function, tensor.data[0], options).await
276 }
277 Value::Num(n) => expand_bracket(function, n, options).await,
278 Value::Int(i) => expand_bracket(function, i.to_f64(), options).await,
279 Value::Bool(b) => expand_bracket(function, if b { 1.0 } else { 0.0 }, options).await,
280 other => Err(fzero_error_with_detail(
281 &FZERO_ERROR_INVALID_INPUT,
282 format!("initial point must be a scalar or two-element bracket, got {other:?}"),
283 )),
284 }
285}
286
287async fn bracket_from_endpoints(function: &Value, a: f64, b: f64) -> BuiltinResult<Bracket> {
288 if !a.is_finite() || !b.is_finite() || a == b {
289 return Err(fzero_error_with_detail(
290 &FZERO_ERROR_INVALID_INPUT,
291 "bracket endpoints must be finite and distinct",
292 ));
293 }
294 let fa = call_scalar_function(NAME, function, a).await?;
295 if fa == 0.0 {
296 return Ok(Bracket {
297 a,
298 b: a,
299 fa,
300 fb: fa,
301 evals: 1,
302 });
303 }
304 let fb = call_scalar_function(NAME, function, b).await?;
305 if fb == 0.0 || fa.signum() != fb.signum() {
306 Ok(Bracket {
307 a,
308 b,
309 fa,
310 fb,
311 evals: 2,
312 })
313 } else {
314 Err(fzero_error_with_detail(
315 &FZERO_ERROR_INVALID_INPUT,
316 "function values at bracket endpoints must differ in sign",
317 ))
318 }
319}
320
321async fn expand_bracket(
322 function: &Value,
323 x0: f64,
324 options: &FzeroOptions,
325) -> BuiltinResult<Bracket> {
326 if !x0.is_finite() {
327 return Err(fzero_error_with_detail(
328 &FZERO_ERROR_INVALID_INPUT,
329 "initial point must be finite",
330 ));
331 }
332 let f0 = call_scalar_function(NAME, function, x0).await?;
333 if f0 == 0.0 {
334 return Ok(Bracket {
335 a: x0,
336 b: x0,
337 fa: f0,
338 fb: f0,
339 evals: 1,
340 });
341 }
342
343 let mut evals = 1usize;
344 let mut step = (x0.abs() * 0.01).max(0.01);
345 while evals + 2 <= options.max_fun_evals {
346 let a = x0 - step;
347 let b = x0 + step;
348 let fa = call_scalar_function(NAME, function, a).await?;
349 let fb = call_scalar_function(NAME, function, b).await?;
350 evals += 2;
351 if fa == 0.0 {
352 return Ok(Bracket {
353 a,
354 b: a,
355 fa,
356 fb: fa,
357 evals,
358 });
359 }
360 if fa.signum() != f0.signum() {
361 return Ok(Bracket {
362 a,
363 b: x0,
364 fa,
365 fb: f0,
366 evals,
367 });
368 }
369 if fb.signum() != f0.signum() {
370 return Ok(Bracket {
371 a: x0,
372 b,
373 fa: f0,
374 fb,
375 evals,
376 });
377 }
378 if fb == 0.0 || fa.signum() != fb.signum() {
379 return Ok(Bracket {
380 a,
381 b,
382 fa,
383 fb,
384 evals,
385 });
386 }
387 step *= 1.6;
388 }
389
390 Err(fzero_error_with_detail(
391 &FZERO_ERROR_INVALID_INPUT,
392 "could not find a sign-changing bracket around the initial point",
393 ))
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399 use crate::builtins::math::optim::brent::interpolation_step_accepted;
400 use futures::executor::block_on;
401 use runmat_builtins::Tensor;
402 use std::sync::Arc;
403
404 #[test]
405 fn fzero_bracketed_builtin_handle() {
406 let bracket = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
407 let root = block_on(fzero_builtin(
408 Value::FunctionHandle("sin".into()),
409 Value::Tensor(bracket),
410 Vec::new(),
411 ))
412 .unwrap();
413 match root {
414 Value::Num(n) => assert!((n - std::f64::consts::PI).abs() < 1.0e-6),
415 other => panic!("unexpected value {other:?}"),
416 }
417 }
418
419 #[test]
420 fn fzero_scalar_initial_guess_expands_bracket() {
421 let root = block_on(fzero_builtin(
422 Value::FunctionHandle("cos".into()),
423 Value::Num(1.0),
424 Vec::new(),
425 ))
426 .unwrap();
427 match root {
428 Value::Num(n) => assert!((n - std::f64::consts::FRAC_PI_2).abs() < 1.0e-6),
429 other => panic!("unexpected value {other:?}"),
430 }
431 }
432
433 #[test]
434 fn fzero_scalar_initial_guess_uses_center_sign_for_bracket() {
435 let root = block_on(fzero_builtin(
436 Value::FunctionHandle("sin".into()),
437 Value::Num(std::f64::consts::FRAC_PI_2),
438 Vec::new(),
439 ))
440 .unwrap();
441 match root {
442 Value::Num(n) => assert!(n.abs() < 1.0e-6),
443 other => panic!("unexpected value {other:?}"),
444 }
445 }
446 #[test]
447 fn fzero_accepts_semantic_function_handle_callback() {
448 let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
449 |function, args, requested_outputs| {
450 assert_eq!(function, 42);
451 assert_eq!(requested_outputs, 1);
452 let x = match &args[0] {
453 Value::Num(value) => *value,
454 other => panic!("expected scalar numeric argument, got {other:?}"),
455 };
456 Box::pin(async move { Ok(Value::Num(x - 2.0)) })
457 },
458 )));
459
460 let root = block_on(fzero_builtin(
461 Value::BoundFunctionHandle {
462 name: "root_function".to_string(),
463 function: 42,
464 },
465 Value::Num(0.0),
466 Vec::new(),
467 ))
468 .unwrap();
469 match root {
470 Value::Num(n) => assert!((n - 2.0).abs() < 1.0e-6),
471 other => panic!("unexpected value {other:?}"),
472 }
473 }
474
475 #[test]
476 fn brent_interpolation_acceptance_uses_signed_q() {
477 assert!(!interpolation_step_accepted(1.0, -2.0, 1.0, 0.1, 10.0));
478 assert!(interpolation_step_accepted(1.0, -2.0, -1.0, 0.1, 10.0));
479 }
480
481 #[test]
482 fn fzero_descriptor_signatures_cover_core_forms() {
483 let labels: Vec<&str> = FZERO_DESCRIPTOR
484 .signatures
485 .iter()
486 .map(|signature| signature.label)
487 .collect();
488 assert_eq!(
489 labels,
490 vec!["x = fzero(fun, x0)", "x = fzero(fun, x0, options)"]
491 );
492
493 let codes: Vec<&str> = FZERO_DESCRIPTOR
494 .errors
495 .iter()
496 .map(|error| error.code)
497 .collect();
498 assert_eq!(
499 codes,
500 vec!["RM.FZERO.INVALID_ARGUMENT", "RM.FZERO.INVALID_INPUT"]
501 );
502 }
503
504 #[test]
505 fn fzero_too_many_args_uses_stable_identifier() {
506 let err = block_on(fzero_builtin(
507 Value::FunctionHandle("sin".into()),
508 Value::Num(0.0),
509 vec![
510 Value::Struct(StructValue::new()),
511 Value::Struct(StructValue::new()),
512 ],
513 ))
514 .unwrap_err();
515 assert_eq!(err.identifier(), Some("RunMat:fzero:InvalidArgument"));
516 }
517}