1use runmat_builtins::{
4 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
6};
7use runmat_builtins::{StructValue, Value};
8use runmat_macros::runtime_builtin;
9
10use crate::builtins::common::spec::{
11 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
12 ReductionNaN, ResidencyPolicy, ShapeRequirements,
13};
14use crate::builtins::math::optim::common::{
15 call_function, initial_guess, option_f64, option_string, option_usize, value_to_real_vector,
16 vector_to_value,
17};
18use crate::builtins::math::optim::least_squares::{
19 solve_least_squares, LeastSquaresBounds, LeastSquaresEvaluator, LeastSquaresOptions,
20 ResidualFuture,
21};
22use crate::builtins::math::optim::type_resolvers::nonlinear_solve_type;
23use crate::{build_runtime_error, BuiltinResult, RuntimeError};
24
25const NAME: &str = "fsolve";
26const DEFAULT_TOL_X: f64 = 1.0e-6;
27const DEFAULT_TOL_FUN: f64 = 1.0e-6;
28const DEFAULT_MAX_ITER: usize = 400;
29
30const FSOLVE_OUTPUT_X: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
31 name: "x",
32 ty: BuiltinParamType::NumericArray,
33 arity: BuiltinParamArity::Required,
34 default: None,
35 description: "Approximate solution vector/scalar.",
36}];
37
38const FSOLVE_INPUTS_CORE: [BuiltinParamDescriptor; 2] = [
39 BuiltinParamDescriptor {
40 name: "fun",
41 ty: BuiltinParamType::Any,
42 arity: BuiltinParamArity::Required,
43 default: None,
44 description: "System residual callback.",
45 },
46 BuiltinParamDescriptor {
47 name: "x0",
48 ty: BuiltinParamType::Any,
49 arity: BuiltinParamArity::Required,
50 default: None,
51 description: "Initial guess scalar/vector.",
52 },
53];
54
55const FSOLVE_INPUTS_WITH_OPTIONS: [BuiltinParamDescriptor; 3] = [
56 BuiltinParamDescriptor {
57 name: "fun",
58 ty: BuiltinParamType::Any,
59 arity: BuiltinParamArity::Required,
60 default: None,
61 description: "System residual callback.",
62 },
63 BuiltinParamDescriptor {
64 name: "x0",
65 ty: BuiltinParamType::Any,
66 arity: BuiltinParamArity::Required,
67 default: None,
68 description: "Initial guess scalar/vector.",
69 },
70 BuiltinParamDescriptor {
71 name: "options",
72 ty: BuiltinParamType::Any,
73 arity: BuiltinParamArity::Optional,
74 default: None,
75 description: "Options struct from optimset.",
76 },
77];
78
79const FSOLVE_SIGNATURES: [BuiltinSignatureDescriptor; 2] = [
80 BuiltinSignatureDescriptor {
81 label: "x = fsolve(fun, x0)",
82 inputs: &FSOLVE_INPUTS_CORE,
83 outputs: &FSOLVE_OUTPUT_X,
84 },
85 BuiltinSignatureDescriptor {
86 label: "x = fsolve(fun, x0, options)",
87 inputs: &FSOLVE_INPUTS_WITH_OPTIONS,
88 outputs: &FSOLVE_OUTPUT_X,
89 },
90];
91
92const FSOLVE_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
93 code: "RM.FSOLVE.INVALID_ARGUMENT",
94 identifier: Some("RunMat:fsolve:InvalidArgument"),
95 when: "Argument grammar/options configuration is invalid.",
96 message: "fsolve: invalid argument",
97};
98
99const FSOLVE_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
100 code: "RM.FSOLVE.INVALID_INPUT",
101 identifier: Some("RunMat:fsolve:InvalidInput"),
102 when: "Initial guess/callback/iteration semantics are invalid.",
103 message: "fsolve: invalid input",
104};
105
106const FSOLVE_ERRORS: [BuiltinErrorDescriptor; 2] =
107 [FSOLVE_ERROR_INVALID_ARGUMENT, FSOLVE_ERROR_INVALID_INPUT];
108
109pub const FSOLVE_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
110 signatures: &FSOLVE_SIGNATURES,
111 output_mode: BuiltinOutputMode::Fixed,
112 completion_policy: BuiltinCompletionPolicy::Public,
113 errors: &FSOLVE_ERRORS,
114};
115
116fn fsolve_error_with_detail(
117 error: &'static BuiltinErrorDescriptor,
118 detail: impl AsRef<str>,
119) -> RuntimeError {
120 let detail = detail.as_ref();
121 let message = if detail.starts_with("fsolve:") {
122 detail.to_string()
123 } else {
124 format!("{}: {detail}", error.message)
125 };
126 let mut builder = build_runtime_error(message).with_builtin(NAME);
127 if let Some(identifier) = error.identifier {
128 builder = builder.with_identifier(identifier);
129 }
130 builder.build()
131}
132
133fn fsolve_map_error(err: RuntimeError, fallback: &'static BuiltinErrorDescriptor) -> RuntimeError {
134 if err.identifier().is_some() {
135 err
136 } else {
137 fsolve_error_with_detail(fallback, err.message())
138 }
139}
140
141#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::optim::fsolve")]
142pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
143 name: "fsolve",
144 op_kind: GpuOpKind::Custom("nonlinear-solve"),
145 supported_precisions: &[],
146 broadcast: BroadcastSemantics::None,
147 provider_hooks: &[],
148 constant_strategy: ConstantStrategy::InlineLiteral,
149 residency: ResidencyPolicy::GatherImmediately,
150 nan_mode: ReductionNaN::Include,
151 two_pass_threshold: None,
152 workgroup_size: None,
153 accepts_nan_mode: false,
154 notes: "Host finite-difference Levenberg-Marquardt solver. Callback computations may use GPU-aware builtins.",
155};
156
157#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::optim::fsolve")]
158pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
159 name: "fsolve",
160 shape: ShapeRequirements::Any,
161 constant_strategy: ConstantStrategy::InlineLiteral,
162 elementwise: None,
163 reduction: None,
164 emits_nan: false,
165 notes: "Nonlinear solving repeatedly invokes user code and terminates fusion planning.",
166};
167
168#[runtime_builtin(
169 name = "fsolve",
170 category = "math/optim",
171 summary = "Solve nonlinear equation systems.",
172 keywords = "fsolve,nonlinear solve,root finding,levenberg-marquardt,jacobian",
173 accel = "sink",
174 type_resolver(nonlinear_solve_type),
175 descriptor(crate::builtins::math::optim::fsolve::FSOLVE_DESCRIPTOR),
176 builtin_path = "crate::builtins::math::optim::fsolve"
177)]
178async fn fsolve_builtin(function: Value, x0: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
179 if rest.len() > 1 {
180 return Err(fsolve_error_with_detail(
181 &FSOLVE_ERROR_INVALID_ARGUMENT,
182 "too many input arguments",
183 ));
184 }
185 let options = parse_options(rest.first())
186 .map_err(|err| fsolve_map_error(err, &FSOLVE_ERROR_INVALID_ARGUMENT))?;
187 let opts = FsolveOptions::from_struct(options.as_ref())
188 .map_err(|err| fsolve_map_error(err, &FSOLVE_ERROR_INVALID_ARGUMENT))?;
189 let guess = initial_guess(NAME, x0)
190 .await
191 .map_err(|err| fsolve_map_error(err, &FSOLVE_ERROR_INVALID_INPUT))?;
192 let solution = solve(&function, guess.values, &guess.shape, guess.scalar, &opts)
193 .await
194 .map_err(|err| fsolve_map_error(err, &FSOLVE_ERROR_INVALID_INPUT))?;
195 vector_to_value(NAME, solution, &guess.shape, guess.scalar)
196 .map_err(|err| fsolve_map_error(err, &FSOLVE_ERROR_INVALID_INPUT))
197}
198
199fn parse_options(value: Option<&Value>) -> BuiltinResult<Option<StructValue>> {
200 match value {
201 None => Ok(None),
202 Some(Value::Struct(options)) => Ok(Some(options.clone())),
203 Some(other) => Err(fsolve_error_with_detail(
204 &FSOLVE_ERROR_INVALID_ARGUMENT,
205 format!("options must be a struct, got {other:?}"),
206 )),
207 }
208}
209
210#[derive(Clone, Copy)]
211struct FsolveOptions {
212 tol_x: f64,
213 tol_fun: f64,
214 max_iter: usize,
215 max_fun_evals: usize,
216}
217
218impl FsolveOptions {
219 fn from_struct(options: Option<&StructValue>) -> BuiltinResult<Self> {
220 let display = option_string(options, "Display", "off")?;
221 if !matches!(display.as_str(), "off" | "none" | "final" | "iter") {
222 return Err(fsolve_error_with_detail(
223 &FSOLVE_ERROR_INVALID_ARGUMENT,
224 "option Display must be 'off', 'none', 'final', or 'iter'",
225 ));
226 }
227 let tol_x = option_f64(NAME, options, "TolX", DEFAULT_TOL_X)?;
228 let tol_fun = option_f64(NAME, options, "TolFun", DEFAULT_TOL_FUN)?;
229 if tol_x <= 0.0 || tol_fun <= 0.0 {
230 return Err(fsolve_error_with_detail(
231 &FSOLVE_ERROR_INVALID_ARGUMENT,
232 "options TolX and TolFun must be positive",
233 ));
234 }
235 let max_iter = option_usize(NAME, options, "MaxIter", DEFAULT_MAX_ITER)?.max(1);
236 let max_fun_evals = option_usize(NAME, options, "MaxFunEvals", 100 * max_iter)?.max(1);
237 Ok(Self {
238 tol_x,
239 tol_fun,
240 max_iter,
241 max_fun_evals,
242 })
243 }
244}
245
246async fn solve(
247 function: &Value,
248 x: Vec<f64>,
249 shape: &[usize],
250 scalar: bool,
251 options: &FsolveOptions,
252) -> BuiltinResult<Vec<f64>> {
253 let mut evaluator = FsolveEvaluator {
254 function,
255 shape: shape.to_vec(),
256 scalar,
257 };
258 let variable_len = x.len();
259 let result = solve_least_squares(
260 NAME,
261 &mut evaluator,
262 x,
263 &LeastSquaresBounds::unbounded(variable_len),
264 &LeastSquaresOptions {
265 tol_x: options.tol_x,
266 tol_fun: options.tol_fun,
267 max_iter: options.max_iter,
268 max_fun_evals: options.max_fun_evals,
269 final_jacobian: false,
270 },
271 )
272 .await?;
273 if result.exitflag > 0
274 && result
275 .residual
276 .iter()
277 .fold(0.0_f64, |acc, value| acc.max(value.abs()))
278 <= options.tol_fun
279 {
280 Ok(result.x)
281 } else {
282 Err(fsolve_error_with_detail(
283 &FSOLVE_ERROR_INVALID_INPUT,
284 format!(
285 "{} Final residual norm is above function tolerance.",
286 result.message
287 ),
288 ))
289 }
290}
291
292struct FsolveEvaluator<'a> {
293 function: &'a Value,
294 shape: Vec<usize>,
295 scalar: bool,
296}
297
298impl LeastSquaresEvaluator for FsolveEvaluator<'_> {
299 fn residual<'a>(&'a mut self, x: &'a [f64]) -> ResidualFuture<'a> {
300 Box::pin(async move {
301 let arg = if self.scalar {
302 Value::Num(x[0])
303 } else {
304 Value::Tensor(
305 runmat_builtins::Tensor::new(x.to_vec(), self.shape.clone())
306 .map_err(|e| fsolve_error_with_detail(&FSOLVE_ERROR_INVALID_INPUT, e))?,
307 )
308 };
309 let value = call_function(self.function, vec![arg]).await?;
310 let residual = value_to_real_vector(NAME, value).await?;
311 if residual.is_empty() {
312 Err(fsolve_error_with_detail(
313 &FSOLVE_ERROR_INVALID_INPUT,
314 "function value must not be empty",
315 ))
316 } else {
317 Ok(residual)
318 }
319 })
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use futures::executor::block_on;
327 use runmat_builtins::Tensor;
328 use std::sync::{Arc, Mutex};
329
330 #[test]
331 fn fsolve_scalar_builtin_handle() {
332 let root = block_on(fsolve_builtin(
333 Value::FunctionHandle("sin".into()),
334 Value::Num(3.0),
335 Vec::new(),
336 ))
337 .unwrap();
338 match root {
339 Value::Num(n) => assert!((n - std::f64::consts::PI).abs() < 1.0e-5),
340 other => panic!("unexpected value {other:?}"),
341 }
342 }
343
344 #[test]
345 fn fsolve_rejects_stationary_non_root_residual() {
346 let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
347 |_function, args, _requested_outputs| {
348 let x = match &args[0] {
349 Value::Num(value) => *value,
350 other => panic!("expected scalar numeric argument, got {other:?}"),
351 };
352 Box::pin(async move { Ok(Value::Num(x * x + 1.0)) })
353 },
354 )));
355 let err = block_on(fsolve_builtin(
356 Value::BoundFunctionHandle {
357 name: "no_real_root".to_string(),
358 function: 44,
359 },
360 Value::Num(0.0),
361 Vec::new(),
362 ))
363 .unwrap_err();
364 assert_eq!(err.identifier(), Some("RunMat:fsolve:InvalidInput"));
365 assert!(err.message().contains("Final residual norm"));
366 }
367
368 #[test]
369 fn fsolve_vector_system_via_semantic_resolver() {
370 let _resolver =
371 crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|_name| {
372 Some(0)
373 })));
374 let _invoker = crate::user_functions::install_semantic_function_invoker(Some(
375 std::sync::Arc::new(|_function, args, _requested_outputs| {
376 let x = match &args[0] {
377 Value::Tensor(t) => t.data.clone(),
378 _ => panic!("expected tensor input"),
379 };
380 Box::pin(async move {
381 Ok(Value::Tensor(
382 Tensor::new(
383 vec![x[0] * x[0] + x[1] * x[1] - 4.0, x[0] * x[1] - 1.0],
384 vec![2, 1],
385 )
386 .unwrap(),
387 ))
388 })
389 }),
390 ));
391 let x0 = Tensor::new(vec![1.0, 1.0], vec![2, 1]).unwrap();
392 let root = block_on(fsolve_builtin(
393 Value::FunctionHandle("system".into()),
394 Value::Tensor(x0),
395 Vec::new(),
396 ))
397 .unwrap();
398 match root {
399 Value::Tensor(t) => {
400 assert!((t.data[0] * t.data[0] + t.data[1] * t.data[1] - 4.0).abs() < 1.0e-5);
401 assert!((t.data[0] * t.data[1] - 1.0).abs() < 1.0e-5);
402 }
403 other => panic!("unexpected value {other:?}"),
404 }
405 }
406
407 #[test]
408 fn fsolve_preserves_row_vector_shape_for_callback() {
409 let seen_shapes = Arc::new(Mutex::new(Vec::new()));
410 let seen_shapes_for_invoker = Arc::clone(&seen_shapes);
411 let _resolver =
412 crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|_name| {
413 Some(0)
414 })));
415 let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
416 move |_function, args, _requested_outputs| {
417 let (x, shape) = match &args[0] {
418 Value::Tensor(t) => (t.data.clone(), t.shape.clone()),
419 other => panic!("expected tensor input, got {other:?}"),
420 };
421 assert_eq!(shape, vec![1, 2]);
422 seen_shapes_for_invoker.lock().unwrap().push(shape.clone());
423 Box::pin(async move {
424 Ok(Value::Tensor(
425 Tensor::new(vec![x[0] - 3.0, x[1] - 4.0], shape).unwrap(),
426 ))
427 })
428 },
429 )));
430 let x0 = Tensor::new(vec![0.0, 0.0], vec![1, 2]).unwrap();
431 let root = block_on(fsolve_builtin(
432 Value::FunctionHandle("row_system".into()),
433 Value::Tensor(x0),
434 Vec::new(),
435 ))
436 .unwrap();
437 match root {
438 Value::Tensor(t) => {
439 assert_eq!(t.shape, vec![1, 2]);
440 assert!((t.data[0] - 3.0).abs() < 1.0e-5);
441 assert!((t.data[1] - 4.0).abs() < 1.0e-5);
442 }
443 other => panic!("unexpected value {other:?}"),
444 }
445 assert!(!seen_shapes.lock().unwrap().is_empty());
446 }
447
448 #[test]
449 fn fsolve_preserves_matrix_shape_for_callback() {
450 let seen_shapes = Arc::new(Mutex::new(Vec::new()));
451 let seen_shapes_for_invoker = Arc::clone(&seen_shapes);
452 let _resolver =
453 crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|_name| {
454 Some(0)
455 })));
456 let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
457 move |_function, args, _requested_outputs| {
458 let (x, shape) = match &args[0] {
459 Value::Tensor(t) => (t.data.clone(), t.shape.clone()),
460 other => panic!("expected tensor input, got {other:?}"),
461 };
462 assert_eq!(shape, vec![2, 2]);
463 seen_shapes_for_invoker.lock().unwrap().push(shape.clone());
464 Box::pin(async move {
465 Ok(Value::Tensor(
466 Tensor::new(vec![x[0] - 1.0, x[1] - 2.0, x[2] - 3.0, x[3] - 4.0], shape)
467 .unwrap(),
468 ))
469 })
470 },
471 )));
472 let x0 = Tensor::new(vec![0.0, 0.0, 0.0, 0.0], vec![2, 2]).unwrap();
473 let root = block_on(fsolve_builtin(
474 Value::FunctionHandle("matrix_system".into()),
475 Value::Tensor(x0),
476 Vec::new(),
477 ))
478 .unwrap();
479 match root {
480 Value::Tensor(t) => {
481 assert_eq!(t.shape, vec![2, 2]);
482 assert!((t.data[0] - 1.0).abs() < 1.0e-5);
483 assert!((t.data[1] - 2.0).abs() < 1.0e-5);
484 assert!((t.data[2] - 3.0).abs() < 1.0e-5);
485 assert!((t.data[3] - 4.0).abs() < 1.0e-5);
486 }
487 other => panic!("unexpected value {other:?}"),
488 }
489 assert!(!seen_shapes.lock().unwrap().is_empty());
490 }
491
492 #[test]
493 fn fsolve_accepts_semantic_function_handle_callback() {
494 let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
495 |function, args, requested_outputs| {
496 assert_eq!(function, 43);
497 assert_eq!(requested_outputs, 1);
498 let x = match &args[0] {
499 Value::Num(value) => *value,
500 other => panic!("expected scalar numeric argument, got {other:?}"),
501 };
502 Box::pin(async move { Ok(Value::Num(x - 3.0)) })
503 },
504 )));
505 let root = block_on(fsolve_builtin(
506 Value::BoundFunctionHandle {
507 name: "system_function".to_string(),
508 function: 43,
509 },
510 Value::Num(1.0),
511 Vec::new(),
512 ))
513 .unwrap();
514 match root {
515 Value::Num(n) => assert!((n - 3.0).abs() < 1.0e-5),
516 other => panic!("unexpected value {other:?}"),
517 }
518 }
519
520 #[test]
521 fn fsolve_descriptor_signatures_cover_core_forms() {
522 let labels: Vec<&str> = FSOLVE_DESCRIPTOR
523 .signatures
524 .iter()
525 .map(|signature| signature.label)
526 .collect();
527 assert_eq!(
528 labels,
529 vec!["x = fsolve(fun, x0)", "x = fsolve(fun, x0, options)"]
530 );
531
532 let codes: Vec<&str> = FSOLVE_DESCRIPTOR
533 .errors
534 .iter()
535 .map(|error| error.code)
536 .collect();
537 assert_eq!(
538 codes,
539 vec!["RM.FSOLVE.INVALID_ARGUMENT", "RM.FSOLVE.INVALID_INPUT"]
540 );
541 }
542
543 #[test]
544 fn fsolve_too_many_args_uses_stable_identifier() {
545 let err = block_on(fsolve_builtin(
546 Value::FunctionHandle("sin".into()),
547 Value::Num(1.0),
548 vec![
549 Value::Struct(StructValue::new()),
550 Value::Struct(StructValue::new()),
551 ],
552 ))
553 .unwrap_err();
554 assert_eq!(err.identifier(), Some("RunMat:fsolve:InvalidArgument"));
555 }
556}