1use nalgebra::{DMatrix, DVector};
4use runmat_builtins::{
5 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
6 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
7};
8use runmat_builtins::{StructValue, Value};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::common::spec::{
12 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13 ReductionNaN, ResidencyPolicy, ShapeRequirements,
14};
15use crate::builtins::math::optim::common::{
16 call_function, initial_guess, option_f64, option_string, option_usize, value_to_real_vector,
17 vector_to_value,
18};
19use crate::builtins::math::optim::type_resolvers::nonlinear_solve_type;
20use crate::{build_runtime_error, BuiltinResult, RuntimeError};
21
22const NAME: &str = "fsolve";
23const DEFAULT_TOL_X: f64 = 1.0e-6;
24const DEFAULT_TOL_FUN: f64 = 1.0e-6;
25const DEFAULT_MAX_ITER: usize = 400;
26
27const FSOLVE_OUTPUT_X: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
28 name: "x",
29 ty: BuiltinParamType::NumericArray,
30 arity: BuiltinParamArity::Required,
31 default: None,
32 description: "Approximate solution vector/scalar.",
33}];
34
35const FSOLVE_INPUTS_CORE: [BuiltinParamDescriptor; 2] = [
36 BuiltinParamDescriptor {
37 name: "fun",
38 ty: BuiltinParamType::Any,
39 arity: BuiltinParamArity::Required,
40 default: None,
41 description: "System residual callback.",
42 },
43 BuiltinParamDescriptor {
44 name: "x0",
45 ty: BuiltinParamType::Any,
46 arity: BuiltinParamArity::Required,
47 default: None,
48 description: "Initial guess scalar/vector.",
49 },
50];
51
52const FSOLVE_INPUTS_WITH_OPTIONS: [BuiltinParamDescriptor; 3] = [
53 BuiltinParamDescriptor {
54 name: "fun",
55 ty: BuiltinParamType::Any,
56 arity: BuiltinParamArity::Required,
57 default: None,
58 description: "System residual callback.",
59 },
60 BuiltinParamDescriptor {
61 name: "x0",
62 ty: BuiltinParamType::Any,
63 arity: BuiltinParamArity::Required,
64 default: None,
65 description: "Initial guess scalar/vector.",
66 },
67 BuiltinParamDescriptor {
68 name: "options",
69 ty: BuiltinParamType::Any,
70 arity: BuiltinParamArity::Optional,
71 default: None,
72 description: "Options struct from optimset.",
73 },
74];
75
76const FSOLVE_SIGNATURES: [BuiltinSignatureDescriptor; 2] = [
77 BuiltinSignatureDescriptor {
78 label: "x = fsolve(fun, x0)",
79 inputs: &FSOLVE_INPUTS_CORE,
80 outputs: &FSOLVE_OUTPUT_X,
81 },
82 BuiltinSignatureDescriptor {
83 label: "x = fsolve(fun, x0, options)",
84 inputs: &FSOLVE_INPUTS_WITH_OPTIONS,
85 outputs: &FSOLVE_OUTPUT_X,
86 },
87];
88
89const FSOLVE_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
90 code: "RM.FSOLVE.INVALID_ARGUMENT",
91 identifier: Some("RunMat:fsolve:InvalidArgument"),
92 when: "Argument grammar/options configuration is invalid.",
93 message: "fsolve: invalid argument",
94};
95
96const FSOLVE_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
97 code: "RM.FSOLVE.INVALID_INPUT",
98 identifier: Some("RunMat:fsolve:InvalidInput"),
99 when: "Initial guess/callback/iteration semantics are invalid.",
100 message: "fsolve: invalid input",
101};
102
103const FSOLVE_ERRORS: [BuiltinErrorDescriptor; 2] =
104 [FSOLVE_ERROR_INVALID_ARGUMENT, FSOLVE_ERROR_INVALID_INPUT];
105
106pub const FSOLVE_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
107 signatures: &FSOLVE_SIGNATURES,
108 output_mode: BuiltinOutputMode::Fixed,
109 completion_policy: BuiltinCompletionPolicy::Public,
110 errors: &FSOLVE_ERRORS,
111};
112
113fn fsolve_error_with_detail(
114 error: &'static BuiltinErrorDescriptor,
115 detail: impl AsRef<str>,
116) -> RuntimeError {
117 let detail = detail.as_ref();
118 let message = if detail.starts_with("fsolve:") {
119 detail.to_string()
120 } else {
121 format!("{}: {detail}", error.message)
122 };
123 let mut builder = build_runtime_error(message).with_builtin(NAME);
124 if let Some(identifier) = error.identifier {
125 builder = builder.with_identifier(identifier);
126 }
127 builder.build()
128}
129
130fn fsolve_map_error(err: RuntimeError, fallback: &'static BuiltinErrorDescriptor) -> RuntimeError {
131 if err.identifier().is_some() {
132 err
133 } else {
134 fsolve_error_with_detail(fallback, err.message())
135 }
136}
137
138#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::optim::fsolve")]
139pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
140 name: "fsolve",
141 op_kind: GpuOpKind::Custom("nonlinear-solve"),
142 supported_precisions: &[],
143 broadcast: BroadcastSemantics::None,
144 provider_hooks: &[],
145 constant_strategy: ConstantStrategy::InlineLiteral,
146 residency: ResidencyPolicy::GatherImmediately,
147 nan_mode: ReductionNaN::Include,
148 two_pass_threshold: None,
149 workgroup_size: None,
150 accepts_nan_mode: false,
151 notes: "Host finite-difference Levenberg-Marquardt solver. Callback computations may use GPU-aware builtins.",
152};
153
154#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::optim::fsolve")]
155pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
156 name: "fsolve",
157 shape: ShapeRequirements::Any,
158 constant_strategy: ConstantStrategy::InlineLiteral,
159 elementwise: None,
160 reduction: None,
161 emits_nan: false,
162 notes: "Nonlinear solving repeatedly invokes user code and terminates fusion planning.",
163};
164
165#[runtime_builtin(
166 name = "fsolve",
167 category = "math/optim",
168 summary = "Solve nonlinear equation systems.",
169 keywords = "fsolve,nonlinear solve,root finding,levenberg-marquardt,jacobian",
170 accel = "sink",
171 type_resolver(nonlinear_solve_type),
172 descriptor(crate::builtins::math::optim::fsolve::FSOLVE_DESCRIPTOR),
173 builtin_path = "crate::builtins::math::optim::fsolve"
174)]
175async fn fsolve_builtin(function: Value, x0: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
176 if rest.len() > 1 {
177 return Err(fsolve_error_with_detail(
178 &FSOLVE_ERROR_INVALID_ARGUMENT,
179 "too many input arguments",
180 ));
181 }
182 let options = parse_options(rest.first())
183 .map_err(|err| fsolve_map_error(err, &FSOLVE_ERROR_INVALID_ARGUMENT))?;
184 let opts = FsolveOptions::from_struct(options.as_ref())
185 .map_err(|err| fsolve_map_error(err, &FSOLVE_ERROR_INVALID_ARGUMENT))?;
186 let guess = initial_guess(NAME, x0)
187 .await
188 .map_err(|err| fsolve_map_error(err, &FSOLVE_ERROR_INVALID_INPUT))?;
189 let solution = solve(&function, guess.values, &guess.shape, guess.scalar, &opts)
190 .await
191 .map_err(|err| fsolve_map_error(err, &FSOLVE_ERROR_INVALID_INPUT))?;
192 vector_to_value(NAME, solution, &guess.shape, guess.scalar)
193 .map_err(|err| fsolve_map_error(err, &FSOLVE_ERROR_INVALID_INPUT))
194}
195
196fn parse_options(value: Option<&Value>) -> BuiltinResult<Option<StructValue>> {
197 match value {
198 None => Ok(None),
199 Some(Value::Struct(options)) => Ok(Some(options.clone())),
200 Some(other) => Err(fsolve_error_with_detail(
201 &FSOLVE_ERROR_INVALID_ARGUMENT,
202 format!("options must be a struct, got {other:?}"),
203 )),
204 }
205}
206
207#[derive(Clone, Copy)]
208struct FsolveOptions {
209 tol_x: f64,
210 tol_fun: f64,
211 max_iter: usize,
212 max_fun_evals: usize,
213}
214
215impl FsolveOptions {
216 fn from_struct(options: Option<&StructValue>) -> BuiltinResult<Self> {
217 let display = option_string(options, "Display", "off")?;
218 if !matches!(display.as_str(), "off" | "none" | "final" | "iter") {
219 return Err(fsolve_error_with_detail(
220 &FSOLVE_ERROR_INVALID_ARGUMENT,
221 "option Display must be 'off', 'none', 'final', or 'iter'",
222 ));
223 }
224 let tol_x = option_f64(NAME, options, "TolX", DEFAULT_TOL_X)?;
225 let tol_fun = option_f64(NAME, options, "TolFun", DEFAULT_TOL_FUN)?;
226 if tol_x <= 0.0 || tol_fun <= 0.0 {
227 return Err(fsolve_error_with_detail(
228 &FSOLVE_ERROR_INVALID_ARGUMENT,
229 "options TolX and TolFun must be positive",
230 ));
231 }
232 let max_iter = option_usize(NAME, options, "MaxIter", DEFAULT_MAX_ITER)?.max(1);
233 let max_fun_evals = option_usize(NAME, options, "MaxFunEvals", 100 * max_iter)?.max(1);
234 Ok(Self {
235 tol_x,
236 tol_fun,
237 max_iter,
238 max_fun_evals,
239 })
240 }
241}
242
243async fn solve(
244 function: &Value,
245 mut x: Vec<f64>,
246 shape: &[usize],
247 scalar: bool,
248 options: &FsolveOptions,
249) -> BuiltinResult<Vec<f64>> {
250 let n = x.len();
251 if n == 0 {
252 return Err(fsolve_error_with_detail(
253 &FSOLVE_ERROR_INVALID_INPUT,
254 "initial guess cannot be empty",
255 ));
256 }
257
258 let mut residual = eval_residual(function, &x, shape, scalar).await?;
259 let mut evals = 1usize;
260 let mut lambda = 1.0e-3;
261
262 if residual_norm_inf(&residual) <= options.tol_fun {
263 return Ok(x);
264 }
265
266 for _ in 0..options.max_iter {
267 if evals >= options.max_fun_evals {
268 return Err(fsolve_error_with_detail(
269 &FSOLVE_ERROR_INVALID_INPUT,
270 "exceeded maximum function evaluations",
271 ));
272 }
273 let jacobian =
274 finite_difference_jacobian(function, &x, shape, scalar, &residual, &mut evals, options)
275 .await?;
276 let j = DMatrix::from_row_slice(residual.len(), n, &jacobian);
277 let f = DVector::from_column_slice(&residual);
278 let gradient = j.transpose() * &f;
279 let mut accepted = false;
280
281 for _ in 0..8 {
282 let normal = j.transpose() * &j + DMatrix::<f64>::identity(n, n) * lambda;
283 let rhs = -&gradient;
284 let Some(delta) = normal.lu().solve(&rhs) else {
285 lambda *= 10.0;
286 continue;
287 };
288 let trial = x
289 .iter()
290 .zip(delta.iter())
291 .map(|(xi, di)| xi + di)
292 .collect::<Vec<_>>();
293 let trial_residual = eval_residual(function, &trial, shape, scalar).await?;
294 evals += 1;
295
296 if norm2(&trial_residual) < norm2(&residual) {
297 let step_norm = delta
298 .iter()
299 .fold(0.0_f64, |acc, value| acc.max(value.abs()));
300 let x_norm = x.iter().fold(0.0_f64, |acc, value| acc.max(value.abs()));
301 x = trial;
302 residual = trial_residual;
303 lambda = (lambda * 0.3).max(1.0e-12);
304 accepted = true;
305 if residual_norm_inf(&residual) <= options.tol_fun
306 || step_norm <= options.tol_x * (1.0 + x_norm)
307 {
308 return Ok(x);
309 }
310 break;
311 }
312
313 lambda *= 10.0;
314 if evals >= options.max_fun_evals {
315 return Err(fsolve_error_with_detail(
316 &FSOLVE_ERROR_INVALID_INPUT,
317 "exceeded maximum function evaluations",
318 ));
319 }
320 }
321
322 if !accepted {
323 return Err(fsolve_error_with_detail(
324 &FSOLVE_ERROR_INVALID_INPUT,
325 "iteration stalled before convergence",
326 ));
327 }
328 }
329
330 Err(fsolve_error_with_detail(
331 &FSOLVE_ERROR_INVALID_INPUT,
332 "exceeded maximum iterations",
333 ))
334}
335
336async fn eval_residual(
337 function: &Value,
338 x: &[f64],
339 shape: &[usize],
340 scalar: bool,
341) -> BuiltinResult<Vec<f64>> {
342 let arg = if scalar {
343 Value::Num(x[0])
344 } else {
345 Value::Tensor(
346 runmat_builtins::Tensor::new(x.to_vec(), shape.to_vec())
347 .map_err(|e| fsolve_error_with_detail(&FSOLVE_ERROR_INVALID_INPUT, e))?,
348 )
349 };
350 let value = call_function(function, vec![arg]).await?;
351 let residual = value_to_real_vector(NAME, value).await?;
352 if residual.is_empty() {
353 Err(fsolve_error_with_detail(
354 &FSOLVE_ERROR_INVALID_INPUT,
355 "function value must not be empty",
356 ))
357 } else {
358 Ok(residual)
359 }
360}
361
362async fn finite_difference_jacobian(
363 function: &Value,
364 x: &[f64],
365 shape: &[usize],
366 scalar: bool,
367 residual: &[f64],
368 evals: &mut usize,
369 options: &FsolveOptions,
370) -> BuiltinResult<Vec<f64>> {
371 let m = residual.len();
372 let n = x.len();
373 let mut jacobian = vec![0.0; m * n];
374
375 for col in 0..n {
376 if *evals >= options.max_fun_evals {
377 return Err(fsolve_error_with_detail(
378 &FSOLVE_ERROR_INVALID_INPUT,
379 "exceeded maximum function evaluations",
380 ));
381 }
382 let mut perturbed = x.to_vec();
383 let step = f64::EPSILON.sqrt() * (x[col].abs() + 1.0);
384 perturbed[col] += step;
385 let next = eval_residual(function, &perturbed, shape, scalar).await?;
386 *evals += 1;
387 if next.len() != m {
388 return Err(fsolve_error_with_detail(
389 &FSOLVE_ERROR_INVALID_INPUT,
390 "function output size changed during finite differencing",
391 ));
392 }
393 for row in 0..m {
394 jacobian[row * n + col] = (next[row] - residual[row]) / step;
395 }
396 }
397
398 Ok(jacobian)
399}
400
401fn norm2(values: &[f64]) -> f64 {
402 values.iter().map(|value| value * value).sum::<f64>().sqrt()
403}
404
405fn residual_norm_inf(values: &[f64]) -> f64 {
406 values
407 .iter()
408 .fold(0.0_f64, |acc, value| acc.max(value.abs()))
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414 use futures::executor::block_on;
415 use runmat_builtins::Tensor;
416 use std::sync::{Arc, Mutex};
417
418 #[test]
419 fn fsolve_scalar_builtin_handle() {
420 let root = block_on(fsolve_builtin(
421 Value::FunctionHandle("sin".into()),
422 Value::Num(3.0),
423 Vec::new(),
424 ))
425 .unwrap();
426 match root {
427 Value::Num(n) => assert!((n - std::f64::consts::PI).abs() < 1.0e-5),
428 other => panic!("unexpected value {other:?}"),
429 }
430 }
431
432 #[test]
433 fn fsolve_vector_system_via_semantic_resolver() {
434 let _resolver =
435 crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|_name| {
436 Some(0)
437 })));
438 let _invoker = crate::user_functions::install_semantic_function_invoker(Some(
439 std::sync::Arc::new(|_function, args, _requested_outputs| {
440 let x = match &args[0] {
441 Value::Tensor(t) => t.data.clone(),
442 _ => panic!("expected tensor input"),
443 };
444 Box::pin(async move {
445 Ok(Value::Tensor(
446 Tensor::new(
447 vec![x[0] * x[0] + x[1] * x[1] - 4.0, x[0] * x[1] - 1.0],
448 vec![2, 1],
449 )
450 .unwrap(),
451 ))
452 })
453 }),
454 ));
455 let x0 = Tensor::new(vec![1.0, 1.0], vec![2, 1]).unwrap();
456 let root = block_on(fsolve_builtin(
457 Value::FunctionHandle("system".into()),
458 Value::Tensor(x0),
459 Vec::new(),
460 ))
461 .unwrap();
462 match root {
463 Value::Tensor(t) => {
464 assert!((t.data[0] * t.data[0] + t.data[1] * t.data[1] - 4.0).abs() < 1.0e-5);
465 assert!((t.data[0] * t.data[1] - 1.0).abs() < 1.0e-5);
466 }
467 other => panic!("unexpected value {other:?}"),
468 }
469 }
470
471 #[test]
472 fn fsolve_preserves_row_vector_shape_for_callback() {
473 let seen_shapes = Arc::new(Mutex::new(Vec::new()));
474 let seen_shapes_for_invoker = Arc::clone(&seen_shapes);
475 let _resolver =
476 crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|_name| {
477 Some(0)
478 })));
479 let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
480 move |_function, args, _requested_outputs| {
481 let (x, shape) = match &args[0] {
482 Value::Tensor(t) => (t.data.clone(), t.shape.clone()),
483 other => panic!("expected tensor input, got {other:?}"),
484 };
485 assert_eq!(shape, vec![1, 2]);
486 seen_shapes_for_invoker.lock().unwrap().push(shape.clone());
487 Box::pin(async move {
488 Ok(Value::Tensor(
489 Tensor::new(vec![x[0] - 3.0, x[1] - 4.0], shape).unwrap(),
490 ))
491 })
492 },
493 )));
494 let x0 = Tensor::new(vec![0.0, 0.0], vec![1, 2]).unwrap();
495 let root = block_on(fsolve_builtin(
496 Value::FunctionHandle("row_system".into()),
497 Value::Tensor(x0),
498 Vec::new(),
499 ))
500 .unwrap();
501 match root {
502 Value::Tensor(t) => {
503 assert_eq!(t.shape, vec![1, 2]);
504 assert!((t.data[0] - 3.0).abs() < 1.0e-5);
505 assert!((t.data[1] - 4.0).abs() < 1.0e-5);
506 }
507 other => panic!("unexpected value {other:?}"),
508 }
509 assert!(!seen_shapes.lock().unwrap().is_empty());
510 }
511
512 #[test]
513 fn fsolve_preserves_matrix_shape_for_callback() {
514 let seen_shapes = Arc::new(Mutex::new(Vec::new()));
515 let seen_shapes_for_invoker = Arc::clone(&seen_shapes);
516 let _resolver =
517 crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|_name| {
518 Some(0)
519 })));
520 let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
521 move |_function, args, _requested_outputs| {
522 let (x, shape) = match &args[0] {
523 Value::Tensor(t) => (t.data.clone(), t.shape.clone()),
524 other => panic!("expected tensor input, got {other:?}"),
525 };
526 assert_eq!(shape, vec![2, 2]);
527 seen_shapes_for_invoker.lock().unwrap().push(shape.clone());
528 Box::pin(async move {
529 Ok(Value::Tensor(
530 Tensor::new(vec![x[0] - 1.0, x[1] - 2.0, x[2] - 3.0, x[3] - 4.0], shape)
531 .unwrap(),
532 ))
533 })
534 },
535 )));
536 let x0 = Tensor::new(vec![0.0, 0.0, 0.0, 0.0], vec![2, 2]).unwrap();
537 let root = block_on(fsolve_builtin(
538 Value::FunctionHandle("matrix_system".into()),
539 Value::Tensor(x0),
540 Vec::new(),
541 ))
542 .unwrap();
543 match root {
544 Value::Tensor(t) => {
545 assert_eq!(t.shape, vec![2, 2]);
546 assert!((t.data[0] - 1.0).abs() < 1.0e-5);
547 assert!((t.data[1] - 2.0).abs() < 1.0e-5);
548 assert!((t.data[2] - 3.0).abs() < 1.0e-5);
549 assert!((t.data[3] - 4.0).abs() < 1.0e-5);
550 }
551 other => panic!("unexpected value {other:?}"),
552 }
553 assert!(!seen_shapes.lock().unwrap().is_empty());
554 }
555
556 #[test]
557 fn fsolve_accepts_semantic_function_handle_callback() {
558 let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
559 |function, args, requested_outputs| {
560 assert_eq!(function, 43);
561 assert_eq!(requested_outputs, 1);
562 let x = match &args[0] {
563 Value::Num(value) => *value,
564 other => panic!("expected scalar numeric argument, got {other:?}"),
565 };
566 Box::pin(async move { Ok(Value::Num(x - 3.0)) })
567 },
568 )));
569 let root = block_on(fsolve_builtin(
570 Value::BoundFunctionHandle {
571 name: "system_function".to_string(),
572 function: 43,
573 },
574 Value::Num(1.0),
575 Vec::new(),
576 ))
577 .unwrap();
578 match root {
579 Value::Num(n) => assert!((n - 3.0).abs() < 1.0e-5),
580 other => panic!("unexpected value {other:?}"),
581 }
582 }
583
584 #[test]
585 fn fsolve_descriptor_signatures_cover_core_forms() {
586 let labels: Vec<&str> = FSOLVE_DESCRIPTOR
587 .signatures
588 .iter()
589 .map(|signature| signature.label)
590 .collect();
591 assert_eq!(
592 labels,
593 vec!["x = fsolve(fun, x0)", "x = fsolve(fun, x0, options)"]
594 );
595
596 let codes: Vec<&str> = FSOLVE_DESCRIPTOR
597 .errors
598 .iter()
599 .map(|error| error.code)
600 .collect();
601 assert_eq!(
602 codes,
603 vec!["RM.FSOLVE.INVALID_ARGUMENT", "RM.FSOLVE.INVALID_INPUT"]
604 );
605 }
606
607 #[test]
608 fn fsolve_too_many_args_uses_stable_identifier() {
609 let err = block_on(fsolve_builtin(
610 Value::FunctionHandle("sin".into()),
611 Value::Num(1.0),
612 vec![
613 Value::Struct(StructValue::new()),
614 Value::Struct(StructValue::new()),
615 ],
616 ))
617 .unwrap_err();
618 assert_eq!(err.identifier(), Some("RunMat:fsolve:InvalidArgument"));
619 }
620}