use thiserror::Error;
#[derive(Debug, Error)]
pub enum ScalingError {
#[error("invalid dimensions: {0}")]
InvalidDimensions(String),
#[error("buffer too small: expected {expected} bytes, got {actual}")]
InsufficientBuffer {
expected: usize,
actual: usize,
},
#[error("unsupported channel count: {0} (must be 1, 3, or 4)")]
UnsupportedChannels(usize),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EnergyFunction {
Gradient,
ForwardEnergy,
}
#[derive(Debug, Clone)]
pub struct SeamCarvingConfig {
pub target_width: u32,
pub target_height: u32,
pub energy_function: EnergyFunction,
pub protect_mask: Option<Vec<u8>>,
pub remove_mask: Option<Vec<u8>>,
}
impl SeamCarvingConfig {
pub fn new(target_width: u32, target_height: u32) -> Self {
Self {
target_width,
target_height,
energy_function: EnergyFunction::Gradient,
protect_mask: None,
remove_mask: None,
}
}
pub fn with_energy_function(mut self, func: EnergyFunction) -> Self {
self.energy_function = func;
self
}
pub fn with_protect_mask(mut self, mask: Vec<u8>) -> Self {
self.protect_mask = Some(mask);
self
}
pub fn with_remove_mask(mut self, mask: Vec<u8>) -> Self {
self.remove_mask = Some(mask);
self
}
}
pub struct SeamCarver {
config: SeamCarvingConfig,
}
impl SeamCarver {
pub fn new(config: SeamCarvingConfig) -> Self {
Self { config }
}
pub fn config(&self) -> &SeamCarvingConfig {
&self.config
}
pub fn carve(
&self,
image: &[u8],
width: u32,
height: u32,
channels: usize,
) -> Result<(Vec<u8>, u32, u32), ScalingError> {
if channels != 1 && channels != 3 && channels != 4 {
return Err(ScalingError::UnsupportedChannels(channels));
}
if width == 0 || height == 0 {
return Err(ScalingError::InvalidDimensions(
"source width and height must be non-zero".into(),
));
}
let expected = width as usize * height as usize * channels;
if image.len() < expected {
return Err(ScalingError::InsufficientBuffer {
expected,
actual: image.len(),
});
}
let tw = self.config.target_width;
let th = self.config.target_height;
if tw == 0 || th == 0 {
return Err(ScalingError::InvalidDimensions(
"target width and height must be non-zero".into(),
));
}
if tw > width {
return Err(ScalingError::InvalidDimensions(format!(
"target width {tw} exceeds source width {width}; seam carving can only reduce"
)));
}
if th > height {
return Err(ScalingError::InvalidDimensions(format!(
"target height {th} exceeds source height {height}; seam carving can only reduce"
)));
}
let mut pixels = image.to_vec();
let mut cur_w = width as usize;
let mut cur_h = height as usize;
let seams_x = width as usize - tw as usize;
for _ in 0..seams_x {
let gray = to_grayscale(&pixels, cur_w, cur_h, channels);
let energy = compute_energy(
&gray,
cur_w,
cur_h,
self.config.energy_function,
self.config.protect_mask.as_deref(),
self.config.remove_mask.as_deref(),
);
let seam = find_vertical_seam(&energy, cur_w, cur_h);
pixels = remove_vertical_seam(&pixels, cur_w, cur_h, channels, &seam);
cur_w -= 1;
}
let seams_y = height as usize - th as usize;
for _ in 0..seams_y {
pixels = transpose_image(&pixels, cur_w, cur_h, channels);
let t_w = cur_h;
let t_h = cur_w;
let gray = to_grayscale(&pixels, t_w, t_h, channels);
let energy = compute_energy(
&gray,
t_w,
t_h,
self.config.energy_function,
None,
None,
);
let seam = find_vertical_seam(&energy, t_w, t_h);
pixels = remove_vertical_seam(&pixels, t_w, t_h, channels, &seam);
pixels = transpose_image(&pixels, t_w - 1, t_h, channels);
cur_h -= 1;
}
Ok((pixels, cur_w as u32, cur_h as u32))
}
}
fn to_grayscale(pixels: &[u8], width: usize, height: usize, channels: usize) -> Vec<f64> {
let n = width * height;
let mut gray = vec![0.0f64; n];
match channels {
1 => {
for (i, &v) in pixels.iter().take(n).enumerate() {
gray[i] = v as f64;
}
}
3 | 4 => {
for i in 0..n {
let base = i * channels;
let r = pixels[base] as f64;
let g = pixels[base + 1] as f64;
let b = pixels[base + 2] as f64;
gray[i] = 0.299 * r + 0.587 * g + 0.114 * b;
}
}
_ => { }
}
gray
}
#[inline]
fn clamp_idx(val: isize, max: usize) -> usize {
val.clamp(0, max as isize - 1) as usize
}
fn compute_energy(
gray: &[f64],
width: usize,
height: usize,
func: EnergyFunction,
protect_mask: Option<&[u8]>,
remove_mask: Option<&[u8]>,
) -> Vec<f64> {
let mut energy = match func {
EnergyFunction::Gradient => gradient_energy(gray, width, height),
EnergyFunction::ForwardEnergy => forward_energy(gray, width, height),
};
const PROTECT_BIAS: f64 = 1_000_000.0;
const REMOVE_BIAS: f64 = -1_000_000.0;
if let Some(mask) = protect_mask {
for (i, &m) in mask.iter().take(width * height).enumerate() {
if m > 0 {
energy[i] += PROTECT_BIAS;
}
}
}
if let Some(mask) = remove_mask {
for (i, &m) in mask.iter().take(width * height).enumerate() {
if m > 0 {
energy[i] = (energy[i] + REMOVE_BIAS).max(-PROTECT_BIAS);
}
}
}
energy
}
fn gradient_energy(gray: &[f64], width: usize, height: usize) -> Vec<f64> {
let mut e = vec![0.0f64; width * height];
for row in 0..height {
for col in 0..width {
let left = gray[row * width + clamp_idx(col as isize - 1, width)];
let right = gray[row * width + clamp_idx(col as isize + 1, width)];
let up = gray[clamp_idx(row as isize - 1, height) * width + col];
let down = gray[clamp_idx(row as isize + 1, height) * width + col];
e[row * width + col] = (right - left).abs() + (down - up).abs();
}
}
e
}
fn forward_energy(gray: &[f64], width: usize, height: usize) -> Vec<f64> {
let mut e = vec![0.0f64; width * height];
for row in 0..height {
for col in 0..width {
let (c_u, c_l, c_r) = forward_costs(gray, width, height, row, col);
e[row * width + col] = c_u.min(c_l).min(c_r);
}
}
e
}
#[inline]
fn forward_costs(
gray: &[f64],
width: usize,
height: usize,
row: usize,
col: usize,
) -> (f64, f64, f64) {
let left = clamp_idx(col as isize - 1, width);
let right = clamp_idx(col as isize + 1, width);
let up = clamp_idx(row as isize - 1, height);
let i_left = gray[row * width + left];
let i_right = gray[row * width + right];
let i_up = gray[up * width + col];
let c_u = (i_right - i_left).abs();
let c_l = c_u + (i_up - i_left).abs();
let c_r = c_u + (i_up - i_right).abs();
(c_u, c_l, c_r)
}
fn find_vertical_seam(energy: &[f64], width: usize, height: usize) -> Vec<usize> {
let mut dp = energy[..width].to_vec();
let mut back = vec![vec![0usize; width]; height];
for row in 1..height {
let prev_dp = dp.clone();
for col in 0..width {
let left = if col > 0 {
prev_dp[col - 1]
} else {
f64::INFINITY
};
let center = prev_dp[col];
let right = if col + 1 < width {
prev_dp[col + 1]
} else {
f64::INFINITY
};
let (min_val, min_col) = if left <= center && left <= right {
(left, col.saturating_sub(1))
} else if center <= right {
(center, col)
} else {
(right, (col + 1).min(width - 1))
};
dp[col] = energy[row * width + col] + min_val;
back[row][col] = min_col;
}
}
let mut min_col = 0;
let mut min_val = dp[0];
for col in 1..width {
if dp[col] < min_val {
min_val = dp[col];
min_col = col;
}
}
let mut seam = vec![0usize; height];
seam[height - 1] = min_col;
for row in (0..height - 1).rev() {
seam[row] = back[row + 1][seam[row + 1]];
}
seam
}
fn remove_vertical_seam(
pixels: &[u8],
width: usize,
height: usize,
channels: usize,
seam: &[usize],
) -> Vec<u8> {
let new_w = width - 1;
let mut out = Vec::with_capacity(new_w * height * channels);
for row in 0..height {
let remove_col = seam[row];
for col in 0..width {
if col == remove_col {
continue;
}
let base = (row * width + col) * channels;
out.extend_from_slice(&pixels[base..base + channels]);
}
}
out
}
fn transpose_image(pixels: &[u8], width: usize, height: usize, channels: usize) -> Vec<u8> {
let mut out = vec![0u8; height * width * channels];
for row in 0..height {
for col in 0..width {
let src = (row * width + col) * channels;
let dst = (col * height + row) * channels;
out[dst..dst + channels].copy_from_slice(&pixels[src..src + channels]);
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
fn solid_gray(width: usize, height: usize, value: u8) -> Vec<u8> {
vec![value; width * height]
}
fn horizontal_gradient(width: usize, height: usize) -> Vec<u8> {
let mut img = vec![0u8; width * height];
for row in 0..height {
for col in 0..width {
img[row * width + col] = (col * 255 / width.max(1)) as u8;
}
}
img
}
fn solid_rgb(width: usize, height: usize, r: u8, g: u8, b: u8) -> Vec<u8> {
let mut img = vec![0u8; width * height * 3];
for i in 0..width * height {
img[i * 3] = r;
img[i * 3 + 1] = g;
img[i * 3 + 2] = b;
}
img
}
#[test]
fn error_invalid_dimensions_zero_src() {
let config = SeamCarvingConfig::new(2, 2);
let carver = SeamCarver::new(config);
let result = carver.carve(&[], 0, 4, 1);
assert!(matches!(result, Err(ScalingError::InvalidDimensions(_))));
}
#[test]
fn error_invalid_dimensions_zero_target() {
let config = SeamCarvingConfig::new(0, 2);
let carver = SeamCarver::new(config);
let img = solid_gray(4, 4, 128);
let result = carver.carve(&img, 4, 4, 1);
assert!(matches!(result, Err(ScalingError::InvalidDimensions(_))));
}
#[test]
fn error_target_exceeds_source_width() {
let config = SeamCarvingConfig::new(10, 2); let carver = SeamCarver::new(config);
let img = solid_gray(4, 4, 0);
let result = carver.carve(&img, 4, 4, 1);
assert!(matches!(result, Err(ScalingError::InvalidDimensions(_))));
}
#[test]
fn error_target_exceeds_source_height() {
let config = SeamCarvingConfig::new(2, 10); let carver = SeamCarver::new(config);
let img = solid_gray(4, 4, 0);
let result = carver.carve(&img, 4, 4, 1);
assert!(matches!(result, Err(ScalingError::InvalidDimensions(_))));
}
#[test]
fn error_insufficient_buffer() {
let config = SeamCarvingConfig::new(2, 2);
let carver = SeamCarver::new(config);
let short = vec![0u8; 3]; let result = carver.carve(&short, 4, 4, 1);
assert!(matches!(
result,
Err(ScalingError::InsufficientBuffer { .. })
));
}
#[test]
fn error_unsupported_channels() {
let config = SeamCarvingConfig::new(2, 2);
let carver = SeamCarver::new(config);
let img = vec![0u8; 16 * 2]; let result = carver.carve(&img, 4, 4, 2);
assert!(matches!(result, Err(ScalingError::UnsupportedChannels(2))));
}
#[test]
fn identity_no_seams_removed() {
let config = SeamCarvingConfig::new(4, 4);
let carver = SeamCarver::new(config);
let img = horizontal_gradient(4, 4);
let (out, w, h) = carver.carve(&img, 4, 4, 1).expect("carve should succeed");
assert_eq!(w, 4);
assert_eq!(h, 4);
assert_eq!(out, img);
}
#[test]
fn reduce_width_by_one_gray() {
let width = 6usize;
let height = 4usize;
let config = SeamCarvingConfig::new((width - 1) as u32, height as u32);
let carver = SeamCarver::new(config);
let img = horizontal_gradient(width, height);
let (out, w, h) = carver
.carve(&img, width as u32, height as u32, 1)
.expect("carve should succeed");
assert_eq!(w as usize, width - 1);
assert_eq!(h as usize, height);
assert_eq!(out.len(), (width - 1) * height);
}
#[test]
fn reduce_width_by_two_rgb() {
let width = 8usize;
let height = 6usize;
let config = SeamCarvingConfig::new((width - 2) as u32, height as u32);
let carver = SeamCarver::new(config);
let img = solid_rgb(width, height, 200, 100, 50);
let (out, w, h) = carver
.carve(&img, width as u32, height as u32, 3)
.expect("carve should succeed");
assert_eq!(w as usize, width - 2);
assert_eq!(h as usize, height);
assert_eq!(out.len(), (width - 2) * height * 3);
}
#[test]
fn reduce_height_by_one_gray() {
let width = 6usize;
let height = 6usize;
let config = SeamCarvingConfig::new(width as u32, (height - 1) as u32);
let carver = SeamCarver::new(config);
let img = horizontal_gradient(width, height);
let (out, w, h) = carver
.carve(&img, width as u32, height as u32, 1)
.expect("carve should succeed");
assert_eq!(w as usize, width);
assert_eq!(h as usize, height - 1);
assert_eq!(out.len(), width * (height - 1));
}
#[test]
fn reduce_both_dimensions() {
let width = 10usize;
let height = 8usize;
let tw = 7u32;
let th = 5u32;
let config = SeamCarvingConfig::new(tw, th);
let carver = SeamCarver::new(config);
let img = horizontal_gradient(width, height);
let (out, w, h) = carver
.carve(&img, width as u32, height as u32, 1)
.expect("carve should succeed");
assert_eq!(w, tw);
assert_eq!(h, th);
assert_eq!(out.len(), tw as usize * th as usize);
}
#[test]
fn forward_energy_reduces_width() {
let width = 8usize;
let height = 6usize;
let config = SeamCarvingConfig::new((width - 2) as u32, height as u32)
.with_energy_function(EnergyFunction::ForwardEnergy);
let carver = SeamCarver::new(config);
let img = horizontal_gradient(width, height);
let (out, w, h) = carver
.carve(&img, width as u32, height as u32, 1)
.expect("forward energy carve should succeed");
assert_eq!(w as usize, width - 2);
assert_eq!(h as usize, height);
assert_eq!(out.len(), (width - 2) * height);
}
#[test]
fn gradient_energy_on_solid_image_removes_any_seam() {
let width = 5usize;
let height = 4usize;
let config = SeamCarvingConfig::new((width - 1) as u32, height as u32)
.with_energy_function(EnergyFunction::Gradient);
let carver = SeamCarver::new(config);
let img = solid_gray(width, height, 128);
let (out, w, h) = carver
.carve(&img, width as u32, height as u32, 1)
.expect("solid gradient carve should succeed");
assert_eq!(w as usize, width - 1);
assert_eq!(h as usize, height);
assert_eq!(out.len(), (width - 1) * height);
assert!(out.iter().all(|&v| v == 128), "all pixels must remain 128");
}
#[test]
fn reduce_width_rgba() {
let width = 6usize;
let height = 4usize;
let config = SeamCarvingConfig::new((width - 1) as u32, height as u32);
let carver = SeamCarver::new(config);
let mut img = vec![0u8; width * height * 4];
for i in 0..width * height {
img[i * 4] = (i % 255) as u8;
img[i * 4 + 1] = 100;
img[i * 4 + 2] = 50;
img[i * 4 + 3] = 255;
}
let (out, w, h) = carver
.carve(&img, width as u32, height as u32, 4)
.expect("RGBA carve should succeed");
assert_eq!(w as usize, width - 1);
assert_eq!(h as usize, height);
assert_eq!(out.len(), (width - 1) * height * 4);
}
#[test]
fn protect_mask_prevents_column_removal() {
let width = 6usize;
let height = 4usize;
let mut img = vec![0u8; width * height];
for row in 0..height {
img[row * width] = 255;
}
let mut protect = vec![0u8; width * height];
for row in 0..height {
protect[row * width] = 1; }
let config =
SeamCarvingConfig::new((width - 1) as u32, height as u32).with_protect_mask(protect);
let carver = SeamCarver::new(config);
let (out, w, h) = carver
.carve(&img, width as u32, height as u32, 1)
.expect("protect mask carve should succeed");
assert_eq!(w as usize, width - 1);
assert_eq!(h as usize, height);
for row in 0..h as usize {
assert_eq!(
out[row * w as usize],
255,
"row {row}: protected column should remain at leftmost position"
);
}
}
#[test]
fn remove_mask_prefers_masked_column() {
let width = 6usize;
let height = 4usize;
let mut img = vec![128u8; width * height];
for row in 0..height {
img[row * width + (width - 1)] = 200;
}
let mut removal = vec![0u8; width * height];
for row in 0..height {
removal[row * width + (width - 1)] = 1;
}
let config =
SeamCarvingConfig::new((width - 1) as u32, height as u32).with_remove_mask(removal);
let carver = SeamCarver::new(config);
let (out, w, _h) = carver
.carve(&img, width as u32, height as u32, 1)
.expect("remove mask carve should succeed");
let has_removed_col_val = out.iter().any(|&v| v == 200);
assert!(
!has_removed_col_val || w as usize == width - 1,
"removal mask should have forced those pixels out"
);
}
#[test]
fn config_builder() {
let mask = vec![0u8; 16];
let config = SeamCarvingConfig::new(3, 3)
.with_energy_function(EnergyFunction::ForwardEnergy)
.with_protect_mask(mask.clone())
.with_remove_mask(mask);
assert_eq!(config.target_width, 3);
assert_eq!(config.target_height, 3);
assert_eq!(config.energy_function, EnergyFunction::ForwardEnergy);
assert!(config.protect_mask.is_some());
assert!(config.remove_mask.is_some());
}
#[test]
fn to_grayscale_single_channel() {
let pixels = vec![10u8, 20, 30, 40];
let gray = to_grayscale(&pixels, 2, 2, 1);
assert_eq!(gray, vec![10.0, 20.0, 30.0, 40.0]);
}
#[test]
fn to_grayscale_rgb() {
let pixels = vec![255u8, 0, 0];
let gray = to_grayscale(&pixels, 1, 1, 3);
let expected = 0.299 * 255.0;
assert!((gray[0] - expected).abs() < 0.01, "gray={}", gray[0]);
}
#[test]
fn find_vertical_seam_prefers_minimum_energy() {
let energy = vec![100.0f64, 1.0, 100.0, 100.0, 1.0, 100.0];
let seam = find_vertical_seam(&energy, 3, 2);
assert_eq!(seam, vec![1, 1], "seam should run through column 1");
}
#[test]
fn transpose_roundtrip() {
let width = 3usize;
let height = 2usize;
let channels = 3usize;
let img: Vec<u8> = (0..(width * height * channels) as u8).collect();
let transposed = transpose_image(&img, width, height, channels);
let back = transpose_image(&transposed, height, width, channels);
assert_eq!(back, img, "double transpose should be identity");
}
#[test]
fn remove_vertical_seam_correct_size() {
let width = 4usize;
let height = 3usize;
let channels = 1usize;
let img = vec![0u8; width * height * channels];
let seam = vec![1usize, 2, 0]; let out = remove_vertical_seam(&img, width, height, channels, &seam);
assert_eq!(out.len(), (width - 1) * height * channels);
}
#[test]
fn gradient_energy_flat_image_is_zero() {
let gray = vec![128.0f64; 4 * 4];
let e = gradient_energy(&gray, 4, 4);
for v in e {
assert!(v.abs() < 1e-9, "flat image energy should be zero");
}
}
#[test]
fn forward_energy_flat_image_is_zero() {
let gray = vec![64.0f64; 4 * 4];
let e = forward_energy(&gray, 4, 4);
for v in e {
assert!(v.abs() < 1e-9, "flat image forward energy should be zero");
}
}
}