Skip to main content

oxigdal_ml/
postprocessing.rs

1//! Postprocessing operations for ML results
2//!
3//! This module provides tile merging, confidence thresholding, polygon conversion,
4//! and GeoJSON export capabilities.
5
6use geo_types::{Coord, LineString, MultiPolygon, Polygon};
7use geojson::{Feature, FeatureCollection, Geometry, GeometryValue};
8use oxigdal_core::buffer::RasterBuffer;
9use serde_json::{Map, Value as JsonValue};
10// use std::collections::HashMap;
11use std::fs::File;
12use std::io::Write;
13use std::path::Path;
14use tracing::debug;
15
16use crate::detection::GeoDetection;
17use crate::error::{PostprocessingError, Result};
18use crate::segmentation::SegmentationMask;
19
20/// Applies confidence thresholding to a probability map
21///
22/// # Errors
23/// Returns an error if thresholding fails
24pub fn apply_threshold(probabilities: &RasterBuffer, threshold: f32) -> Result<RasterBuffer> {
25    if !(0.0..=1.0).contains(&threshold) {
26        return Err(PostprocessingError::InvalidThreshold { value: threshold }.into());
27    }
28
29    let mut result = probabilities.clone();
30
31    for y in 0..probabilities.height() {
32        for x in 0..probabilities.width() {
33            let prob =
34                probabilities
35                    .get_pixel(x, y)
36                    .map_err(|e| PostprocessingError::ExportFailed {
37                        reason: format!("Failed to get probability: {}", e),
38                    })?;
39
40            let value = if prob >= threshold as f64 { 1.0 } else { 0.0 };
41
42            result
43                .set_pixel(x, y, value)
44                .map_err(|e| PostprocessingError::ExportFailed {
45                    reason: format!("Failed to set value: {}", e),
46                })?;
47        }
48    }
49
50    Ok(result)
51}
52
53/// Converts a binary mask to polygons using marching squares algorithm
54///
55/// # Errors
56/// Returns an error if conversion fails
57pub fn mask_to_polygons(mask: &RasterBuffer, min_area: f64) -> Result<Vec<Polygon>> {
58    debug!(
59        "Converting {}x{} mask to polygons",
60        mask.width(),
61        mask.height()
62    );
63
64    let mut polygons = Vec::new();
65
66    // Simplified polygon extraction using contour tracing
67    // A production implementation would use a proper marching squares algorithm
68    let width = mask.width();
69    let height = mask.height();
70
71    let mut visited = vec![vec![false; width as usize]; height as usize];
72
73    for y in 0..height {
74        for x in 0..width {
75            if visited[y as usize][x as usize] {
76                continue;
77            }
78
79            let value =
80                mask.get_pixel(x, y)
81                    .map_err(|e| PostprocessingError::PolygonConversionFailed {
82                        reason: format!("Failed to get pixel: {}", e),
83                    })?;
84
85            if value > 0.0 {
86                let polygon = trace_contour(mask, x, y, &mut visited)?;
87                let area = calculate_polygon_area(&polygon);
88
89                if area >= min_area {
90                    polygons.push(polygon);
91                }
92            }
93        }
94    }
95
96    debug!("Extracted {} polygons", polygons.len());
97
98    Ok(polygons)
99}
100
101/// Traces a contour starting from a point
102fn trace_contour(
103    mask: &RasterBuffer,
104    start_x: u64,
105    start_y: u64,
106    visited: &mut [Vec<bool>],
107) -> Result<Polygon> {
108    let mut coords = Vec::new();
109
110    // Simplified contour tracing - just creates a bounding box
111    // A real implementation would do proper boundary following
112    let mut min_x = start_x;
113    let mut min_y = start_y;
114    let mut max_x = start_x;
115    let mut max_y = start_y;
116
117    // Find extent of connected component
118    let mut stack = vec![(start_x, start_y)];
119
120    while let Some((x, y)) = stack.pop() {
121        if x >= mask.width() || y >= mask.height() {
122            continue;
123        }
124
125        if visited[y as usize][x as usize] {
126            continue;
127        }
128
129        let value =
130            mask.get_pixel(x, y)
131                .map_err(|e| PostprocessingError::PolygonConversionFailed {
132                    reason: format!("Failed to get pixel: {}", e),
133                })?;
134
135        if value > 0.0 {
136            visited[y as usize][x as usize] = true;
137
138            min_x = min_x.min(x);
139            min_y = min_y.min(y);
140            max_x = max_x.max(x);
141            max_y = max_y.max(y);
142
143            // Add neighbors
144            if x > 0 {
145                stack.push((x - 1, y));
146            }
147            if x + 1 < mask.width() {
148                stack.push((x + 1, y));
149            }
150            if y > 0 {
151                stack.push((x, y - 1));
152            }
153            if y + 1 < mask.height() {
154                stack.push((x, y + 1));
155            }
156        }
157    }
158
159    // Create rectangle polygon
160    coords.push(Coord {
161        x: min_x as f64,
162        y: min_y as f64,
163    });
164    coords.push(Coord {
165        x: max_x as f64 + 1.0,
166        y: min_y as f64,
167    });
168    coords.push(Coord {
169        x: max_x as f64 + 1.0,
170        y: max_y as f64 + 1.0,
171    });
172    coords.push(Coord {
173        x: min_x as f64,
174        y: max_y as f64 + 1.0,
175    });
176    coords.push(Coord {
177        x: min_x as f64,
178        y: min_y as f64,
179    }); // Close the ring
180
181    Ok(Polygon::new(LineString::from(coords), vec![]))
182}
183
184/// Calculates the area of a polygon
185fn calculate_polygon_area(polygon: &Polygon) -> f64 {
186    let coords = polygon.exterior().coords().collect::<Vec<_>>();
187    if coords.len() < 3 {
188        return 0.0;
189    }
190
191    let mut area = 0.0;
192    for i in 0..coords.len() - 1 {
193        area += coords[i].x * coords[i + 1].y - coords[i + 1].x * coords[i].y;
194    }
195
196    (area / 2.0).abs()
197}
198
199/// Exports detections to GeoJSON format
200///
201/// # Errors
202/// Returns an error if export fails
203pub fn export_detections_geojson<P: AsRef<Path>>(
204    detections: &[GeoDetection],
205    output_path: P,
206) -> Result<()> {
207    debug!("Exporting {} detections to GeoJSON", detections.len());
208
209    let features: Vec<Feature> = detections.iter().map(detection_to_feature).collect();
210
211    let collection = FeatureCollection {
212        bbox: None,
213        features,
214        foreign_members: None,
215    };
216
217    let json = serde_json::to_string_pretty(&collection).map_err(|e| {
218        PostprocessingError::ExportFailed {
219            reason: format!("Failed to serialize GeoJSON: {}", e),
220        }
221    })?;
222
223    let mut file =
224        File::create(output_path.as_ref()).map_err(|e| PostprocessingError::ExportFailed {
225            reason: format!("Failed to create output file: {}", e),
226        })?;
227
228    file.write_all(json.as_bytes())
229        .map_err(|e| PostprocessingError::ExportFailed {
230            reason: format!("Failed to write GeoJSON: {}", e),
231        })?;
232
233    debug!("Successfully exported detections");
234
235    Ok(())
236}
237
238/// Converts a detection to a GeoJSON feature
239fn detection_to_feature(det: &GeoDetection) -> Feature {
240    let polygon = det.geo_bbox.to_polygon();
241
242    let mut properties = Map::new();
243    properties.insert(
244        "class_id".to_string(),
245        JsonValue::Number(det.detection.class_id.into()),
246    );
247    properties.insert(
248        "confidence".to_string(),
249        JsonValue::Number(
250            serde_json::Number::from_f64(det.detection.confidence as f64)
251                .unwrap_or_else(|| serde_json::Number::from(0)),
252        ),
253    );
254
255    if let Some(ref label) = det.detection.class_label {
256        properties.insert("class_label".to_string(), JsonValue::String(label.clone()));
257    }
258
259    for (key, value) in &det.detection.attributes {
260        properties.insert(key.clone(), JsonValue::String(value.clone()));
261    }
262
263    Feature {
264        bbox: None,
265        geometry: Some(Geometry::new(GeometryValue::from(&polygon))),
266        id: None,
267        properties: Some(properties),
268        foreign_members: None,
269    }
270}
271
272/// Exports a segmentation mask to GeoJSON
273///
274/// # Errors
275/// Returns an error if export fails
276pub fn export_segmentation_geojson<P: AsRef<Path>>(
277    mask: &SegmentationMask,
278    output_path: P,
279    min_area: f64,
280) -> Result<()> {
281    debug!("Exporting segmentation mask to GeoJSON");
282
283    let polygons = mask_to_polygons(&mask.mask, min_area)?;
284
285    let features: Vec<Feature> = polygons
286        .iter()
287        .enumerate()
288        .map(|(i, poly)| {
289            let mut properties = Map::new();
290            properties.insert("id".to_string(), JsonValue::Number(i.into()));
291
292            Feature {
293                bbox: None,
294                geometry: Some(Geometry::new(GeometryValue::from(poly))),
295                id: None,
296                properties: Some(properties),
297                foreign_members: None,
298            }
299        })
300        .collect();
301
302    let collection = FeatureCollection {
303        bbox: None,
304        features,
305        foreign_members: None,
306    };
307
308    let json = serde_json::to_string_pretty(&collection).map_err(|e| {
309        PostprocessingError::ExportFailed {
310            reason: format!("Failed to serialize GeoJSON: {}", e),
311        }
312    })?;
313
314    let mut file =
315        File::create(output_path.as_ref()).map_err(|e| PostprocessingError::ExportFailed {
316            reason: format!("Failed to create output file: {}", e),
317        })?;
318
319    file.write_all(json.as_bytes())
320        .map_err(|e| PostprocessingError::ExportFailed {
321            reason: format!("Failed to write GeoJSON: {}", e),
322        })?;
323
324    debug!("Successfully exported segmentation");
325
326    Ok(())
327}
328
329/// Simplifies polygons using the Douglas-Peucker algorithm
330///
331/// # Errors
332/// Returns an error if simplification fails
333pub fn simplify_polygons(polygons: &[Polygon], tolerance: f64) -> Result<Vec<Polygon>> {
334    if tolerance < 0.0 {
335        return Err(PostprocessingError::ExportFailed {
336            reason: "Tolerance must be non-negative".to_string(),
337        }
338        .into());
339    }
340
341    // Simplified implementation - returns copy
342    // A real implementation would use proper Douglas-Peucker algorithm
343    Ok(polygons.to_vec())
344}
345
346/// Merges overlapping polygons
347///
348/// # Errors
349/// Returns an error if merging fails
350pub fn merge_polygons(polygons: &[Polygon]) -> Result<MultiPolygon> {
351    // Simplified implementation
352    // A real implementation would use proper geometry union operations
353    Ok(MultiPolygon::new(polygons.to_vec()))
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359    use oxigdal_core::types::RasterDataType;
360    use std::collections::HashMap;
361
362    #[test]
363    fn test_apply_threshold() {
364        let probs = RasterBuffer::zeros(10, 10, RasterDataType::Float32);
365        let result = apply_threshold(&probs, 0.5);
366        assert!(result.is_ok());
367    }
368
369    #[test]
370    fn test_mask_to_polygons() {
371        let mut mask = RasterBuffer::zeros(10, 10, RasterDataType::Float32);
372        let _ = mask.set_pixel(5, 5, 1.0);
373        let polygons = mask_to_polygons(&mask, 0.0);
374        assert!(polygons.is_ok());
375    }
376
377    #[test]
378    fn test_calculate_polygon_area() {
379        let polygon = Polygon::new(
380            LineString::from(vec![
381                Coord { x: 0.0, y: 0.0 },
382                Coord { x: 10.0, y: 0.0 },
383                Coord { x: 10.0, y: 10.0 },
384                Coord { x: 0.0, y: 10.0 },
385                Coord { x: 0.0, y: 0.0 },
386            ]),
387            vec![],
388        );
389
390        let area = calculate_polygon_area(&polygon);
391        assert!((area - 100.0).abs() < 1.0);
392    }
393
394    #[test]
395    fn test_export_detections_geojson() {
396        use crate::detection::{BoundingBox, Detection, GeoBoundingBox};
397        use std::env;
398
399        let temp_dir = env::temp_dir();
400        let output_path = temp_dir.join("test_detections.geojson");
401
402        let detections = vec![GeoDetection {
403            detection: Detection {
404                bbox: BoundingBox::new(0.0, 0.0, 10.0, 10.0),
405                class_id: 0,
406                class_label: Some("test".to_string()),
407                confidence: 0.9,
408                attributes: HashMap::new(),
409            },
410            geo_bbox: GeoBoundingBox {
411                min_x: 0.0,
412                min_y: 0.0,
413                max_x: 10.0,
414                max_y: 10.0,
415            },
416        }];
417
418        let result = export_detections_geojson(&detections, &output_path);
419        assert!(result.is_ok());
420
421        // Clean up
422        let _ = std::fs::remove_file(output_path);
423    }
424}