Skip to main content

proof_engine/ml/
upscale.rs

1//! Neural and classical upscaling for textures and glyph maps.
2
3use super::tensor::Tensor;
4use super::model::{Model, Sequential, DenseLayer, Conv2DLayer, Layer};
5
6/// Quality preset for upscaling.
7#[derive(Debug, Clone, Copy, PartialEq)]
8pub enum UpscaleQuality {
9    Fast,
10    Balanced,
11    HighQuality,
12}
13
14/// Configuration for the upscaler.
15#[derive(Debug, Clone)]
16pub struct UpscaleConfig {
17    pub factor: u32,
18    pub model_path: Option<String>,
19    pub quality: UpscaleQuality,
20}
21
22impl Default for UpscaleConfig {
23    fn default() -> Self {
24        Self { factor: 2, model_path: None, quality: UpscaleQuality::Balanced }
25    }
26}
27
28/// Neural upscaler wrapping a model.
29pub struct Upscaler {
30    pub model: Model,
31    pub scale_factor: u32,
32}
33
34impl Upscaler {
35    pub fn new(model: Model, scale_factor: u32) -> Self {
36        Self { model, scale_factor }
37    }
38
39    /// Run the upscaling model on input. Input shape: (C, H, W).
40    /// Output shape: (C, H*factor, W*factor).
41    pub fn upscale(&self, input: &Tensor) -> Tensor {
42        assert_eq!(input.shape.len(), 3);
43        // First bilinear upscale to target size, then refine with model
44        let upscaled = bilinear_upscale(input, self.scale_factor);
45        // Flatten, run through model, reshape back
46        let c = upscaled.shape[0];
47        let h = upscaled.shape[1];
48        let w = upscaled.shape[2];
49        let flat = upscaled.flatten();
50        let refined = self.model.forward(&flat);
51        // Clamp to valid range and reshape
52        let data: Vec<f32> = refined.data.iter().map(|&v| v.clamp(0.0, 1.0)).collect();
53        if data.len() == c * h * w {
54            Tensor { shape: vec![c, h, w], data }
55        } else {
56            // If model output size doesn't match, return bilinear result
57            upscaled
58        }
59    }
60}
61
62/// Bilinear upscaling fallback. Input shape: (C, H, W).
63pub fn bilinear_upscale(input: &Tensor, factor: u32) -> Tensor {
64    assert_eq!(input.shape.len(), 3);
65    let c = input.shape[0];
66    let h = input.shape[1];
67    let w = input.shape[2];
68    let f = factor as usize;
69    let new_h = h * f;
70    let new_w = w * f;
71    let mut data = vec![0.0f32; c * new_h * new_w];
72
73    for ch in 0..c {
74        for ny in 0..new_h {
75            for nx in 0..new_w {
76                let src_y = ny as f32 / f as f32;
77                let src_x = nx as f32 / f as f32;
78
79                let y0 = (src_y.floor() as usize).min(h - 1);
80                let y1 = (y0 + 1).min(h - 1);
81                let x0 = (src_x.floor() as usize).min(w - 1);
82                let x1 = (x0 + 1).min(w - 1);
83
84                let fy = src_y - src_y.floor();
85                let fx = src_x - src_x.floor();
86
87                let v00 = input.data[ch * h * w + y0 * w + x0];
88                let v01 = input.data[ch * h * w + y0 * w + x1];
89                let v10 = input.data[ch * h * w + y1 * w + x0];
90                let v11 = input.data[ch * h * w + y1 * w + x1];
91
92                let val = v00 * (1.0 - fy) * (1.0 - fx)
93                    + v01 * (1.0 - fy) * fx
94                    + v10 * fy * (1.0 - fx)
95                    + v11 * fy * fx;
96
97                data[ch * new_h * new_w + ny * new_w + nx] = val;
98            }
99        }
100    }
101    Tensor { shape: vec![c, new_h, new_w], data }
102}
103
104/// Bicubic upscaling fallback. Input shape: (C, H, W).
105pub fn bicubic_upscale(input: &Tensor, factor: u32) -> Tensor {
106    assert_eq!(input.shape.len(), 3);
107    let c = input.shape[0];
108    let h = input.shape[1];
109    let w = input.shape[2];
110    let f = factor as usize;
111    let new_h = h * f;
112    let new_w = w * f;
113    let mut data = vec![0.0f32; c * new_h * new_w];
114
115    // Cubic interpolation kernel
116    fn cubic(t: f32) -> [f32; 4] {
117        let a = -0.5f32;
118        let t2 = t * t;
119        let t3 = t2 * t;
120        [
121            a * t3 - 2.0 * a * t2 + a * t,
122            (a + 2.0) * t3 - (a + 3.0) * t2 + 1.0,
123            -(a + 2.0) * t3 + (2.0 * a + 3.0) * t2 - a * t,
124            -a * t3 + a * t2,
125        ]
126    }
127
128    fn clamp_idx(v: isize, max: usize) -> usize {
129        v.max(0).min(max as isize - 1) as usize
130    }
131
132    for ch in 0..c {
133        for ny in 0..new_h {
134            for nx in 0..new_w {
135                let src_y = ny as f32 / f as f32;
136                let src_x = nx as f32 / f as f32;
137
138                let iy = src_y.floor() as isize;
139                let ix = src_x.floor() as isize;
140                let fy = src_y - src_y.floor();
141                let fx = src_x - src_x.floor();
142
143                let wy = cubic(fy);
144                let wx = cubic(fx);
145
146                let mut val = 0.0f32;
147                for dy in 0..4isize {
148                    for dx in 0..4isize {
149                        let sy = clamp_idx(iy + dy - 1, h);
150                        let sx = clamp_idx(ix + dx - 1, w);
151                        val += wy[dy as usize] * wx[dx as usize]
152                            * input.data[ch * h * w + sy * w + sx];
153                    }
154                }
155                data[ch * new_h * new_w + ny * new_w + nx] = val;
156            }
157        }
158    }
159    Tensor { shape: vec![c, new_h, new_w], data }
160}
161
162/// Create a simple ESPCN-style upscaler. The model learns a mapping from
163/// low-res features to high-res via sub-pixel convolution (simulated with dense layers).
164pub fn create_simple_upscaler(factor: u32) -> Upscaler {
165    // For a simple upscaler, we use dense layers that map
166    // flattened low-res input to flattened high-res output.
167    // In practice this would be a convolutional model, but
168    // we approximate with dense layers for simplicity.
169    let f = factor as usize;
170    // Assume small patches: the model processes the entire flattened image.
171    // We build a generic model; actual I/O sizes depend on usage.
172    let model = Sequential::new("espcn_upscaler")
173        .dense(64, 128)
174        .relu()
175        .dense(128, 256)
176        .relu()
177        .dense(256, 256)
178        .build();
179    Upscaler::new(model, factor)
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    #[test]
187    fn test_bilinear_upscale_shape() {
188        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![1, 2, 2]);
189        let up = bilinear_upscale(&input, 2);
190        assert_eq!(up.shape, vec![1, 4, 4]);
191    }
192
193    #[test]
194    fn test_bilinear_upscale_corners() {
195        let input = Tensor::from_vec(vec![0.0, 1.0, 0.0, 1.0], vec![1, 2, 2]);
196        let up = bilinear_upscale(&input, 2);
197        // Top-left corner should be close to 0.0
198        assert!(up.get(&[0, 0, 0]).abs() < 0.01);
199    }
200
201    #[test]
202    fn test_bicubic_upscale_shape() {
203        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![1, 2, 2]);
204        let up = bicubic_upscale(&input, 3);
205        assert_eq!(up.shape, vec![1, 6, 6]);
206    }
207
208    #[test]
209    fn test_bicubic_constant_input() {
210        // Constant image should upscale to constant
211        let input = Tensor::from_vec(vec![0.5; 9], vec![1, 3, 3]);
212        let up = bicubic_upscale(&input, 2);
213        for &v in &up.data {
214            assert!((v - 0.5).abs() < 0.1, "bicubic of constant deviated: {v}");
215        }
216    }
217
218    #[test]
219    fn test_create_simple_upscaler() {
220        let upscaler = create_simple_upscaler(2);
221        assert_eq!(upscaler.scale_factor, 2);
222        assert!(upscaler.model.parameter_count() > 0);
223    }
224
225    #[test]
226    fn test_upscaler_upscale() {
227        // The neural upscaler may not produce perfect results with random weights,
228        // but it should not panic and output the correct shape (via bilinear fallback).
229        let upscaler = create_simple_upscaler(2);
230        let input = Tensor::rand(vec![1, 4, 4], 42);
231        let out = upscaler.upscale(&input);
232        assert_eq!(out.shape, vec![1, 8, 8]);
233    }
234
235    #[test]
236    fn test_upscale_config_default() {
237        let cfg = UpscaleConfig::default();
238        assert_eq!(cfg.factor, 2);
239        assert_eq!(cfg.quality, UpscaleQuality::Balanced);
240        assert!(cfg.model_path.is_none());
241    }
242
243    #[test]
244    fn test_bilinear_multichannel() {
245        let input = Tensor::rand(vec![3, 4, 4], 123);
246        let up = bilinear_upscale(&input, 2);
247        assert_eq!(up.shape, vec![3, 8, 8]);
248    }
249}