use crate::{AlignError, AlignResult};
#[derive(Debug, Clone)]
pub struct RollingShutterParams {
pub readout_time: f64,
pub frame_rate: f64,
pub direction: ReadoutDirection,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReadoutDirection {
TopToBottom,
BottomToTop,
LeftToRight,
RightToLeft,
}
impl RollingShutterParams {
#[must_use]
pub fn new(readout_time: f64, frame_rate: f64, direction: ReadoutDirection) -> Self {
Self {
readout_time,
frame_rate,
direction,
}
}
#[must_use]
pub fn compute_scanline_time(&self, scanline: usize, total_lines: usize) -> f64 {
let progress = match self.direction {
ReadoutDirection::TopToBottom => scanline as f64 / total_lines as f64,
ReadoutDirection::BottomToTop => 1.0 - (scanline as f64 / total_lines as f64),
ReadoutDirection::LeftToRight => scanline as f64 / total_lines as f64,
ReadoutDirection::RightToLeft => 1.0 - (scanline as f64 / total_lines as f64),
};
progress * self.readout_time
}
}
#[derive(Debug, Clone, Copy)]
pub struct MotionVector {
pub dx: f32,
pub dy: f32,
pub confidence: f32,
}
impl MotionVector {
#[must_use]
pub fn new(dx: f32, dy: f32, confidence: f32) -> Self {
Self { dx, dy, confidence }
}
#[must_use]
pub fn zero() -> Self {
Self {
dx: 0.0,
dy: 0.0,
confidence: 1.0,
}
}
#[must_use]
pub fn magnitude(&self) -> f32 {
(self.dx * self.dx + self.dy * self.dy).sqrt()
}
}
pub struct RollingShutterEstimator {
pub block_size: usize,
pub search_range: isize,
}
impl Default for RollingShutterEstimator {
fn default() -> Self {
Self {
block_size: 16,
search_range: 16,
}
}
}
impl RollingShutterEstimator {
#[must_use]
pub fn new(block_size: usize, search_range: isize) -> Self {
Self {
block_size,
search_range,
}
}
pub fn estimate_motion(
&self,
frame1: &[u8],
frame2: &[u8],
width: usize,
height: usize,
) -> AlignResult<Vec<MotionVector>> {
if frame1.len() != width * height * 3 || frame2.len() != width * height * 3 {
return Err(AlignError::InvalidConfig("Frame size mismatch".to_string()));
}
let mut motion_vectors = Vec::new();
for y in (0..height).step_by(self.block_size) {
let mv = self.estimate_row_motion(frame1, frame2, width, height, y);
motion_vectors.push(mv);
}
Ok(motion_vectors)
}
fn estimate_row_motion(
&self,
frame1: &[u8],
frame2: &[u8],
width: usize,
height: usize,
y: usize,
) -> MotionVector {
let mut best_dx = 0;
let mut best_dy = 0;
let mut best_sad = u32::MAX;
for dy in -self.search_range..=self.search_range {
for dx in -self.search_range..=self.search_range {
let sad = self.compute_sad(frame1, frame2, width, height, 0, y, dx, dy);
if sad < best_sad {
best_sad = sad;
best_dx = dx;
best_dy = dy;
}
}
}
let confidence = if best_sad == 0 {
1.0
} else {
1.0 / (1.0 + (best_sad as f32 / 1000.0))
};
MotionVector::new(best_dx as f32, best_dy as f32, confidence)
}
#[allow(clippy::too_many_arguments)]
fn compute_sad(
&self,
frame1: &[u8],
frame2: &[u8],
width: usize,
height: usize,
x: usize,
y: usize,
dx: isize,
dy: isize,
) -> u32 {
let mut sad = 0u32;
let block_height = self.block_size.min(height - y);
for by in 0..block_height {
for bx in 0..self.block_size.min(width) {
let x1 = x + bx;
let y1 = y + by;
let x2 = (x1 as isize + dx).max(0).min((width - 1) as isize) as usize;
let y2 = (y1 as isize + dy).max(0).min((height - 1) as isize) as usize;
let idx1 = (y1 * width + x1) * 3;
let idx2 = (y2 * width + x2) * 3;
if idx1 + 2 < frame1.len() && idx2 + 2 < frame2.len() {
for c in 0..3 {
sad += u32::from(
(i16::from(frame1[idx1 + c]) - i16::from(frame2[idx2 + c]))
.unsigned_abs(),
);
}
}
}
}
sad
}
}
pub struct RollingShutterCorrector {
pub params: RollingShutterParams,
estimator: RollingShutterEstimator,
}
impl RollingShutterCorrector {
#[must_use]
pub fn new(params: RollingShutterParams) -> Self {
Self {
params,
estimator: RollingShutterEstimator::default(),
}
}
pub fn correct(
&self,
frame: &[u8],
motion_vectors: &[MotionVector],
width: usize,
height: usize,
) -> AlignResult<Vec<u8>> {
if frame.len() != width * height * 3 {
return Err(AlignError::InvalidConfig("Frame size mismatch".to_string()));
}
let mut corrected = vec![0u8; width * height * 3];
for (block_idx, mv) in motion_vectors.iter().enumerate() {
let y_start = block_idx * self.estimator.block_size;
let y_end = (y_start + self.estimator.block_size).min(height);
for y in y_start..y_end {
self.correct_scanline(frame, &mut corrected, width, y, mv);
}
}
Ok(corrected)
}
fn correct_scanline(
&self,
input: &[u8],
output: &mut [u8],
width: usize,
y: usize,
mv: &MotionVector,
) {
for x in 0..width {
let src_x = (x as f32 - mv.dx).round() as isize;
let src_y = (y as f32 - mv.dy).round() as isize;
if src_x >= 0 && src_x < width as isize && src_y >= 0 {
let src_idx = (src_y as usize * width + src_x as usize) * 3;
let dst_idx = (y * width + x) * 3;
if src_idx + 2 < input.len() && dst_idx + 2 < output.len() {
output[dst_idx..dst_idx + 3].copy_from_slice(&input[src_idx..src_idx + 3]);
}
}
}
}
pub fn estimate_and_correct(
&self,
frame1: &[u8],
frame2: &[u8],
width: usize,
height: usize,
) -> AlignResult<Vec<u8>> {
let motion_vectors = self
.estimator
.estimate_motion(frame1, frame2, width, height)?;
self.correct(frame2, &motion_vectors, width, height)
}
}
pub struct WobbleDetector {
pub threshold: f32,
}
impl Default for WobbleDetector {
fn default() -> Self {
Self { threshold: 5.0 }
}
}
impl WobbleDetector {
#[must_use]
pub fn new(threshold: f32) -> Self {
Self { threshold }
}
#[must_use]
pub fn detect_wobble(&self, motion_vectors: &[MotionVector]) -> bool {
if motion_vectors.len() < 3 {
return false;
}
let mut sign_changes = 0;
for i in 2..motion_vectors.len() {
let d1 = motion_vectors[i - 1].dx - motion_vectors[i - 2].dx;
let d2 = motion_vectors[i].dx - motion_vectors[i - 1].dx;
if d1 * d2 < 0.0 && d1.abs() > self.threshold {
sign_changes += 1;
}
}
sign_changes > motion_vectors.len() / 4
}
#[must_use]
pub fn compute_wobble_metric(&self, motion_vectors: &[MotionVector]) -> f32 {
if motion_vectors.len() < 2 {
return 0.0;
}
let mut total_variation = 0.0f32;
for i in 1..motion_vectors.len() {
let ddx = motion_vectors[i].dx - motion_vectors[i - 1].dx;
let ddy = motion_vectors[i].dy - motion_vectors[i - 1].dy;
total_variation += (ddx * ddx + ddy * ddy).sqrt();
}
let avg_variation = total_variation / (motion_vectors.len() - 1) as f32;
(avg_variation / 20.0).min(1.0)
}
}
pub struct SkewCorrector {
pub angular_velocity: f64,
}
impl SkewCorrector {
#[must_use]
pub fn new(angular_velocity: f64) -> Self {
Self { angular_velocity }
}
pub fn correct(
&self,
frame: &[u8],
width: usize,
height: usize,
params: &RollingShutterParams,
) -> AlignResult<Vec<u8>> {
if frame.len() != width * height * 3 {
return Err(AlignError::InvalidConfig("Frame size mismatch".to_string()));
}
let mut corrected = vec![0u8; width * height * 3];
for y in 0..height {
let time = params.compute_scanline_time(y, height);
let angle = self.angular_velocity * time;
let offset = (angle * (height as f64 / 2.0)) as isize;
self.shift_scanline(frame, &mut corrected, width, y, offset);
}
Ok(corrected)
}
fn shift_scanline(
&self,
input: &[u8],
output: &mut [u8],
width: usize,
y: usize,
offset: isize,
) {
for x in 0..width {
let src_x = (x as isize - offset).max(0).min((width - 1) as isize) as usize;
let src_idx = (y * width + src_x) * 3;
let dst_idx = (y * width + x) * 3;
if src_idx + 2 < input.len() && dst_idx + 2 < output.len() {
output[dst_idx..dst_idx + 3].copy_from_slice(&input[src_idx..src_idx + 3]);
}
}
}
}
pub struct GlobalShutterSimulator {
pub sub_frames: usize,
}
impl Default for GlobalShutterSimulator {
fn default() -> Self {
Self { sub_frames: 10 }
}
}
impl GlobalShutterSimulator {
#[must_use]
pub fn new(sub_frames: usize) -> Self {
Self { sub_frames }
}
pub fn simulate(
&self,
frames: &[&[u8]],
width: usize,
height: usize,
params: &RollingShutterParams,
) -> AlignResult<Vec<u8>> {
if frames.is_empty() {
return Err(AlignError::InsufficientData(
"Need at least one frame".to_string(),
));
}
let mut output = vec![0u32; width * height * 3];
for y in 0..height {
let _time = params.compute_scanline_time(y, height);
for frame in frames {
if frame.len() != width * height * 3 {
continue;
}
for x in 0..width {
let idx = (y * width + x) * 3;
if idx + 2 < frame.len() {
output[idx] += u32::from(frame[idx]);
output[idx + 1] += u32::from(frame[idx + 1]);
output[idx + 2] += u32::from(frame[idx + 2]);
}
}
}
}
let n = frames.len() as u32;
let result = output.iter().map(|&v| (v / n) as u8).collect();
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rolling_shutter_params() {
let params = RollingShutterParams::new(0.033, 30.0, ReadoutDirection::TopToBottom);
assert_eq!(params.readout_time, 0.033);
assert_eq!(params.frame_rate, 30.0);
}
#[test]
fn test_scanline_time() {
let params = RollingShutterParams::new(0.01, 100.0, ReadoutDirection::TopToBottom);
let time = params.compute_scanline_time(500, 1000);
assert!((time - 0.005).abs() < 1e-10);
}
#[test]
fn test_motion_vector() {
let mv = MotionVector::new(10.0, 20.0, 0.9);
assert_eq!(mv.dx, 10.0);
assert_eq!(mv.dy, 20.0);
assert_eq!(mv.confidence, 0.9);
let mag = mv.magnitude();
assert!((mag - (10.0f32 * 10.0 + 20.0 * 20.0).sqrt()).abs() < 1e-6);
}
#[test]
fn test_zero_motion_vector() {
let mv = MotionVector::zero();
assert_eq!(mv.dx, 0.0);
assert_eq!(mv.dy, 0.0);
assert_eq!(mv.magnitude(), 0.0);
}
#[test]
fn test_wobble_detector() {
let detector = WobbleDetector::new(5.0);
assert_eq!(detector.threshold, 5.0);
}
#[test]
fn test_wobble_metric() {
let detector = WobbleDetector::default();
let vectors = vec![
MotionVector::new(0.0, 0.0, 1.0),
MotionVector::new(10.0, 0.0, 1.0),
MotionVector::new(0.0, 0.0, 1.0),
MotionVector::new(10.0, 0.0, 1.0),
];
let metric = detector.compute_wobble_metric(&vectors);
assert!(metric > 0.0);
}
#[test]
fn test_skew_corrector() {
let corrector = SkewCorrector::new(1.0);
assert_eq!(corrector.angular_velocity, 1.0);
}
#[test]
fn test_global_shutter_simulator() {
let simulator = GlobalShutterSimulator::new(10);
assert_eq!(simulator.sub_frames, 10);
}
#[test]
fn test_readout_direction() {
assert_eq!(ReadoutDirection::TopToBottom, ReadoutDirection::TopToBottom);
assert_ne!(ReadoutDirection::TopToBottom, ReadoutDirection::BottomToTop);
}
}