use super::motion_compensation::apply_luma_mv_block;
use super::reference_buffer::ReconFrame;
use super::transform::forward_hadamard_4x4;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct MotionVector {
pub mv_x: i16,
pub mv_y: i16,
}
impl MotionVector {
pub const ZERO: Self = Self { mv_x: 0, mv_y: 0 };
}
#[derive(Debug, Clone, Copy)]
pub struct MotionSearchResult {
pub mv: MotionVector,
pub cost: u32,
}
#[derive(Debug, Default)]
pub struct MotionEstimator {
_private: (),
}
impl MotionEstimator {
pub fn new() -> Self {
Self { _private: () }
}
pub fn search_block(
&mut self,
source: &[u8],
source_stride: usize,
reference: &ReconFrame,
block_x: u32,
block_y: u32,
block_w: u32,
block_h: u32,
predicted_mv: MotionVector,
) -> MotionSearchResult {
self.search_block_with_candidates(
source, source_stride, reference, block_x, block_y, block_w, block_h,
predicted_mv, &[predicted_mv],
)
}
#[allow(clippy::too_many_arguments)]
pub fn search_block_with_candidates(
&mut self,
source: &[u8],
source_stride: usize,
reference: &ReconFrame,
block_x: u32,
block_y: u32,
block_w: u32,
block_h: u32,
predicted_mv: MotionVector,
candidates: &[MotionVector],
) -> MotionSearchResult {
let lambda = me_lambda();
let fallback = clip_mv_to_frame(
predicted_mv, reference, block_x, block_y, block_w, block_h,
);
let mut start = fallback;
let mut start_cost = u32::MAX;
for &c in candidates.iter() {
let c_int = MotionVector {
mv_x: (c.mv_x >> 2) << 2,
mv_y: (c.mv_y >> 2) << 2,
};
let c_clipped = clip_mv_to_frame(
c_int, reference, block_x, block_y, block_w, block_h,
);
let cost = satd_at_mv(
source, source_stride, reference, block_x, block_y, block_w, block_h, c_clipped,
) + lambda * mv_bit_cost(c_clipped, predicted_mv);
if cost < start_cost {
start_cost = cost;
start = c_clipped;
}
}
let pre_start = if umh_enabled() {
let after_cross = cross_search_r16(
source, source_stride, reference, block_x, block_y, block_w, block_h,
start, predicted_mv, lambda,
);
multi_hex_search(
source, source_stride, reference, block_x, block_y, block_w, block_h,
after_cross, predicted_mv, lambda,
)
} else {
start
};
let integer_mv = integer_hex_search(
source, source_stride, reference, block_x, block_y, block_w, block_h,
pre_start, predicted_mv, lambda,
);
let halfpel_mv = refine_5point(
source, source_stride, reference, block_x, block_y, block_w, block_h,
integer_mv, 2, predicted_mv, lambda,
);
let qpel_mv = refine_5point(
source, source_stride, reference, block_x, block_y, block_w, block_h,
halfpel_mv, 1, predicted_mv, lambda,
);
let cost = satd_at_mv(
source, source_stride, reference, block_x, block_y, block_w, block_h, qpel_mv,
);
MotionSearchResult { mv: qpel_mv, cost }
}
pub fn search_16x16(
&mut self,
source: &[[u8; 16]; 16],
reference: &ReconFrame,
block_x: u32,
block_y: u32,
predicted_mv: MotionVector,
) -> MotionSearchResult {
self.search_block(
source.as_flattened(),
16,
reference,
block_x,
block_y,
16,
16,
predicted_mv,
)
}
}
fn sad_block(
source: &[u8],
source_stride: usize,
pred: &[u8],
pred_stride: usize,
block_w: u32,
block_h: u32,
) -> u32 {
super::simd::sad_block_dispatch(
source,
source_stride,
pred,
pred_stride,
block_w,
block_h,
|| sad_block_scalar(source, source_stride, pred, pred_stride, block_w, block_h),
)
}
#[inline]
fn sad_block_scalar(
source: &[u8],
source_stride: usize,
pred: &[u8],
pred_stride: usize,
block_w: u32,
block_h: u32,
) -> u32 {
let mut sum = 0u32;
for y in 0..block_h as usize {
for x in 0..block_w as usize {
let d = source[y * source_stride + x] as i32 - pred[y * pred_stride + x] as i32;
sum += d.unsigned_abs();
}
}
sum
}
fn satd_block(
source: &[u8],
source_stride: usize,
pred: &[u8],
pred_stride: usize,
block_w: u32,
block_h: u32,
) -> u32 {
debug_assert!(block_w.is_multiple_of(4) && block_h.is_multiple_of(4));
super::simd::satd_block_dispatch(
source,
source_stride,
pred,
pred_stride,
block_w,
block_h,
|| satd_block_scalar(source, source_stride, pred, pred_stride, block_w, block_h),
)
}
#[inline]
fn satd_block_scalar(
source: &[u8],
source_stride: usize,
pred: &[u8],
pred_stride: usize,
block_w: u32,
block_h: u32,
) -> u32 {
let mut total: u32 = 0;
let tiles_y = (block_h / 4) as usize;
let tiles_x = (block_w / 4) as usize;
for by in 0..tiles_y {
for bx in 0..tiles_x {
let mut residual = [[0i32; 4]; 4];
for dy in 0..4 {
for dx in 0..4 {
let sx = bx * 4 + dx;
let sy = by * 4 + dy;
residual[dy][dx] = source[sy * source_stride + sx] as i32
- pred[sy * pred_stride + sx] as i32;
}
}
let h = forward_hadamard_4x4(&residual);
for row in &h {
for &v in row {
total = total.saturating_add(v.unsigned_abs());
}
}
}
}
total
}
const HEX_PATTERN: [(i16, i16); 6] = [
(-2, 0),
(-1, -2),
(1, -2),
(2, 0),
(1, 2),
(-1, 2),
];
const MAX_HEX_ITER: usize = 16;
const CROSS_PATTERN_R16: [(i16, i16); 4] = [
(-16, 0),
(16, 0),
(0, -16),
(0, 16),
];
const MULTI_HEX_PATTERN: [(i16, i16); 16] = [
(-4, -4), (-2, -4), (0, -4), (2, -4), (4, -4),
(-4, -2), (4, -2),
(-4, 0), (4, 0),
(-4, 2), (4, 2),
(-4, 4), (-2, 4), (0, 4), (2, 4), (4, 4),
];
const MAX_MULTI_HEX_ITER: usize = 4;
const LAMBDA_MOTION_DEFAULT: u32 = 1;
#[inline]
fn me_lambda() -> u32 {
std::env::var("PHASM_ME_LAMBDA")
.ok()
.and_then(|s| s.parse::<u32>().ok())
.map_or(LAMBDA_MOTION_DEFAULT, |v| v.clamp(1, 32))
}
#[inline]
fn se_bits(d: i32) -> u32 {
let absd_plus_1 = d.unsigned_abs() + 1;
2 * (31 - absd_plus_1.leading_zeros()) + 1
}
#[inline]
fn mv_bit_cost(mv: MotionVector, predictor: MotionVector) -> u32 {
let dx = mv.mv_x as i32 - predictor.mv_x as i32;
let dy = mv.mv_y as i32 - predictor.mv_y as i32;
se_bits(dx) + se_bits(dy)
}
#[allow(clippy::too_many_arguments)]
fn cross_search_r16(
source: &[u8],
source_stride: usize,
reference: &ReconFrame,
block_x: u32,
block_y: u32,
block_w: u32,
block_h: u32,
seed: MotionVector,
predictor: MotionVector,
lambda: u32,
) -> MotionVector {
let mut best_mv = seed;
let mut best_cost = sad_at_mv(
source, source_stride, reference, block_x, block_y, block_w, block_h, seed,
) + lambda * mv_bit_cost(seed, predictor);
for (dx, dy) in CROSS_PATTERN_R16 {
let candidate = MotionVector {
mv_x: seed.mv_x + dx * 4,
mv_y: seed.mv_y + dy * 4,
};
let candidate =
clip_mv_to_frame(candidate, reference, block_x, block_y, block_w, block_h);
let cost = sad_at_mv(
source, source_stride, reference, block_x, block_y, block_w, block_h, candidate,
) + lambda * mv_bit_cost(candidate, predictor);
if cost < best_cost {
best_cost = cost;
best_mv = candidate;
}
}
best_mv
}
#[allow(clippy::too_many_arguments)]
fn multi_hex_search(
source: &[u8],
source_stride: usize,
reference: &ReconFrame,
block_x: u32,
block_y: u32,
block_w: u32,
block_h: u32,
seed: MotionVector,
predictor: MotionVector,
lambda: u32,
) -> MotionVector {
let mut center = seed;
let mut center_cost = sad_at_mv(
source, source_stride, reference, block_x, block_y, block_w, block_h, center,
) + lambda * mv_bit_cost(center, predictor);
for _ in 0..MAX_MULTI_HEX_ITER {
let mut best_mv = center;
let mut best_cost = center_cost;
for (dx, dy) in MULTI_HEX_PATTERN {
let candidate = MotionVector {
mv_x: center.mv_x + dx * 4,
mv_y: center.mv_y + dy * 4,
};
let candidate =
clip_mv_to_frame(candidate, reference, block_x, block_y, block_w, block_h);
let cost = sad_at_mv(
source, source_stride, reference, block_x, block_y, block_w, block_h, candidate,
) + lambda * mv_bit_cost(candidate, predictor);
if cost < best_cost {
best_cost = cost;
best_mv = candidate;
}
}
if best_mv == center {
break;
}
center = best_mv;
center_cost = best_cost;
}
center
}
#[inline]
fn umh_enabled() -> bool {
std::env::var("PHASM_ME_UMH")
.ok()
.is_none_or(|v| v != "0")
}
#[allow(clippy::too_many_arguments)]
fn integer_hex_search(
source: &[u8],
source_stride: usize,
reference: &ReconFrame,
block_x: u32,
block_y: u32,
block_w: u32,
block_h: u32,
start_mv: MotionVector,
predictor: MotionVector,
lambda: u32,
) -> MotionVector {
let mut center = MotionVector {
mv_x: (start_mv.mv_x >> 2) << 2,
mv_y: (start_mv.mv_y >> 2) << 2,
};
let mut center_cost = sad_at_mv(
source, source_stride, reference, block_x, block_y, block_w, block_h, center,
) + lambda * mv_bit_cost(center, predictor);
for _ in 0..MAX_HEX_ITER {
let mut best_mv = center;
let mut best_cost = center_cost;
for (dx, dy) in HEX_PATTERN {
let candidate = MotionVector {
mv_x: center.mv_x + dx * 4,
mv_y: center.mv_y + dy * 4,
};
let candidate =
clip_mv_to_frame(candidate, reference, block_x, block_y, block_w, block_h);
let cost = sad_at_mv(
source, source_stride, reference, block_x, block_y, block_w, block_h, candidate,
) + lambda * mv_bit_cost(candidate, predictor);
if cost < best_cost {
best_cost = cost;
best_mv = candidate;
}
}
if best_mv == center {
break;
}
center = best_mv;
center_cost = best_cost;
}
center
}
#[inline]
fn diamond_enabled() -> bool {
std::env::var_os("PHASM_ME_DIAMOND").is_some()
}
#[allow(clippy::too_many_arguments)]
fn refine_5point(
source: &[u8],
source_stride: usize,
reference: &ReconFrame,
block_x: u32,
block_y: u32,
block_w: u32,
block_h: u32,
center: MotionVector,
step_qpel: i16,
predictor: MotionVector,
lambda: u32,
) -> MotionVector {
let diag = if diamond_enabled() { step_qpel } else { 0 };
let candidates = [
center,
MotionVector { mv_x: center.mv_x - step_qpel, mv_y: center.mv_y },
MotionVector { mv_x: center.mv_x + step_qpel, mv_y: center.mv_y },
MotionVector { mv_x: center.mv_x, mv_y: center.mv_y - step_qpel },
MotionVector { mv_x: center.mv_x, mv_y: center.mv_y + step_qpel },
MotionVector { mv_x: center.mv_x - diag, mv_y: center.mv_y - diag },
MotionVector { mv_x: center.mv_x + diag, mv_y: center.mv_y - diag },
MotionVector { mv_x: center.mv_x - diag, mv_y: center.mv_y + diag },
MotionVector { mv_x: center.mv_x + diag, mv_y: center.mv_y + diag },
];
let mut best_mv = center;
let mut best_cost = u32::MAX;
for cand in candidates {
let cand_clipped =
clip_mv_to_frame(cand, reference, block_x, block_y, block_w, block_h);
let cost = satd_at_mv(
source, source_stride, reference, block_x, block_y, block_w, block_h, cand_clipped,
) + lambda * mv_bit_cost(cand_clipped, predictor);
if cost < best_cost {
best_cost = cost;
best_mv = cand_clipped;
}
}
best_mv
}
#[allow(clippy::too_many_arguments)]
fn sad_at_mv(
source: &[u8],
source_stride: usize,
reference: &ReconFrame,
block_x: u32,
block_y: u32,
block_w: u32,
block_h: u32,
mv: MotionVector,
) -> u32 {
let used = (block_w * block_h) as usize;
let mut pred_storage = [0u8; 256];
let pred = &mut pred_storage[..used];
apply_luma_mv_block(
reference,
block_x,
block_y,
block_w,
block_h,
mv,
pred,
block_w as usize,
);
sad_block(source, source_stride, pred, block_w as usize, block_w, block_h)
}
#[allow(clippy::too_many_arguments)]
fn satd_at_mv(
source: &[u8],
source_stride: usize,
reference: &ReconFrame,
block_x: u32,
block_y: u32,
block_w: u32,
block_h: u32,
mv: MotionVector,
) -> u32 {
let used = (block_w * block_h) as usize;
let mut pred_storage = [0u8; 256];
let pred = &mut pred_storage[..used];
apply_luma_mv_block(
reference,
block_x,
block_y,
block_w,
block_h,
mv,
pred,
block_w as usize,
);
satd_block(source, source_stride, pred, block_w as usize, block_w, block_h)
}
fn clip_mv_to_frame(
mv: MotionVector,
reference: &ReconFrame,
block_x: u32,
block_y: u32,
block_w: u32,
block_h: u32,
) -> MotionVector {
let halo = 3i32;
let ref_w = reference.width as i32;
let ref_h = reference.height as i32;
let min_x = -(block_x as i32 + halo) * 4;
let max_x = (ref_w - block_x as i32 - block_w as i32 - halo).max(0) * 4;
let min_y = -(block_y as i32 + halo) * 4;
let max_y = (ref_h - block_y as i32 - block_h as i32 - halo).max(0) * 4;
MotionVector {
mv_x: (mv.mv_x as i32).clamp(min_x, max_x) as i16,
mv_y: (mv.mv_y as i32).clamp(min_y, max_y) as i16,
}
}
fn median3(a: i32, b: i32, c: i32) -> i32 {
(a + b + c) - a.max(b).max(c) - a.min(b).min(c)
}
pub fn median_mv_predictor(
left: Option<MotionVector>,
above: Option<MotionVector>,
above_right: Option<MotionVector>,
above_left: Option<MotionVector>,
) -> MotionVector {
let c = above_right.or(above_left);
let mut available: Vec<MotionVector> = Vec::new();
if let Some(m) = left {
available.push(m);
}
if let Some(m) = above {
available.push(m);
}
if let Some(m) = c {
available.push(m);
}
match available.len() {
0 => MotionVector::ZERO,
1 => available[0],
2 => {
let a = available[0];
let b = available[1];
MotionVector {
mv_x: median3(a.mv_x as i32, b.mv_x as i32, 0) as i16,
mv_y: median3(a.mv_y as i32, b.mv_y as i32, 0) as i16,
}
}
3 => {
let a = available[0];
let b = available[1];
let c = available[2];
MotionVector {
mv_x: median3(a.mv_x as i32, b.mv_x as i32, c.mv_x as i32) as i16,
mv_y: median3(a.mv_y as i32, b.mv_y as i32, c.mv_y as i32) as i16,
}
}
_ => unreachable!(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::reconstruction::ReconBuffer;
fn build_ref(width: u32, height: u32, fill: impl Fn(u32, u32) -> u8) -> ReconFrame {
let mut rb = ReconBuffer::new(width, height).unwrap();
for y in 0..height {
for x in 0..width {
rb.y[(y * width + x) as usize] = fill(x, y);
}
}
for v in rb.cb.iter_mut() { *v = 128; }
for v in rb.cr.iter_mut() { *v = 128; }
ReconFrame::snapshot(&rb)
}
fn extract_block(frame: &ReconFrame, x: u32, y: u32) -> [[u8; 16]; 16] {
let mut b = [[0u8; 16]; 16];
for dy in 0..16 {
for dx in 0..16 {
b[dy][dx] = frame.y_at(x + dx as u32, y + dy as u32);
}
}
b
}
#[test]
fn median3_basic() {
assert_eq!(median3(1, 2, 3), 2);
assert_eq!(median3(5, 1, 3), 3);
assert_eq!(median3(-1, -5, 0), -1);
assert_eq!(median3(7, 7, 2), 7);
}
#[test]
fn median_predictor_no_neighbors() {
assert_eq!(
median_mv_predictor(None, None, None, None),
MotionVector::ZERO
);
}
#[test]
fn median_predictor_single_neighbor() {
let mv = MotionVector { mv_x: 12, mv_y: -4 };
assert_eq!(
median_mv_predictor(Some(mv), None, None, None),
mv,
);
}
#[test]
fn median_predictor_three_neighbors_component_median() {
let a = MotionVector { mv_x: 10, mv_y: 5 };
let b = MotionVector { mv_x: 20, mv_y: -5 };
let c = MotionVector { mv_x: 15, mv_y: 0 };
let med = median_mv_predictor(Some(a), Some(b), Some(c), None);
assert_eq!(med.mv_x, 15); assert_eq!(med.mv_y, 0); }
#[test]
fn me_identical_frame_finds_zero_mv() {
let reference = build_ref(64, 48, |x, y| ((x * 7 + y * 3) & 0xFF) as u8);
let source = extract_block(&reference, 16, 16);
let mut me = MotionEstimator::new();
let r = me.search_16x16(&source, &reference, 16, 16, MotionVector::ZERO);
assert_eq!(r.mv, MotionVector::ZERO);
assert_eq!(r.cost, 0, "identical source+ref should have zero SATD");
}
struct UmhOffGuard;
impl UmhOffGuard {
fn new() -> Self {
unsafe { std::env::set_var("PHASM_ME_UMH", "0"); }
Self
}
}
impl Drop for UmhOffGuard {
fn drop(&mut self) {
unsafe { std::env::remove_var("PHASM_ME_UMH"); }
}
}
#[test]
fn me_integer_translation_finds_shift() {
let _g = UmhOffGuard::new();
let reference = build_ref(64, 48, |x, y| ((x * 11 + y * 7) & 0xFF) as u8);
let mut source = [[0u8; 16]; 16];
for dy in 0..16 {
for dx in 0..16 {
source[dy][dx] = reference.y_at(20 + dx as u32, 16 + dy as u32);
}
}
let mut me = MotionEstimator::new();
let r = me.search_16x16(&source, &reference, 16, 16, MotionVector::ZERO);
assert_eq!(r.cost, 0, "exact-match ME should cost 0");
}
#[test]
fn sad_self_equals_zero() {
let b = [[100u8; 16]; 16];
assert_eq!(sad_block(b.as_flattened(), 16, b.as_flattened(), 16, 16, 16), 0);
}
#[test]
fn sad_constant_offset() {
let a = [[100u8; 16]; 16];
let b = [[103u8; 16]; 16];
assert_eq!(
sad_block(a.as_flattened(), 16, b.as_flattened(), 16, 16, 16),
16 * 16 * 3
);
}
#[test]
fn mv_clipping_within_frame_is_noop() {
let reference = build_ref(64, 48, |_, _| 0);
let mv = MotionVector { mv_x: 4, mv_y: 8 };
let clipped = clip_mv_to_frame(mv, &reference, 16, 16, 16, 16);
assert_eq!(clipped, mv, "in-bounds MV should not be clipped");
}
#[test]
fn mv_clipping_large_negative() {
let reference = build_ref(64, 48, |_, _| 0);
let mv = MotionVector { mv_x: -1000, mv_y: -1000 };
let clipped = clip_mv_to_frame(mv, &reference, 16, 16, 16, 16);
assert!(clipped.mv_x > -1000);
assert!(clipped.mv_y > -1000);
}
#[test]
fn search_block_matches_16x16_on_identity() {
let reference = build_ref(64, 48, |x, y| ((x * 11 + y * 7) & 0xFF) as u8);
let source = extract_block(&reference, 16, 16);
let mut me = MotionEstimator::new();
let r = me.search_block(
source.as_flattened(),
16,
&reference,
16,
16,
16,
16,
MotionVector::ZERO,
);
assert_eq!(r.mv, MotionVector::ZERO);
assert_eq!(r.cost, 0);
}
#[test]
fn search_block_finds_shift_on_8x8() {
let reference = build_ref(64, 48, |x, y| ((x * 5 + y * 3) & 0xFF) as u8);
let mut source = [[0u8; 8]; 8];
for dy in 0..8 {
for dx in 0..8 {
source[dy][dx] = reference.y_at(20 + dx as u32, 16 + dy as u32);
}
}
let mut me = MotionEstimator::new();
let r = me.search_block(
source.as_flattened(),
8,
&reference,
16,
16,
8,
8,
MotionVector::ZERO,
);
assert_eq!(r.mv, MotionVector { mv_x: 16, mv_y: 0 });
assert_eq!(r.cost, 0);
}
#[test]
fn search_block_finds_shift_on_4x4() {
let reference = build_ref(64, 48, |x, y| ((x * 5 + y * 3) & 0xFF) as u8);
let mut source = [[0u8; 4]; 4];
for dy in 0..4 {
for dx in 0..4 {
source[dy][dx] = reference.y_at(20 + dx as u32, 16 + dy as u32);
}
}
let mut me = MotionEstimator::new();
let r = me.search_block(
source.as_flattened(),
4,
&reference,
16,
16,
4,
4,
MotionVector::ZERO,
);
assert_eq!(r.mv, MotionVector { mv_x: 16, mv_y: 0 });
assert_eq!(r.cost, 0);
}
}