use ndarray::{Array2, arr1, stack, Axis, ArrayBase, OwnedRepr, Dim, array};
use argmin::{
core::{observers::ObserverMode, CostFunction, Error, Executor},
solver::{goldensectionsearch::GoldenSectionSearch},
};
use crate::errors::LSQError;
use pyo3::prelude::*;
use ndarray_linalg::Inverse;
#[derive(Debug, Clone)]
#[pyclass]
pub struct Ellipsoid {
#[pyo3(get, set)]
x: f64,
#[pyo3(get, set)]
y: f64,
#[pyo3(get, set)]
major_axis: f64,
#[pyo3(get, set)]
minor_axis: f64,
#[pyo3(get, set)]
angle: f64,
}
#[pymethods]
impl Ellipsoid {
#[new]
pub fn py_new(x: f64, y: f64, major_axis: f64, minor_axis: f64, angle: f64) -> Self {
Ellipsoid {
x,
y,
major_axis,
minor_axis,
angle
}
}
}
impl Ellipsoid {
pub fn new(x: f64, y: f64, major_axis: f64, minor_axis: f64, angle: f64) -> Self {
Ellipsoid {
x,
y,
major_axis,
minor_axis,
angle
}
}
pub fn get_eigenvectors(&self) -> ([f64; 2], [f64; 2]){
let fist_eigen = [self.angle.cos(), self.angle.sin()];
let second_eigen = [-self.angle.sin(), self.angle.cos()];
(fist_eigen, second_eigen)
}
pub fn get_eigenvalues(&self) -> (f64, f64) {
let first_eigen = 1.0 / self.major_axis.powi(2);
let second_eigen = 1.0 / self.minor_axis.powi(2);
(first_eigen, second_eigen)
}
pub fn get_matrix_representation(&self) -> ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>{
let eigenvectors = self.get_eigenvectors();
let eigen_matrix = stack![Axis(0), arr1(&eigenvectors.0), arr1(&eigenvectors.1)];
let eigenvalues = self.get_eigenvalues();
let diag_matrix = Array2::from_diag(&arr1(&[eigenvalues.0, eigenvalues.1]));
let ellipse_representation = eigen_matrix.t().dot(&diag_matrix).dot(&eigen_matrix);
ellipse_representation
}
}
pub struct EllipsoidIntersection {
ellipse_a: Ellipsoid,
ellipse_b: Ellipsoid,
}
impl EllipsoidIntersection {
pub fn new(ellipse_a: Ellipsoid, ellipse_b: Ellipsoid) -> Self {
EllipsoidIntersection {
ellipse_a,
ellipse_b,
}
}
}
impl CostFunction for EllipsoidIntersection {
type Param = f64;
type Output = f64;
fn cost(&self, param: &Self::Param) -> Result<Self::Output, Error> {
let matrix_a = self.ellipse_a.get_matrix_representation();
let matrix_b = self.ellipse_b.get_matrix_representation();
let subst_centers = array![self.ellipse_b.x - self.ellipse_a.x,
self.ellipse_b.y - self.ellipse_a.y];
let inv_a = match matrix_a.inv() {
Ok(inv) => inv / (1.0 - param),
Err(e) => {
eprintln!("Error inverting matrix A: {:?}", e);
return Ok(f64::INFINITY);
}
};
let inv_b = match matrix_b.inv() {
Ok(inv) => inv / (param - 0.0),
Err(e) => {
eprint!("Error inverting matrix B: {:?}", e);
return Ok(f64::INFINITY);
}
};
let inv_a_plus_b = (inv_a + inv_b).inv().unwrap();
let k_of_s = 1.0 - subst_centers.t().dot(&inv_a_plus_b).dot(&subst_centers);
Ok(k_of_s)
}
}
#[derive(Debug, Clone)]
#[pyclass]
pub struct EllipsoidIntersectionParameters {
#[pyo3(get, set)]
pub tolerance: f64,
#[pyo3(get, set)]
pub max_iters: u64,
}
#[pymethods]
impl EllipsoidIntersectionParameters {
#[new]
pub fn py_new() -> Self {
EllipsoidIntersectionParameters::new()
}
}
impl EllipsoidIntersectionParameters {
pub fn new() -> Self {
EllipsoidIntersectionParameters {
tolerance: 1e-4,
max_iters: 1000,
}
}
pub fn with_tolerance(mut self, precision: f64) -> Self {
self.tolerance = precision;
self
}
pub fn with_max_iters(mut self, max_iters: u64) -> Self {
self.max_iters = max_iters;
self
}
}
#[pyfunction]
#[pyo3(signature = (ellipse_a, ellipse_b, parameters=None))]
pub fn check_ellipsoid_intersection(ellipse_a: Ellipsoid, ellipse_b: Ellipsoid, parameters: Option<EllipsoidIntersectionParameters>) -> Result<f64, LSQError> {
let parameters = match parameters {
Some(p) => p,
None => EllipsoidIntersectionParameters::new(),
};
let ellipse_intersection = EllipsoidIntersection::new(ellipse_a, ellipse_b);
let init_param = 0.5;
let solver = GoldenSectionSearch::new(0.0, 1.0)?.with_tolerance(parameters.tolerance)?;
let result = Executor::new(ellipse_intersection, solver)
.configure(|state| state.param(init_param).max_iters(100))
.run()
.unwrap();
Ok(result.state.get_best_cost())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_not_touching() {
let ellipse1 = Ellipsoid::new(0.0, 0.0, 2.0, 1.0, 0.0);
let ellipse2 = Ellipsoid::new(10.0, 0.0, 2.0, 1.0, 0.0);
let parameters = EllipsoidIntersectionParameters::new();
let intersection = check_ellipsoid_intersection(ellipse1, ellipse2, Some(parameters));
assert_eq!(intersection.unwrap(), -5.25);
}
#[test]
fn test_not_touching_default_params() {
let ellipse1 = Ellipsoid::new(0.0, 0.0, 2.0, 1.0, 0.0);
let ellipse2 = Ellipsoid::new(10.0, 0.0, 2.0, 1.0, 0.0);
let intersection = check_ellipsoid_intersection(ellipse1, ellipse2, None);
assert_eq!(intersection.unwrap(), -5.25);
}
#[test]
fn test_just_touching() {
let ellipse1 = Ellipsoid::new(0.0, 0.0, 2.0, 1.0, 0.0);
let ellipse2 = Ellipsoid::new(4.0, 0.0, 2.0, 1.0, 0.0);
let parameters = EllipsoidIntersectionParameters::new();
let intersection = check_ellipsoid_intersection(ellipse1, ellipse2, Some(parameters));
assert_eq!(intersection.unwrap(), 0.0);
}
#[test]
fn test_superposition() {
let ellipse1 = Ellipsoid::new(0.0, 0.0, 2.0, 1.0, 0.0);
let ellipse2 = Ellipsoid::new(2.0, 0.0, 2.0, 1.0, 0.0);
let parameters = EllipsoidIntersectionParameters::new();
let intersection = check_ellipsoid_intersection(ellipse1, ellipse2, Some(parameters));
assert_eq!(intersection.unwrap(), 0.75);
}
}