1use runmat_builtins::{StructValue, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::spec::{
7 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8 ReductionNaN, ResidencyPolicy, ShapeRequirements,
9};
10use crate::builtins::math::optim::common::{
11 call_scalar_function, optim_error, option_f64, option_string, option_usize,
12};
13use crate::builtins::math::optim::type_resolvers::scalar_root_type;
14use crate::BuiltinResult;
15
16const NAME: &str = "fzero";
17const DEFAULT_TOL_X: f64 = 1.0e-6;
18const DEFAULT_MAX_ITER: usize = 400;
19const DEFAULT_MAX_FUN_EVALS: usize = 500;
20
21#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::optim::fzero")]
22pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
23 name: "fzero",
24 op_kind: GpuOpKind::Custom("scalar-root-find"),
25 supported_precisions: &[],
26 broadcast: BroadcastSemantics::None,
27 provider_hooks: &[],
28 constant_strategy: ConstantStrategy::InlineLiteral,
29 residency: ResidencyPolicy::GatherImmediately,
30 nan_mode: ReductionNaN::Include,
31 two_pass_threshold: None,
32 workgroup_size: None,
33 accepts_nan_mode: false,
34 notes: "Host iterative solver. Callback values may use GPU-aware builtins, but the root search runs on the CPU.",
35};
36
37#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::optim::fzero")]
38pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
39 name: "fzero",
40 shape: ShapeRequirements::Any,
41 constant_strategy: ConstantStrategy::InlineLiteral,
42 elementwise: None,
43 reduction: None,
44 emits_nan: false,
45 notes: "Root finding repeatedly invokes user code and terminates fusion planning.",
46};
47
48#[runtime_builtin(
49 name = "fzero",
50 category = "math/optim",
51 summary = "Find a zero of a scalar nonlinear function using bracket expansion and Brent's method.",
52 keywords = "fzero,root finding,zero,brent,optimization",
53 accel = "sink",
54 type_resolver(scalar_root_type),
55 builtin_path = "crate::builtins::math::optim::fzero"
56)]
57async fn fzero_builtin(function: Value, x: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
58 if rest.len() > 1 {
59 return Err(optim_error(NAME, "fzero: too many input arguments"));
60 }
61 let options = parse_options(rest.first())?;
62 let opts = FzeroOptions::from_struct(options.as_ref())?;
63 let bracket = initial_bracket(&function, x, &opts).await?;
64 let root = brent(&function, bracket, &opts).await?;
65 Ok(Value::Num(root))
66}
67
68fn parse_options(value: Option<&Value>) -> BuiltinResult<Option<StructValue>> {
69 match value {
70 None => Ok(None),
71 Some(Value::Struct(options)) => Ok(Some(options.clone())),
72 Some(other) => Err(optim_error(
73 NAME,
74 format!("fzero: options must be a struct, got {other:?}"),
75 )),
76 }
77}
78
79#[derive(Clone, Copy)]
80struct FzeroOptions {
81 tol_x: f64,
82 max_iter: usize,
83 max_fun_evals: usize,
84}
85
86impl FzeroOptions {
87 fn from_struct(options: Option<&StructValue>) -> BuiltinResult<Self> {
88 let display = option_string(options, "Display", "off")?;
89 if !matches!(display.as_str(), "off" | "none" | "final" | "iter") {
90 return Err(optim_error(
91 NAME,
92 "fzero: option Display must be 'off', 'none', 'final', or 'iter'",
93 ));
94 }
95 let tol_x = option_f64(NAME, options, "TolX", DEFAULT_TOL_X)?;
96 if tol_x <= 0.0 {
97 return Err(optim_error(NAME, "fzero: option TolX must be positive"));
98 }
99 let max_iter = option_usize(NAME, options, "MaxIter", DEFAULT_MAX_ITER)?;
100 let max_fun_evals = option_usize(NAME, options, "MaxFunEvals", DEFAULT_MAX_FUN_EVALS)?;
101 Ok(Self {
102 tol_x,
103 max_iter: max_iter.max(1),
104 max_fun_evals: max_fun_evals.max(1),
105 })
106 }
107}
108
109#[derive(Clone, Copy)]
110struct Bracket {
111 a: f64,
112 b: f64,
113 fa: f64,
114 fb: f64,
115 evals: usize,
116}
117
118async fn initial_bracket(
119 function: &Value,
120 x: Value,
121 options: &FzeroOptions,
122) -> BuiltinResult<Bracket> {
123 let x = crate::dispatcher::gather_if_needed_async(&x).await?;
124 match x {
125 Value::Tensor(tensor) if tensor.data.len() == 2 => {
126 let a = tensor.data[0];
127 let b = tensor.data[1];
128 bracket_from_endpoints(function, a, b).await
129 }
130 Value::Tensor(tensor) if tensor.data.len() == 1 => {
131 expand_bracket(function, tensor.data[0], options).await
132 }
133 Value::Num(n) => expand_bracket(function, n, options).await,
134 Value::Int(i) => expand_bracket(function, i.to_f64(), options).await,
135 Value::Bool(b) => expand_bracket(function, if b { 1.0 } else { 0.0 }, options).await,
136 other => Err(optim_error(
137 NAME,
138 format!("fzero: initial point must be a scalar or two-element bracket, got {other:?}"),
139 )),
140 }
141}
142
143async fn bracket_from_endpoints(function: &Value, a: f64, b: f64) -> BuiltinResult<Bracket> {
144 if !a.is_finite() || !b.is_finite() || a == b {
145 return Err(optim_error(
146 NAME,
147 "fzero: bracket endpoints must be finite and distinct",
148 ));
149 }
150 let fa = call_scalar_function(NAME, function, a).await?;
151 if fa == 0.0 {
152 return Ok(Bracket {
153 a,
154 b: a,
155 fa,
156 fb: fa,
157 evals: 1,
158 });
159 }
160 let fb = call_scalar_function(NAME, function, b).await?;
161 if fb == 0.0 || fa.signum() != fb.signum() {
162 Ok(Bracket {
163 a,
164 b,
165 fa,
166 fb,
167 evals: 2,
168 })
169 } else {
170 Err(optim_error(
171 NAME,
172 "fzero: function values at bracket endpoints must differ in sign",
173 ))
174 }
175}
176
177async fn expand_bracket(
178 function: &Value,
179 x0: f64,
180 options: &FzeroOptions,
181) -> BuiltinResult<Bracket> {
182 if !x0.is_finite() {
183 return Err(optim_error(NAME, "fzero: initial point must be finite"));
184 }
185 let f0 = call_scalar_function(NAME, function, x0).await?;
186 if f0 == 0.0 {
187 return Ok(Bracket {
188 a: x0,
189 b: x0,
190 fa: f0,
191 fb: f0,
192 evals: 1,
193 });
194 }
195
196 let mut evals = 1usize;
197 let mut step = (x0.abs() * 0.01).max(0.01);
198 while evals + 2 <= options.max_fun_evals {
199 let a = x0 - step;
200 let b = x0 + step;
201 let fa = call_scalar_function(NAME, function, a).await?;
202 let fb = call_scalar_function(NAME, function, b).await?;
203 evals += 2;
204 if fa == 0.0 {
205 return Ok(Bracket {
206 a,
207 b: a,
208 fa,
209 fb: fa,
210 evals,
211 });
212 }
213 if fa.signum() != f0.signum() {
214 return Ok(Bracket {
215 a,
216 b: x0,
217 fa,
218 fb: f0,
219 evals,
220 });
221 }
222 if fb.signum() != f0.signum() {
223 return Ok(Bracket {
224 a: x0,
225 b,
226 fa: f0,
227 fb,
228 evals,
229 });
230 }
231 if fb == 0.0 || fa.signum() != fb.signum() {
232 return Ok(Bracket {
233 a,
234 b,
235 fa,
236 fb,
237 evals,
238 });
239 }
240 step *= 1.6;
241 }
242
243 Err(optim_error(
244 NAME,
245 "fzero: could not find a sign-changing bracket around the initial point",
246 ))
247}
248
249async fn brent(function: &Value, bracket: Bracket, options: &FzeroOptions) -> BuiltinResult<f64> {
250 if bracket.fa == 0.0 || bracket.a == bracket.b {
251 return Ok(bracket.a);
252 }
253 if bracket.fb == 0.0 {
254 return Ok(bracket.b);
255 }
256
257 let mut a = bracket.a;
258 let mut b = bracket.b;
259 let mut c = a;
260 let mut fa = bracket.fa;
261 let mut fb = bracket.fb;
262 let mut fc = fa;
263 let mut d = b - a;
264 let mut e = d;
265 let mut evals = bracket.evals;
266
267 for _ in 0..options.max_iter {
268 if fb.signum() == fc.signum() {
269 c = a;
270 fc = fa;
271 d = b - a;
272 e = d;
273 }
274 if fc.abs() < fb.abs() {
275 let old_b = b;
276 let old_fb = fb;
277 a = b;
278 fa = fb;
279 b = c;
280 fb = fc;
281 c = old_b;
282 fc = old_fb;
283 }
284
285 let tol = 2.0 * f64::EPSILON * b.abs() + 0.5 * options.tol_x;
286 let midpoint = 0.5 * (c - b);
287 if midpoint.abs() <= tol || fb == 0.0 {
288 return Ok(b);
289 }
290 if evals >= options.max_fun_evals {
291 return Err(optim_error(
292 NAME,
293 "fzero: exceeded maximum function evaluations",
294 ));
295 }
296
297 if e.abs() >= tol && fa.abs() > fb.abs() {
298 let s = fb / fa;
299 let (mut p, mut q) = if a == c {
300 (2.0 * midpoint * s, 1.0 - s)
301 } else {
302 let q = fa / fc;
303 let r = fb / fc;
304 (
305 s * (2.0 * midpoint * q * (q - r) - (b - a) * (r - 1.0)),
306 (q - 1.0) * (r - 1.0) * (s - 1.0),
307 )
308 };
309 if p > 0.0 {
310 q = -q;
311 }
312 p = p.abs();
313 if interpolation_step_accepted(p, q, midpoint, tol, e) {
314 e = d;
315 d = p / q;
316 } else {
317 d = midpoint;
318 e = d;
319 }
320 } else {
321 d = midpoint;
322 e = d;
323 }
324
325 a = b;
326 fa = fb;
327 b += if d.abs() > tol {
328 d
329 } else if midpoint >= 0.0 {
330 tol
331 } else {
332 -tol
333 };
334 fb = call_scalar_function(NAME, function, b).await?;
335 evals += 1;
336 }
337
338 Err(optim_error(NAME, "fzero: exceeded maximum iterations"))
339}
340
341fn interpolation_step_accepted(p: f64, q: f64, midpoint: f64, tol: f64, e: f64) -> bool {
342 let min_a = 3.0 * midpoint * q - (tol * q).abs();
343 let min_b = (e * q).abs();
344 2.0 * p < min_a.min(min_b)
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use futures::executor::block_on;
351 use runmat_builtins::Tensor;
352
353 #[test]
354 fn fzero_bracketed_builtin_handle() {
355 let bracket = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
356 let root = block_on(fzero_builtin(
357 Value::FunctionHandle("sin".into()),
358 Value::Tensor(bracket),
359 Vec::new(),
360 ))
361 .unwrap();
362 match root {
363 Value::Num(n) => assert!((n - std::f64::consts::PI).abs() < 1.0e-6),
364 other => panic!("unexpected value {other:?}"),
365 }
366 }
367
368 #[test]
369 fn fzero_scalar_initial_guess_expands_bracket() {
370 let root = block_on(fzero_builtin(
371 Value::FunctionHandle("cos".into()),
372 Value::Num(1.0),
373 Vec::new(),
374 ))
375 .unwrap();
376 match root {
377 Value::Num(n) => assert!((n - std::f64::consts::FRAC_PI_2).abs() < 1.0e-6),
378 other => panic!("unexpected value {other:?}"),
379 }
380 }
381
382 #[test]
383 fn fzero_scalar_initial_guess_uses_center_sign_for_bracket() {
384 let root = block_on(fzero_builtin(
385 Value::FunctionHandle("sin".into()),
386 Value::Num(std::f64::consts::FRAC_PI_2),
387 Vec::new(),
388 ))
389 .unwrap();
390 match root {
391 Value::Num(n) => assert!(n.abs() < 1.0e-6),
392 other => panic!("unexpected value {other:?}"),
393 }
394 }
395
396 #[test]
397 fn brent_interpolation_acceptance_uses_signed_q() {
398 assert!(!interpolation_step_accepted(1.0, -2.0, 1.0, 0.1, 10.0));
399 assert!(interpolation_step_accepted(1.0, -2.0, -1.0, 0.1, 10.0));
400 }
401}