use crate::{AlignError, AlignResult};
#[derive(Debug, Clone)]
pub struct RollingShutterParams {
pub readout_time: f64,
pub frame_rate: f64,
pub direction: ReadoutDirection,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReadoutDirection {
TopToBottom,
BottomToTop,
LeftToRight,
RightToLeft,
}
impl RollingShutterParams {
#[must_use]
pub fn new(readout_time: f64, frame_rate: f64, direction: ReadoutDirection) -> Self {
Self {
readout_time,
frame_rate,
direction,
}
}
#[must_use]
pub fn compute_scanline_time(&self, scanline: usize, total_lines: usize) -> f64 {
let progress = match self.direction {
ReadoutDirection::TopToBottom => scanline as f64 / total_lines as f64,
ReadoutDirection::BottomToTop => 1.0 - (scanline as f64 / total_lines as f64),
ReadoutDirection::LeftToRight => scanline as f64 / total_lines as f64,
ReadoutDirection::RightToLeft => 1.0 - (scanline as f64 / total_lines as f64),
};
progress * self.readout_time
}
}
#[derive(Debug, Clone, Copy)]
pub struct MotionVector {
pub dx: f32,
pub dy: f32,
pub confidence: f32,
}
impl MotionVector {
#[must_use]
pub fn new(dx: f32, dy: f32, confidence: f32) -> Self {
Self { dx, dy, confidence }
}
#[must_use]
pub fn zero() -> Self {
Self {
dx: 0.0,
dy: 0.0,
confidence: 1.0,
}
}
#[must_use]
pub fn magnitude(&self) -> f32 {
(self.dx * self.dx + self.dy * self.dy).sqrt()
}
}
pub struct RollingShutterEstimator {
pub block_size: usize,
pub search_range: isize,
}
impl Default for RollingShutterEstimator {
fn default() -> Self {
Self {
block_size: 16,
search_range: 16,
}
}
}
impl RollingShutterEstimator {
#[must_use]
pub fn new(block_size: usize, search_range: isize) -> Self {
Self {
block_size,
search_range,
}
}
pub fn estimate_motion(
&self,
frame1: &[u8],
frame2: &[u8],
width: usize,
height: usize,
) -> AlignResult<Vec<MotionVector>> {
if frame1.len() != width * height * 3 || frame2.len() != width * height * 3 {
return Err(AlignError::InvalidConfig("Frame size mismatch".to_string()));
}
let mut motion_vectors = Vec::new();
for y in (0..height).step_by(self.block_size) {
let mv = self.estimate_row_motion(frame1, frame2, width, height, y);
motion_vectors.push(mv);
}
Ok(motion_vectors)
}
fn estimate_row_motion(
&self,
frame1: &[u8],
frame2: &[u8],
width: usize,
height: usize,
y: usize,
) -> MotionVector {
let mut best_dx = 0;
let mut best_dy = 0;
let mut best_sad = u32::MAX;
for dy in -self.search_range..=self.search_range {
for dx in -self.search_range..=self.search_range {
let sad = self.compute_sad(frame1, frame2, width, height, 0, y, dx, dy);
if sad < best_sad {
best_sad = sad;
best_dx = dx;
best_dy = dy;
}
}
}
let confidence = if best_sad == 0 {
1.0
} else {
1.0 / (1.0 + (best_sad as f32 / 1000.0))
};
MotionVector::new(best_dx as f32, best_dy as f32, confidence)
}
#[allow(clippy::too_many_arguments)]
fn compute_sad(
&self,
frame1: &[u8],
frame2: &[u8],
width: usize,
height: usize,
x: usize,
y: usize,
dx: isize,
dy: isize,
) -> u32 {
let mut sad = 0u32;
let block_height = self.block_size.min(height - y);
for by in 0..block_height {
for bx in 0..self.block_size.min(width) {
let x1 = x + bx;
let y1 = y + by;
let x2 = (x1 as isize + dx).max(0).min((width - 1) as isize) as usize;
let y2 = (y1 as isize + dy).max(0).min((height - 1) as isize) as usize;
let idx1 = (y1 * width + x1) * 3;
let idx2 = (y2 * width + x2) * 3;
if idx1 + 2 < frame1.len() && idx2 + 2 < frame2.len() {
for c in 0..3 {
sad += u32::from(
(i16::from(frame1[idx1 + c]) - i16::from(frame2[idx2 + c]))
.unsigned_abs(),
);
}
}
}
}
sad
}
}
pub struct RollingShutterCorrector {
pub params: RollingShutterParams,
estimator: RollingShutterEstimator,
}
impl RollingShutterCorrector {
#[must_use]
pub fn new(params: RollingShutterParams) -> Self {
Self {
params,
estimator: RollingShutterEstimator::default(),
}
}
pub fn correct(
&self,
frame: &[u8],
motion_vectors: &[MotionVector],
width: usize,
height: usize,
) -> AlignResult<Vec<u8>> {
if frame.len() != width * height * 3 {
return Err(AlignError::InvalidConfig("Frame size mismatch".to_string()));
}
let mut corrected = vec![0u8; width * height * 3];
for (block_idx, mv) in motion_vectors.iter().enumerate() {
let y_start = block_idx * self.estimator.block_size;
let y_end = (y_start + self.estimator.block_size).min(height);
for y in y_start..y_end {
self.correct_scanline(frame, &mut corrected, width, y, mv);
}
}
Ok(corrected)
}
fn correct_scanline(
&self,
input: &[u8],
output: &mut [u8],
width: usize,
y: usize,
mv: &MotionVector,
) {
for x in 0..width {
let src_x = (x as f32 - mv.dx).round() as isize;
let src_y = (y as f32 - mv.dy).round() as isize;
if src_x >= 0 && src_x < width as isize && src_y >= 0 {
let src_idx = (src_y as usize * width + src_x as usize) * 3;
let dst_idx = (y * width + x) * 3;
if src_idx + 2 < input.len() && dst_idx + 2 < output.len() {
output[dst_idx..dst_idx + 3].copy_from_slice(&input[src_idx..src_idx + 3]);
}
}
}
}
pub fn estimate_and_correct(
&self,
frame1: &[u8],
frame2: &[u8],
width: usize,
height: usize,
) -> AlignResult<Vec<u8>> {
let motion_vectors = self
.estimator
.estimate_motion(frame1, frame2, width, height)?;
self.correct(frame2, &motion_vectors, width, height)
}
}
pub struct WobbleDetector {
pub threshold: f32,
}
impl Default for WobbleDetector {
fn default() -> Self {
Self { threshold: 5.0 }
}
}
impl WobbleDetector {
#[must_use]
pub fn new(threshold: f32) -> Self {
Self { threshold }
}
#[must_use]
pub fn detect_wobble(&self, motion_vectors: &[MotionVector]) -> bool {
if motion_vectors.len() < 3 {
return false;
}
let mut sign_changes = 0;
for i in 2..motion_vectors.len() {
let d1 = motion_vectors[i - 1].dx - motion_vectors[i - 2].dx;
let d2 = motion_vectors[i].dx - motion_vectors[i - 1].dx;
if d1 * d2 < 0.0 && d1.abs() > self.threshold {
sign_changes += 1;
}
}
sign_changes > motion_vectors.len() / 4
}
#[must_use]
pub fn compute_wobble_metric(&self, motion_vectors: &[MotionVector]) -> f32 {
if motion_vectors.len() < 2 {
return 0.0;
}
let mut total_variation = 0.0f32;
for i in 1..motion_vectors.len() {
let ddx = motion_vectors[i].dx - motion_vectors[i - 1].dx;
let ddy = motion_vectors[i].dy - motion_vectors[i - 1].dy;
total_variation += (ddx * ddx + ddy * ddy).sqrt();
}
let avg_variation = total_variation / (motion_vectors.len() - 1) as f32;
(avg_variation / 20.0).min(1.0)
}
}
pub struct SkewCorrector {
pub angular_velocity: f64,
}
impl SkewCorrector {
#[must_use]
pub fn new(angular_velocity: f64) -> Self {
Self { angular_velocity }
}
pub fn correct(
&self,
frame: &[u8],
width: usize,
height: usize,
params: &RollingShutterParams,
) -> AlignResult<Vec<u8>> {
if frame.len() != width * height * 3 {
return Err(AlignError::InvalidConfig("Frame size mismatch".to_string()));
}
let mut corrected = vec![0u8; width * height * 3];
for y in 0..height {
let time = params.compute_scanline_time(y, height);
let angle = self.angular_velocity * time;
let offset = (angle * (height as f64 / 2.0)) as isize;
self.shift_scanline(frame, &mut corrected, width, y, offset);
}
Ok(corrected)
}
fn shift_scanline(
&self,
input: &[u8],
output: &mut [u8],
width: usize,
y: usize,
offset: isize,
) {
for x in 0..width {
let src_x = (x as isize - offset).max(0).min((width - 1) as isize) as usize;
let src_idx = (y * width + src_x) * 3;
let dst_idx = (y * width + x) * 3;
if src_idx + 2 < input.len() && dst_idx + 2 < output.len() {
output[dst_idx..dst_idx + 3].copy_from_slice(&input[src_idx..src_idx + 3]);
}
}
}
}
pub struct GlobalShutterSimulator {
pub sub_frames: usize,
}
impl Default for GlobalShutterSimulator {
fn default() -> Self {
Self { sub_frames: 10 }
}
}
impl GlobalShutterSimulator {
#[must_use]
pub fn new(sub_frames: usize) -> Self {
Self { sub_frames }
}
pub fn simulate(
&self,
frames: &[&[u8]],
width: usize,
height: usize,
params: &RollingShutterParams,
) -> AlignResult<Vec<u8>> {
if frames.is_empty() {
return Err(AlignError::InsufficientData(
"Need at least one frame".to_string(),
));
}
let mut output = vec![0u32; width * height * 3];
for y in 0..height {
let _time = params.compute_scanline_time(y, height);
for frame in frames {
if frame.len() != width * height * 3 {
continue;
}
for x in 0..width {
let idx = (y * width + x) * 3;
if idx + 2 < frame.len() {
output[idx] += u32::from(frame[idx]);
output[idx + 1] += u32::from(frame[idx + 1]);
output[idx + 2] += u32::from(frame[idx + 2]);
}
}
}
}
let n = frames.len() as u32;
let result = output.iter().map(|&v| (v / n) as u8).collect();
Ok(result)
}
}
pub struct TemporalSmoother {
alpha: f64,
state: Vec<MotionVector>,
}
impl TemporalSmoother {
#[must_use]
pub fn new(alpha: f64) -> Self {
Self {
alpha: alpha.clamp(0.01, 1.0),
state: Vec::new(),
}
}
pub fn smooth(&mut self, motion_vectors: &[MotionVector]) -> Vec<MotionVector> {
if self.state.len() != motion_vectors.len() {
self.state = motion_vectors.to_vec();
return motion_vectors.to_vec();
}
let alpha = self.alpha as f32;
let one_minus = 1.0 - alpha;
let mut result = Vec::with_capacity(motion_vectors.len());
for (prev, new) in self.state.iter_mut().zip(motion_vectors.iter()) {
let dx = alpha * new.dx + one_minus * prev.dx;
let dy = alpha * new.dy + one_minus * prev.dy;
let conf = alpha * new.confidence + one_minus * prev.confidence;
prev.dx = dx;
prev.dy = dy;
prev.confidence = conf;
result.push(MotionVector::new(dx, dy, conf));
}
result
}
pub fn reset(&mut self) {
self.state.clear();
}
#[must_use]
pub fn alpha(&self) -> f64 {
self.alpha
}
#[must_use]
pub fn num_blocks(&self) -> usize {
self.state.len()
}
}
pub struct GaussianTemporalSmoother {
radius: usize,
kernel: Vec<f64>,
history: Vec<Vec<MotionVector>>,
capacity: usize,
}
impl GaussianTemporalSmoother {
#[must_use]
pub fn new(radius: usize, sigma: f64) -> Self {
let sigma = sigma.max(0.1);
let cap = 2 * radius + 1;
let mut kernel = Vec::with_capacity(cap);
for i in 0..cap {
let x = i as f64 - radius as f64;
kernel.push((-0.5 * x * x / (sigma * sigma)).exp());
}
let sum: f64 = kernel.iter().sum();
if sum > 1e-15 {
for v in &mut kernel {
*v /= sum;
}
}
Self {
radius,
kernel,
history: Vec::with_capacity(cap),
capacity: cap,
}
}
pub fn push(&mut self, motion_vectors: &[MotionVector]) -> Vec<MotionVector> {
self.history.push(motion_vectors.to_vec());
if self.history.len() > self.capacity {
self.history.remove(0);
}
let num_blocks = motion_vectors.len();
let num_frames = self.history.len();
let centre = if num_frames > self.radius {
num_frames - 1 - self.radius.min(num_frames - 1)
} else {
0
};
let mut result = Vec::with_capacity(num_blocks);
for block_idx in 0..num_blocks {
let mut sum_dx = 0.0_f64;
let mut sum_dy = 0.0_f64;
let mut sum_conf = 0.0_f64;
let mut weight_total = 0.0_f64;
for (frame_offset, frame) in self.history.iter().enumerate() {
if block_idx >= frame.len() {
continue;
}
let ki = frame_offset as isize - centre as isize + self.radius as isize;
if ki < 0 || ki >= self.kernel.len() as isize {
continue;
}
let w = self.kernel[ki as usize];
let mv = &frame[block_idx];
sum_dx += f64::from(mv.dx) * w;
sum_dy += f64::from(mv.dy) * w;
sum_conf += f64::from(mv.confidence) * w;
weight_total += w;
}
if weight_total > 1e-15 {
result.push(MotionVector::new(
(sum_dx / weight_total) as f32,
(sum_dy / weight_total) as f32,
(sum_conf / weight_total) as f32,
));
} else {
result.push(
motion_vectors
.get(block_idx)
.copied()
.unwrap_or(MotionVector::zero()),
);
}
}
result
}
pub fn reset(&mut self) {
self.history.clear();
}
#[must_use]
pub fn history_len(&self) -> usize {
self.history.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rolling_shutter_params() {
let params = RollingShutterParams::new(0.033, 30.0, ReadoutDirection::TopToBottom);
assert_eq!(params.readout_time, 0.033);
assert_eq!(params.frame_rate, 30.0);
}
#[test]
fn test_scanline_time() {
let params = RollingShutterParams::new(0.01, 100.0, ReadoutDirection::TopToBottom);
let time = params.compute_scanline_time(500, 1000);
assert!((time - 0.005).abs() < 1e-10);
}
#[test]
fn test_motion_vector() {
let mv = MotionVector::new(10.0, 20.0, 0.9);
assert_eq!(mv.dx, 10.0);
assert_eq!(mv.dy, 20.0);
assert_eq!(mv.confidence, 0.9);
let mag = mv.magnitude();
assert!((mag - (10.0f32 * 10.0 + 20.0 * 20.0).sqrt()).abs() < 1e-6);
}
#[test]
fn test_zero_motion_vector() {
let mv = MotionVector::zero();
assert_eq!(mv.dx, 0.0);
assert_eq!(mv.dy, 0.0);
assert_eq!(mv.magnitude(), 0.0);
}
#[test]
fn test_wobble_detector() {
let detector = WobbleDetector::new(5.0);
assert_eq!(detector.threshold, 5.0);
}
#[test]
fn test_wobble_metric() {
let detector = WobbleDetector::default();
let vectors = vec![
MotionVector::new(0.0, 0.0, 1.0),
MotionVector::new(10.0, 0.0, 1.0),
MotionVector::new(0.0, 0.0, 1.0),
MotionVector::new(10.0, 0.0, 1.0),
];
let metric = detector.compute_wobble_metric(&vectors);
assert!(metric > 0.0);
}
#[test]
fn test_skew_corrector() {
let corrector = SkewCorrector::new(1.0);
assert_eq!(corrector.angular_velocity, 1.0);
}
#[test]
fn test_global_shutter_simulator() {
let simulator = GlobalShutterSimulator::new(10);
assert_eq!(simulator.sub_frames, 10);
}
#[test]
fn test_readout_direction() {
assert_eq!(ReadoutDirection::TopToBottom, ReadoutDirection::TopToBottom);
assert_ne!(ReadoutDirection::TopToBottom, ReadoutDirection::BottomToTop);
}
#[test]
fn test_temporal_smoother_first_frame_passthrough() {
let mut smoother = TemporalSmoother::new(0.5);
let mvs = vec![
MotionVector::new(10.0, 5.0, 0.9),
MotionVector::new(-3.0, 2.0, 0.8),
];
let result = smoother.smooth(&mvs);
assert_eq!(result.len(), 2);
assert!((result[0].dx - 10.0).abs() < 1e-5);
assert!((result[1].dy - 2.0).abs() < 1e-5);
}
#[test]
fn test_temporal_smoother_convergence() {
let mut smoother = TemporalSmoother::new(0.3);
let mvs = vec![MotionVector::new(4.0, -2.0, 1.0)];
for _ in 0..50 {
let _ = smoother.smooth(&mvs);
}
let result = smoother.smooth(&mvs);
assert!(
(result[0].dx - 4.0).abs() < 0.01,
"should converge to 4.0, got {}",
result[0].dx
);
assert!(
(result[0].dy + 2.0).abs() < 0.01,
"should converge to -2.0, got {}",
result[0].dy
);
}
#[test]
fn test_temporal_smoother_dampens_jitter() {
let mut smoother = TemporalSmoother::new(0.2);
let _ = smoother.smooth(&[MotionVector::new(10.0, 0.0, 1.0)]);
for _ in 0..20 {
let _ = smoother.smooth(&[MotionVector::new(-10.0, 0.0, 1.0)]);
let _ = smoother.smooth(&[MotionVector::new(10.0, 0.0, 1.0)]);
}
let result = smoother.smooth(&[MotionVector::new(-10.0, 0.0, 1.0)]);
assert!(
result[0].dx.abs() < 5.0,
"jitter should be dampened, got {}",
result[0].dx
);
}
#[test]
fn test_temporal_smoother_alpha_clamping() {
let s1 = TemporalSmoother::new(0.0);
assert!((s1.alpha() - 0.01).abs() < 1e-10);
let s2 = TemporalSmoother::new(2.0);
assert!((s2.alpha() - 1.0).abs() < 1e-10);
}
#[test]
fn test_temporal_smoother_reset() {
let mut smoother = TemporalSmoother::new(0.5);
let _ = smoother.smooth(&[MotionVector::new(5.0, 5.0, 1.0)]);
assert_eq!(smoother.num_blocks(), 1);
smoother.reset();
assert_eq!(smoother.num_blocks(), 0);
}
#[test]
fn test_gaussian_smoother_constant_input() {
let mut smoother = GaussianTemporalSmoother::new(2, 1.0);
let mvs = vec![MotionVector::new(3.0, -1.0, 0.9)];
for _ in 0..10 {
let result = smoother.push(&mvs);
assert_eq!(result.len(), 1);
assert!((result[0].dx - 3.0).abs() < 0.5, "dx={}", result[0].dx);
}
}
#[test]
fn test_gaussian_smoother_dampens_spike() {
let mut smoother = GaussianTemporalSmoother::new(2, 1.0);
let normal = vec![MotionVector::new(0.0, 0.0, 1.0)];
let spike = vec![MotionVector::new(100.0, 0.0, 1.0)];
let _ = smoother.push(&normal);
let _ = smoother.push(&normal);
let result = smoother.push(&spike); assert!(
result[0].dx < 100.0,
"spike should be dampened: dx={}",
result[0].dx
);
}
#[test]
fn test_gaussian_smoother_history_len() {
let mut smoother = GaussianTemporalSmoother::new(1, 0.5);
assert_eq!(smoother.history_len(), 0);
let mvs = vec![MotionVector::zero()];
let _ = smoother.push(&mvs);
assert_eq!(smoother.history_len(), 1);
let _ = smoother.push(&mvs);
let _ = smoother.push(&mvs);
assert_eq!(smoother.history_len(), 3);
let _ = smoother.push(&mvs);
assert_eq!(smoother.history_len(), 3);
}
#[test]
fn test_gaussian_smoother_reset() {
let mut smoother = GaussianTemporalSmoother::new(2, 1.0);
let _ = smoother.push(&[MotionVector::zero()]);
assert_eq!(smoother.history_len(), 1);
smoother.reset();
assert_eq!(smoother.history_len(), 0);
}
}