Skip to main content

oxigdal_ml/superres/
model.rs

1//! Super-resolution model implementation
2
3use ndarray::{Array2, Array3, Array4, Axis, s};
4use ort::session::Session;
5use ort::session::builder::GraphOptimizationLevel;
6use ort::value::Value;
7use rayon::prelude::*;
8use serde::{Deserialize, Serialize};
9use std::path::Path;
10use tracing::{debug, info};
11
12use crate::error::{InferenceError, ModelError, Result};
13use oxigdal_core::buffer::RasterBuffer;
14use oxigdal_core::types::RasterDataType;
15
16use super::UpscaleFactor;
17
18/// Configuration for super-resolution
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct SuperResConfig {
21    /// Upscale factor (2x or 4x)
22    pub scale_factor: usize,
23    /// Tile size (width and height in pixels)
24    pub tile_size: usize,
25    /// Overlap between tiles (in pixels)
26    pub overlap: usize,
27    /// Batch size for processing multiple tiles
28    pub batch_size: usize,
29}
30
31impl SuperResConfig {
32    /// Creates a new super-resolution configuration
33    ///
34    /// # Arguments
35    ///
36    /// * `scale_factor` - Upscale factor (2 or 4)
37    /// * `tile_size` - Size of tiles for processing
38    /// * `overlap` - Overlap between tiles in pixels
39    ///
40    /// # Example
41    ///
42    /// ```
43    /// use oxigdal_ml::superres::SuperResConfig;
44    ///
45    /// let config = SuperResConfig::new(2, 256, 32);
46    /// assert_eq!(config.scale_factor, 2);
47    /// assert_eq!(config.tile_size, 256);
48    /// assert_eq!(config.overlap, 32);
49    /// ```
50    #[must_use]
51    pub fn new(scale_factor: usize, tile_size: usize, overlap: usize) -> Self {
52        Self {
53            scale_factor,
54            tile_size,
55            overlap,
56            batch_size: 1,
57        }
58    }
59
60    /// Set batch size for parallel processing
61    #[must_use]
62    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
63        self.batch_size = batch_size;
64        self
65    }
66}
67
68impl Default for SuperResConfig {
69    fn default() -> Self {
70        Self::new(2, 256, 32)
71    }
72}
73
74/// Super-resolution model using ONNX Runtime
75pub struct SuperResolution {
76    session: Session,
77    config: SuperResConfig,
78}
79
80impl SuperResolution {
81    /// Load a super-resolution model from an ONNX file
82    ///
83    /// # Arguments
84    ///
85    /// * `path` - Path to the ONNX model file
86    /// * `config` - Super-resolution configuration
87    ///
88    /// # Errors
89    ///
90    /// Returns an error if the model file cannot be loaded or is invalid
91    ///
92    /// # Example
93    ///
94    /// ```no_run
95    /// use oxigdal_ml::superres::{SuperResolution, SuperResConfig};
96    ///
97    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
98    /// let config = SuperResConfig::default();
99    /// let model = SuperResolution::from_file("real_esrgan_2x.onnx", config)?;
100    /// # Ok(())
101    /// # }
102    /// ```
103    pub fn from_file<P: AsRef<Path>>(path: P, config: SuperResConfig) -> Result<Self> {
104        let path = path.as_ref();
105
106        if !path.exists() {
107            return Err(ModelError::NotFound {
108                path: path.display().to_string(),
109            }
110            .into());
111        }
112
113        let session = Session::builder()
114            .map_err(|e: ort::Error| ModelError::InitializationFailed {
115                reason: e.to_string(),
116            })?
117            .with_optimization_level(GraphOptimizationLevel::Level3)
118            .map_err(|e: ort::Error| ModelError::InitializationFailed {
119                reason: e.to_string(),
120            })?
121            .commit_from_file(path)
122            .map_err(|e: ort::Error| ModelError::LoadFailed {
123                reason: e.to_string(),
124            })?;
125
126        info!("Loaded super-resolution model from {}", path.display());
127
128        Ok(Self { session, config })
129    }
130
131    /// Upscale a raster using the super-resolution model
132    ///
133    /// # Arguments
134    ///
135    /// * `input` - Input raster buffer
136    ///
137    /// # Errors
138    ///
139    /// Returns an error if inference fails or input validation fails
140    ///
141    /// # Example
142    ///
143    /// ```no_run
144    /// use oxigdal_ml::superres::{SuperResolution, SuperResConfig};
145    /// use oxigdal_core::buffer::RasterBuffer;
146    /// use oxigdal_core::types::RasterDataType;
147    ///
148    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
149    /// let config = SuperResConfig::default();
150    /// let mut model = SuperResolution::from_file("real_esrgan_2x.onnx", config)?;
151    ///
152    /// let input = RasterBuffer::zeros(512, 512, RasterDataType::Float32);
153    /// let output = model.upscale(&input)?;
154    /// # Ok(())
155    /// # }
156    /// ```
157    pub fn upscale(&mut self, input: &RasterBuffer) -> Result<RasterBuffer> {
158        let width = input.width() as usize;
159        let height = input.height() as usize;
160
161        debug!(
162            "Starting super-resolution upscaling: {}x{} -> {}x{}",
163            width,
164            height,
165            width * self.config.scale_factor,
166            height * self.config.scale_factor
167        );
168
169        // Extract tiles with overlap
170        let tiles = self.extract_tiles(input)?;
171
172        // Process tiles in batches (sequential for now due to &mut self requirement)
173        let processed_tiles = self.process_batch(&tiles)?;
174
175        // Merge tiles with blending
176        let merged = self.merge_tiles(
177            &processed_tiles,
178            width * self.config.scale_factor,
179            height * self.config.scale_factor,
180        )?;
181
182        // Create output buffer
183        RasterBuffer::new(
184            merged
185                .as_slice()
186                .ok_or_else(|| InferenceError::OutputParsingFailed {
187                    reason: "Failed to convert array to slice".to_string(),
188                })?
189                .iter()
190                .flat_map(|&v: &f32| v.to_le_bytes())
191                .collect(),
192            (width * self.config.scale_factor) as u64,
193            (height * self.config.scale_factor) as u64,
194            RasterDataType::Float32,
195            input.nodata(),
196        )
197        .map_err(Into::into)
198    }
199
200    /// Extract tiles from input raster with overlap
201    fn extract_tiles(&self, input: &RasterBuffer) -> Result<Vec<TileInfo>> {
202        let width = input.width() as usize;
203        let height = input.height() as usize;
204        let tile_size = self.config.tile_size;
205        let overlap = self.config.overlap;
206        let stride = tile_size - overlap;
207
208        let mut tiles = Vec::new();
209
210        let mut y = 0;
211        while y < height {
212            let mut x = 0;
213            while x < width {
214                let tile_w = (tile_size).min(width - x);
215                let tile_h = (tile_size).min(height - y);
216
217                tiles.push(TileInfo {
218                    x,
219                    y,
220                    width: tile_w,
221                    height: tile_h,
222                    data: self.extract_tile_data(input, x, y, tile_w, tile_h)?,
223                });
224
225                if x + tile_w >= width {
226                    break;
227                }
228                x += stride;
229            }
230
231            if y + tile_size >= height {
232                break;
233            }
234            y += stride;
235        }
236
237        debug!(
238            "Extracted {} tiles from {}x{} image",
239            tiles.len(),
240            width,
241            height
242        );
243
244        Ok(tiles)
245    }
246
247    /// Extract single tile data
248    fn extract_tile_data(
249        &self,
250        input: &RasterBuffer,
251        x: usize,
252        y: usize,
253        width: usize,
254        height: usize,
255    ) -> Result<Array3<f32>> {
256        // For simplicity, assume single-band input (can be extended to multi-band)
257        let mut tile = Array3::<f32>::zeros((1, height, width));
258
259        // Extract data from buffer (simplified - assumes Float32)
260        for ty in 0..height {
261            for tx in 0..width {
262                let pixel_idx = ((y + ty) * input.width() as usize + (x + tx)) * 4; // 4 bytes for Float32
263                let bytes = input.as_bytes();
264
265                if pixel_idx + 4 <= bytes.len() {
266                    let value = f32::from_le_bytes([
267                        bytes[pixel_idx],
268                        bytes[pixel_idx + 1],
269                        bytes[pixel_idx + 2],
270                        bytes[pixel_idx + 3],
271                    ]);
272                    tile[[0, ty, tx]] = value;
273                }
274            }
275        }
276
277        Ok(tile)
278    }
279
280    /// Process a batch of tiles through the model
281    fn process_batch(&mut self, tiles: &[TileInfo]) -> Result<Vec<ProcessedTile>> {
282        let mut processed = Vec::with_capacity(tiles.len());
283
284        for tile in tiles {
285            // Create input tensor
286            let input_tensor = tile.data.clone().insert_axis(Axis(0));
287
288            // Create ONNX value from array
289            let input_value =
290                Value::from_array(input_tensor.clone()).map_err(|e: ort::Error| {
291                    InferenceError::Failed {
292                        reason: format!("Failed to create input tensor: {}", e),
293                    }
294                })?;
295
296            // Run inference
297            let outputs =
298                self.session
299                    .run(ort::inputs![input_value])
300                    .map_err(|e: ort::Error| InferenceError::Failed {
301                        reason: e.to_string(),
302                    })?;
303
304            // Extract output (first output)
305            let output_value = &outputs[0];
306            let output_tensor =
307                output_value
308                    .try_extract_tensor::<f32>()
309                    .map_err(|e: ort::Error| InferenceError::OutputParsingFailed {
310                        reason: e.to_string(),
311                    })?;
312
313            // Get shape and data
314            let (shape, data) = output_tensor;
315            let shape_vec: Vec<usize> = shape.iter().map(|&d| d as usize).collect();
316
317            // Convert to ndarray
318            let output: Array4<f32> = Array4::from_shape_vec(
319                (shape_vec[0], shape_vec[1], shape_vec[2], shape_vec[3]),
320                data.to_vec(),
321            )
322            .map_err(|e| InferenceError::OutputParsingFailed {
323                reason: format!("Failed to reshape output: {}", e),
324            })?;
325
326            processed.push(ProcessedTile {
327                x: tile.x * self.config.scale_factor,
328                y: tile.y * self.config.scale_factor,
329                data: output.index_axis_move(Axis(0), 0),
330            });
331        }
332
333        Ok(processed)
334    }
335
336    /// Merge processed tiles with overlap blending
337    fn merge_tiles(
338        &self,
339        tiles: &[ProcessedTile],
340        output_width: usize,
341        output_height: usize,
342    ) -> Result<Array3<f32>> {
343        let mut output = Array3::<f32>::zeros((1, output_height, output_width));
344        let mut weight_map = Array3::<f32>::zeros((1, output_height, output_width));
345
346        let overlap = self.config.overlap * self.config.scale_factor;
347
348        for tile in tiles {
349            let tile_height = tile.data.shape()[1];
350            let tile_width = tile.data.shape()[2];
351
352            // Create weight matrix for alpha blending
353            let weights = self.create_blend_weights(tile_width, tile_height, overlap);
354
355            // Blend tile into output
356            for c in 0..1 {
357                for ty in 0..tile_height {
358                    for tx in 0..tile_width {
359                        let out_y = tile.y + ty;
360                        let out_x = tile.x + tx;
361
362                        if out_y < output_height && out_x < output_width {
363                            let weight = weights[[ty, tx]];
364                            output[[c, out_y, out_x]] += tile.data[[c, ty, tx]] * weight;
365                            weight_map[[c, out_y, out_x]] += weight;
366                        }
367                    }
368                }
369            }
370        }
371
372        // Normalize by weight map
373        output.zip_mut_with(&weight_map, |out, &w| {
374            if w > 0.0 {
375                *out /= w;
376            }
377        });
378
379        Ok(output)
380    }
381
382    /// Create blend weights for smooth tile merging
383    fn create_blend_weights(&self, width: usize, height: usize, overlap: usize) -> Array2<f32> {
384        let mut weights = Array2::<f32>::ones((height, width));
385
386        if overlap == 0 {
387            return weights;
388        }
389
390        // Create linear blend in overlap regions
391        for y in 0..height {
392            for x in 0..width {
393                let mut w = 1.0_f32;
394
395                // Blend on edges
396                if x < overlap {
397                    w = w.min(x as f32 / overlap as f32);
398                }
399                if x >= width - overlap {
400                    w = w.min((width - x) as f32 / overlap as f32);
401                }
402                if y < overlap {
403                    w = w.min(y as f32 / overlap as f32);
404                }
405                if y >= height - overlap {
406                    w = w.min((height - y) as f32 / overlap as f32);
407                }
408
409                weights[[y, x]] = w;
410            }
411        }
412
413        weights
414    }
415}
416
417/// Information about an extracted tile
418struct TileInfo {
419    x: usize,
420    y: usize,
421    width: usize,
422    height: usize,
423    data: Array3<f32>,
424}
425
426/// A processed (upscaled) tile
427struct ProcessedTile {
428    x: usize,
429    y: usize,
430    data: Array3<f32>,
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    #[test]
438    fn test_config_creation() {
439        let config = SuperResConfig::new(2, 256, 32);
440        assert_eq!(config.scale_factor, 2);
441        assert_eq!(config.tile_size, 256);
442        assert_eq!(config.overlap, 32);
443    }
444
445    #[test]
446    fn test_config_default() {
447        let config = SuperResConfig::default();
448        assert_eq!(config.scale_factor, 2);
449        assert_eq!(config.batch_size, 1);
450    }
451
452    #[test]
453    #[ignore = "Requires ONNX Runtime to be installed"]
454    fn test_blend_weights() {
455        let _config = SuperResConfig::default();
456        let session = Session::builder().ok();
457
458        if session.is_none() {
459            // Skip if ONNX Runtime not available
460        }
461
462        // We can't easily test this without a full model
463        // Just verify the method signature compiles
464    }
465}