use crate::error::Result;
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::random::{Rng, RngExt};
#[derive(Debug, Clone)]
pub struct RansacConfig {
pub max_iterations: usize,
pub threshold: f64,
pub min_inliers: usize,
pub confidence: f64,
pub seed: Option<u64>,
pub refinement_iterations: usize,
}
impl Default for RansacConfig {
fn default() -> Self {
Self {
max_iterations: 1000,
threshold: 3.0,
min_inliers: 8,
confidence: 0.99,
seed: None,
refinement_iterations: 5,
}
}
}
#[derive(Debug, Clone)]
pub struct RansacResult<T> {
pub model: T,
pub inliers: Vec<usize>,
pub inlier_ratio: f64,
pub iterations: usize,
}
pub trait RansacModel: Sized + Clone {
type DataPoint: Clone;
fn estimate(samples: &[Self::DataPoint]) -> Result<Self>;
fn residual(&self, point: &Self::DataPoint) -> f64;
fn min_samples() -> usize;
fn refine(&self, selfinliers: &[Self::DataPoint]) -> Result<Self> {
Ok(self.clone())
}
}
#[allow(dead_code)]
pub fn run_ransac<M: RansacModel>(
data: &[M::DataPoint],
config: &RansacConfig,
) -> Result<RansacResult<M>> {
if data.len() < M::min_samples() {
return Err(crate::error::VisionError::InvalidParameter(format!(
"Not enough data points: {} < {}",
data.len(),
M::min_samples()
)));
}
let mut rng = scirs2_core::random::rng();
let n_points = data.len();
let min_samples = M::min_samples();
let mut best_model = None;
let mut best_inliers = Vec::new();
let mut best_inlier_count = config.min_inliers;
let mut iterations = 0;
let mut dynamic_iterations = config.max_iterations;
for iter in 0..config.max_iterations {
if iter >= dynamic_iterations {
break;
}
iterations = iter + 1;
let mut sample_indices = (0..n_points).collect::<Vec<_>>();
for i in (1..n_points).rev() {
let j = rng.random_range(0..i + 1);
sample_indices.swap(i, j);
}
let sample_indices = &sample_indices[0..min_samples];
let samples: Vec<M::DataPoint> = sample_indices
.iter()
.map(|&idx| data[idx].clone())
.collect();
let model = match M::estimate(&samples) {
Ok(model) => model,
Err(_) => continue, };
let mut inliers = Vec::new();
for (idx, point) in data.iter().enumerate() {
let error = model.residual(point);
if error < config.threshold {
inliers.push(idx);
}
}
if inliers.len() > best_inlier_count {
best_model = Some(model);
best_inliers = inliers;
best_inlier_count = best_inliers.len();
if best_inlier_count > min_samples {
let inlier_ratio = best_inlier_count as f64 / n_points as f64;
let non_outlier_prob = inlier_ratio.powf(min_samples as f64);
if non_outlier_prob > 0.0 {
let k = (1.0 - config.confidence).ln() / (1.0 - non_outlier_prob).ln();
dynamic_iterations = k.ceil() as usize;
dynamic_iterations = dynamic_iterations.min(config.max_iterations);
}
}
}
}
if best_model.is_none() {
return Err(crate::error::VisionError::OperationError(
"RANSAC failed to find a model with enough inliers".to_string(),
));
}
let mut best_model = best_model.expect("Operation failed");
if !best_inliers.is_empty() && config.refinement_iterations > 0 {
let inlier_data: Vec<M::DataPoint> =
best_inliers.iter().map(|&idx| data[idx].clone()).collect();
for _ in 0..config.refinement_iterations {
best_model = best_model.refine(&inlier_data)?;
best_inliers.clear();
for (idx, point) in data.iter().enumerate() {
let error = best_model.residual(point);
if error < config.threshold {
best_inliers.push(idx);
}
}
}
}
let inlier_ratio = best_inliers.len() as f64 / n_points as f64;
Ok(RansacResult {
model: best_model,
inliers: best_inliers,
inlier_ratio,
iterations,
})
}
#[derive(Debug, Clone)]
pub struct Homography {
pub matrix: Array2<f64>,
pub inverse: Array2<f64>,
}
impl Homography {
pub fn new(_matrixdata: &[f64; 9]) -> Self {
let matrix =
Array2::from_shape_vec((3, 3), _matrixdata.to_vec()).expect("Operation failed");
let inverse = match Self::invert_matrix(&matrix) {
Ok(inv) => inv,
Err(_) => Array2::eye(3),
};
Self { matrix, inverse }
}
pub fn identity() -> Self {
Self::new(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0])
}
pub fn transform_point(&self, x: f64, y: f64) -> (f64, f64) {
let z = self.matrix[[2, 0]] * x + self.matrix[[2, 1]] * y + self.matrix[[2, 2]];
if z.abs() < 1e-10 {
return (x, y); }
let x_new = (self.matrix[[0, 0]] * x + self.matrix[[0, 1]] * y + self.matrix[[0, 2]]) / z;
let y_new = (self.matrix[[1, 0]] * x + self.matrix[[1, 1]] * y + self.matrix[[1, 2]]) / z;
(x_new, y_new)
}
pub fn inverse_transform_point(&self, x: f64, y: f64) -> (f64, f64) {
let z = self.inverse[[2, 0]] * x + self.inverse[[2, 1]] * y + self.inverse[[2, 2]];
if z.abs() < 1e-10 {
return (x, y); }
let x_new =
(self.inverse[[0, 0]] * x + self.inverse[[0, 1]] * y + self.inverse[[0, 2]]) / z;
let y_new =
(self.inverse[[1, 0]] * x + self.inverse[[1, 1]] * y + self.inverse[[1, 2]]) / z;
(x_new, y_new)
}
pub fn compose(&self, other: &Self) -> Self {
let mut matrix = Array2::zeros((3, 3));
for i in 0..3 {
for j in 0..3 {
for k in 0..3 {
matrix[[i, j]] += self.matrix[[i, k]] * other.matrix[[k, j]];
}
}
}
let inverse = Self::invert_matrix(&matrix).unwrap_or_else(|_| Array2::eye(3));
Self { matrix, inverse }
}
fn invert_matrix(matrix: &Array2<f64>) -> Result<Array2<f64>> {
if matrix.shape() != [3, 3] {
return Err(crate::error::VisionError::InvalidParameter(
"Matrix must be 3x3".to_string(),
));
}
let det = matrix[[0, 0]]
* (matrix[[1, 1]] * matrix[[2, 2]] - matrix[[1, 2]] * matrix[[2, 1]])
- matrix[[0, 1]] * (matrix[[1, 0]] * matrix[[2, 2]] - matrix[[1, 2]] * matrix[[2, 0]])
+ matrix[[0, 2]] * (matrix[[1, 0]] * matrix[[2, 1]] - matrix[[1, 1]] * matrix[[2, 0]]);
if det.abs() < 1e-10 {
return Err(crate::error::VisionError::OperationError(
"Matrix is singular, cannot compute inverse".to_string(),
));
}
let mut inverse = Array2::zeros((3, 3));
inverse[[0, 0]] = (matrix[[1, 1]] * matrix[[2, 2]] - matrix[[1, 2]] * matrix[[2, 1]]) / det;
inverse[[0, 1]] = (matrix[[0, 2]] * matrix[[2, 1]] - matrix[[0, 1]] * matrix[[2, 2]]) / det;
inverse[[0, 2]] = (matrix[[0, 1]] * matrix[[1, 2]] - matrix[[0, 2]] * matrix[[1, 1]]) / det;
inverse[[1, 0]] = (matrix[[1, 2]] * matrix[[2, 0]] - matrix[[1, 0]] * matrix[[2, 2]]) / det;
inverse[[1, 1]] = (matrix[[0, 0]] * matrix[[2, 2]] - matrix[[0, 2]] * matrix[[2, 0]]) / det;
inverse[[1, 2]] = (matrix[[0, 2]] * matrix[[1, 0]] - matrix[[0, 0]] * matrix[[1, 2]]) / det;
inverse[[2, 0]] = (matrix[[1, 0]] * matrix[[2, 1]] - matrix[[1, 1]] * matrix[[2, 0]]) / det;
inverse[[2, 1]] = (matrix[[0, 1]] * matrix[[2, 0]] - matrix[[0, 0]] * matrix[[2, 1]]) / det;
inverse[[2, 2]] = (matrix[[0, 0]] * matrix[[1, 1]] - matrix[[0, 1]] * matrix[[1, 0]]) / det;
Ok(inverse)
}
}
#[derive(Debug, Clone)]
pub struct PointMatch {
pub point1: (f64, f64),
pub point2: (f64, f64),
}
impl RansacModel for Homography {
type DataPoint = PointMatch;
fn estimate(samples: &[Self::DataPoint]) -> Result<Self> {
if samples.len() < Self::min_samples() {
return Err(crate::error::VisionError::InvalidParameter(format!(
"Not enough samples: {} < {}",
samples.len(),
Self::min_samples()
)));
}
let mut a = Array2::zeros((samples.len() * 2, 9));
for (i, match_point) in samples.iter().enumerate() {
let (x1, y1) = match_point.point1;
let (x2, y2) = match_point.point2;
a[[i * 2, 0]] = -x1;
a[[i * 2, 1]] = -y1;
a[[i * 2, 2]] = -1.0;
a[[i * 2, 6]] = x1 * x2;
a[[i * 2, 7]] = y1 * x2;
a[[i * 2, 8]] = x2;
a[[i * 2 + 1, 3]] = -x1;
a[[i * 2 + 1, 4]] = -y1;
a[[i * 2 + 1, 5]] = -1.0;
a[[i * 2 + 1, 6]] = x1 * y2;
a[[i * 2 + 1, 7]] = y1 * y2;
a[[i * 2 + 1, 8]] = y2;
}
let svd = Self::compute_svd(&a)?;
let h = Array1::from_iter(svd.into_iter().skip(8 * 9).take(9));
let _matrixdata = [h[0], h[1], h[2], h[3], h[4], h[5], h[6], h[7], h[8]];
let homography = Self::new(&_matrixdata);
Ok(homography)
}
fn residual(&self, point: &Self::DataPoint) -> f64 {
let (x1, y1) = point.point1;
let (x2, y2) = point.point2;
let (x1_transformed, y1_transformed) = self.transform_point(x1, y1);
let forward_error = (x1_transformed - x2).powi(2) + (y1_transformed - y2).powi(2);
let (x2_transformed, y2_transformed) = self.inverse_transform_point(x2, y2);
let backward_error = (x2_transformed - x1).powi(2) + (y2_transformed - y1).powi(2);
(forward_error + backward_error) / 2.0
}
fn min_samples() -> usize {
4 }
fn refine(&self, inliers: &[Self::DataPoint]) -> Result<Self> {
Self::estimate(inliers)
}
}
impl Homography {
fn compute_svd(a: &Array2<f64>) -> Result<Vec<f64>> {
let (m, n) = a.dim();
let mut ata: Array2<f64> = Array2::zeros((n, n));
for i in 0..n {
for j in 0..n {
for k in 0..m {
ata[[i, j]] += a[[k, i]] * a[[k, j]];
}
}
}
let mut v = Array1::ones(n);
let norm = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
v.mapv_inplace(|x| x / norm);
for _ in 0..50 {
let mut av = Array1::zeros(n);
for i in 0..n {
for j in 0..n {
av[i] += ata[[i, j]] * v[j];
}
}
let lambda = v
.iter()
.zip(av.iter())
.map(|(&vi, &avi): (&f64, &f64)| vi * avi)
.sum::<f64>();
for i in 0..n {
v[i] = av[i] - lambda * v[i];
}
let norm = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
if norm < 1e-10 {
let mut rng = scirs2_core::random::rng();
for i in 0..n {
v[i] = rng.random::<f64>();
}
let norm = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
v.mapv_inplace(|x| x / norm);
} else {
v.mapv_inplace(|x| x / norm);
}
}
let mut result = vec![0.0; n * n];
for i in 0..n {
result[i * n + i] = 1.0;
}
for i in 0..n {
result[i * n + n - 1] = v[i];
}
Ok(result)
}
}
#[derive(Debug, Clone)]
pub struct TranslationScale {
pub tx: f64,
pub ty: f64,
pub scale: f64,
pub rotation: f64,
}
impl Default for TranslationScale {
fn default() -> Self {
Self {
tx: 0.0,
ty: 0.0,
scale: 1.0,
rotation: 0.0,
}
}
}
impl RansacModel for TranslationScale {
type DataPoint = PointMatch;
fn estimate(samples: &[Self::DataPoint]) -> Result<Self> {
if samples.len() < Self::min_samples() {
return Err(crate::error::VisionError::InvalidParameter(format!(
"Not enough samples: {} < {}",
samples.len(),
Self::min_samples()
)));
}
let (x1_1, y1_1) = samples[0].point1;
let (x2_1, y2_1) = samples[0].point2;
let (x1_2, y1_2) = samples[1].point1;
let (x2_2, y2_2) = samples[1].point2;
let dx1 = x1_2 - x1_1;
let dy1 = y1_2 - y1_1;
let dx2 = x2_2 - x2_1;
let dy2 = y2_2 - y2_1;
let len1 = (dx1 * dx1 + dy1 * dy1).sqrt();
let len2 = (dx2 * dx2 + dy2 * dy2).sqrt();
if len1 < 1e-8 || len2 < 1e-8 {
return Err(crate::error::VisionError::InvalidParameter(
"Points are too close to estimate model".to_string(),
));
}
let scale = len2 / len1;
let cos_angle1 = if len1 > 0.0 { dx1 / len1 } else { 1.0 };
let sin_angle1 = if len1 > 0.0 { dy1 / len1 } else { 0.0 };
let cos_angle2 = if len2 > 0.0 { dx2 / len2 } else { 1.0 };
let sin_angle2 = if len2 > 0.0 { dy2 / len2 } else { 0.0 };
let rot_cos = cos_angle1 * cos_angle2 + sin_angle1 * sin_angle2;
let rot_sin = -cos_angle1 * sin_angle2 + sin_angle1 * cos_angle2;
let rotation = rot_sin.atan2(rot_cos);
let cos_rot = rotation.cos();
let sin_rot = rotation.sin();
let scaled_x1 = scale * (cos_rot * x1_1 - sin_rot * y1_1);
let scaled_y1 = scale * (sin_rot * x1_1 + cos_rot * y1_1);
let tx = x2_1 - scaled_x1;
let ty = y2_1 - scaled_y1;
Ok(Self {
tx,
ty,
scale,
rotation,
})
}
fn residual(&self, point: &Self::DataPoint) -> f64 {
let (x1, y1) = point.point1;
let (x2, y2) = point.point2;
let cos_rot = self.rotation.cos();
let sin_rot = self.rotation.sin();
let x_transformed = self.scale * (cos_rot * x1 - sin_rot * y1) + self.tx;
let y_transformed = self.scale * (sin_rot * x1 + cos_rot * y1) + self.ty;
let dx = x_transformed - x2;
let dy = y_transformed - y2;
dx * dx + dy * dy
}
fn min_samples() -> usize {
2 }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_homography_identity() {
let h = Homography::identity();
let (x, y) = (10.0, 20.0);
let (x_transformed, y_transformed) = h.transform_point(x, y);
assert!((x_transformed - x).abs() < 1e-10);
assert!((y_transformed - y).abs() < 1e-10);
}
#[test]
fn test_homography_translation() {
let h = Homography::new(&[1.0, 0.0, 5.0, 0.0, 1.0, 7.0, 0.0, 0.0, 1.0]);
let (x, y) = (10.0, 20.0);
let (x_transformed, y_transformed) = h.transform_point(x, y);
assert!((x_transformed - (x + 5.0)).abs() < 1e-10);
assert!((y_transformed - (y + 7.0)).abs() < 1e-10);
}
#[test]
fn test_ransac_simple_translation() {
let true_model = TranslationScale {
tx: 10.0,
ty: 5.0,
scale: 1.0,
rotation: 0.0,
};
let mut matches = Vec::new();
for i in 0..100 {
let x1 = i as f64;
let y1 = (i % 10) as f64;
let x2 = x1 + true_model.tx;
let y2 = y1 + true_model.ty;
matches.push(PointMatch {
point1: (x1, y1),
point2: (x2, y2),
});
}
for _ in 0..20 {
matches.push(PointMatch {
point1: (100.0, 100.0),
point2: (150.0, 200.0),
});
}
let config = RansacConfig {
max_iterations: 100,
threshold: 1.0,
min_inliers: 10,
confidence: 0.99,
seed: Some(42),
refinement_iterations: 1,
};
let result = run_ransac::<TranslationScale>(&matches, &config).expect("Operation failed");
assert!((result.model.tx - true_model.tx).abs() < 1.0);
assert!((result.model.ty - true_model.ty).abs() < 1.0);
assert!((result.model.scale - true_model.scale).abs() < 0.1);
assert!((result.model.rotation - true_model.rotation).abs() < 0.1);
assert!(result.inliers.len() >= 90);
}
}