use crate::{CscMatrix, FactorStatus, Solver};
use std::panic::{catch_unwind, AssertUnwindSafe};
pub const FERAL_SUCCESS: i32 = 0;
pub const FERAL_SINGULAR: i32 = 1;
pub const FERAL_WRONG_INERTIA: i32 = 2;
pub const FERAL_FATAL: i32 = 3;
pub struct FeralSolver {
solver: Solver,
matrix: Option<CscMatrix>,
neg_evals: i32,
}
#[no_mangle]
pub extern "C" fn feral_new() -> *mut FeralSolver {
catch_unwind(|| {
Box::into_raw(Box::new(FeralSolver {
solver: Solver::new(),
matrix: None,
neg_evals: 0,
}))
})
.unwrap_or(std::ptr::null_mut())
}
#[no_mangle]
pub unsafe extern "C" fn feral_free(s: *mut FeralSolver) {
if s.is_null() {
return;
}
let _ = catch_unwind(AssertUnwindSafe(|| {
drop(Box::from_raw(s));
}));
}
#[no_mangle]
pub unsafe extern "C" fn feral_set_structure(
s: *mut FeralSolver,
n: i32,
nnz: i32,
ia: *const i32,
ja: *const i32,
) -> i32 {
catch_unwind(AssertUnwindSafe(|| {
if s.is_null() || ia.is_null() || ja.is_null() || n < 0 || nnz < 0 {
return FERAL_FATAL;
}
let s = &mut *s;
let n_usize = n as usize;
let nnz_usize = nnz as usize;
let ia_slice = std::slice::from_raw_parts(ia, n_usize + 1);
let ja_slice = std::slice::from_raw_parts(ja, nnz_usize);
let col_ptr: Vec<usize> = ia_slice.iter().map(|&x| x as usize).collect();
let row_idx: Vec<usize> = ja_slice.iter().map(|&x| x as usize).collect();
let matrix = CscMatrix {
n: n_usize,
col_ptr,
row_idx,
values: vec![0.0; nnz_usize],
};
if matrix.validate().is_err() {
return FERAL_FATAL;
}
s.matrix = Some(matrix);
s.neg_evals = 0;
FERAL_SUCCESS
}))
.unwrap_or(FERAL_FATAL)
}
#[no_mangle]
pub unsafe extern "C" fn feral_values_ptr(s: *mut FeralSolver) -> *mut f64 {
if s.is_null() {
return std::ptr::null_mut();
}
catch_unwind(AssertUnwindSafe(|| {
let s = &mut *s;
match &mut s.matrix {
Some(m) => m.values.as_mut_ptr(),
None => std::ptr::null_mut(),
}
}))
.unwrap_or(std::ptr::null_mut())
}
#[no_mangle]
pub unsafe extern "C" fn feral_factor(
s: *mut FeralSolver,
check_neg: i32,
expected_neg: i32,
) -> i32 {
catch_unwind(AssertUnwindSafe(|| {
if s.is_null() {
return FERAL_FATAL;
}
let s = &mut *s;
let matrix = match &s.matrix {
Some(m) => m.clone(),
None => return FERAL_FATAL,
};
let status = s.solver.factor(&matrix, None);
match status {
FactorStatus::Success => {
let inertia = match s.solver.inertia() {
Some(i) => i.clone(),
None => return FERAL_FATAL,
};
s.neg_evals = inertia.negative as i32;
if check_neg != 0 && s.neg_evals != expected_neg {
FERAL_WRONG_INERTIA
} else {
FERAL_SUCCESS
}
}
FactorStatus::Singular => FERAL_SINGULAR,
FactorStatus::WrongInertia { actual, .. } => {
s.neg_evals = actual.negative as i32;
FERAL_WRONG_INERTIA
}
FactorStatus::FatalError(_) => FERAL_FATAL,
}
}))
.unwrap_or(FERAL_FATAL)
}
#[no_mangle]
pub unsafe extern "C" fn feral_solve(s: *mut FeralSolver, nrhs: i32, rhs: *mut f64) -> i32 {
catch_unwind(AssertUnwindSafe(|| {
if s.is_null() || rhs.is_null() || nrhs <= 0 {
return FERAL_FATAL;
}
let s = &*s;
let n = match &s.matrix {
Some(m) => m.n,
None => return FERAL_FATAL,
};
let nrhs_usize = nrhs as usize;
let rhs_slice = std::slice::from_raw_parts_mut(rhs, n * nrhs_usize);
match s.solver.solve_many(rhs_slice, nrhs_usize) {
Ok(x) => {
rhs_slice.copy_from_slice(&x);
FERAL_SUCCESS
}
Err(_) => FERAL_FATAL,
}
}))
.unwrap_or(FERAL_FATAL)
}
#[no_mangle]
pub unsafe extern "C" fn feral_num_neg(s: *const FeralSolver) -> i32 {
if s.is_null() {
return -1;
}
catch_unwind(AssertUnwindSafe(|| {
let s = &*s;
s.neg_evals
}))
.unwrap_or(-1)
}