use linfa::prelude::*;
use linfa_bayes::MultinomialNb;
use ndarray::{Array1, Array2};
use rustledger_plugin_types::{DirectiveData, DirectiveWrapper};
use std::collections::HashMap;
pub struct CategorizationModel {
model: MultinomialNb<f64, usize>,
vocabulary: HashMap<String, usize>,
idf: Vec<f64>,
labels: Vec<String>,
}
#[derive(Debug)]
pub enum MlError {
InsufficientData(String),
TrainingFailed(String),
}
impl std::fmt::Display for MlError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InsufficientData(msg) => write!(f, "insufficient training data: {msg}"),
Self::TrainingFailed(msg) => write!(f, "training failed: {msg}"),
}
}
}
impl std::error::Error for MlError {}
impl CategorizationModel {
pub fn train(directives: &[DirectiveWrapper]) -> Result<Self, MlError> {
let mut samples: Vec<(String, String)> = Vec::new();
for d in directives {
if let DirectiveData::Transaction(txn) = &d.data {
if txn.postings.len() < 2 {
continue;
}
let account = &txn.postings[1].account;
let mut text = String::new();
if let Some(ref payee) = txn.payee {
text.push_str(payee);
text.push(' ');
}
text.push_str(&txn.narration);
if !text.trim().is_empty() {
samples.push((text.to_lowercase(), account.clone()));
}
}
}
if samples.len() < 2 {
return Err(MlError::InsufficientData(format!(
"need at least 2 transactions, got {}",
samples.len()
)));
}
let mut label_set: Vec<String> = samples.iter().map(|(_, a)| a.clone()).collect();
label_set.sort();
label_set.dedup();
if label_set.len() < 2 {
return Err(MlError::InsufficientData(
"need at least 2 distinct categories".to_string(),
));
}
let label_to_idx: HashMap<&str, usize> = label_set
.iter()
.enumerate()
.map(|(i, s)| (s.as_str(), i))
.collect();
let mut vocab: HashMap<String, usize> = HashMap::new();
let tokenized: Vec<Vec<String>> = samples.iter().map(|(text, _)| tokenize(text)).collect();
for tokens in &tokenized {
for token in tokens {
let len = vocab.len();
vocab.entry(token.clone()).or_insert(len);
}
}
if vocab.is_empty() {
return Err(MlError::InsufficientData(
"no tokens found in training data".to_string(),
));
}
let n_docs = samples.len() as f64;
let mut doc_freq = vec![0u32; vocab.len()];
for tokens in &tokenized {
let mut seen = std::collections::HashSet::new();
for token in tokens {
if let Some(&idx) = vocab.get(token)
&& seen.insert(idx)
{
doc_freq[idx] += 1;
}
}
}
let idf: Vec<f64> = doc_freq
.iter()
.map(|&df| (n_docs / (1.0 + f64::from(df))).ln() + 1.0)
.collect();
let n_samples = samples.len();
let n_features = vocab.len();
let mut features = Array2::<f64>::zeros((n_samples, n_features));
let mut targets = Array1::<usize>::zeros(n_samples);
for (i, (tokens, (_, account))) in tokenized.iter().zip(samples.iter()).enumerate() {
let mut tf = vec![0u32; n_features];
for token in tokens {
if let Some(&idx) = vocab.get(token) {
tf[idx] += 1;
}
}
for (j, &count) in tf.iter().enumerate() {
features[[i, j]] = f64::from(count) * idf[j];
}
targets[i] = label_to_idx[account.as_str()];
}
let dataset = DatasetBase::new(features, targets);
let model = MultinomialNb::params()
.fit(&dataset)
.map_err(|e| MlError::TrainingFailed(format!("{e}")))?;
Ok(Self {
model,
vocabulary: vocab,
idf,
labels: label_set,
})
}
#[must_use]
pub fn predict(&self, narration: &str, payee: Option<&str>) -> Vec<(String, f64)> {
let mut text = String::new();
if let Some(p) = payee {
text.push_str(p);
text.push(' ');
}
text.push_str(narration);
let features = self.vectorize(&text.to_lowercase());
let features_2d = features.insert_axis(ndarray::Axis(0));
let prediction = self.model.predict(&features_2d);
let predicted_idx = prediction[0];
let mut results: Vec<(String, f64)> = self
.labels
.iter()
.enumerate()
.map(|(i, label)| {
let conf = if i == predicted_idx { 0.8 } else { 0.0 };
(label.clone(), conf)
})
.filter(|(_, conf)| *conf > 0.0)
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results
}
fn vectorize(&self, text: &str) -> Array1<f64> {
let tokens = tokenize(text);
let n_features = self.vocabulary.len();
let mut tf = vec![0u32; n_features];
for token in &tokens {
if let Some(&idx) = self.vocabulary.get(token) {
tf[idx] += 1;
}
}
let mut features = Array1::<f64>::zeros(n_features);
for (j, &count) in tf.iter().enumerate() {
features[j] = f64::from(count) * self.idf[j];
}
features
}
#[must_use]
pub const fn num_categories(&self) -> usize {
self.labels.len()
}
#[must_use]
pub fn vocab_size(&self) -> usize {
self.vocabulary.len()
}
}
fn tokenize(text: &str) -> Vec<String> {
text.split(|c: char| !c.is_alphanumeric())
.filter(|s| s.len() >= 2)
.map(str::to_lowercase)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use rustledger_plugin_types::{AmountData, PostingData, TransactionData};
fn make_txn(
payee: Option<&str>,
narration: &str,
from_account: &str,
to_account: &str,
) -> DirectiveWrapper {
DirectiveWrapper {
directive_type: "transaction".to_string(),
date: "2024-01-15".to_string(),
filename: None,
lineno: None,
data: DirectiveData::Transaction(TransactionData {
flag: "*".to_string(),
payee: payee.map(String::from),
narration: narration.to_string(),
tags: vec![],
links: vec![],
metadata: vec![],
postings: vec![
PostingData {
account: from_account.to_string(),
units: Some(AmountData {
number: "-50.00".to_string(),
currency: "USD".to_string(),
}),
cost: None,
price: None,
flag: None,
metadata: vec![],
},
PostingData {
account: to_account.to_string(),
units: None,
cost: None,
price: None,
flag: None,
metadata: vec![],
},
],
}),
}
}
fn training_data() -> Vec<DirectiveWrapper> {
vec![
make_txn(
Some("Whole Foods"),
"Groceries",
"Assets:Bank",
"Expenses:Groceries",
),
make_txn(
Some("Trader Joe's"),
"Weekly groceries",
"Assets:Bank",
"Expenses:Groceries",
),
make_txn(
Some("Safeway"),
"Food shopping",
"Assets:Bank",
"Expenses:Groceries",
),
make_txn(
Some("Kroger"),
"Groceries",
"Assets:Bank",
"Expenses:Groceries",
),
make_txn(
Some("Starbucks"),
"Coffee",
"Assets:Bank",
"Expenses:Dining",
),
make_txn(
Some("McDonald's"),
"Lunch",
"Assets:Bank",
"Expenses:Dining",
),
make_txn(Some("Chipotle"), "Dinner", "Assets:Bank", "Expenses:Dining"),
make_txn(
Some("Panera"),
"Coffee and sandwich",
"Assets:Bank",
"Expenses:Dining",
),
make_txn(Some("Shell"), "Gas", "Assets:Bank", "Expenses:Transport"),
make_txn(Some("Chevron"), "Fuel", "Assets:Bank", "Expenses:Transport"),
make_txn(
Some("Uber"),
"Ride to airport",
"Assets:Bank",
"Expenses:Transport",
),
]
}
#[test]
fn train_and_predict() {
let data = training_data();
let model = CategorizationModel::train(&data).unwrap();
assert_eq!(model.num_categories(), 3);
assert!(model.vocab_size() > 5);
let predictions = model.predict("Weekly food shopping at the store", None);
assert!(!predictions.is_empty());
assert_eq!(predictions[0].0, "Expenses:Groceries");
}
#[test]
fn predict_dining() {
let data = training_data();
let model = CategorizationModel::train(&data).unwrap();
let predictions = model.predict("Coffee", Some("Starbucks"));
assert!(!predictions.is_empty());
assert_eq!(predictions[0].0, "Expenses:Dining");
}
#[test]
fn predict_transport() {
let data = training_data();
let model = CategorizationModel::train(&data).unwrap();
let predictions = model.predict("Fuel for car", Some("Shell"));
assert!(!predictions.is_empty());
assert_eq!(predictions[0].0, "Expenses:Transport");
}
#[test]
fn insufficient_data() {
let data = vec![make_txn(
Some("Store"),
"Stuff",
"Assets:Bank",
"Expenses:Misc",
)];
let result = CategorizationModel::train(&data);
assert!(result.is_err());
}
#[test]
fn insufficient_categories() {
let data = vec![
make_txn(Some("Store"), "Stuff", "Assets:Bank", "Expenses:Misc"),
make_txn(Some("Shop"), "Things", "Assets:Bank", "Expenses:Misc"),
];
let result = CategorizationModel::train(&data);
assert!(result.is_err());
}
#[test]
fn tokenize_basic() {
let tokens = tokenize("WHOLE FOODS MARKET #1234");
assert!(tokens.contains(&"whole".to_string()));
assert!(tokens.contains(&"foods".to_string()));
assert!(tokens.contains(&"market".to_string()));
assert!(tokens.contains(&"1234".to_string()));
}
}