crate::ix!();
#[derive(Getters,Debug, PartialEq)]
#[getset(get="pub")]
pub struct TransactionCategoryPrediction<TxCat:TransactionCategory> {
category: TxCat,
score: Decimal,
}
impl<TxCat:TransactionCategory> PartialOrd for TransactionCategoryPrediction<TxCat> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
other.score.partial_cmp(&self.score)
}
}
impl<TxCat:TransactionCategory> Eq for TransactionCategoryPrediction<TxCat> {}
impl<TxCat:TransactionCategory> Ord for TransactionCategoryPrediction<TxCat> {
fn cmp(&self, other: &Self) -> Ordering {
other.score.partial_cmp(&self.score).unwrap()
}
}
pub fn predict_category<TxCat:TransactionCategory>(desc: &str, category_map: &CategoryMap<TxCat>) -> Vec<TransactionCategoryPrediction<TxCat>> {
let desc = desc.to_lowercase();
let stemmed_tokens = preprocess_vendor_description(&desc);
let mut category_scores = HashMap::new();
for token in stemmed_tokens {
if let Some(categories) = category_map.get(&StemmedToken::from_str(&token).unwrap()) {
let n_categories = categories.len();
for category in categories {
let score = Decimal::ONE / Decimal::from(n_categories);
*category_scores.entry(category).or_insert(Decimal::ZERO) += score;
}
}
}
let mut predictions: Vec<_> = category_scores
.iter()
.map(|(&category, &score)| TransactionCategoryPrediction {
category: *category,
score,
})
.collect();
predictions.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
predictions
}