gradients/
accuracy.rs

1use custos::CDatatype;
2use custos_math::Matrix;
3
4pub fn find_idxs<T: Copy + Default + PartialEq>(
5    search_for: &Matrix<T>,
6    search_with: &Matrix<T>,
7) -> Vec<usize> {
8    let rows = search_for.rows();
9    let search_for = search_for.read();
10    let search_with = search_with.read();
11    purpur::utils::find_idxs(rows, &search_for, &search_with)
12}
13
14pub fn correct_classes<T: CDatatype>(targets: &[usize], search_for: &Matrix<T>) -> usize {
15    let search_with = search_for.max_cols();
16    let idxs = find_idxs(&search_for, &search_with);
17    let mut correct = 0;
18    for (idx, correct_idx) in idxs.iter().zip(targets) {
19        if idx == correct_idx {
20            correct += 1;
21        }
22    }
23    correct
24}