use chess_corners_core::{CornerDescriptor, ImageView};
use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum UpscaleConfig {
#[default]
Disabled,
Fixed(u32),
}
impl UpscaleConfig {
pub fn disabled() -> Self {
Self::Disabled
}
pub fn fixed(factor: u32) -> Self {
Self::Fixed(factor)
}
#[inline]
pub fn effective_factor(&self) -> u32 {
match *self {
Self::Disabled => 1,
Self::Fixed(k) => k,
}
}
pub fn validate(&self) -> Result<(), UpscaleError> {
match *self {
Self::Disabled => Ok(()),
Self::Fixed(2..=4) => Ok(()),
Self::Fixed(k) => Err(UpscaleError::InvalidFactor(k)),
}
}
}
#[derive(Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum UpscaleError {
InvalidFactor(u32),
DimensionOverflow { src: (usize, usize), factor: u32 },
DimensionMismatch {
actual: usize,
expected: usize,
},
}
impl core::fmt::Display for UpscaleError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::InvalidFactor(k) => {
write!(f, "upscale factor {k} not supported (expected 2, 3, or 4)")
}
Self::DimensionOverflow { src, factor } => write!(
f,
"upscaled dimensions overflow: {}x{} * {} exceeds usize",
src.0, src.1, factor
),
Self::DimensionMismatch { actual, expected } => write!(
f,
"image buffer length mismatch: expected {expected} bytes (src_w*src_h), got {actual}"
),
}
}
}
impl std::error::Error for UpscaleError {}
#[derive(Debug, Default, Clone)]
pub struct UpscaleBuffers {
buf: Vec<u8>,
w: usize,
h: usize,
}
impl UpscaleBuffers {
pub fn new() -> Self {
Self::default()
}
fn ensure(&mut self, w: usize, h: usize) {
self.w = w;
self.h = h;
let needed = w.saturating_mul(h);
if self.buf.len() < needed {
self.buf.resize(needed, 0);
}
}
pub fn width(&self) -> usize {
self.w
}
pub fn height(&self) -> usize {
self.h
}
}
pub fn upscale_bilinear_u8<'a>(
src: &[u8],
src_w: usize,
src_h: usize,
factor: u32,
buffers: &'a mut UpscaleBuffers,
) -> Result<ImageView<'a>, UpscaleError> {
if !matches!(factor, 2..=4) {
return Err(UpscaleError::InvalidFactor(factor));
}
let k = factor as usize;
let dst_w = src_w
.checked_mul(k)
.ok_or(UpscaleError::DimensionOverflow {
src: (src_w, src_h),
factor,
})?;
let dst_h = src_h
.checked_mul(k)
.ok_or(UpscaleError::DimensionOverflow {
src: (src_w, src_h),
factor,
})?;
let expected = src_w * src_h;
if src.len() != expected {
return Err(UpscaleError::DimensionMismatch {
actual: src.len(),
expected,
});
}
buffers.ensure(dst_w, dst_h);
if src_w == 0 || src_h == 0 {
return Ok(ImageView::from_u8_slice(dst_w, dst_h, &buffers.buf[..dst_w * dst_h]).unwrap());
}
let inv_k = 1.0f32 / factor as f32;
let max_x = src_w as i32 - 1;
let max_y = src_h as i32 - 1;
let mut xw: Vec<(usize, usize, f32)> = Vec::with_capacity(dst_w);
for x_out in 0..dst_w {
let xf = (x_out as f32 + 0.5) * inv_k - 0.5;
let x0 = xf.floor() as i32;
let wx = xf - x0 as f32;
let x0c = x0.clamp(0, max_x) as usize;
let x1c = (x0 + 1).clamp(0, max_x) as usize;
xw.push((x0c, x1c, wx));
}
for y_out in 0..dst_h {
let yf = (y_out as f32 + 0.5) * inv_k - 0.5;
let y0 = yf.floor() as i32;
let wy = yf - y0 as f32;
let y0c = y0.clamp(0, max_y) as usize;
let y1c = (y0 + 1).clamp(0, max_y) as usize;
let row0 = y0c * src_w;
let row1 = y1c * src_w;
let dst_row = y_out * dst_w;
for (x_out, &(x0, x1, wx)) in xw.iter().enumerate().take(dst_w) {
let i00 = src[row0 + x0] as f32;
let i10 = src[row0 + x1] as f32;
let i01 = src[row1 + x0] as f32;
let i11 = src[row1 + x1] as f32;
let top = i00 + (i10 - i00) * wx;
let bot = i01 + (i11 - i01) * wx;
let v = top + (bot - top) * wy;
let rounded = v + 0.5;
buffers.buf[dst_row + x_out] = rounded.clamp(0.0, 255.0) as u8;
}
}
let slice = &buffers.buf[..dst_w * dst_h];
Ok(ImageView::from_u8_slice(dst_w, dst_h, slice).expect("dims match"))
}
pub fn rescale_descriptors_to_input(descriptors: &mut [CornerDescriptor], factor: u32) {
if factor <= 1 {
return;
}
let inv = 1.0f32 / factor as f32;
let shift = 0.5 * (1.0 - inv);
for d in descriptors.iter_mut() {
d.x = d.x * inv - shift;
d.y = d.y * inv - shift;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_default_is_disabled() {
let cfg = UpscaleConfig::default();
assert_eq!(cfg, UpscaleConfig::Disabled);
assert_eq!(cfg.effective_factor(), 1);
assert!(cfg.validate().is_ok());
}
#[test]
fn config_rejects_invalid_factors() {
for bad in [0u32, 1, 5, 8] {
let cfg = UpscaleConfig::fixed(bad);
assert_eq!(cfg.validate(), Err(UpscaleError::InvalidFactor(bad)));
}
}
#[test]
fn config_accepts_valid_factors() {
for good in [2u32, 3, 4] {
let cfg = UpscaleConfig::fixed(good);
assert!(cfg.validate().is_ok());
assert_eq!(cfg.effective_factor(), good);
}
}
#[test]
fn disabled_round_trips_through_serde() {
let cfg = UpscaleConfig::Disabled;
let json = serde_json::to_string(&cfg).expect("serialize disabled");
assert!(json.contains("disabled"));
let decoded: UpscaleConfig = serde_json::from_str(&json).expect("deserialize disabled");
assert_eq!(decoded, cfg);
}
#[test]
fn fixed_round_trips_through_serde() {
let cfg = UpscaleConfig::Fixed(3);
let json = serde_json::to_string(&cfg).expect("serialize fixed");
assert!(json.contains("fixed"));
let decoded: UpscaleConfig = serde_json::from_str(&json).expect("deserialize fixed");
assert_eq!(decoded, cfg);
}
#[test]
fn upscale_factor_2_uniform_image_is_uniform() {
let src = vec![42u8; 8 * 6];
let mut buffers = UpscaleBuffers::new();
let view = upscale_bilinear_u8(&src, 8, 6, 2, &mut buffers).unwrap();
assert_eq!(view.width, 16);
assert_eq!(view.height, 12);
assert!(view.data.iter().all(|&v| v == 42));
}
#[test]
fn upscale_factor_2_of_1x1_fills_buffer() {
let src = [77u8];
let mut buffers = UpscaleBuffers::new();
let view = upscale_bilinear_u8(&src, 1, 1, 2, &mut buffers).unwrap();
assert_eq!(view.width, 2);
assert_eq!(view.height, 2);
assert!(view.data.iter().all(|&v| v == 77));
}
#[test]
fn upscale_preserves_linear_gradient_factor_2() {
let src: Vec<u8> = (0..8).map(|i| i * 10).collect();
let src = {
let mut row = Vec::with_capacity(8 * 3);
for _ in 0..3 {
row.extend_from_slice(&src);
}
row
};
let mut buffers = UpscaleBuffers::new();
let view = upscale_bilinear_u8(&src, 8, 3, 2, &mut buffers).unwrap();
for r in 0..view.height {
let row = &view.data[r * view.width..(r + 1) * view.width];
for w in row.windows(2) {
assert!(w[1] >= w[0].saturating_sub(1), "non-monotonic row: {row:?}");
}
}
}
#[test]
fn upscale_factor_3_doubles_dimensions_correctly() {
let src = vec![128u8; 5 * 4];
let mut buffers = UpscaleBuffers::new();
let view = upscale_bilinear_u8(&src, 5, 4, 3, &mut buffers).unwrap();
assert_eq!(view.width, 15);
assert_eq!(view.height, 12);
assert_eq!(view.data.len(), 180);
}
#[test]
fn buffers_are_reused_across_calls() {
let src1 = vec![10u8; 4 * 4];
let src2 = vec![200u8; 4 * 4];
let mut buffers = UpscaleBuffers::new();
let _ = upscale_bilinear_u8(&src1, 4, 4, 2, &mut buffers).unwrap();
let cap1 = buffers.buf.capacity();
let _ = upscale_bilinear_u8(&src2, 4, 4, 2, &mut buffers).unwrap();
assert_eq!(buffers.buf.capacity(), cap1, "buffer should be reused");
}
#[test]
fn rejects_invalid_factor_at_runtime() {
let src = vec![0u8; 4];
let mut buffers = UpscaleBuffers::new();
let err = upscale_bilinear_u8(&src, 2, 2, 5, &mut buffers).unwrap_err();
assert_eq!(err, UpscaleError::InvalidFactor(5));
}
#[test]
fn rescale_inverts_half_pixel_upscale() {
use chess_corners_core::{AxisEstimate, CornerDescriptor};
fn desc(x: f32, y: f32) -> CornerDescriptor {
CornerDescriptor::new(
x,
y,
1.0,
0.0,
0.0,
[AxisEstimate::new(0.0, 0.0), AxisEstimate::new(0.0, 0.0)],
)
}
for &(k, x_src, y_src) in &[
(2u32, 7.25f32, 3.0f32),
(3u32, 4.0f32, 8.5f32),
(4u32, 0.5f32, 12.25f32),
] {
let kf = k as f32;
let x_out = (x_src + 0.5) * kf - 0.5;
let y_out = (y_src + 0.5) * kf - 0.5;
let mut d = [desc(x_out, y_out)];
rescale_descriptors_to_input(&mut d, k);
assert!(
(d[0].x - x_src).abs() < 1e-5,
"k={k}: x {} != expected {x_src}",
d[0].x
);
assert!(
(d[0].y - y_src).abs() < 1e-5,
"k={k}: y {} != expected {y_src}",
d[0].y
);
}
}
#[test]
fn rescale_is_noop_for_factor_1() {
use chess_corners_core::{AxisEstimate, CornerDescriptor};
let mut d = [CornerDescriptor::new(
2.5,
3.75,
1.0,
0.0,
0.0,
[AxisEstimate::new(0.0, 0.0), AxisEstimate::new(0.0, 0.0)],
)];
rescale_descriptors_to_input(&mut d, 1);
assert_eq!(d[0].x, 2.5);
assert_eq!(d[0].y, 3.75);
}
}