use crate::numeric::factorize::NumericParams;
use crate::scaling::ScalingStrategy;
use crate::symbolic::supernode::SupernodeParams;
use crate::{CscMatrix, FactorStatus, Solver};
use std::panic::{catch_unwind, AssertUnwindSafe};
pub const FERAL_SUCCESS: i32 = 0;
pub const FERAL_SINGULAR: i32 = 1;
pub const FERAL_WRONG_INERTIA: i32 = 2;
pub const FERAL_FATAL: i32 = 3;
pub struct FeralSolver {
solver: Solver,
matrix: Option<CscMatrix>,
neg_evals: i32,
}
#[no_mangle]
pub extern "C" fn feral_new() -> *mut FeralSolver {
catch_unwind(|| {
let cb_on = matches!(
std::env::var("FERAL_CASCADE_BREAK").as_deref(),
Ok("1") | Ok("on") | Ok("true") | Ok("yes"),
);
let mut np = NumericParams::default();
if let Ok(s) = std::env::var("FERAL_SCALING") {
match s.as_str() {
"identity" => np.scaling = ScalingStrategy::Identity,
"infnorm" => np.scaling = ScalingStrategy::InfNorm,
"mc64" => np.scaling = ScalingStrategy::Mc64Symmetric,
"auto" => np.scaling = ScalingStrategy::Auto,
_ => {} }
}
if let Ok(s) = std::env::var("FERAL_PIVTOL") {
if let Ok(v) = s.parse::<f64>() {
if v.is_finite() && v >= 0.0 {
np.bk.pivot_threshold = v;
}
}
}
if let Ok(s) = std::env::var("FERAL_STATIC_PIVOT") {
if let Ok(v) = s.parse::<f64>() {
if v.is_finite() && v > 0.0 {
np.static_pivot_threshold = Some(v);
}
}
}
let mut solver = Solver::with_params(np, SupernodeParams::default());
if cb_on {
solver = solver.with_cascade_break(0.5).with_cascade_break_eps(1e-10);
} else {
let beta = std::env::var("FERAL_AUTO_CB_BETA")
.ok()
.and_then(|s| s.parse::<f64>().ok())
.filter(|v| v.is_finite() && *v >= 0.0)
.unwrap_or(0.05);
if beta > 0.0 {
solver = solver.with_auto_cascade_break(beta);
}
}
if matches!(
std::env::var("FERAL_PARALLEL").as_deref(),
Ok("0") | Ok("off") | Ok("false") | Ok("no"),
) {
solver = solver.with_parallel(false);
}
Box::into_raw(Box::new(FeralSolver {
solver,
matrix: None,
neg_evals: 0,
}))
})
.unwrap_or(std::ptr::null_mut())
}
#[no_mangle]
pub unsafe extern "C" fn feral_free(s: *mut FeralSolver) {
if s.is_null() {
return;
}
let _ = catch_unwind(AssertUnwindSafe(|| {
drop(Box::from_raw(s));
}));
}
#[no_mangle]
pub unsafe extern "C" fn feral_set_structure(
s: *mut FeralSolver,
n: i32,
nnz: i32,
ia: *const i32,
ja: *const i32,
) -> i32 {
catch_unwind(AssertUnwindSafe(|| {
if s.is_null() || ia.is_null() || ja.is_null() || n < 0 || nnz < 0 {
return FERAL_FATAL;
}
let s = &mut *s;
let n_usize = n as usize;
let nnz_usize = nnz as usize;
let ia_slice = std::slice::from_raw_parts(ia, n_usize + 1);
let ja_slice = std::slice::from_raw_parts(ja, nnz_usize);
let col_ptr: Vec<usize> = ia_slice.iter().map(|&x| x as usize).collect();
let row_idx: Vec<usize> = ja_slice.iter().map(|&x| x as usize).collect();
let matrix = CscMatrix {
n: n_usize,
col_ptr,
row_idx,
values: vec![0.0; nnz_usize],
};
if matrix.validate().is_err() {
return FERAL_FATAL;
}
s.matrix = Some(matrix);
s.neg_evals = 0;
FERAL_SUCCESS
}))
.unwrap_or(FERAL_FATAL)
}
#[no_mangle]
pub unsafe extern "C" fn feral_values_ptr(s: *mut FeralSolver) -> *mut f64 {
if s.is_null() {
return std::ptr::null_mut();
}
catch_unwind(AssertUnwindSafe(|| {
let s = &mut *s;
match &mut s.matrix {
Some(m) => m.values.as_mut_ptr(),
None => std::ptr::null_mut(),
}
}))
.unwrap_or(std::ptr::null_mut())
}
#[no_mangle]
pub unsafe extern "C" fn feral_factor(
s: *mut FeralSolver,
check_neg: i32,
expected_neg: i32,
) -> i32 {
catch_unwind(AssertUnwindSafe(|| {
if s.is_null() {
return FERAL_FATAL;
}
let s = &mut *s;
let matrix = match &s.matrix {
Some(m) => m.clone(),
None => return FERAL_FATAL,
};
let trace_factor = matches!(
std::env::var("FERAL_FACTOR_TRACE").as_deref(),
Ok("1") | Ok("on"),
);
let solve_t0 = if trace_factor {
Some(std::time::Instant::now())
} else {
None
};
let status = s.solver.factor(&matrix, None);
if let Some(t0) = solve_t0 {
let ms = t0.elapsed().as_secs_f64() * 1e3;
let (sum_delayed, max_delayed, n_snodes) = match s.solver.factors() {
Some(f) => {
let sum: usize = f.node_factors.iter().map(|nf| nf.n_delayed_in).sum();
let max: usize = f
.node_factors
.iter()
.map(|nf| nf.n_delayed_in)
.max()
.unwrap_or(0);
(sum, max, f.node_factors.len())
}
None => (0, 0, 0),
};
eprintln!(
"[feral factor] n={} nnz={} {:.1} ms snodes={} sum_delayed={} max_delayed={} status={:?}",
matrix.n,
matrix.row_idx.len(),
ms,
n_snodes,
sum_delayed,
max_delayed,
status,
);
}
match status {
FactorStatus::Success => {
let inertia = match s.solver.inertia() {
Some(i) => i.clone(),
None => return FERAL_FATAL,
};
s.neg_evals = inertia.negative as i32;
if check_neg != 0 && s.neg_evals != expected_neg {
FERAL_WRONG_INERTIA
} else {
FERAL_SUCCESS
}
}
FactorStatus::Singular => FERAL_SINGULAR,
FactorStatus::WrongInertia { actual, .. } => {
s.neg_evals = actual.negative as i32;
FERAL_WRONG_INERTIA
}
FactorStatus::FatalError(_) => FERAL_FATAL,
}
}))
.unwrap_or(FERAL_FATAL)
}
#[no_mangle]
pub unsafe extern "C" fn feral_solve(s: *mut FeralSolver, nrhs: i32, rhs: *mut f64) -> i32 {
catch_unwind(AssertUnwindSafe(|| {
if s.is_null() || rhs.is_null() || nrhs <= 0 {
return FERAL_FATAL;
}
let s = &*s;
let matrix = match &s.matrix {
Some(m) => m,
None => return FERAL_FATAL,
};
let n = matrix.n;
let nrhs_usize = nrhs as usize;
let rhs_slice = std::slice::from_raw_parts_mut(rhs, n * nrhs_usize);
let refined = !matches!(
std::env::var("FERAL_REFINE").as_deref(),
Ok("0") | Ok("false") | Ok("off") | Ok("no"),
);
let solved = if refined {
s.solver.solve_many_refined(matrix, rhs_slice, nrhs_usize)
} else {
s.solver.solve_many(rhs_slice, nrhs_usize)
};
match solved {
Ok(x) => {
rhs_slice.copy_from_slice(&x);
FERAL_SUCCESS
}
Err(_) => FERAL_FATAL,
}
}))
.unwrap_or(FERAL_FATAL)
}
#[no_mangle]
pub unsafe extern "C" fn feral_num_neg(s: *const FeralSolver) -> i32 {
if s.is_null() {
return -1;
}
catch_unwind(AssertUnwindSafe(|| {
let s = &*s;
s.neg_evals
}))
.unwrap_or(-1)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn capi_factor_and_refined_solve() {
unsafe {
let s = feral_new();
assert!(!s.is_null());
let ia: [i32; 3] = [0, 2, 3];
let ja: [i32; 3] = [0, 1, 1];
assert_eq!(
feral_set_structure(s, 2, 3, ia.as_ptr(), ja.as_ptr()),
FERAL_SUCCESS
);
let vp = feral_values_ptr(s);
assert!(!vp.is_null());
std::ptr::copy_nonoverlapping([1.0_f64, 2.0, 1.0].as_ptr(), vp, 3);
assert_eq!(feral_factor(s, 1, 1), FERAL_SUCCESS);
assert_eq!(feral_num_neg(s), 1);
let mut rhs = [3.0_f64, 3.0];
assert_eq!(feral_solve(s, 1, rhs.as_mut_ptr()), FERAL_SUCCESS);
assert!((rhs[0] - 1.0).abs() < 1e-12, "x0 = {}", rhs[0]);
assert!((rhs[1] - 1.0).abs() < 1e-12, "x1 = {}", rhs[1]);
feral_free(s);
}
}
#[test]
fn capi_solve_unrefined_opt_out() {
let prior = std::env::var("FERAL_REFINE").ok();
unsafe {
std::env::set_var("FERAL_REFINE", "0");
let s = feral_new();
let ia: [i32; 3] = [0, 2, 3];
let ja: [i32; 3] = [0, 1, 1];
assert_eq!(
feral_set_structure(s, 2, 3, ia.as_ptr(), ja.as_ptr()),
FERAL_SUCCESS
);
let vp = feral_values_ptr(s);
std::ptr::copy_nonoverlapping([1.0_f64, 2.0, 1.0].as_ptr(), vp, 3);
assert_eq!(feral_factor(s, 0, 0), FERAL_SUCCESS);
let mut rhs = [3.0_f64, 3.0];
assert_eq!(feral_solve(s, 1, rhs.as_mut_ptr()), FERAL_SUCCESS);
assert!((rhs[0] - 1.0).abs() < 1e-12);
assert!((rhs[1] - 1.0).abs() < 1e-12);
feral_free(s);
match prior {
Some(v) => std::env::set_var("FERAL_REFINE", v),
None => std::env::remove_var("FERAL_REFINE"),
}
}
}
}