#[derive(Debug, Clone)]
pub struct KnnItem {
pub label: String,
pub data: Vec<f64>,
}
#[derive(Debug, Clone)]
pub struct KnnClassifier {
pub k: usize,
pub items: Vec<KnnItem>,
}
impl KnnClassifier {
pub fn new(k: usize) -> KnnClassifier {
let k = if k > 0 { k } else { 5 };
let k = if k % 2 == 1 { k } else { k + 1 };
KnnClassifier { k, items: vec![] }
}
pub fn fit(&mut self, data: &[&[f64]], labels: &[&str]) {
data.iter().zip(labels.iter()).for_each(|(it, label)| {
let item = KnnItem { label: label.to_string(), data: it.to_vec() };
self.items.push(item);
});
}
pub fn fit_one(&mut self, data: &[f64], label: &str) {
let item = KnnItem { label: label.to_string(), data: data.to_vec() };
self.items.push(item);
}
pub fn predict_one(&self, item: &[f64]) -> String {
let mut distances: Vec<(usize, f64)> = self.items.iter().enumerate().map(|(i, it)| {
(i, calc_distance(&it.data, &item))
}).collect();
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let mut counter_map = std::collections::HashMap::new();
for (i, _) in distances.iter().take(self.k) {
let label = &self.items[*i].label;
*counter_map.entry(label).or_insert(0) += 1;
}
let label = counter_map.into_iter().max_by_key(|&(_, count)| count).unwrap().0;
label.clone()
}
pub fn predict(&self, items: &[Vec<f64>]) -> Vec<String> {
items.iter().map(|it| self.predict_one(&it.to_vec())).collect()
}
pub fn to_csv(&self, delimiter: char) -> String {
let mut s = String::new();
for it in &self.items {
s.push_str(&it.label);
s.push(delimiter);
for d in &it.data {
s.push_str(&d.to_string());
s.push(delimiter);
}
s.pop();
s.push('\n');
}
s
}
pub fn from_csv(&mut self, s: &str, delimiter: char, label_col: usize, skip_header: bool) {
for (i, line) in s.lines().enumerate() {
if skip_header && i == 0 { continue; }
let line = line.trim();
if line == "" { continue; }
let mut it = KnnItem { label: "".to_string(), data: vec![] };
let columns_iter = line.split(delimiter);
for (i, d) in columns_iter.enumerate() {
if i == label_col {
it.label = d.trim().to_string();
} else {
it.data.push(d.trim().parse().unwrap());
}
}
self.items.push(it);
}
}
}
pub fn calc_distance(a: &[f64], b: &[f64]) -> f64 {
a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum::<f64>().sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_knn1() {
let mut c = KnnClassifier::new(5);
c.fit_one(&[150.0, 80.0], "肥満");
c.fit_one(&[153.0, 69.0], "肥満");
c.fit_one(&[153.0, 94.0], "肥満");
c.fit_one(&[189.0, 96.0], "肥満");
c.fit_one(&[159.0, 74.0], "肥満");
c.fit_one(&[169.0, 64.0], "標準");
c.fit_one(&[171.0, 64.0], "標準");
c.fit_one(&[186.0, 59.0], "痩せ");
c.fit_one(&[173.0, 84.0], "肥満");
c.fit_one(&[156.0, 77.0], "肥満");
c.fit_one(&[174.0, 46.0], "痩せ");
c.fit_one(&[174.0, 54.0], "痩せ");
c.fit_one(&[162.0, 77.0], "肥満");
c.fit_one(&[151.0, 76.0], "肥満");
c.fit_one(&[188.0, 55.0], "痩せ");
c.fit_one(&[189.0, 97.0], "肥満");
c.fit_one(&[173.0, 68.0], "標準");
c.fit_one(&[174.0, 80.0], "肥満");
c.fit_one(&[167.0, 56.0], "標準");
c.fit_one(&[187.0, 95.0], "肥満");
c.fit_one(&[175.0, 100.0], "肥満");
c.fit_one(&[163.0, 73.0], "肥満");
c.fit_one(&[158.0, 79.0], "肥満");
c.fit_one(&[159.0, 45.0], "痩せ");
c.fit_one(&[170.0, 45.0], "痩せ");
c.fit_one(&[166.0, 81.0], "肥満");
c.fit_one(&[155.0, 98.0], "肥満");
c.fit_one(&[165.0, 50.0], "痩せ");
c.fit_one(&[150.0, 83.0], "肥満");
c.fit_one(&[168.0, 85.0], "肥満");
let lbl = c.predict_one(&[159.0, 85.0]);
assert_eq!(lbl, "肥満");
let lbl = c.predict_one(&[162.0, 58.0]);
assert_eq!(lbl, "標準");
let lbl = c.predict_one(&[183.0, 48.0]);
assert_eq!(lbl, "痩せ");
}
#[test]
fn test_knn2() {
let mut c = KnnClassifier::new(5);
c.fit(
&[&[150.0, 80.0], &[153.0, 69.0], &[153.0, 94.0], &[189.0, 96.0], &[159.0, 74.0], &[169.0, 64.0], &[171.0, 64.0], &[186.0, 59.0], &[173.0, 84.0], &[156.0, 77.0], &[174.0, 46.0], &[174.0, 54.0], &[162.0, 77.0], &[151.0, 76.0], &[188.0, 55.0], &[189.0, 97.0], &[173.0, 68.0], &[174.0, 80.0], &[167.0, 56.0], &[187.0, 95.0], &[175.0, 100.0], &[163.0, 73.0], &[158.0, 79.0], &[159.0, 45.0], &[170.0, 45.0], &[166.0, 81.0], &[155.0, 98.0], &[165.0, 50.0], &[150.0, 83.0], &[168.0, 85.0]],
&["肥満", "肥満", "肥満", "肥満", "肥満", "標準", "標準", "痩せ", "肥満", "肥満", "痩せ", "痩せ", "肥満", "肥満", "痩せ", "肥満", "標準", "肥満", "標準", "肥満", "肥満", "肥満", "肥満", "痩せ", "痩せ", "肥満", "肥満", "痩せ", "肥満", "肥満"]);
let labels = c.predict(&[vec![159.0, 85.0], vec![162.0, 58.0], vec![183.0, 48.0]]);
assert_eq!(labels, ["肥満", "標準", "痩せ"]);
}
#[test]
fn test_knn3() {
let mut c = KnnClassifier::new(0);
c.fit(
&[&[150.0, 80.0], &[153.0, 69.0], &[153.0, 94.0], &[189.0, 96.0], &[159.0, 74.0], &[169.0, 64.0], &[171.0, 64.0], &[186.0, 59.0], &[173.0, 84.0], &[156.0, 77.0], &[174.0, 46.0], &[174.0, 54.0], &[162.0, 77.0], &[151.0, 76.0], &[188.0, 55.0], &[189.0, 97.0], &[173.0, 68.0], &[174.0, 80.0], &[167.0, 56.0], &[187.0, 95.0], &[175.0, 100.0], &[163.0, 73.0], &[158.0, 79.0], &[159.0, 45.0], &[170.0, 45.0], &[166.0, 81.0], &[155.0, 98.0], &[165.0, 50.0], &[150.0, 83.0], &[168.0, 85.0]],
&["肥満", "肥満", "肥満", "肥満", "肥満", "標準", "標準", "痩せ", "肥満", "肥満", "痩せ", "痩せ", "肥満", "肥満", "痩せ", "肥満", "標準", "肥満", "標準", "肥満", "肥満", "肥満", "肥満", "痩せ", "痩せ", "肥満", "肥満", "痩せ", "肥満", "肥満"]);
let labels = c.predict(&[vec![159.0, 85.0], vec![162.0, 58.0], vec![183.0, 48.0]]);
assert_eq!(labels, ["肥満", "標準", "痩せ"]);
}
#[test]
fn test_to_csv() {
let mut c = KnnClassifier::new(5);
c.fit_one(&[150.0, 80.0], "肥満");
c.fit_one(&[153.0, 69.0], "肥満");
c.fit_one(&[153.0, 94.0], "肥満");
let s = c.to_csv(',');
assert_eq!(s, "肥満,150,80\n肥満,153,69\n肥満,153,94\n");
let mut c = KnnClassifier::new(5);
c.from_csv(&s, ',', 0, false);
assert_eq!(&c.to_csv(','), "肥満,150,80\n肥満,153,69\n肥満,153,94\n");
let mut c = KnnClassifier::new(5);
c.from_csv("肥満, 150, 80\n肥満 , 153, 69.0\n 肥満, 153, 94.0\n", ',', 0, false);
assert_eq!(&c.to_csv(','), "肥満,150,80\n肥満,153,69\n肥満,153,94\n");
}
}