eorst 1.0.1

Earth Observation and Remote Sensing Toolkit - library for raster processing pipelines
//! ML methods for RasterDataset (LightGBM integration).
//!
//! This module requires the `use_lgbm` feature flag.

use crate::core_types::RasterType;
use crate::types::SamplingMethod;
use crate::gdal_utils::{create_rayon_pool, create_temp_file, file_stem_str, get_class, mosaic_translate_cleanup};
use crate::rasterdataset::RasterDataset;

use lgbm::Booster;
use lgbm::Dataset as LgbmDataset;
use lgbm::Field;
use lgbm::Parameters;
use lgbm::PredictType;
use lgbm::mat::{MatBuf, MatLayouts};
use ndarray::Array;
use rayon::prelude::*;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;

impl<R> RasterDataset<R>
where
    R: RasterType,
{
    /// Fits a LightGBM classifier using extracted raster values.
    #[cfg(feature = "use_lgbm")]
    pub fn lightgbm_fit_classifier(
        &self,
        training_data: &PathBuf,
        parameters: &HashMap<&str, String>,
        class_column: &str,
    ) {
        let mut params = Parameters::new();
        for (k, v) in parameters.clone() {
            params.push(k, v);
        }

        let x = self.extract_blockwise(training_data, "id", SamplingMethod::Value, None);
        let x: Vec<Vec<f64>> = x
            .into_values()
            .map(|v| v.iter().map(|v| *v as f64).collect())
            .collect();
        let num_cols = x.first().map(|row| row.len()).unwrap_or(0);
        let num_rows = x.len();
        let layout = MatLayouts::RowMajor;
        let x: Vec<f64> = x.into_iter().flatten().collect();
        let x = MatBuf::from_vec(x, num_rows, num_cols, layout);

        let labels = get_class(training_data, "id", class_column);
        let labels: Vec<f32> = labels.values().map(|v| *v as f32).collect();
        log::debug!("Fit parameters are: {:?}", params);
        let mut train_dataset = LgbmDataset::from_mat(&x, None, &params).unwrap();
        train_dataset.set_field(Field::LABEL, &labels).unwrap();

        let mut b = Booster::new(Arc::new(train_dataset), &params).unwrap();
        let num_iter: i32 = parameters
            .get("num_iterations")
            .and_then(|v| v.parse().ok())
            .unwrap_or(100);

        for _i in 0..num_iter {
            let _current_score = b.get_eval(0).unwrap()[0];
            let u = b.update_one_iter().unwrap();
            if u {
                break;
            }
        }
        b.save_model(
            0,
            None,
            lgbm::FeatureImportanceType::Gain,
            &PathBuf::from("test.mod").to_owned(),
        )
        .unwrap();
    }

    /// Predicts using a LightGBM classifier model.
    #[cfg(feature = "use_lgbm")]
    pub fn lightgbm_predict_classifier(&self, out_file: &Path, parameters: &HashMap<&str, String>) {
        let pool = create_rayon_pool(1);

        let tmp_file = PathBuf::from(create_temp_file("vrt"));
        let handle = pool.install(|| {
            self.blocks
                .to_owned()
                .into_par_iter()
                .map(|raster_block: crate::blocks::RasterBlock<R>| -> PathBuf {
                    let booster =
                        Booster::from_file(&PathBuf::from("test.mod").to_owned()).unwrap();
                    let booster = booster.0;
                    let id = raster_block.block_index;
                    let file_stem = file_stem_str(&tmp_file);
                    let block_fn = tmp_file.with_file_name(format!("{}_{}.tif", file_stem, id));
                    let block_data = self.read_block::<f64>(id);
                    let features = block_data
                        .t()
                        .into_shape_with_order((
                            block_data.shape()[2] * block_data.shape()[3],
                            block_data.shape()[1],
                        ))
                        .unwrap();

                    let mut features_vec = Vec::new();

                    for row in features.rows() {
                        features_vec.push(row.to_vec());
                    }
                    let num_rows = features_vec.len();
                    let num_cols = features_vec[0].len();
                    let features_vec = features_vec.into_iter().flatten().collect();
                    let layout = MatLayouts::RowMajor;

                    let predict_features =
                        MatBuf::from_vec(features_vec, num_rows, num_cols, layout);

                    let params = Parameters::new();
                    let n_classes: usize = parameters["num_class"].parse().unwrap();

                    let binding = booster
                        .predict_for_mat(
                            predict_features,
                            PredictType::Normal,
                            0,
                            None,
                            &params,
                        )
                        .unwrap();
                    let predictions = binding.values();
                    let predicted_classes: Vec<u8> = predictions
                        .chunks_exact(n_classes)
                        .map(|row| {
                            row.iter()
                                .enumerate()
                                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
                                .unwrap()
                                .0 as u8
                        })
                        .collect();

                    let result = Array::from_shape_vec(
                        (block_data.shape()[2] * block_data.shape()[3], 1),
                        predicted_classes,
                    )
                    .unwrap();
                    raster_block.write_samples_feature(&result, &block_fn, u8::MAX);
                    block_fn
                })
        });

        let collected: Vec<PathBuf> = handle.collect();
        mosaic_translate_cleanup(&collected, &tmp_file, out_file, self.metadata.epsg_code);
    }
}