1use runmat_builtins::{StructValue, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::spec::{
7 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8 ReductionNaN, ResidencyPolicy, ShapeRequirements,
9};
10use crate::builtins::math::optim::brent::{brent_zero, BrentParams, BrentZeroBracket};
11use crate::builtins::math::optim::common::{
12 call_scalar_function, optim_error, option_f64, option_string, option_usize,
13};
14use crate::builtins::math::optim::type_resolvers::scalar_root_type;
15use crate::BuiltinResult;
16
17const NAME: &str = "fzero";
18const DEFAULT_TOL_X: f64 = 1.0e-6;
19const DEFAULT_MAX_ITER: usize = 400;
20const DEFAULT_MAX_FUN_EVALS: usize = 500;
21
22#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::optim::fzero")]
23pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
24 name: "fzero",
25 op_kind: GpuOpKind::Custom("scalar-root-find"),
26 supported_precisions: &[],
27 broadcast: BroadcastSemantics::None,
28 provider_hooks: &[],
29 constant_strategy: ConstantStrategy::InlineLiteral,
30 residency: ResidencyPolicy::GatherImmediately,
31 nan_mode: ReductionNaN::Include,
32 two_pass_threshold: None,
33 workgroup_size: None,
34 accepts_nan_mode: false,
35 notes: "Host iterative solver. Callback values may use GPU-aware builtins, but the root search runs on the CPU.",
36};
37
38#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::optim::fzero")]
39pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
40 name: "fzero",
41 shape: ShapeRequirements::Any,
42 constant_strategy: ConstantStrategy::InlineLiteral,
43 elementwise: None,
44 reduction: None,
45 emits_nan: false,
46 notes: "Root finding repeatedly invokes user code and terminates fusion planning.",
47};
48
49#[runtime_builtin(
50 name = "fzero",
51 category = "math/optim",
52 summary = "Find a zero of a scalar nonlinear function using bracket expansion and Brent's method.",
53 keywords = "fzero,root finding,zero,brent,optimization",
54 accel = "sink",
55 type_resolver(scalar_root_type),
56 builtin_path = "crate::builtins::math::optim::fzero"
57)]
58async fn fzero_builtin(function: Value, x: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
59 if rest.len() > 1 {
60 return Err(optim_error(NAME, "fzero: too many input arguments"));
61 }
62 let options = parse_options(rest.first())?;
63 let opts = FzeroOptions::from_struct(options.as_ref())?;
64 let bracket = initial_bracket(&function, x, &opts).await?;
65 let root = brent_zero(
66 NAME,
67 &function,
68 BrentZeroBracket {
69 a: bracket.a,
70 b: bracket.b,
71 fa: bracket.fa,
72 fb: bracket.fb,
73 evals: bracket.evals,
74 },
75 BrentParams {
76 tol_x: opts.tol_x,
77 max_iter: opts.max_iter,
78 max_fun_evals: opts.max_fun_evals,
79 },
80 )
81 .await?;
82 Ok(Value::Num(root))
83}
84
85fn parse_options(value: Option<&Value>) -> BuiltinResult<Option<StructValue>> {
86 match value {
87 None => Ok(None),
88 Some(Value::Struct(options)) => Ok(Some(options.clone())),
89 Some(other) => Err(optim_error(
90 NAME,
91 format!("fzero: options must be a struct, got {other:?}"),
92 )),
93 }
94}
95
96#[derive(Clone, Copy)]
97struct FzeroOptions {
98 tol_x: f64,
99 max_iter: usize,
100 max_fun_evals: usize,
101}
102
103impl FzeroOptions {
104 fn from_struct(options: Option<&StructValue>) -> BuiltinResult<Self> {
105 let display = option_string(options, "Display", "off")?;
106 if !matches!(display.as_str(), "off" | "none" | "final" | "iter") {
107 return Err(optim_error(
108 NAME,
109 "fzero: option Display must be 'off', 'none', 'final', or 'iter'",
110 ));
111 }
112 let tol_x = option_f64(NAME, options, "TolX", DEFAULT_TOL_X)?;
113 if tol_x <= 0.0 {
114 return Err(optim_error(NAME, "fzero: option TolX must be positive"));
115 }
116 let max_iter = option_usize(NAME, options, "MaxIter", DEFAULT_MAX_ITER)?;
117 let max_fun_evals = option_usize(NAME, options, "MaxFunEvals", DEFAULT_MAX_FUN_EVALS)?;
118 Ok(Self {
119 tol_x,
120 max_iter: max_iter.max(1),
121 max_fun_evals: max_fun_evals.max(1),
122 })
123 }
124}
125
126#[derive(Clone, Copy)]
127struct Bracket {
128 a: f64,
129 b: f64,
130 fa: f64,
131 fb: f64,
132 evals: usize,
133}
134
135async fn initial_bracket(
136 function: &Value,
137 x: Value,
138 options: &FzeroOptions,
139) -> BuiltinResult<Bracket> {
140 let x = crate::dispatcher::gather_if_needed_async(&x).await?;
141 match x {
142 Value::Tensor(tensor) if tensor.data.len() == 2 => {
143 let a = tensor.data[0];
144 let b = tensor.data[1];
145 bracket_from_endpoints(function, a, b).await
146 }
147 Value::Tensor(tensor) if tensor.data.len() == 1 => {
148 expand_bracket(function, tensor.data[0], options).await
149 }
150 Value::Num(n) => expand_bracket(function, n, options).await,
151 Value::Int(i) => expand_bracket(function, i.to_f64(), options).await,
152 Value::Bool(b) => expand_bracket(function, if b { 1.0 } else { 0.0 }, options).await,
153 other => Err(optim_error(
154 NAME,
155 format!("fzero: initial point must be a scalar or two-element bracket, got {other:?}"),
156 )),
157 }
158}
159
160async fn bracket_from_endpoints(function: &Value, a: f64, b: f64) -> BuiltinResult<Bracket> {
161 if !a.is_finite() || !b.is_finite() || a == b {
162 return Err(optim_error(
163 NAME,
164 "fzero: bracket endpoints must be finite and distinct",
165 ));
166 }
167 let fa = call_scalar_function(NAME, function, a).await?;
168 if fa == 0.0 {
169 return Ok(Bracket {
170 a,
171 b: a,
172 fa,
173 fb: fa,
174 evals: 1,
175 });
176 }
177 let fb = call_scalar_function(NAME, function, b).await?;
178 if fb == 0.0 || fa.signum() != fb.signum() {
179 Ok(Bracket {
180 a,
181 b,
182 fa,
183 fb,
184 evals: 2,
185 })
186 } else {
187 Err(optim_error(
188 NAME,
189 "fzero: function values at bracket endpoints must differ in sign",
190 ))
191 }
192}
193
194async fn expand_bracket(
195 function: &Value,
196 x0: f64,
197 options: &FzeroOptions,
198) -> BuiltinResult<Bracket> {
199 if !x0.is_finite() {
200 return Err(optim_error(NAME, "fzero: initial point must be finite"));
201 }
202 let f0 = call_scalar_function(NAME, function, x0).await?;
203 if f0 == 0.0 {
204 return Ok(Bracket {
205 a: x0,
206 b: x0,
207 fa: f0,
208 fb: f0,
209 evals: 1,
210 });
211 }
212
213 let mut evals = 1usize;
214 let mut step = (x0.abs() * 0.01).max(0.01);
215 while evals + 2 <= options.max_fun_evals {
216 let a = x0 - step;
217 let b = x0 + step;
218 let fa = call_scalar_function(NAME, function, a).await?;
219 let fb = call_scalar_function(NAME, function, b).await?;
220 evals += 2;
221 if fa == 0.0 {
222 return Ok(Bracket {
223 a,
224 b: a,
225 fa,
226 fb: fa,
227 evals,
228 });
229 }
230 if fa.signum() != f0.signum() {
231 return Ok(Bracket {
232 a,
233 b: x0,
234 fa,
235 fb: f0,
236 evals,
237 });
238 }
239 if fb.signum() != f0.signum() {
240 return Ok(Bracket {
241 a: x0,
242 b,
243 fa: f0,
244 fb,
245 evals,
246 });
247 }
248 if fb == 0.0 || fa.signum() != fb.signum() {
249 return Ok(Bracket {
250 a,
251 b,
252 fa,
253 fb,
254 evals,
255 });
256 }
257 step *= 1.6;
258 }
259
260 Err(optim_error(
261 NAME,
262 "fzero: could not find a sign-changing bracket around the initial point",
263 ))
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269 use futures::executor::block_on;
270 use runmat_builtins::Tensor;
271
272 #[test]
273 fn fzero_bracketed_builtin_handle() {
274 let bracket = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
275 let root = block_on(fzero_builtin(
276 Value::FunctionHandle("sin".into()),
277 Value::Tensor(bracket),
278 Vec::new(),
279 ))
280 .unwrap();
281 match root {
282 Value::Num(n) => assert!((n - std::f64::consts::PI).abs() < 1.0e-6),
283 other => panic!("unexpected value {other:?}"),
284 }
285 }
286
287 #[test]
288 fn fzero_scalar_initial_guess_expands_bracket() {
289 let root = block_on(fzero_builtin(
290 Value::FunctionHandle("cos".into()),
291 Value::Num(1.0),
292 Vec::new(),
293 ))
294 .unwrap();
295 match root {
296 Value::Num(n) => assert!((n - std::f64::consts::FRAC_PI_2).abs() < 1.0e-6),
297 other => panic!("unexpected value {other:?}"),
298 }
299 }
300
301 #[test]
302 fn fzero_scalar_initial_guess_uses_center_sign_for_bracket() {
303 let root = block_on(fzero_builtin(
304 Value::FunctionHandle("sin".into()),
305 Value::Num(std::f64::consts::FRAC_PI_2),
306 Vec::new(),
307 ))
308 .unwrap();
309 match root {
310 Value::Num(n) => assert!(n.abs() < 1.0e-6),
311 other => panic!("unexpected value {other:?}"),
312 }
313 }
314}