mod analyse;
mod compensate;
mod pyramid;
pub(crate) use analyse::run_analyse;
pub(crate) use compensate::run_compensate;
use cubecl::prelude::*;
use cubecl::server::Handle;
pub(crate) use pyramid::{pyramid_pixels_per_frame, run_pyramid_build};
#[non_exhaustive]
#[derive(Debug, Default, Clone, Copy, PartialEq)]
pub enum MotionCompensationMode {
#[default]
None,
Mvtools {
blksize: u32,
overlap: u32,
search_radius: u32,
pyramid_levels: u32,
},
}
pub const DEFAULT_BLKSIZE: u32 = 16;
pub const DEFAULT_OVERLAP: u32 = 8;
pub const DEFAULT_SEARCH_RADIUS: u32 = 4;
pub const DEFAULT_PYRAMID_LEVELS: u32 = 2;
pub const MAX_PYRAMID_LEVELS: u32 = 3;
pub const MAX_SEARCH_RADIUS: u32 = 8;
pub const MAX_BLKSIZE: u32 = 32;
impl MotionCompensationMode {
pub fn mvtools_default() -> Self {
Self::Mvtools {
blksize: DEFAULT_BLKSIZE,
overlap: DEFAULT_OVERLAP,
search_radius: DEFAULT_SEARCH_RADIUS,
pyramid_levels: DEFAULT_PYRAMID_LEVELS,
}
}
pub(crate) fn is_active(self) -> bool {
!matches!(self, Self::None)
}
pub fn validate(&self) -> Result<(), anyhow::Error> {
let Self::Mvtools {
blksize,
overlap,
search_radius,
pyramid_levels,
} = *self
else {
return Ok(());
};
if blksize < 4 {
anyhow::bail!("motion-compensation blksize={blksize} is too small; minimum is 4 pixels per side");
}
if blksize > MAX_BLKSIZE {
anyhow::bail!(
"motion-compensation blksize={blksize} exceeds the supported maximum ({MAX_BLKSIZE})"
);
}
if blksize % 2 != 0 {
anyhow::bail!(
"motion-compensation blksize={blksize} must be even so the /2 coarse level is well-defined"
);
}
if overlap >= blksize {
anyhow::bail!(
"motion-compensation overlap={overlap} must be strictly less than blksize ({blksize}) so step > 0"
);
}
if search_radius == 0 || search_radius > MAX_SEARCH_RADIUS {
anyhow::bail!(
"motion-compensation search_radius={search_radius} must be in 1..={MAX_SEARCH_RADIUS}"
);
}
if pyramid_levels == 0 || pyramid_levels > MAX_PYRAMID_LEVELS {
anyhow::bail!(
"motion-compensation pyramid_levels={pyramid_levels} must be in 1..={MAX_PYRAMID_LEVELS}"
);
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub(crate) struct MotionCtx {
pub blksize: u32,
pub step: u32,
pub search_radius: u32,
pub pyramid_levels: u32,
pub blocks_x: u32,
pub blocks_y: u32,
}
impl MotionCtx {
pub fn new(mode: MotionCompensationMode, width: u32, height: u32) -> Option<Self> {
let MotionCompensationMode::Mvtools {
blksize,
overlap,
search_radius,
pyramid_levels,
} = mode
else {
return None;
};
let step = blksize - overlap;
let blocks_x = width.div_ceil(step).max(1);
let blocks_y = height.div_ceil(step).max(1);
Some(Self {
blksize,
step,
search_radius,
pyramid_levels,
blocks_x,
blocks_y,
})
}
pub fn mv_slots_per_neighbour(&self) -> usize {
(self.blocks_x * self.blocks_y) as usize
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn build_pyramid_for_slot<R: Runtime>(
client: &ComputeClient<R>,
mc: &MotionCtx,
width: u32,
height: u32,
frame_count: u32,
slot: u32,
full_res: &Handle,
pyramid: &Handle,
stored_ch: u32,
) -> Result<(), anyhow::Error> {
if mc.pyramid_levels <= 1 {
return Ok(());
}
run_pyramid_build::<R>(
client,
mc,
width,
height,
frame_count,
slot,
full_res,
pyramid,
stored_ch,
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn none_is_inactive() {
let m = MotionCompensationMode::None;
assert!(!m.is_active());
m.validate().unwrap();
}
#[test]
fn mvtools_default_is_active() {
let m = MotionCompensationMode::mvtools_default();
assert!(m.is_active());
m.validate().unwrap();
}
#[test]
fn validate_rejects_tiny_blksize() {
let m = MotionCompensationMode::Mvtools {
blksize: 2,
overlap: 0,
search_radius: 4,
pyramid_levels: 2,
};
assert!(m.validate().is_err());
}
#[test]
fn validate_rejects_odd_blksize() {
let m = MotionCompensationMode::Mvtools {
blksize: 9,
overlap: 0,
search_radius: 4,
pyramid_levels: 2,
};
assert!(m.validate().is_err());
}
#[test]
fn validate_rejects_overlap_equal_to_blksize() {
let m = MotionCompensationMode::Mvtools {
blksize: 16,
overlap: 16,
search_radius: 4,
pyramid_levels: 2,
};
assert!(m.validate().is_err());
}
#[test]
fn validate_accepts_half_overlap() {
let m = MotionCompensationMode::Mvtools {
blksize: 16,
overlap: 8,
search_radius: 4,
pyramid_levels: 2,
};
m.validate().unwrap();
}
#[test]
fn validate_rejects_zero_search_radius() {
let m = MotionCompensationMode::Mvtools {
blksize: 16,
overlap: 4,
search_radius: 0,
pyramid_levels: 2,
};
assert!(m.validate().is_err());
}
#[test]
fn validate_rejects_zero_pyramid_levels() {
let m = MotionCompensationMode::Mvtools {
blksize: 16,
overlap: 4,
search_radius: 4,
pyramid_levels: 0,
};
assert!(m.validate().is_err());
}
#[test]
fn motion_ctx_blocks_match_step() {
let mode = MotionCompensationMode::Mvtools {
blksize: 16,
overlap: 8,
search_radius: 4,
pyramid_levels: 2,
};
let ctx = MotionCtx::new(mode, 1920, 1080).unwrap();
assert_eq!(ctx.step, 8);
assert_eq!(ctx.blocks_x, 1920u32.div_ceil(8));
assert_eq!(ctx.blocks_y, 1080u32.div_ceil(8));
}
}