use super::interval::Interval;
use crate::error::{CoreError, CoreResult};
#[derive(Clone, Debug)]
pub struct IntervalVector {
data: Vec<Interval<f64>>,
}
impl IntervalVector {
#[inline]
pub fn new(data: Vec<Interval<f64>>) -> Self {
Self { data }
}
pub fn from_point_slice(values: &[f64]) -> Self {
Self {
data: values.iter().copied().map(Interval::point).collect(),
}
}
#[inline]
pub fn len(&self) -> usize {
self.data.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
#[inline]
pub fn get(&self, i: usize) -> Option<&Interval<f64>> {
self.data.get(i)
}
#[inline]
pub fn get_mut(&mut self, i: usize) -> Option<&mut Interval<f64>> {
self.data.get_mut(i)
}
#[inline]
pub fn iter(&self) -> core::slice::Iter<'_, Interval<f64>> {
self.data.iter()
}
#[inline]
pub fn iter_mut(&mut self) -> core::slice::IterMut<'_, Interval<f64>> {
self.data.iter_mut()
}
pub fn add(&self, rhs: &Self) -> CoreResult<Self> {
if self.len() != rhs.len() {
return Err(CoreError::InvalidInput(
crate::error::ErrorContext::new("IntervalVector::add: length mismatch"),
));
}
Ok(Self {
data: self
.data
.iter()
.zip(rhs.data.iter())
.map(|(&a, &b)| a + b)
.collect(),
})
}
pub fn sub(&self, rhs: &Self) -> CoreResult<Self> {
if self.len() != rhs.len() {
return Err(CoreError::InvalidInput(
crate::error::ErrorContext::new("IntervalVector::sub: length mismatch"),
));
}
Ok(Self {
data: self
.data
.iter()
.zip(rhs.data.iter())
.map(|(&a, &b)| a - b)
.collect(),
})
}
pub fn scale(&self, s: Interval<f64>) -> Self {
Self {
data: self.data.iter().map(|&x| x * s).collect(),
}
}
pub fn dot(&self, rhs: &Self) -> CoreResult<Interval<f64>> {
if self.len() != rhs.len() {
return Err(CoreError::InvalidInput(
crate::error::ErrorContext::new("IntervalVector::dot: length mismatch"),
));
}
let mut acc = Interval::point(0.0_f64);
for (&a, &b) in self.data.iter().zip(rhs.data.iter()) {
acc = acc + a * b;
}
Ok(acc)
}
pub fn norm_bound(&self) -> Interval<f64> {
let sq_hi: f64 = self.data.iter().map(|x: &Interval<f64>| x.mag().powi(2)).sum();
Interval::new(0.0, sq_hi.sqrt())
}
pub fn contained_in(&self, other: &Self) -> bool {
if self.len() != other.len() {
return false;
}
self.data
.iter()
.zip(other.data.iter())
.all(|(a, b): (&Interval<f64>, &Interval<f64>)| b.contains_interval(a))
}
#[inline]
pub fn as_slice(&self) -> &[Interval<f64>] {
&self.data
}
}
#[derive(Clone, Debug)]
pub struct IntervalMatrix {
rows: usize,
cols: usize,
data: Vec<Interval<f64>>,
}
impl IntervalMatrix {
pub fn from_flat(rows: usize, cols: usize, data: Vec<Interval<f64>>) -> CoreResult<Self> {
if data.len() != rows * cols {
return Err(CoreError::InvalidInput(
crate::error::ErrorContext::new(&format!(
"IntervalMatrix::from_flat: expected {} elements, got {}",
rows * cols,
data.len()
)),
));
}
Ok(Self { rows, cols, data })
}
pub fn from_f64_rows(rows_data: &[Vec<f64>]) -> CoreResult<Self> {
let n_rows = rows_data.len();
if n_rows == 0 {
return Ok(Self {
rows: 0,
cols: 0,
data: Vec::new(),
});
}
let n_cols = rows_data[0].len();
for (i, row) in rows_data.iter().enumerate() {
let row: &Vec<f64> = row;
if row.len() != n_cols {
return Err(CoreError::InvalidInput(
crate::error::ErrorContext::new(&format!(
"IntervalMatrix::from_f64_rows: row {} has {} columns, expected {}",
i,
row.len(),
n_cols
)),
));
}
}
let data: Vec<Interval<f64>> = rows_data
.iter()
.flat_map(|row: &Vec<f64>| row.iter().map(|&v| Interval::point(v)))
.collect();
Ok(Self {
rows: n_rows,
cols: n_cols,
data,
})
}
#[inline]
pub fn rows(&self) -> usize {
self.rows
}
#[inline]
pub fn cols(&self) -> usize {
self.cols
}
#[inline]
pub fn get(&self, row: usize, col: usize) -> Option<&Interval<f64>> {
if row >= self.rows || col >= self.cols {
None
} else {
self.data.get(row * self.cols + col)
}
}
#[inline]
pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut Interval<f64>> {
if row >= self.rows || col >= self.cols {
None
} else {
self.data.get_mut(row * self.cols + col)
}
}
pub fn set(&mut self, row: usize, col: usize, val: Interval<f64>) -> CoreResult<()> {
if row >= self.rows || col >= self.cols {
return Err(CoreError::InvalidInput(
crate::error::ErrorContext::new("IntervalMatrix::set: index out of bounds"),
));
}
self.data[row * self.cols + col] = val;
Ok(())
}
pub fn row(&self, r: usize) -> CoreResult<IntervalVector> {
if r >= self.rows {
return Err(CoreError::InvalidInput(
crate::error::ErrorContext::new("IntervalMatrix::row: index out of bounds"),
));
}
let start = r * self.cols;
Ok(IntervalVector::new(self.data[start..start + self.cols].to_vec()))
}
pub fn mul_vec(&self, v: &IntervalVector) -> CoreResult<IntervalVector> {
if self.cols != v.len() {
return Err(CoreError::InvalidInput(
crate::error::ErrorContext::new(&format!(
"IntervalMatrix::mul_vec: matrix cols {} != vector len {}",
self.cols,
v.len()
)),
));
}
let mut result = Vec::with_capacity(self.rows);
for r in 0..self.rows {
let mut acc = Interval::point(0.0_f64);
for c in 0..self.cols {
let a = self.data[r * self.cols + c];
let b = *v.get(c).expect("index within bounds");
acc = acc + a * b;
}
result.push(acc);
}
Ok(IntervalVector::new(result))
}
pub fn mul_mat(&self, rhs: &Self) -> CoreResult<Self> {
if self.cols != rhs.rows {
return Err(CoreError::InvalidInput(
crate::error::ErrorContext::new(&format!(
"IntervalMatrix::mul_mat: self.cols {} != rhs.rows {}",
self.cols, rhs.rows
)),
));
}
let n = self.rows;
let m = rhs.cols;
let k = self.cols;
let mut out = vec![Interval::point(0.0_f64); n * m];
for i in 0..n {
for j in 0..m {
let mut acc = Interval::point(0.0_f64);
for l in 0..k {
acc = acc + self.data[i * k + l] * rhs.data[l * m + j];
}
out[i * m + j] = acc;
}
}
Self::from_flat(n, m, out)
}
pub fn transpose(&self) -> Self {
let mut data = vec![Interval::point(0.0_f64); self.rows * self.cols];
for r in 0..self.rows {
for c in 0..self.cols {
data[c * self.rows + r] = self.data[r * self.cols + c];
}
}
Self {
rows: self.cols,
cols: self.rows,
data,
}
}
}
pub fn gaussian_elimination_interval(
a: &IntervalMatrix,
b: &IntervalVector,
) -> CoreResult<IntervalVector> {
let n = a.rows();
if a.cols() != n {
return Err(CoreError::InvalidInput(
crate::error::ErrorContext::new("gaussian_elimination_interval: A must be square"),
));
}
if b.len() != n {
return Err(CoreError::InvalidInput(
crate::error::ErrorContext::new(
"gaussian_elimination_interval: b length must equal number of rows",
),
));
}
if n == 0 {
return Ok(IntervalVector::new(Vec::new()));
}
let aug_cols = n + 1;
let mut aug: Vec<Interval<f64>> = Vec::with_capacity(n * aug_cols);
for r in 0..n {
for c in 0..n {
aug.push(
*a.get(r, c)
.ok_or_else(|| CoreError::InvalidInput(
crate::error::ErrorContext::new("gaussian_elimination_interval: index error"),
))?,
);
}
aug.push(
*b.get(r)
.ok_or_else(|| CoreError::InvalidInput(
crate::error::ErrorContext::new("gaussian_elimination_interval: b index error"),
))?,
);
}
let idx = |r: usize, c: usize| -> usize { r * aug_cols + c };
for col in 0..n {
let pivot_row = {
let mut best = col;
let mut best_mag = aug[idx(col, col)].mag();
for r in (col + 1)..n {
let m = aug[idx(r, col)].mag();
if m > best_mag {
best_mag = m;
best = r;
}
}
best
};
let pivot = aug[idx(pivot_row, col)];
if pivot.mig() == 0.0 && pivot.mag() == 0.0 {
return Err(CoreError::ComputationError(
crate::error::ErrorContext::new(&format!(
"gaussian_elimination_interval: zero pivot at column {}",
col
)),
));
}
if pivot_row != col {
for c in 0..aug_cols {
aug.swap(idx(col, c), idx(pivot_row, c));
}
}
let pivot = aug[idx(col, col)];
for r in (col + 1)..n {
let factor = aug[idx(r, col)] / pivot;
aug[idx(r, col)] = Interval::point(0.0_f64); for c in (col + 1)..aug_cols {
let rhs_val = aug[idx(col, c)];
let current = aug[idx(r, c)];
aug[idx(r, c)] = current - factor * rhs_val;
}
}
}
let mut x = vec![Interval::point(0.0_f64); n];
for i in (0..n).rev() {
let mut rhs = aug[idx(i, n)]; for j in (i + 1)..n {
rhs = rhs - aug[idx(i, j)] * x[j];
}
let pivot = aug[idx(i, i)];
if pivot.mig() == 0.0 && pivot.mag() == 0.0 {
return Err(CoreError::ComputationError(
crate::error::ErrorContext::new(&format!(
"gaussian_elimination_interval: zero pivot during back-substitution at row {}",
i
)),
));
}
x[i] = rhs / pivot;
}
Ok(IntervalVector::new(x))
}
#[cfg(test)]
mod tests {
use super::*;
fn iv(lo: f64, hi: f64) -> Interval<f64> {
Interval::new(lo, hi)
}
fn ip(x: f64) -> Interval<f64> {
Interval::point(x)
}
#[test]
fn test_matvec() {
let a = IntervalMatrix::from_f64_rows(&[vec![1.0, 0.0], vec![0.0, 1.0]])
.expect("identity matrix");
let v = IntervalVector::from_point_slice(&[2.0, 3.0]);
let r = a.mul_vec(&v).expect("mul_vec");
assert!(r.get(0).expect("get").contains(2.0));
assert!(r.get(1).expect("get").contains(3.0));
}
#[test]
fn test_gaussian_identity() {
let a = IntervalMatrix::from_f64_rows(&[vec![1.0, 0.0], vec![0.0, 1.0]])
.expect("identity");
let b = IntervalVector::from_point_slice(&[3.0, 7.0]);
let x = gaussian_elimination_interval(&a, &b).expect("solve");
assert!(
x.get(0).expect("get").contains(3.0),
"x[0] = {:?}",
x.get(0)
);
assert!(
x.get(1).expect("get").contains(7.0),
"x[1] = {:?}",
x.get(1)
);
}
#[test]
fn test_gaussian_2x2() {
let a = IntervalMatrix::from_f64_rows(&[vec![2.0, 1.0], vec![1.0, 3.0]])
.expect("matrix");
let b = IntervalVector::from_point_slice(&[5.0, 10.0]);
let x = gaussian_elimination_interval(&a, &b).expect("solve");
assert!(
x.get(0).expect("x[0]").contains(1.0),
"x[0] should contain 1.0, got {:?}",
x.get(0)
);
assert!(
x.get(1).expect("x[1]").contains(3.0),
"x[1] should contain 3.0, got {:?}",
x.get(1)
);
}
#[test]
fn test_gaussian_with_intervals() {
let a = IntervalMatrix::from_flat(
2,
2,
vec![
iv(1.9, 2.1),
iv(0.9, 1.1),
iv(0.9, 1.1),
iv(2.9, 3.1),
],
)
.expect("matrix");
let b = IntervalVector::new(vec![iv(4.9, 5.1), iv(9.9, 10.1)]);
let x = gaussian_elimination_interval(&a, &b).expect("solve");
assert!(
x.get(0).expect("x[0]").contains(1.0),
"x[0] = {:?}",
x.get(0)
);
assert!(
x.get(1).expect("x[1]").contains(3.0),
"x[1] = {:?}",
x.get(1)
);
}
#[test]
fn test_dot_product() {
let u = IntervalVector::from_point_slice(&[1.0, 2.0, 3.0]);
let v = IntervalVector::from_point_slice(&[4.0, 5.0, 6.0]);
let d = u.dot(&v).expect("dot");
assert!(d.contains(32.0), "dot = {:?}", d);
}
#[test]
fn test_matrix_multiply() {
let a = IntervalMatrix::from_f64_rows(&[vec![1.0, 2.0], vec![3.0, 4.0]])
.expect("a");
let b = IntervalMatrix::from_f64_rows(&[vec![5.0, 6.0], vec![7.0, 8.0]])
.expect("b");
let c = a.mul_mat(&b).expect("mul_mat");
assert!(c.get(0, 0).expect("00").contains(19.0));
assert!(c.get(0, 1).expect("01").contains(22.0));
assert!(c.get(1, 0).expect("10").contains(43.0));
assert!(c.get(1, 1).expect("11").contains(50.0));
}
}