1use ndarray::{Array2, Array3, Array4, Axis, s};
4use ort::session::Session;
5use ort::session::builder::GraphOptimizationLevel;
6use ort::value::Value;
7use rayon::prelude::*;
8use serde::{Deserialize, Serialize};
9use std::path::Path;
10use tracing::{debug, info};
11
12use crate::error::{InferenceError, ModelError, Result};
13use oxigdal_core::buffer::RasterBuffer;
14use oxigdal_core::types::RasterDataType;
15
16use super::UpscaleFactor;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct SuperResConfig {
21 pub scale_factor: usize,
23 pub tile_size: usize,
25 pub overlap: usize,
27 pub batch_size: usize,
29}
30
31impl SuperResConfig {
32 #[must_use]
51 pub fn new(scale_factor: usize, tile_size: usize, overlap: usize) -> Self {
52 Self {
53 scale_factor,
54 tile_size,
55 overlap,
56 batch_size: 1,
57 }
58 }
59
60 #[must_use]
62 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
63 self.batch_size = batch_size;
64 self
65 }
66}
67
68impl Default for SuperResConfig {
69 fn default() -> Self {
70 Self::new(2, 256, 32)
71 }
72}
73
74pub struct SuperResolution {
76 session: Session,
77 config: SuperResConfig,
78}
79
80impl SuperResolution {
81 pub fn from_file<P: AsRef<Path>>(path: P, config: SuperResConfig) -> Result<Self> {
104 let path = path.as_ref();
105
106 if !path.exists() {
107 return Err(ModelError::NotFound {
108 path: path.display().to_string(),
109 }
110 .into());
111 }
112
113 let session = Session::builder()
114 .map_err(|e: ort::Error| ModelError::InitializationFailed {
115 reason: e.to_string(),
116 })?
117 .with_optimization_level(GraphOptimizationLevel::Level3)
118 .map_err(|e: ort::Error| ModelError::InitializationFailed {
119 reason: e.to_string(),
120 })?
121 .commit_from_file(path)
122 .map_err(|e: ort::Error| ModelError::LoadFailed {
123 reason: e.to_string(),
124 })?;
125
126 info!("Loaded super-resolution model from {}", path.display());
127
128 Ok(Self { session, config })
129 }
130
131 pub fn upscale(&mut self, input: &RasterBuffer) -> Result<RasterBuffer> {
158 let width = input.width() as usize;
159 let height = input.height() as usize;
160
161 debug!(
162 "Starting super-resolution upscaling: {}x{} -> {}x{}",
163 width,
164 height,
165 width * self.config.scale_factor,
166 height * self.config.scale_factor
167 );
168
169 let tiles = self.extract_tiles(input)?;
171
172 let processed_tiles = self.process_batch(&tiles)?;
174
175 let merged = self.merge_tiles(
177 &processed_tiles,
178 width * self.config.scale_factor,
179 height * self.config.scale_factor,
180 )?;
181
182 RasterBuffer::new(
184 merged
185 .as_slice()
186 .ok_or_else(|| InferenceError::OutputParsingFailed {
187 reason: "Failed to convert array to slice".to_string(),
188 })?
189 .iter()
190 .flat_map(|&v: &f32| v.to_le_bytes())
191 .collect(),
192 (width * self.config.scale_factor) as u64,
193 (height * self.config.scale_factor) as u64,
194 RasterDataType::Float32,
195 input.nodata(),
196 )
197 .map_err(Into::into)
198 }
199
200 fn extract_tiles(&self, input: &RasterBuffer) -> Result<Vec<TileInfo>> {
202 let width = input.width() as usize;
203 let height = input.height() as usize;
204 let tile_size = self.config.tile_size;
205 let overlap = self.config.overlap;
206 let stride = tile_size - overlap;
207
208 let mut tiles = Vec::new();
209
210 let mut y = 0;
211 while y < height {
212 let mut x = 0;
213 while x < width {
214 let tile_w = (tile_size).min(width - x);
215 let tile_h = (tile_size).min(height - y);
216
217 tiles.push(TileInfo {
218 x,
219 y,
220 width: tile_w,
221 height: tile_h,
222 data: self.extract_tile_data(input, x, y, tile_w, tile_h)?,
223 });
224
225 if x + tile_w >= width {
226 break;
227 }
228 x += stride;
229 }
230
231 if y + tile_size >= height {
232 break;
233 }
234 y += stride;
235 }
236
237 debug!(
238 "Extracted {} tiles from {}x{} image",
239 tiles.len(),
240 width,
241 height
242 );
243
244 Ok(tiles)
245 }
246
247 fn extract_tile_data(
249 &self,
250 input: &RasterBuffer,
251 x: usize,
252 y: usize,
253 width: usize,
254 height: usize,
255 ) -> Result<Array3<f32>> {
256 let mut tile = Array3::<f32>::zeros((1, height, width));
258
259 for ty in 0..height {
261 for tx in 0..width {
262 let pixel_idx = ((y + ty) * input.width() as usize + (x + tx)) * 4; let bytes = input.as_bytes();
264
265 if pixel_idx + 4 <= bytes.len() {
266 let value = f32::from_le_bytes([
267 bytes[pixel_idx],
268 bytes[pixel_idx + 1],
269 bytes[pixel_idx + 2],
270 bytes[pixel_idx + 3],
271 ]);
272 tile[[0, ty, tx]] = value;
273 }
274 }
275 }
276
277 Ok(tile)
278 }
279
280 fn process_batch(&mut self, tiles: &[TileInfo]) -> Result<Vec<ProcessedTile>> {
282 let mut processed = Vec::with_capacity(tiles.len());
283
284 for tile in tiles {
285 let input_tensor = tile.data.clone().insert_axis(Axis(0));
287
288 let input_value =
290 Value::from_array(input_tensor.clone()).map_err(|e: ort::Error| {
291 InferenceError::Failed {
292 reason: format!("Failed to create input tensor: {}", e),
293 }
294 })?;
295
296 let outputs =
298 self.session
299 .run(ort::inputs![input_value])
300 .map_err(|e: ort::Error| InferenceError::Failed {
301 reason: e.to_string(),
302 })?;
303
304 let output_value = &outputs[0];
306 let output_tensor =
307 output_value
308 .try_extract_tensor::<f32>()
309 .map_err(|e: ort::Error| InferenceError::OutputParsingFailed {
310 reason: e.to_string(),
311 })?;
312
313 let (shape, data) = output_tensor;
315 let shape_vec: Vec<usize> = shape.iter().map(|&d| d as usize).collect();
316
317 let output: Array4<f32> = Array4::from_shape_vec(
319 (shape_vec[0], shape_vec[1], shape_vec[2], shape_vec[3]),
320 data.to_vec(),
321 )
322 .map_err(|e| InferenceError::OutputParsingFailed {
323 reason: format!("Failed to reshape output: {}", e),
324 })?;
325
326 processed.push(ProcessedTile {
327 x: tile.x * self.config.scale_factor,
328 y: tile.y * self.config.scale_factor,
329 data: output.index_axis_move(Axis(0), 0),
330 });
331 }
332
333 Ok(processed)
334 }
335
336 fn merge_tiles(
338 &self,
339 tiles: &[ProcessedTile],
340 output_width: usize,
341 output_height: usize,
342 ) -> Result<Array3<f32>> {
343 let mut output = Array3::<f32>::zeros((1, output_height, output_width));
344 let mut weight_map = Array3::<f32>::zeros((1, output_height, output_width));
345
346 let overlap = self.config.overlap * self.config.scale_factor;
347
348 for tile in tiles {
349 let tile_height = tile.data.shape()[1];
350 let tile_width = tile.data.shape()[2];
351
352 let weights = self.create_blend_weights(tile_width, tile_height, overlap);
354
355 for c in 0..1 {
357 for ty in 0..tile_height {
358 for tx in 0..tile_width {
359 let out_y = tile.y + ty;
360 let out_x = tile.x + tx;
361
362 if out_y < output_height && out_x < output_width {
363 let weight = weights[[ty, tx]];
364 output[[c, out_y, out_x]] += tile.data[[c, ty, tx]] * weight;
365 weight_map[[c, out_y, out_x]] += weight;
366 }
367 }
368 }
369 }
370 }
371
372 output.zip_mut_with(&weight_map, |out, &w| {
374 if w > 0.0 {
375 *out /= w;
376 }
377 });
378
379 Ok(output)
380 }
381
382 fn create_blend_weights(&self, width: usize, height: usize, overlap: usize) -> Array2<f32> {
384 let mut weights = Array2::<f32>::ones((height, width));
385
386 if overlap == 0 {
387 return weights;
388 }
389
390 for y in 0..height {
392 for x in 0..width {
393 let mut w = 1.0_f32;
394
395 if x < overlap {
397 w = w.min(x as f32 / overlap as f32);
398 }
399 if x >= width - overlap {
400 w = w.min((width - x) as f32 / overlap as f32);
401 }
402 if y < overlap {
403 w = w.min(y as f32 / overlap as f32);
404 }
405 if y >= height - overlap {
406 w = w.min((height - y) as f32 / overlap as f32);
407 }
408
409 weights[[y, x]] = w;
410 }
411 }
412
413 weights
414 }
415}
416
417struct TileInfo {
419 x: usize,
420 y: usize,
421 width: usize,
422 height: usize,
423 data: Array3<f32>,
424}
425
426struct ProcessedTile {
428 x: usize,
429 y: usize,
430 data: Array3<f32>,
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 #[test]
438 fn test_config_creation() {
439 let config = SuperResConfig::new(2, 256, 32);
440 assert_eq!(config.scale_factor, 2);
441 assert_eq!(config.tile_size, 256);
442 assert_eq!(config.overlap, 32);
443 }
444
445 #[test]
446 fn test_config_default() {
447 let config = SuperResConfig::default();
448 assert_eq!(config.scale_factor, 2);
449 assert_eq!(config.batch_size, 1);
450 }
451
452 #[test]
453 #[ignore = "Requires ONNX Runtime to be installed"]
454 fn test_blend_weights() {
455 let _config = SuperResConfig::default();
456 let session = Session::builder().ok();
457
458 if session.is_none() {
459 }
461
462 }
465}