1use super::tensor::Tensor;
4use super::model::{Model, Sequential, DenseLayer, Conv2DLayer, Layer};
5
6#[derive(Debug, Clone, Copy, PartialEq)]
8pub enum UpscaleQuality {
9 Fast,
10 Balanced,
11 HighQuality,
12}
13
14#[derive(Debug, Clone)]
16pub struct UpscaleConfig {
17 pub factor: u32,
18 pub model_path: Option<String>,
19 pub quality: UpscaleQuality,
20}
21
22impl Default for UpscaleConfig {
23 fn default() -> Self {
24 Self { factor: 2, model_path: None, quality: UpscaleQuality::Balanced }
25 }
26}
27
28pub struct Upscaler {
30 pub model: Model,
31 pub scale_factor: u32,
32}
33
34impl Upscaler {
35 pub fn new(model: Model, scale_factor: u32) -> Self {
36 Self { model, scale_factor }
37 }
38
39 pub fn upscale(&self, input: &Tensor) -> Tensor {
42 assert_eq!(input.shape.len(), 3);
43 let upscaled = bilinear_upscale(input, self.scale_factor);
45 let c = upscaled.shape[0];
47 let h = upscaled.shape[1];
48 let w = upscaled.shape[2];
49 let flat = upscaled.flatten();
50 let refined = self.model.forward(&flat);
51 let data: Vec<f32> = refined.data.iter().map(|&v| v.clamp(0.0, 1.0)).collect();
53 if data.len() == c * h * w {
54 Tensor { shape: vec![c, h, w], data }
55 } else {
56 upscaled
58 }
59 }
60}
61
62pub fn bilinear_upscale(input: &Tensor, factor: u32) -> Tensor {
64 assert_eq!(input.shape.len(), 3);
65 let c = input.shape[0];
66 let h = input.shape[1];
67 let w = input.shape[2];
68 let f = factor as usize;
69 let new_h = h * f;
70 let new_w = w * f;
71 let mut data = vec![0.0f32; c * new_h * new_w];
72
73 for ch in 0..c {
74 for ny in 0..new_h {
75 for nx in 0..new_w {
76 let src_y = ny as f32 / f as f32;
77 let src_x = nx as f32 / f as f32;
78
79 let y0 = (src_y.floor() as usize).min(h - 1);
80 let y1 = (y0 + 1).min(h - 1);
81 let x0 = (src_x.floor() as usize).min(w - 1);
82 let x1 = (x0 + 1).min(w - 1);
83
84 let fy = src_y - src_y.floor();
85 let fx = src_x - src_x.floor();
86
87 let v00 = input.data[ch * h * w + y0 * w + x0];
88 let v01 = input.data[ch * h * w + y0 * w + x1];
89 let v10 = input.data[ch * h * w + y1 * w + x0];
90 let v11 = input.data[ch * h * w + y1 * w + x1];
91
92 let val = v00 * (1.0 - fy) * (1.0 - fx)
93 + v01 * (1.0 - fy) * fx
94 + v10 * fy * (1.0 - fx)
95 + v11 * fy * fx;
96
97 data[ch * new_h * new_w + ny * new_w + nx] = val;
98 }
99 }
100 }
101 Tensor { shape: vec![c, new_h, new_w], data }
102}
103
104pub fn bicubic_upscale(input: &Tensor, factor: u32) -> Tensor {
106 assert_eq!(input.shape.len(), 3);
107 let c = input.shape[0];
108 let h = input.shape[1];
109 let w = input.shape[2];
110 let f = factor as usize;
111 let new_h = h * f;
112 let new_w = w * f;
113 let mut data = vec![0.0f32; c * new_h * new_w];
114
115 fn cubic(t: f32) -> [f32; 4] {
117 let a = -0.5f32;
118 let t2 = t * t;
119 let t3 = t2 * t;
120 [
121 a * t3 - 2.0 * a * t2 + a * t,
122 (a + 2.0) * t3 - (a + 3.0) * t2 + 1.0,
123 -(a + 2.0) * t3 + (2.0 * a + 3.0) * t2 - a * t,
124 -a * t3 + a * t2,
125 ]
126 }
127
128 fn clamp_idx(v: isize, max: usize) -> usize {
129 v.max(0).min(max as isize - 1) as usize
130 }
131
132 for ch in 0..c {
133 for ny in 0..new_h {
134 for nx in 0..new_w {
135 let src_y = ny as f32 / f as f32;
136 let src_x = nx as f32 / f as f32;
137
138 let iy = src_y.floor() as isize;
139 let ix = src_x.floor() as isize;
140 let fy = src_y - src_y.floor();
141 let fx = src_x - src_x.floor();
142
143 let wy = cubic(fy);
144 let wx = cubic(fx);
145
146 let mut val = 0.0f32;
147 for dy in 0..4isize {
148 for dx in 0..4isize {
149 let sy = clamp_idx(iy + dy - 1, h);
150 let sx = clamp_idx(ix + dx - 1, w);
151 val += wy[dy as usize] * wx[dx as usize]
152 * input.data[ch * h * w + sy * w + sx];
153 }
154 }
155 data[ch * new_h * new_w + ny * new_w + nx] = val;
156 }
157 }
158 }
159 Tensor { shape: vec![c, new_h, new_w], data }
160}
161
162pub fn create_simple_upscaler(factor: u32) -> Upscaler {
165 let f = factor as usize;
170 let model = Sequential::new("espcn_upscaler")
173 .dense(64, 128)
174 .relu()
175 .dense(128, 256)
176 .relu()
177 .dense(256, 256)
178 .build();
179 Upscaler::new(model, factor)
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 #[test]
187 fn test_bilinear_upscale_shape() {
188 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![1, 2, 2]);
189 let up = bilinear_upscale(&input, 2);
190 assert_eq!(up.shape, vec![1, 4, 4]);
191 }
192
193 #[test]
194 fn test_bilinear_upscale_corners() {
195 let input = Tensor::from_vec(vec![0.0, 1.0, 0.0, 1.0], vec![1, 2, 2]);
196 let up = bilinear_upscale(&input, 2);
197 assert!(up.get(&[0, 0, 0]).abs() < 0.01);
199 }
200
201 #[test]
202 fn test_bicubic_upscale_shape() {
203 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![1, 2, 2]);
204 let up = bicubic_upscale(&input, 3);
205 assert_eq!(up.shape, vec![1, 6, 6]);
206 }
207
208 #[test]
209 fn test_bicubic_constant_input() {
210 let input = Tensor::from_vec(vec![0.5; 9], vec![1, 3, 3]);
212 let up = bicubic_upscale(&input, 2);
213 for &v in &up.data {
214 assert!((v - 0.5).abs() < 0.1, "bicubic of constant deviated: {v}");
215 }
216 }
217
218 #[test]
219 fn test_create_simple_upscaler() {
220 let upscaler = create_simple_upscaler(2);
221 assert_eq!(upscaler.scale_factor, 2);
222 assert!(upscaler.model.parameter_count() > 0);
223 }
224
225 #[test]
226 fn test_upscaler_upscale() {
227 let upscaler = create_simple_upscaler(2);
230 let input = Tensor::rand(vec![1, 4, 4], 42);
231 let out = upscaler.upscale(&input);
232 assert_eq!(out.shape, vec![1, 8, 8]);
233 }
234
235 #[test]
236 fn test_upscale_config_default() {
237 let cfg = UpscaleConfig::default();
238 assert_eq!(cfg.factor, 2);
239 assert_eq!(cfg.quality, UpscaleQuality::Balanced);
240 assert!(cfg.model_path.is_none());
241 }
242
243 #[test]
244 fn test_bilinear_multichannel() {
245 let input = Tensor::rand(vec![3, 4, 4], 123);
246 let up = bilinear_upscale(&input, 2);
247 assert_eq!(up.shape, vec![3, 8, 8]);
248 }
249}