use clarabel::algebra::CscMatrix;
use clarabel::solver::{
DefaultSettings, DefaultSolver, IPSolver, NonnegativeConeT, SolverStatus, ZeroConeT,
};
use ndarray::{Array1, Array2};
pub fn solve_qr(
x_data: &Array2<f64>,
y_data: &Array1<f64>,
tau: f64,
) -> Result<Vec<f64>, String> {
let (n_obs, n_features) = x_data.dim();
if y_data.len() != n_obs {
return Err(
"Input dimensions mismatch: X and y must have the same number of observations."
.to_string(),
);
}
if !(0.0..=1.0).contains(&tau) {
return Err("Tau must be between 0 and 1.".to_string());
}
let n_vars = n_features + 2 * n_obs;
let p = CscMatrix::new(n_vars, n_vars, vec![0; n_vars + 1], vec![], vec![]);
let mut q = vec![0.0; n_vars];
for i in 0..n_obs {
q[n_features + i] = tau; q[n_features + n_obs + i] = 1.0 - tau; }
let mut b = vec![0.0; 3 * n_obs];
for i in 0..n_obs {
b[i] = y_data[i];
}
let mut a_col_ptr = vec![0];
let mut a_row_ind = vec![];
let mut a_nz_val = vec![];
for j in 0..n_features {
for i in 0..n_obs {
a_row_ind.push(i);
a_nz_val.push(x_data[[i, j]]);
}
a_col_ptr.push(a_nz_val.len());
}
for i in 0..n_obs {
a_row_ind.push(i);
a_nz_val.push(1.0);
a_row_ind.push(n_obs + i);
a_nz_val.push(-1.0);
a_col_ptr.push(a_nz_val.len());
}
for i in 0..n_obs {
a_row_ind.push(i);
a_nz_val.push(-1.0);
a_row_ind.push(2 * n_obs + i);
a_nz_val.push(-1.0);
a_col_ptr.push(a_nz_val.len());
}
let a = CscMatrix::new(3 * n_obs, n_vars, a_col_ptr, a_row_ind, a_nz_val);
let cones = vec![ZeroConeT(n_obs), NonnegativeConeT(2 * n_obs)];
let settings = DefaultSettings::default();
let mut solver = match DefaultSolver::new(&p, &q, &a, &b, &cones, settings) {
Ok(s) => s,
Err(e) => return Err(format!("Error creating solver: {:?}", e)),
};
solver.solve();
if solver.solution.status == SolverStatus::Solved {
let coeffs = solver.solution.x[0..n_features].to_vec();
Ok(coeffs)
} else {
Err(format!(
"Solver failed with status: {:?}",
solver.solution.status
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn test_solve_qr_median() {
let y = array![1.0, 2.0, 3.0, 4.0, 5.0];
let x = array![
[1.0, 1.0],
[1.0, 2.0],
[1.0, 3.0],
[1.0, 4.0],
[1.0, 5.0]
];
let tau = 0.5;
let expected_betas = vec![0.0, 1.0];
let result = solve_qr(&x, &y, tau).unwrap();
assert_eq!(result.len(), 2);
let tolerance = 1e-4; assert!((result[0] - expected_betas[0]).abs() < tolerance);
assert!((result[1] - expected_betas[1]).abs() < tolerance);
}
#[test]
fn test_solve_qr_quartile() {
let y = array![1.0, 2.0, 3.0, 4.0, 5.0];
let x = array![
[1.0, 1.0],
[1.0, 2.0],
[1.0, 3.0],
[1.0, 4.0],
[1.0, 5.0]
];
let tau = 0.25;
let expected_betas = vec![0.0, 1.0];
let result = solve_qr(&x, &y, tau).unwrap();
assert_eq!(result.len(), 2);
let tolerance = 1e-4; assert!((result[0] - expected_betas[0]).abs() < tolerance);
assert!((result[1] - expected_betas[1]).abs() < tolerance);
}
}