use std::panic::{catch_unwind, AssertUnwindSafe};
use std::ptr;
use super::types::{SciResult, SciVector};
unsafe fn validate_vector(v: *const SciVector, name: &str) -> Result<&'static [f64], String> {
if v.is_null() {
return Err(format!("{}: null vector pointer", name));
}
let vec = unsafe { &*v };
if vec.data.is_null() && vec.len > 0 {
return Err(format!("{}: data pointer is null but len > 0", name));
}
if vec.len == 0 {
return Ok(&[]);
}
Ok(unsafe { std::slice::from_raw_parts(vec.data, vec.len) })
}
#[no_mangle]
pub unsafe extern "C" fn sci_mean(vec: *const SciVector, out: *mut f64) -> SciResult {
if out.is_null() {
return SciResult::err("sci_mean: out pointer is null");
}
let result = catch_unwind(AssertUnwindSafe(|| {
let data = unsafe { validate_vector(vec, "sci_mean") }?;
if data.is_empty() {
return Err("sci_mean: empty vector".to_string());
}
let sum: f64 = data.iter().sum();
Ok(sum / data.len() as f64)
}));
match result {
Ok(Ok(mean)) => {
unsafe { ptr::write(out, mean) };
SciResult::ok()
}
Ok(Err(msg)) => SciResult::err(&msg),
Err(e) => SciResult::from_panic(e),
}
}
#[no_mangle]
pub unsafe extern "C" fn sci_std(vec: *const SciVector, ddof: usize, out: *mut f64) -> SciResult {
if out.is_null() {
return SciResult::err("sci_std: out pointer is null");
}
let result = catch_unwind(AssertUnwindSafe(|| {
let data = unsafe { validate_vector(vec, "sci_std") }?;
let n = data.len();
if n == 0 {
return Err("sci_std: empty vector".to_string());
}
if n <= ddof {
return Err(format!(
"sci_std: need at least {} elements for ddof={}, got {}",
ddof + 1,
ddof,
n
));
}
let mean: f64 = data.iter().sum::<f64>() / n as f64;
let variance: f64 =
data.iter().map(|&x| (x - mean) * (x - mean)).sum::<f64>() / (n - ddof) as f64;
Ok(variance.sqrt())
}));
match result {
Ok(Ok(std_dev)) => {
unsafe { ptr::write(out, std_dev) };
SciResult::ok()
}
Ok(Err(msg)) => SciResult::err(&msg),
Err(e) => SciResult::from_panic(e),
}
}
#[no_mangle]
pub unsafe extern "C" fn sci_percentile(vec: *const SciVector, q: f64, out: *mut f64) -> SciResult {
if out.is_null() {
return SciResult::err("sci_percentile: out pointer is null");
}
let result = catch_unwind(AssertUnwindSafe(|| {
let data = unsafe { validate_vector(vec, "sci_percentile") }?;
if data.is_empty() {
return Err("sci_percentile: empty vector".to_string());
}
if !(0.0..=100.0).contains(&q) {
return Err(format!("sci_percentile: q must be in [0, 100], got {}", q));
}
if q.is_nan() {
return Err("sci_percentile: q is NaN".to_string());
}
let mut sorted: Vec<f64> = data.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n = sorted.len();
if n == 1 {
return Ok(sorted[0]);
}
let rank = (q / 100.0) * (n - 1) as f64;
let lower = rank.floor() as usize;
let upper = rank.ceil() as usize;
let frac = rank - lower as f64;
if lower == upper || upper >= n {
Ok(sorted[lower.min(n - 1)])
} else {
Ok(sorted[lower] + frac * (sorted[upper] - sorted[lower]))
}
}));
match result {
Ok(Ok(pct)) => {
unsafe { ptr::write(out, pct) };
SciResult::ok()
}
Ok(Err(msg)) => SciResult::err(&msg),
Err(e) => SciResult::from_panic(e),
}
}
#[no_mangle]
pub unsafe extern "C" fn sci_correlation(
x: *const SciVector,
y: *const SciVector,
out: *mut f64,
) -> SciResult {
if out.is_null() {
return SciResult::err("sci_correlation: out pointer is null");
}
let result = catch_unwind(AssertUnwindSafe(|| {
let x_data = unsafe { validate_vector(x, "sci_correlation (x)") }?;
let y_data = unsafe { validate_vector(y, "sci_correlation (y)") }?;
if x_data.len() != y_data.len() {
return Err(format!(
"sci_correlation: x and y must have the same length, got {} and {}",
x_data.len(),
y_data.len()
));
}
let n = x_data.len();
if n < 2 {
return Err(format!(
"sci_correlation: need at least 2 elements, got {}",
n
));
}
let x_mean: f64 = x_data.iter().sum::<f64>() / n as f64;
let y_mean: f64 = y_data.iter().sum::<f64>() / n as f64;
let mut cov = 0.0f64;
let mut var_x = 0.0f64;
let mut var_y = 0.0f64;
for i in 0..n {
let dx = x_data[i] - x_mean;
let dy = y_data[i] - y_mean;
cov += dx * dy;
var_x += dx * dx;
var_y += dy * dy;
}
let denom = (var_x * var_y).sqrt();
if denom == 0.0 {
return Err("sci_correlation: one or both vectors have zero variance".to_string());
}
Ok(cov / denom)
}));
match result {
Ok(Ok(corr)) => {
unsafe { ptr::write(out, corr) };
SciResult::ok()
}
Ok(Err(msg)) => SciResult::err(&msg),
Err(e) => SciResult::from_panic(e),
}
}
#[no_mangle]
pub unsafe extern "C" fn sci_variance(
vec: *const SciVector,
ddof: usize,
out: *mut f64,
) -> SciResult {
if out.is_null() {
return SciResult::err("sci_variance: out pointer is null");
}
let result = catch_unwind(AssertUnwindSafe(|| {
let data = unsafe { validate_vector(vec, "sci_variance") }?;
let n = data.len();
if n == 0 {
return Err("sci_variance: empty vector".to_string());
}
if n <= ddof {
return Err(format!(
"sci_variance: need at least {} elements for ddof={}, got {}",
ddof + 1,
ddof,
n
));
}
let mean: f64 = data.iter().sum::<f64>() / n as f64;
let variance: f64 =
data.iter().map(|&x| (x - mean) * (x - mean)).sum::<f64>() / (n - ddof) as f64;
Ok(variance)
}));
match result {
Ok(Ok(var)) => {
unsafe { ptr::write(out, var) };
SciResult::ok()
}
Ok(Err(msg)) => SciResult::err(&msg),
Err(e) => SciResult::from_panic(e),
}
}
#[no_mangle]
pub unsafe extern "C" fn sci_median(vec: *const SciVector, out: *mut f64) -> SciResult {
if out.is_null() {
return SciResult::err("sci_median: out pointer is null");
}
let result = catch_unwind(AssertUnwindSafe(|| {
let data = unsafe { validate_vector(vec, "sci_median") }?;
if data.is_empty() {
return Err("sci_median: empty vector".to_string());
}
let mut sorted: Vec<f64> = data.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n = sorted.len();
if n % 2 == 1 {
Ok(sorted[n / 2])
} else {
Ok((sorted[n / 2 - 1] + sorted[n / 2]) / 2.0)
}
}));
match result {
Ok(Ok(med)) => {
unsafe { ptr::write(out, med) };
SciResult::ok()
}
Ok(Err(msg)) => SciResult::err(&msg),
Err(e) => SciResult::from_panic(e),
}
}