use crate::error::{CvError, CvResult};
use crate::interpolate::optical_flow::FlowField;
use bytes::Bytes;
use oximedia_codec::{Plane, VideoFrame};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum WarpMode {
Nearest,
#[default]
Bilinear,
Bicubic,
}
pub struct Warper {
mode: WarpMode,
}
impl Warper {
#[must_use]
pub const fn new(mode: WarpMode) -> Self {
Self { mode }
}
pub fn set_mode(&mut self, mode: WarpMode) {
self.mode = mode;
}
pub fn warp_forward(
&self,
frame: &VideoFrame,
flow: &FlowField,
t: f32,
) -> CvResult<VideoFrame> {
self.validate_frame_and_flow(frame, flow)?;
let mut result = VideoFrame::new(frame.format, frame.width, frame.height);
result.allocate();
result.timestamp = frame.timestamp;
result.frame_type = frame.frame_type;
result.color_info = frame.color_info;
for plane_idx in 0..frame.planes.len() {
let (plane_w, plane_h) = frame.plane_dimensions(plane_idx);
self.warp_plane_forward(
&frame.planes[plane_idx],
&mut result.planes[plane_idx],
flow,
plane_w,
plane_h,
t,
)?;
}
Ok(result)
}
pub fn warp_backward(
&self,
frame: &VideoFrame,
flow: &FlowField,
t: f32,
) -> CvResult<VideoFrame> {
self.validate_frame_and_flow(frame, flow)?;
let mut result = VideoFrame::new(frame.format, frame.width, frame.height);
result.allocate();
result.timestamp = frame.timestamp;
result.frame_type = frame.frame_type;
result.color_info = frame.color_info;
for plane_idx in 0..frame.planes.len() {
let (plane_w, plane_h) = frame.plane_dimensions(plane_idx);
self.warp_plane_backward(
&frame.planes[plane_idx],
&mut result.planes[plane_idx],
flow,
plane_w,
plane_h,
t,
)?;
}
Ok(result)
}
#[allow(clippy::too_many_arguments)]
fn warp_plane_forward(
&self,
src_plane: &Plane,
dst_plane: &mut Plane,
flow: &FlowField,
width: u32,
height: u32,
t: f32,
) -> CvResult<()> {
let mut dst_data = vec![0u8; (width * height) as usize];
let mut weight_map = vec![0.0f32; (width * height) as usize];
let scale_x = width as f32 / flow.width as f32;
let scale_y = height as f32 / flow.height as f32;
for y in 0..height {
for x in 0..width {
let src_idx = (y * width + x) as usize;
if src_idx >= src_plane.data.len() {
continue;
}
let src_value = src_plane.data[src_idx];
let flow_x = (x as f32 / scale_x) as u32;
let flow_y = (y as f32 / scale_y) as u32;
let (dx, dy) = flow.get(flow_x, flow_y);
let dest_x = x as f32 + dx * scale_x * t;
let dest_y = y as f32 + dy * scale_y * t;
self.splat_pixel(
&mut dst_data,
&mut weight_map,
width,
height,
dest_x,
dest_y,
src_value,
);
}
}
for i in 0..dst_data.len() {
if weight_map[i] > 0.0 {
dst_data[i] = (dst_data[i] as f32 / weight_map[i]) as u8;
} else {
let x = (i % width as usize) as u32;
let y = (i / width as usize) as u32;
dst_data[i] = self.get_pixel_nearest(src_plane, width, height, x as f32, y as f32);
}
}
dst_plane.data = dst_data;
dst_plane.stride = width as usize;
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn warp_plane_backward(
&self,
src_plane: &Plane,
dst_plane: &mut Plane,
flow: &FlowField,
width: u32,
height: u32,
t: f32,
) -> CvResult<()> {
let mut dst_data = vec![0u8; (width * height) as usize];
let scale_x = width as f32 / flow.width as f32;
let scale_y = height as f32 / flow.height as f32;
for y in 0..height {
for x in 0..width {
let flow_x = (x as f32 / scale_x) as u32;
let flow_y = (y as f32 / scale_y) as u32;
let (dx, dy) = flow.get(flow_x, flow_y);
let src_x = x as f32 - dx * scale_x * t;
let src_y = y as f32 - dy * scale_y * t;
let value = match self.mode {
WarpMode::Nearest => {
self.get_pixel_nearest(src_plane, width, height, src_x, src_y)
}
WarpMode::Bilinear => {
self.get_pixel_bilinear(src_plane, width, height, src_x, src_y)
}
WarpMode::Bicubic => {
self.get_pixel_bicubic(src_plane, width, height, src_x, src_y)
}
};
let dst_idx = (y * width + x) as usize;
dst_data[dst_idx] = value;
}
}
dst_plane.data = dst_data;
dst_plane.stride = width as usize;
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn splat_pixel(
&self,
dst_data: &mut [u8],
weight_map: &mut [f32],
width: u32,
height: u32,
x: f32,
y: f32,
value: u8,
) {
let x0 = x.floor() as i32;
let y0 = y.floor() as i32;
let fx = x - x0 as f32;
let fy = y - y0 as f32;
let weights = [
(1.0 - fx) * (1.0 - fy),
fx * (1.0 - fy),
(1.0 - fx) * fy,
fx * fy,
];
let positions = [(x0, y0), (x0 + 1, y0), (x0, y0 + 1), (x0 + 1, y0 + 1)];
for i in 0..4 {
let (px, py) = positions[i];
if px >= 0 && px < width as i32 && py >= 0 && py < height as i32 {
let idx = (py as u32 * width + px as u32) as usize;
if idx < dst_data.len() {
dst_data[idx] += (value as f32 * weights[i]) as u8;
weight_map[idx] += weights[i];
}
}
}
}
fn get_pixel_nearest(&self, plane: &Plane, width: u32, height: u32, x: f32, y: f32) -> u8 {
let xi = x.round().clamp(0.0, (width - 1) as f32) as u32;
let yi = y.round().clamp(0.0, (height - 1) as f32) as u32;
let idx = (yi * width + xi) as usize;
if idx < plane.data.len() {
plane.data[idx]
} else {
0
}
}
fn get_pixel_bilinear(&self, plane: &Plane, width: u32, height: u32, x: f32, y: f32) -> u8 {
let x = x.clamp(0.0, (width - 1) as f32);
let y = y.clamp(0.0, (height - 1) as f32);
let x0 = x.floor() as u32;
let y0 = y.floor() as u32;
let x1 = (x0 + 1).min(width - 1);
let y1 = (y0 + 1).min(height - 1);
let fx = x - x0 as f32;
let fy = y - y0 as f32;
let p00 = self.get_pixel_at(plane, width, x0, y0);
let p10 = self.get_pixel_at(plane, width, x1, y0);
let p01 = self.get_pixel_at(plane, width, x0, y1);
let p11 = self.get_pixel_at(plane, width, x1, y1);
let v0 = p00 as f32 * (1.0 - fx) + p10 as f32 * fx;
let v1 = p01 as f32 * (1.0 - fx) + p11 as f32 * fx;
let result = v0 * (1.0 - fy) + v1 * fy;
result.round().clamp(0.0, 255.0) as u8
}
fn get_pixel_bicubic(&self, plane: &Plane, width: u32, height: u32, x: f32, y: f32) -> u8 {
let x = x.clamp(0.0, (width - 1) as f32);
let y = y.clamp(0.0, (height - 1) as f32);
let x0 = x.floor() as i32;
let y0 = y.floor() as i32;
let fx = x - x0 as f32;
let fy = y - y0 as f32;
let mut result = 0.0f32;
for j in -1..=2 {
let mut row_result = 0.0f32;
for i in -1..=2 {
let px = (x0 + i).clamp(0, width as i32 - 1) as u32;
let py = (y0 + j).clamp(0, height as i32 - 1) as u32;
let pixel = self.get_pixel_at(plane, width, px, py) as f32;
row_result += pixel * cubic_weight(fx - i as f32);
}
result += row_result * cubic_weight(fy - j as f32);
}
result.round().clamp(0.0, 255.0) as u8
}
fn get_pixel_at(&self, plane: &Plane, width: u32, x: u32, y: u32) -> u8 {
let idx = (y * width + x) as usize;
if idx < plane.data.len() {
plane.data[idx]
} else {
0
}
}
fn validate_frame_and_flow(&self, frame: &VideoFrame, flow: &FlowField) -> CvResult<()> {
if frame.width != flow.width || frame.height != flow.height {
return Err(CvError::invalid_dimensions(frame.width, frame.height));
}
if frame.planes.is_empty() {
return Err(CvError::insufficient_data(1, 0));
}
Ok(())
}
}
impl Default for Warper {
fn default() -> Self {
Self::new(WarpMode::Bilinear)
}
}
fn cubic_weight(x: f32) -> f32 {
let x = x.abs();
if x < 1.0 {
(1.5 * x - 2.5) * x * x + 1.0
} else if x < 2.0 {
((-0.5 * x + 2.5) * x - 4.0) * x + 2.0
} else {
0.0
}
}