use image::GrayImage;
use crate::pyramid::build_pyramid_into;
use crate::utils::fast_gradients::compute_gradients_into;
pub const DEFAULT_MIN_EIGEN_THRESHOLD: f32 = 1e-3;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TrackStatus {
Tracked,
OutOfBounds,
Diverged,
LowTexture,
FbInconsistent,
}
pub const DEFAULT_FB_THRESHOLD: f32 = 0.7;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct TrackResult {
pub pos: (f32, f32),
pub status: TrackStatus,
pub error: f32,
}
#[deprecated(
since = "0.3.0",
note = "use `calc_optical_flow_ex`, which also returns per-point status and error"
)]
pub fn calc_optical_flow(
prev_pyramid: &[GrayImage],
curr_pyramid: &[GrayImage],
prev_points: &[(f32, f32)],
window_size: usize,
max_iterations: usize,
) -> Vec<(f32, f32)> {
calc_optical_flow_ex(
prev_pyramid,
curr_pyramid,
prev_points,
None,
window_size,
max_iterations,
DEFAULT_MIN_EIGEN_THRESHOLD,
)
.into_iter()
.map(|r| r.pos)
.collect()
}
pub fn calc_optical_flow_ex(
prev_pyramid: &[GrayImage],
curr_pyramid: &[GrayImage],
prev_points: &[(f32, f32)],
predicted: Option<&[(f32, f32)]>,
window_size: usize,
max_iterations: usize,
min_eigen_threshold: f32,
) -> Vec<TrackResult> {
let mut scratch = Scratch::default();
let mut out = Vec::new();
track_into(
prev_pyramid,
curr_pyramid,
prev_points,
predicted,
window_size,
max_iterations,
min_eigen_threshold,
&mut scratch,
&mut out,
);
out
}
#[derive(Default)]
struct Scratch {
offsets: Vec<(f32, f32)>,
prev_patch: Vec<f32>,
ix_patch: Vec<f32>,
iy_patch: Vec<f32>,
displacements: Vec<(f32, f32)>,
grad_x: Vec<i16>,
grad_y: Vec<i16>,
}
#[allow(clippy::too_many_arguments)]
fn track_into(
prev_pyramid: &[GrayImage],
curr_pyramid: &[GrayImage],
prev_points: &[(f32, f32)],
predicted: Option<&[(f32, f32)]>,
window_size: usize,
max_iterations: usize,
min_eigen_threshold: f32,
scratch: &mut Scratch,
out: &mut Vec<TrackResult>,
) {
assert_eq!(prev_pyramid.len(), curr_pyramid.len());
assert!(
!prev_pyramid.is_empty(),
"pyramid must have at least 1 level"
);
assert!(window_size % 2 == 1, "Window size must be odd");
if let Some(predicted) = predicted {
assert_eq!(
predicted.len(),
prev_points.len(),
"predicted must have one entry per prev_point"
);
}
let n_levels = prev_pyramid.len();
let radius = window_size / 2;
let n_pixels = window_size * window_size;
let epsilon = 1e-3;
let det_epsilon = 1e-6;
let Scratch {
offsets,
prev_patch,
ix_patch,
iy_patch,
displacements,
grad_x: grad_x_buf,
grad_y: grad_y_buf,
} = scratch;
build_window_offsets_into(radius, offsets);
prev_patch.resize(n_pixels, 0.0);
ix_patch.resize(n_pixels, 0.0);
iy_patch.resize(n_pixels, 0.0);
displacements.clear();
match predicted {
Some(predicted) => displacements.extend(
prev_points
.iter()
.zip(predicted.iter())
.map(|((px, py), (gx, gy))| (gx - px, gy - py)),
),
None => displacements.resize(prev_points.len(), (0.0, 0.0)),
}
let (w0, h0) = prev_pyramid[0].dimensions();
grad_x_buf.resize((w0 * h0) as usize, 0);
grad_y_buf.resize((w0 * h0) as usize, 0);
out.clear();
out.extend(prev_points.iter().map(|&(x, y)| TrackResult {
pos: (x, y),
status: TrackStatus::Tracked,
error: f32::INFINITY,
}));
for level in (0..n_levels).rev() {
let scale = 2f32.powi(level as i32);
let is_finest = level == 0;
let prev_img = &prev_pyramid[level];
let curr_img = &curr_pyramid[level];
let (lw, lh) = prev_img.dimensions();
let level_pixels = (lw * lh) as usize;
compute_gradients_into(
prev_img,
&mut grad_x_buf[..level_pixels],
&mut grad_y_buf[..level_pixels],
);
let grad_x = &grad_x_buf[..level_pixels];
let grad_y = &grad_y_buf[..level_pixels];
for (idx, (prev_x, prev_y)) in prev_points.iter().enumerate() {
let x = *prev_x / scale;
let y = *prev_y / scale;
let mut dx = displacements[idx].0 / scale;
let mut dy = displacements[idx].1 / scale;
if !in_bounds(prev_img, x, y, radius) {
out[idx].status = TrackStatus::OutOfBounds;
continue;
}
let mut gxx = 0.0f32;
let mut gxy = 0.0f32;
let mut gyy = 0.0f32;
for (i, (ox, oy)) in offsets.iter().enumerate() {
let sample_x = x + ox;
let sample_y = y + oy;
let ix = interpolate_i16(grad_x, lw, lh, sample_x, sample_y) / 32.0;
let iy = interpolate_i16(grad_y, lw, lh, sample_x, sample_y) / 32.0;
prev_patch[i] = interpolate(prev_img, sample_x, sample_y);
ix_patch[i] = ix;
iy_patch[i] = iy;
gxx += ix * ix;
gxy += ix * iy;
gyy += iy * iy;
}
let min_eig = min_eigenvalue(gxx, gxy, gyy) / n_pixels as f32;
if min_eig < min_eigen_threshold {
out[idx].status = TrackStatus::LowTexture;
if is_finest {
out[idx].error =
window_error(curr_img, prev_patch, offsets, x + dx, y + dy, radius);
}
continue;
}
let Some((inv_h00, inv_h01, inv_h11)) = invert_2x2(gxx, gxy, gyy, det_epsilon) else {
out[idx].status = TrackStatus::LowTexture;
continue;
};
let mut converged = false;
let mut out_of_bounds = false;
let mut diverged = false;
for _ in 0..max_iterations {
let curr_x = x + dx;
let curr_y = y + dy;
if !in_bounds(curr_img, curr_x, curr_y, radius) {
out_of_bounds = true;
break;
}
let mut bx = 0.0f32;
let mut by = 0.0f32;
for (i, (ox, oy)) in offsets.iter().enumerate() {
let curr = interpolate(curr_img, curr_x + ox, curr_y + oy);
let error = prev_patch[i] - curr;
bx += ix_patch[i] * error;
by += iy_patch[i] * error;
}
let ddx = inv_h00 * bx + inv_h01 * by;
let ddy = inv_h01 * bx + inv_h11 * by;
dx += ddx;
dy += ddy;
if !dx.is_finite()
|| !dy.is_finite()
|| ddx.abs() > window_size as f32
|| ddy.abs() > window_size as f32
{
diverged = true;
break;
}
if ddx.abs() < epsilon && ddy.abs() < epsilon {
converged = true;
break;
}
}
out[idx].status = if out_of_bounds {
TrackStatus::OutOfBounds
} else if diverged || !converged {
TrackStatus::Diverged
} else {
TrackStatus::Tracked
};
displacements[idx] = (dx * scale, dy * scale);
if is_finest {
out[idx].error = if out_of_bounds {
f32::INFINITY
} else {
window_error(curr_img, prev_patch, offsets, x + dx, y + dy, radius)
};
}
}
}
for (idx, (x, y)) in prev_points.iter().enumerate() {
let (dx, dy) = displacements[idx];
out[idx].pos = (x + dx, y + dy);
}
}
#[allow(clippy::too_many_arguments)]
pub fn calc_optical_flow_fb(
prev_pyramid: &[GrayImage],
next_pyramid: &[GrayImage],
prev_points: &[(f32, f32)],
predicted: Option<&[(f32, f32)]>,
window_size: usize,
max_iterations: usize,
min_eigen_threshold: f32,
fb_threshold: f32,
) -> Vec<TrackResult> {
let mut scratch = Scratch::default();
let mut forward = Vec::new();
track_into(
prev_pyramid,
next_pyramid,
prev_points,
predicted,
window_size,
max_iterations,
min_eigen_threshold,
&mut scratch,
&mut forward,
);
let forward_pos: Vec<(f32, f32)> = forward.iter().map(|r| r.pos).collect();
let mut backward = Vec::new();
track_into(
next_pyramid,
prev_pyramid,
&forward_pos,
Some(prev_points),
window_size,
max_iterations,
min_eigen_threshold,
&mut scratch,
&mut backward,
);
mark_fb_inconsistent(&mut forward, &backward, prev_points, fb_threshold);
forward
}
fn mark_fb_inconsistent(
forward: &mut [TrackResult],
backward: &[TrackResult],
prev_points: &[(f32, f32)],
fb_threshold: f32,
) {
let threshold_sq = fb_threshold * fb_threshold;
for (idx, result) in forward.iter_mut().enumerate() {
if result.status != TrackStatus::Tracked {
continue;
}
let back = &backward[idx];
let dx = back.pos.0 - prev_points[idx].0;
let dy = back.pos.1 - prev_points[idx].1;
if back.status != TrackStatus::Tracked || dx * dx + dy * dy > threshold_sq {
result.status = TrackStatus::FbInconsistent;
}
}
}
fn min_eigenvalue(a: f32, b: f32, c: f32) -> f32 {
let trace = a + c;
let det = a * c - b * b;
let disc = (trace * trace - 4.0 * det).max(0.0).sqrt();
(trace - disc) / 2.0
}
fn window_error(
img: &GrayImage,
prev_patch: &[f32],
offsets: &[(f32, f32)],
cx: f32,
cy: f32,
radius: usize,
) -> f32 {
if !in_bounds(img, cx, cy, radius) {
return f32::INFINITY;
}
let mut sum = 0.0f32;
for (i, (ox, oy)) in offsets.iter().enumerate() {
let curr = interpolate(img, cx + ox, cy + oy);
sum += (prev_patch[i] - curr).abs();
}
sum / offsets.len() as f32
}
fn build_window_offsets_into(radius: usize, offsets: &mut Vec<(f32, f32)>) {
offsets.clear();
offsets.reserve((2 * radius + 1) * (2 * radius + 1));
for j in -(radius as i32)..=radius as i32 {
for i in -(radius as i32)..=radius as i32 {
offsets.push((i as f32, j as f32));
}
}
}
fn invert_2x2(a00: f32, a01: f32, a11: f32, det_epsilon: f32) -> Option<(f32, f32, f32)> {
let det = a00 * a11 - a01 * a01;
if det.abs() <= det_epsilon {
return None;
}
let inv_det = 1.0 / det;
Some((a11 * inv_det, -a01 * inv_det, a00 * inv_det))
}
fn in_bounds(img: &GrayImage, x: f32, y: f32, radius: usize) -> bool {
let (w, h) = (img.width() as f32, img.height() as f32);
x >= radius as f32 && x < w - radius as f32 && y >= radius as f32 && y < h - radius as f32
}
fn interpolate(img: &GrayImage, x: f32, y: f32) -> f32 {
let w = img.width() as i32;
let h = img.height() as i32;
let x0 = x.floor() as i32;
let y0 = y.floor() as i32;
let dx = x - x0 as f32;
let dy = y - y0 as f32;
let data = img.as_raw();
if x0 >= 0 && y0 >= 0 && x0 + 1 < w && y0 + 1 < h {
let stride = w as usize;
let base = y0 as usize * stride + x0 as usize;
let (p00, p10, p01, p11) = unsafe {
(
*data.get_unchecked(base) as f32,
*data.get_unchecked(base + 1) as f32,
*data.get_unchecked(base + stride) as f32,
*data.get_unchecked(base + stride + 1) as f32,
)
};
return p00 * (1.0 - dx) * (1.0 - dy)
+ p01 * (1.0 - dx) * dy
+ p10 * dx * (1.0 - dy)
+ p11 * dx * dy;
}
let x1 = x0 + 1;
let y1 = y0 + 1;
let mut sum = 0.0;
for (sx, sy) in &[(x0, y0), (x0, y1), (x1, y0), (x1, y1)] {
let px = if *sx >= 0 && *sy >= 0 && *sx < w && *sy < h {
data[*sy as usize * w as usize + *sx as usize] as f32
} else {
0.0
};
let wx = if sx == &x0 { 1.0 - dx } else { dx };
let wy = if sy == &y0 { 1.0 - dy } else { dy };
sum += px * wx * wy;
}
sum
}
fn interpolate_i16(data: &[i16], width: u32, height: u32, x: f32, y: f32) -> f32 {
let w = width as i32;
let h = height as i32;
let x0 = x.floor() as i32;
let y0 = y.floor() as i32;
let dx = x - x0 as f32;
let dy = y - y0 as f32;
if x0 >= 0 && y0 >= 0 && x0 + 1 < w && y0 + 1 < h {
let stride = width as usize;
let base = y0 as usize * stride + x0 as usize;
let (p00, p10, p01, p11) = unsafe {
(
*data.get_unchecked(base) as f32,
*data.get_unchecked(base + 1) as f32,
*data.get_unchecked(base + stride) as f32,
*data.get_unchecked(base + stride + 1) as f32,
)
};
return p00 * (1.0 - dx) * (1.0 - dy)
+ p01 * (1.0 - dx) * dy
+ p10 * dx * (1.0 - dy)
+ p11 * dx * dy;
}
let x1 = x0 + 1;
let y1 = y0 + 1;
let mut sum = 0.0;
for (sx, sy) in &[(x0, y0), (x0, y1), (x1, y0), (x1, y1)] {
let px = if *sx >= 0 && *sy >= 0 && *sx < w && *sy < h {
data[*sy as usize * width as usize + *sx as usize] as f32
} else {
0.0
};
let wx = if sx == &x0 { 1.0 - dx } else { dx };
let wy = if sy == &y0 { 1.0 - dy } else { dy };
sum += px * wx * wy;
}
sum
}
#[derive(Default)]
pub struct TrackerContext {
prev_pyramid: Vec<GrayImage>,
next_pyramid: Vec<GrayImage>,
scratch: Scratch,
results: Vec<TrackResult>,
forward_pos: Vec<(f32, f32)>,
backward: Vec<TrackResult>,
}
impl TrackerContext {
pub fn new() -> Self {
Self::default()
}
pub fn prepare(&mut self, prev: &GrayImage, next: &GrayImage, levels: usize) {
build_pyramid_into(prev, levels, &mut self.prev_pyramid);
build_pyramid_into(next, levels, &mut self.next_pyramid);
}
pub fn prev_pyramid(&self) -> &[GrayImage] {
&self.prev_pyramid
}
pub fn next_pyramid(&self) -> &[GrayImage] {
&self.next_pyramid
}
pub fn track(
&mut self,
prev_points: &[(f32, f32)],
predicted: Option<&[(f32, f32)]>,
window_size: usize,
max_iterations: usize,
min_eigen_threshold: f32,
) -> &[TrackResult] {
track_into(
&self.prev_pyramid,
&self.next_pyramid,
prev_points,
predicted,
window_size,
max_iterations,
min_eigen_threshold,
&mut self.scratch,
&mut self.results,
);
&self.results
}
pub fn track_fb(
&mut self,
prev_points: &[(f32, f32)],
predicted: Option<&[(f32, f32)]>,
window_size: usize,
max_iterations: usize,
min_eigen_threshold: f32,
fb_threshold: f32,
) -> &[TrackResult] {
track_into(
&self.prev_pyramid,
&self.next_pyramid,
prev_points,
predicted,
window_size,
max_iterations,
min_eigen_threshold,
&mut self.scratch,
&mut self.results,
);
self.forward_pos.clear();
self.forward_pos.extend(self.results.iter().map(|r| r.pos));
track_into(
&self.next_pyramid,
&self.prev_pyramid,
&self.forward_pos,
Some(prev_points),
window_size,
max_iterations,
min_eigen_threshold,
&mut self.scratch,
&mut self.backward,
);
mark_fb_inconsistent(&mut self.results, &self.backward, prev_points, fb_threshold);
&self.results
}
}
#[cfg(test)]
mod tests {
use super::invert_2x2;
#[test]
fn invert_2x2_returns_inverse_components() {
let (inv00, inv01, inv11) = invert_2x2(4.0, 1.0, 3.0, 1e-6).unwrap();
assert!((inv00 - 3.0 / 11.0).abs() < 1e-6);
assert!((inv01 + 1.0 / 11.0).abs() < 1e-6);
assert!((inv11 - 4.0 / 11.0).abs() < 1e-6);
}
#[test]
fn invert_2x2_rejects_singular_matrix() {
assert!(invert_2x2(1.0, 2.0, 4.0, 1e-6).is_none());
}
}