use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
use thiserror::Error;
pub mod vocab;
pub use vocab::Vocabulary;
pub mod opencv_utils;
#[cfg(feature = "opencv")]
pub use opencv_utils::*;
pub type Desc = [u8; 32];
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct BoW(pub Vec<f32>);
pub type DirectIdx = Vec<IdPath>;
pub type IdPath = SmallVec<[usize; 5]>;
impl BoW {
pub fn l1(&self, other: &Self) -> f32 {
let values = self.0.iter().zip(&other.0);
1. - 0.5 * (values.fold(0., |a, (b, c)| a + (b - c).abs()))
}
}
type BowResult<T> = std::result::Result<T, BowErr>;
#[derive(Error, Debug)]
pub enum BowErr {
#[error("No Features Provided")]
NoFeatures,
#[error("Io Error")]
Io(#[from] std::io::Error),
#[cfg(feature = "bincode")]
#[error("Vocabulary Serialization Error")]
Bincode(#[from] bincode::Error),
#[cfg(feature = "opencv")]
#[error("Opencv Error")]
OpenCvInternal(#[from] opencv::Error),
#[cfg(feature = "opencv")]
#[error("Opencv Descriptor decode error")]
OpenCvDecode,
}
#[cfg(test)]
#[cfg(feature = "opencv")]
mod test {
use super::*;
use std::path::{Path, PathBuf};
#[test]
fn test_recall() {
let features = all_kps_from_dir("data/train").unwrap();
println!("Detected {} ORB features.", features.len());
for &k in &[6_usize, 8_usize, 10_usize] {
for &l in &[3_usize, 4_usize, 5_usize] {
for _ in 0..2 {
let voc = Vocabulary::create(&features, k, l);
println!("Vocabulary: {:#?}", voc);
let mut bows: Vec<(PathBuf, BoW)> = Vec::new();
for entry in Path::new("data/test").read_dir().expect("Error").flatten() {
let new_feat = load_img_get_kps(&entry.path()).unwrap();
bows.push((entry.path(), voc.transform(&new_feat).unwrap()));
}
let num = |s: &str| -> usize {
let s = s.strip_suffix(".jpg").unwrap();
s.parse().unwrap()
};
bows.sort_by(|a, b| {
num(a.0.file_name().unwrap().to_str().unwrap())
.partial_cmp(&num(b.0.file_name().unwrap().to_str().unwrap()))
.unwrap()
});
let mut cost = 0;
for (f1, bow1) in bows.iter().skip(12).take(158) {
let mut scores: Vec<(f32, usize, i32)> = Vec::new();
let reference = num(f1.file_name().unwrap().to_str().unwrap());
for (f2, bow2) in bows.iter() {
let d = bow1.l1(bow2);
let matched = num(f2.file_name().unwrap().to_str().unwrap());
let cost = i32::abs(matched as i32 - reference as i32);
scores.push((d, matched, cost));
}
let base_cost = 36;
scores.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
for m in scores[..12].iter() {
cost += m.2;
}
cost -= base_cost;
}
println!("k: {}, l: {}. Total Cost: {}", k, l, cost);
}
}
}
}
}