1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
extern crate ndarray;

use ndarray::Array2;
use std::error::Error;

/// Generates a one-hot encoding for a vector of integers.
///
/// # Arguments
///
/// * `input_array`: List of integers to be encoded. Each integer should be less than or equal to the maximum integer in the array.
///
/// Returns: `Array2<f64>` where each row represents the one-hot encoding of the corresponding integer from the input array. The columns represent the range of integers from the input array.
///
/// # Errors
///
/// Returns an error if the `input_array` is empty.
///
/// # Examples
///
/// ```
/// use ducky_learn::util::one_hot_encoding_vec;
/// use ndarray::prelude::*;
///
/// let input_array = vec![2, 0, 1];
/// let output_array = array![
///     [0., 0., 1.],
///     [1., 0., 0.],
///     [0., 1., 0.]
/// ];
///
/// assert_eq!(one_hot_encoding_vec(input_array).unwrap(), output_array);
/// ```
pub fn one_hot_encoding_vec<T: AsRef<[usize]>>(input_array: T) -> Result<Array2<f64>, Box<dyn Error>> {
    let input_array = input_array.as_ref();
    let max_val = match input_array.iter().max() {
        Some(&max) => max + 1,
        None => return Err("Empty input array".into()),
    };

    let mut encoding_array: Vec<Vec<f64>> = Vec::with_capacity(input_array.len());

    for &input_value in input_array {
        let mut row = vec![0.; max_val];
        row[input_value] = 1.;
        encoding_array.push(row);
    }

    let data: Vec<f64> = encoding_array.into_iter().flatten().collect();
    let n_row = input_array.len();
    let n_col = max_val;

    Array2::from_shape_vec((n_row, n_col), data).map_err(|err| err.into())
}