use super::energy::{compute_cumulative_energy, EnergyFunction, EnergyMap};
use crate::error::{CvError, CvResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum EnergyMode {
#[default]
Backward,
Forward,
}
pub fn compute_forward_energy_map(image: &[u8], width: u32, height: u32) -> CvResult<Vec<f32>> {
if width == 0 || height == 0 {
return Err(CvError::invalid_dimensions(width, height));
}
let w = width as usize;
let h = height as usize;
let expected = w * h;
if image.len() < expected {
return Err(CvError::insufficient_data(expected, image.len()));
}
let px = |x: i32, y: i32| -> f32 {
if x < 0 || y < 0 || x >= width as i32 || y >= height as i32 {
0.0
} else {
image[y as usize * w + x as usize] as f32
}
};
let mut m = vec![0.0f32; w * h];
for x in 0..w {
let xi = x as i32;
let c_u = (px(xi + 1, 0) - px(xi - 1, 0)).abs();
m[x] = c_u;
}
for y in 1..h {
let yi = y as i32;
for x in 0..w {
let xi = x as i32;
let c_u = (px(xi + 1, yi) - px(xi - 1, yi)).abs();
let c_l = c_u + (px(xi - 1, yi) - px(xi, yi - 1)).abs();
let c_r = c_u + (px(xi + 1, yi) - px(xi, yi - 1)).abs();
let m_up = m[(y - 1) * w + x];
let m_ul = if x > 0 {
m[(y - 1) * w + x - 1]
} else {
f32::INFINITY
};
let m_ur = if x + 1 < w {
m[(y - 1) * w + x + 1]
} else {
f32::INFINITY
};
let from_up = m_up + c_u;
let from_left = m_ul + c_l;
let from_right = m_ur + c_r;
m[y * w + x] = from_up.min(from_left).min(from_right);
}
}
Ok(m)
}
pub fn find_vertical_seam_forward(cost_map: &[f32], width: u32, height: u32) -> CvResult<Seam> {
if width == 0 || height == 0 {
return Err(CvError::invalid_dimensions(width, height));
}
let w = width as usize;
let h = height as usize;
let expected = w * h;
if cost_map.len() < expected {
return Err(CvError::insufficient_data(expected, cost_map.len()));
}
let last_row_start = (h - 1) * w;
let (min_x, min_cost) = cost_map[last_row_start..last_row_start + w]
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, &v)| (i, v))
.unwrap_or((0, 0.0));
let mut path = vec![0u32; h];
path[h - 1] = min_x as u32;
for y in (0..h - 1).rev() {
let x = path[y + 1] as usize;
let mut best_x = x;
let mut best_cost = cost_map[y * w + x];
if x > 0 {
let c = cost_map[y * w + x - 1];
if c < best_cost {
best_cost = c;
best_x = x - 1;
}
}
if x + 1 < w {
let c = cost_map[y * w + x + 1];
if c < best_cost {
best_x = x + 1;
}
}
path[y] = best_x as u32;
}
Ok(Seam::new(path, min_cost as f64, true))
}
#[derive(Debug, Clone)]
pub struct Seam {
pub path: Vec<u32>,
pub energy: f64,
pub vertical: bool,
}
impl Seam {
#[must_use]
pub fn new(path: Vec<u32>, energy: f64, vertical: bool) -> Self {
Self {
path,
energy,
vertical,
}
}
#[must_use]
pub fn len(&self) -> usize {
self.path.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.path.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct SeamCarver {
energy_function: EnergyFunction,
energy_mode: EnergyMode,
protection_mask: Option<Vec<u8>>,
protection_scale: f64,
}
impl SeamCarver {
#[must_use]
pub fn new(energy_function: EnergyFunction) -> Self {
Self {
energy_function,
energy_mode: EnergyMode::Backward,
protection_mask: None,
protection_scale: 1000.0,
}
}
#[must_use]
pub fn new_with_mode(energy_function: EnergyFunction, energy_mode: EnergyMode) -> Self {
Self {
energy_function,
energy_mode,
protection_mask: None,
protection_scale: 1000.0,
}
}
pub fn set_energy_mode(&mut self, mode: EnergyMode) {
self.energy_mode = mode;
}
pub fn set_protection_mask(&mut self, mask: Vec<u8>) {
self.protection_mask = Some(mask);
}
pub fn set_protection_scale(&mut self, scale: f64) {
self.protection_scale = scale;
}
pub fn find_vertical_seam(&self, image: &[u8], width: u32, height: u32) -> CvResult<Seam> {
match self.energy_mode {
EnergyMode::Forward => {
let cost_map = compute_forward_energy_map(image, width, height)?;
find_vertical_seam_forward(&cost_map, width, height)
}
EnergyMode::Backward => {
let energy = self.compute_energy(image, width, height)?;
Ok(find_min_vertical_seam(&energy))
}
}
}
pub fn find_horizontal_seam(&self, image: &[u8], width: u32, height: u32) -> CvResult<Seam> {
let energy = self.compute_energy(image, width, height)?;
Ok(find_min_horizontal_seam(&energy))
}
pub fn remove_vertical_seam(
&self,
image: &[u8],
width: u32,
height: u32,
seam: &Seam,
) -> CvResult<Vec<u8>> {
if !seam.vertical {
return Err(CvError::invalid_parameter("seam", "expected vertical seam"));
}
if seam.path.len() != height as usize {
return Err(CvError::invalid_parameter(
"seam.path.len()",
format!("expected {}, got {}", height, seam.path.len()),
));
}
let new_width = width - 1;
let mut result = vec![0u8; new_width as usize * height as usize];
for y in 0..height as usize {
let seam_x = seam.path[y] as usize;
let src_row_start = y * width as usize;
let dst_row_start = y * new_width as usize;
for x in 0..seam_x {
result[dst_row_start + x] = image[src_row_start + x];
}
for x in seam_x + 1..width as usize {
result[dst_row_start + x - 1] = image[src_row_start + x];
}
}
Ok(result)
}
pub fn remove_horizontal_seam(
&self,
image: &[u8],
width: u32,
height: u32,
seam: &Seam,
) -> CvResult<Vec<u8>> {
if seam.vertical {
return Err(CvError::invalid_parameter(
"seam",
"expected horizontal seam",
));
}
if seam.path.len() != width as usize {
return Err(CvError::invalid_parameter(
"seam.path.len()",
format!("expected {}, got {}", width, seam.path.len()),
));
}
let new_height = height - 1;
let mut result = vec![0u8; width as usize * new_height as usize];
for x in 0..width as usize {
let seam_y = seam.path[x] as usize;
let mut dst_y = 0;
for y in 0..seam_y {
result[dst_y * width as usize + x] = image[y * width as usize + x];
dst_y += 1;
}
for y in seam_y + 1..height as usize {
result[dst_y * width as usize + x] = image[y * width as usize + x];
dst_y += 1;
}
}
Ok(result)
}
pub fn insert_vertical_seam(
&self,
image: &[u8],
width: u32,
height: u32,
seam: &Seam,
) -> CvResult<Vec<u8>> {
if !seam.vertical {
return Err(CvError::invalid_parameter("seam", "expected vertical seam"));
}
if seam.path.len() != height as usize {
return Err(CvError::invalid_parameter(
"seam.path.len()",
format!("expected {}, got {}", height, seam.path.len()),
));
}
let new_width = width + 1;
let mut result = vec![0u8; new_width as usize * height as usize];
for y in 0..height as usize {
let seam_x = seam.path[y] as usize;
let src_row_start = y * width as usize;
let dst_row_start = y * new_width as usize;
for x in 0..seam_x {
result[dst_row_start + x] = image[src_row_start + x];
}
result[dst_row_start + seam_x] = image[src_row_start + seam_x];
if seam_x < width as usize - 1 {
let left = image[src_row_start + seam_x] as u16;
let right = image[src_row_start + seam_x + 1] as u16;
result[dst_row_start + seam_x + 1] = ((left + right) / 2) as u8;
} else {
result[dst_row_start + seam_x + 1] = image[src_row_start + seam_x];
}
for x in seam_x + 1..width as usize {
result[dst_row_start + x + 1] = image[src_row_start + x];
}
}
Ok(result)
}
pub fn insert_horizontal_seam(
&self,
image: &[u8],
width: u32,
height: u32,
seam: &Seam,
) -> CvResult<Vec<u8>> {
if seam.vertical {
return Err(CvError::invalid_parameter(
"seam",
"expected horizontal seam",
));
}
if seam.path.len() != width as usize {
return Err(CvError::invalid_parameter(
"seam.path.len()",
format!("expected {}, got {}", width, seam.path.len()),
));
}
let new_height = height + 1;
let mut result = vec![0u8; width as usize * new_height as usize];
for x in 0..width as usize {
let seam_y = seam.path[x] as usize;
let mut dst_y = 0;
for y in 0..seam_y {
result[dst_y * width as usize + x] = image[y * width as usize + x];
dst_y += 1;
}
result[dst_y * width as usize + x] = image[seam_y * width as usize + x];
dst_y += 1;
if seam_y < height as usize - 1 {
let top = image[seam_y * width as usize + x] as u16;
let bottom = image[(seam_y + 1) * width as usize + x] as u16;
result[dst_y * width as usize + x] = ((top + bottom) / 2) as u8;
} else {
result[dst_y * width as usize + x] = image[seam_y * width as usize + x];
}
dst_y += 1;
for y in seam_y + 1..height as usize {
result[dst_y * width as usize + x] = image[y * width as usize + x];
dst_y += 1;
}
}
Ok(result)
}
pub fn reduce_width(
&self,
image: &[u8],
width: u32,
height: u32,
target_width: u32,
) -> CvResult<Vec<u8>> {
if target_width >= width {
return Err(CvError::invalid_parameter(
"target_width",
"must be less than current width",
));
}
let mut current_image = image.to_vec();
let mut current_width = width;
while current_width > target_width {
let seam = self.find_vertical_seam(¤t_image, current_width, height)?;
current_image =
self.remove_vertical_seam(¤t_image, current_width, height, &seam)?;
current_width -= 1;
}
Ok(current_image)
}
pub fn reduce_height(
&self,
image: &[u8],
width: u32,
height: u32,
target_height: u32,
) -> CvResult<Vec<u8>> {
if target_height >= height {
return Err(CvError::invalid_parameter(
"target_height",
"must be less than current height",
));
}
let mut current_image = image.to_vec();
let mut current_height = height;
while current_height > target_height {
let seam = self.find_horizontal_seam(¤t_image, width, current_height)?;
current_image =
self.remove_horizontal_seam(¤t_image, width, current_height, &seam)?;
current_height -= 1;
}
Ok(current_image)
}
pub fn enlarge_width(
&self,
image: &[u8],
width: u32,
height: u32,
target_width: u32,
) -> CvResult<Vec<u8>> {
if target_width <= width {
return Err(CvError::invalid_parameter(
"target_width",
"must be greater than current width",
));
}
let num_seams = target_width - width;
let mut seams = Vec::new();
let mut temp_image = image.to_vec();
let mut temp_width = width;
for _ in 0..num_seams {
let seam = self.find_vertical_seam(&temp_image, temp_width, height)?;
seams.push(seam.clone());
temp_image = self.remove_vertical_seam(&temp_image, temp_width, height, &seam)?;
temp_width -= 1;
}
let mut result = image.to_vec();
let mut current_width = width;
for (i, seam) in seams.iter().enumerate() {
let mut adjusted_path = seam.path.clone();
let path_len = adjusted_path.len();
for idx in 0..path_len {
let current_val = adjusted_path[idx];
let mut offset = 0;
for prev_seam in &seams[..i] {
let prev_path_len = prev_seam.path.len();
if prev_path_len > 0
&& idx < prev_path_len
&& current_val >= prev_seam.path[idx]
{
offset += 1;
}
}
adjusted_path[idx] += offset;
}
let adjusted_seam = Seam::new(adjusted_path, seam.energy, true);
result = self.insert_vertical_seam(&result, current_width, height, &adjusted_seam)?;
current_width += 1;
}
Ok(result)
}
fn compute_energy(&self, image: &[u8], width: u32, height: u32) -> CvResult<EnergyMap> {
let energy_data = self.energy_function.compute(image, width, height)?;
let mut energy_map = EnergyMap::from_data(energy_data, width, height)?;
if let Some(ref mask) = self.protection_mask {
energy_map.apply_protection_mask(mask, self.protection_scale);
}
Ok(energy_map)
}
}
fn find_min_vertical_seam(energy: &EnergyMap) -> Seam {
let cumulative = compute_cumulative_energy(energy, true);
let w = energy.width as usize;
let h = energy.height as usize;
let last_row_start = (h - 1) * w;
let (min_x, min_energy) = cumulative.data[last_row_start..last_row_start + w]
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, &e)| (i, e))
.unwrap_or((0, 0.0));
let mut path = vec![0u32; h];
path[h - 1] = min_x as u32;
for y in (0..h - 1).rev() {
let x = path[y + 1] as usize;
let mut min_prev_x = x;
let mut min_prev_energy = cumulative.data[y * w + x];
if x > 0 {
let left_energy = cumulative.data[y * w + x - 1];
if left_energy < min_prev_energy {
min_prev_energy = left_energy;
min_prev_x = x - 1;
}
}
if x < w - 1 {
let right_energy = cumulative.data[y * w + x + 1];
if right_energy < min_prev_energy {
min_prev_x = x + 1;
}
}
path[y] = min_prev_x as u32;
}
Seam::new(path, min_energy, true)
}
fn find_min_horizontal_seam(energy: &EnergyMap) -> Seam {
let cumulative = compute_cumulative_energy(energy, false);
let w = energy.width as usize;
let h = energy.height as usize;
let (min_y, min_energy) = (0..h)
.map(|y| (y, cumulative.data[y * w + w - 1]))
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or((0, 0.0));
let mut path = vec![0u32; w];
path[w - 1] = min_y as u32;
for x in (0..w - 1).rev() {
let y = path[x + 1] as usize;
let mut min_prev_y = y;
let mut min_prev_energy = cumulative.data[y * w + x];
if y > 0 {
let top_energy = cumulative.data[(y - 1) * w + x];
if top_energy < min_prev_energy {
min_prev_energy = top_energy;
min_prev_y = y - 1;
}
}
if y < h - 1 {
let bottom_energy = cumulative.data[(y + 1) * w + x];
if bottom_energy < min_prev_energy {
min_prev_y = y + 1;
}
}
path[x] = min_prev_y as u32;
}
Seam::new(path, min_energy, false)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_forward_energy_uniform_image_valid_seam() {
let w = 12u32;
let h = 8u32;
let image = vec![128u8; (w * h) as usize];
let cost_map = compute_forward_energy_map(&image, w, h)
.expect("compute_forward_energy_map should succeed on uniform image");
let seam = find_vertical_seam_forward(&cost_map, w, h)
.expect("find_vertical_seam_forward should succeed");
assert_eq!(seam.path.len(), h as usize);
assert!(seam.vertical);
for row in 1..h as usize {
let diff = (seam.path[row] as i32 - seam.path[row - 1] as i32).unsigned_abs();
assert!(diff <= 1, "column jump > 1 at row {row}");
}
for &col in &seam.path {
assert!(col < w, "column {col} >= width {w}");
}
}
#[test]
fn test_forward_energy_gradient_image_valid_seam() {
let w = 20u32;
let h = 10u32;
let image: Vec<u8> = (0..h)
.flat_map(|_y| (0..w).map(|x| (x * 255 / (w - 1)) as u8))
.collect();
let cost_map = compute_forward_energy_map(&image, w, h)
.expect("compute_forward_energy_map should succeed");
let seam = find_vertical_seam_forward(&cost_map, w, h)
.expect("find_vertical_seam_forward should succeed");
assert_eq!(seam.path.len(), h as usize);
for row in 1..h as usize {
let diff = (seam.path[row] as i32 - seam.path[row - 1] as i32).unsigned_abs();
assert!(diff <= 1, "column jump > 1 at row {row}");
}
for &col in &seam.path {
assert!(col < w);
}
}
#[test]
fn test_forward_vs_backward_seams_differ_on_edge_pattern() {
let w = 16u32;
let h = 8u32;
let image: Vec<u8> = (0..h)
.flat_map(|_y| (0..w).map(|x| if x < 8 { 0u8 } else { 255u8 }))
.collect();
let bwd_carver = SeamCarver::new_with_mode(EnergyFunction::Gradient, EnergyMode::Backward);
let bwd_seam = bwd_carver
.find_vertical_seam(&image, w, h)
.expect("backward seam");
let fwd_carver = SeamCarver::new_with_mode(EnergyFunction::Gradient, EnergyMode::Forward);
let fwd_seam = fwd_carver
.find_vertical_seam(&image, w, h)
.expect("forward seam");
for seam in [&bwd_seam, &fwd_seam] {
assert_eq!(seam.path.len(), h as usize);
for row in 1..h as usize {
let diff = (seam.path[row] as i32 - seam.path[row - 1] as i32).unsigned_abs();
assert!(diff <= 1);
}
}
let same = bwd_seam.path == fwd_seam.path;
if same {
eprintln!(
"[WARN] backward and forward seams happened to coincide on edge-pattern image"
);
}
}
#[test]
fn test_forward_energy_remove_seam_reduces_width() {
let w = 10u32;
let h = 6u32;
let image: Vec<u8> = (0..h)
.flat_map(|y| (0..w).map(move |x| ((y * w + x) % 256) as u8))
.collect();
let carver = SeamCarver::new_with_mode(EnergyFunction::Gradient, EnergyMode::Forward);
let seam = carver
.find_vertical_seam(&image, w, h)
.expect("find_vertical_seam forward");
let result = carver
.remove_vertical_seam(&image, w, h, &seam)
.expect("remove_vertical_seam");
assert_eq!(result.len(), ((w - 1) * h) as usize);
}
#[test]
fn test_seam_new() {
let seam = Seam::new(vec![0, 1, 2], 10.0, true);
assert_eq!(seam.len(), 3);
assert!(seam.vertical);
assert_eq!(seam.energy, 10.0);
}
#[test]
fn test_find_vertical_seam() {
let image = vec![128u8; 100];
let carver = SeamCarver::new(EnergyFunction::Gradient);
let seam = carver
.find_vertical_seam(&image, 10, 10)
.expect("find_vertical_seam should succeed");
assert_eq!(seam.len(), 10);
assert!(seam.vertical);
}
#[test]
fn test_remove_vertical_seam() {
let image = vec![128u8; 100];
let carver = SeamCarver::new(EnergyFunction::Gradient);
let seam = carver
.find_vertical_seam(&image, 10, 10)
.expect("find_vertical_seam should succeed");
let result = carver
.remove_vertical_seam(&image, 10, 10, &seam)
.expect("remove_vertical_seam should succeed");
assert_eq!(result.len(), 90); }
#[test]
fn test_insert_vertical_seam() {
let image = vec![128u8; 100];
let carver = SeamCarver::new(EnergyFunction::Gradient);
let seam = carver
.find_vertical_seam(&image, 10, 10)
.expect("find_vertical_seam should succeed");
let result = carver
.insert_vertical_seam(&image, 10, 10, &seam)
.expect("insert_vertical_seam should succeed");
assert_eq!(result.len(), 110); }
#[test]
fn test_reduce_width() {
let image = vec![128u8; 100];
let carver = SeamCarver::new(EnergyFunction::Gradient);
let result = carver
.reduce_width(&image, 10, 10, 8)
.expect("reduce_width should succeed");
assert_eq!(result.len(), 80); }
#[test]
fn test_reduce_height() {
let image = vec![128u8; 100];
let carver = SeamCarver::new(EnergyFunction::Gradient);
let result = carver
.reduce_height(&image, 10, 10, 8)
.expect("reduce_height should succeed");
assert_eq!(result.len(), 80); }
}