use rustledger_plugin_types::{DirectiveData, DirectiveWrapper};
use std::collections::HashMap;
pub struct CategorizationModel {
model: MultinomialNB,
vocabulary: HashMap<String, usize>,
idf: Vec<f64>,
labels: Vec<String>,
}
#[derive(Debug)]
#[non_exhaustive]
pub enum MlError {
InsufficientData(String),
}
impl std::fmt::Display for MlError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self::InsufficientData(msg) = self;
write!(f, "insufficient training data: {msg}")
}
}
impl std::error::Error for MlError {}
impl CategorizationModel {
pub fn train(directives: &[DirectiveWrapper]) -> Result<Self, MlError> {
let mut samples: Vec<(String, String)> = Vec::new();
for d in directives {
if let DirectiveData::Transaction(txn) = &d.data {
if txn.postings.len() < 2 {
continue;
}
let account = &txn.postings[1].account;
let mut text = String::new();
if let Some(ref payee) = txn.payee {
text.push_str(payee);
text.push(' ');
}
text.push_str(&txn.narration);
if !text.trim().is_empty() {
samples.push((text.to_lowercase(), account.clone()));
}
}
}
if samples.len() < 2 {
return Err(MlError::InsufficientData(format!(
"need at least 2 transactions, got {}",
samples.len()
)));
}
let mut label_set: Vec<String> = samples.iter().map(|(_, a)| a.clone()).collect();
label_set.sort();
label_set.dedup();
if label_set.len() < 2 {
return Err(MlError::InsufficientData(
"need at least 2 distinct categories".to_string(),
));
}
let label_to_idx: HashMap<&str, usize> = label_set
.iter()
.enumerate()
.map(|(i, s)| (s.as_str(), i))
.collect();
let mut vocab: HashMap<String, usize> = HashMap::new();
let tokenized: Vec<Vec<String>> = samples.iter().map(|(text, _)| tokenize(text)).collect();
for tokens in &tokenized {
for token in tokens {
let len = vocab.len();
vocab.entry(token.clone()).or_insert(len);
}
}
if vocab.is_empty() {
return Err(MlError::InsufficientData(
"no tokens found in training data".to_string(),
));
}
let n_docs = samples.len() as f64;
let mut doc_freq = vec![0u32; vocab.len()];
for tokens in &tokenized {
let mut seen = std::collections::HashSet::new();
for token in tokens {
if let Some(&idx) = vocab.get(token)
&& seen.insert(idx)
{
doc_freq[idx] += 1;
}
}
}
let idf: Vec<f64> = doc_freq
.iter()
.map(|&df| (n_docs / (1.0 + f64::from(df))).ln() + 1.0)
.collect();
let n_features = vocab.len();
let mut features: Vec<Vec<(usize, f64)>> = Vec::with_capacity(samples.len());
let mut targets: Vec<usize> = Vec::with_capacity(samples.len());
for (tokens, (_, account)) in tokenized.iter().zip(samples.iter()) {
features.push(tfidf_row(tokens, &vocab, &idf));
targets.push(label_to_idx[account.as_str()]);
}
let model = MultinomialNB::fit(&features, &targets, label_set.len(), n_features);
Ok(Self {
model,
vocabulary: vocab,
idf,
labels: label_set,
})
}
#[must_use]
pub fn predict(&self, narration: &str, payee: Option<&str>) -> Vec<(String, f64)> {
let mut text = String::new();
if let Some(p) = payee {
text.push_str(p);
text.push(' ');
}
text.push_str(narration);
let features = self.vectorize(&text.to_lowercase());
let mut results: Vec<(String, f64)> = self
.labels
.iter()
.cloned()
.zip(self.model.predict_proba(&features))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results
}
fn vectorize(&self, text: &str) -> Vec<(usize, f64)> {
tfidf_row(&tokenize(text), &self.vocabulary, &self.idf)
}
#[must_use]
pub const fn num_categories(&self) -> usize {
self.labels.len()
}
#[must_use]
pub fn vocab_size(&self) -> usize {
self.vocabulary.len()
}
}
struct MultinomialNB {
class_log_prior: Vec<f64>,
feature_log_prob: Vec<Vec<f64>>,
}
impl MultinomialNB {
const ALPHA: f64 = 1.0;
fn fit(
features: &[Vec<(usize, f64)>],
targets: &[usize],
n_classes: usize,
n_features: usize,
) -> Self {
debug_assert_eq!(
features.len(),
targets.len(),
"features and targets must be parallel"
);
let n_samples = features.len() as f64;
let mut class_count = vec![0.0_f64; n_classes];
let mut feature_count = vec![vec![0.0_f64; n_features]; n_classes];
for (row, &class) in features.iter().zip(targets) {
class_count[class] += 1.0;
let counts = &mut feature_count[class];
for &(j, value) in row {
counts[j] += value;
}
}
let class_log_prior = class_count.iter().map(|&n| (n / n_samples).ln()).collect();
let feature_log_prob = feature_count
.iter()
.map(|counts| {
let denom: f64 = Self::ALPHA.mul_add(n_features as f64, counts.iter().sum::<f64>());
counts
.iter()
.map(|&count| ((count + Self::ALPHA) / denom).ln())
.collect()
})
.collect();
Self {
class_log_prior,
feature_log_prob,
}
}
fn predict_proba(&self, x: &[(usize, f64)]) -> Vec<f64> {
let jll: Vec<f64> = self
.class_log_prior
.iter()
.zip(&self.feature_log_prob)
.map(|(&prior, log_prob)| prior + x.iter().map(|&(j, v)| v * log_prob[j]).sum::<f64>())
.collect();
let max = jll.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = jll.iter().map(|&v| (v - max).exp()).collect();
let total: f64 = exps.iter().sum();
exps.iter().map(|&e| e / total).collect()
}
}
fn tokenize(text: &str) -> Vec<String> {
text.split(|c: char| !c.is_alphanumeric())
.filter(|s| s.len() >= 2)
.map(str::to_lowercase)
.collect()
}
fn tfidf_row(tokens: &[String], vocab: &HashMap<String, usize>, idf: &[f64]) -> Vec<(usize, f64)> {
let mut tf: HashMap<usize, u32> = HashMap::new();
for token in tokens {
if let Some(&idx) = vocab.get(token) {
*tf.entry(idx).or_insert(0) += 1;
}
}
let mut row: Vec<(usize, f64)> = tf
.into_iter()
.map(|(idx, count)| (idx, f64::from(count) * idf[idx]))
.collect();
row.sort_unstable_by_key(|&(idx, _)| idx);
row
}
#[cfg(test)]
mod tests {
use super::*;
use rustledger_plugin_types::{AmountData, PostingData, TransactionData};
fn make_txn(
payee: Option<&str>,
narration: &str,
from_account: &str,
to_account: &str,
) -> DirectiveWrapper {
DirectiveWrapper {
directive_type: "transaction".to_string(),
date: "2024-01-15".to_string(),
filename: None,
lineno: None,
data: DirectiveData::Transaction(TransactionData {
flag: "*".to_string(),
payee: payee.map(String::from),
narration: narration.to_string(),
tags: vec![],
links: vec![],
metadata: vec![],
postings: vec![
PostingData {
account: from_account.to_string(),
units: Some(AmountData {
number: "-50.00".to_string(),
currency: "USD".to_string(),
}),
cost: None,
price: None,
flag: None,
metadata: vec![],
span: None,
},
PostingData {
account: to_account.to_string(),
units: None,
cost: None,
price: None,
flag: None,
metadata: vec![],
span: None,
},
],
}),
}
}
fn training_data() -> Vec<DirectiveWrapper> {
vec![
make_txn(
Some("Whole Foods"),
"Groceries",
"Assets:Bank",
"Expenses:Groceries",
),
make_txn(
Some("Trader Joe's"),
"Weekly groceries",
"Assets:Bank",
"Expenses:Groceries",
),
make_txn(
Some("Safeway"),
"Food shopping",
"Assets:Bank",
"Expenses:Groceries",
),
make_txn(
Some("Kroger"),
"Groceries",
"Assets:Bank",
"Expenses:Groceries",
),
make_txn(
Some("Starbucks"),
"Coffee",
"Assets:Bank",
"Expenses:Dining",
),
make_txn(
Some("McDonald's"),
"Lunch",
"Assets:Bank",
"Expenses:Dining",
),
make_txn(Some("Chipotle"), "Dinner", "Assets:Bank", "Expenses:Dining"),
make_txn(
Some("Panera"),
"Coffee and sandwich",
"Assets:Bank",
"Expenses:Dining",
),
make_txn(Some("Shell"), "Gas", "Assets:Bank", "Expenses:Transport"),
make_txn(Some("Chevron"), "Fuel", "Assets:Bank", "Expenses:Transport"),
make_txn(
Some("Uber"),
"Ride to airport",
"Assets:Bank",
"Expenses:Transport",
),
]
}
#[test]
fn train_and_predict() {
let data = training_data();
let model = CategorizationModel::train(&data).unwrap();
assert_eq!(model.num_categories(), 3);
assert!(model.vocab_size() > 5);
let predictions = model.predict("Weekly food shopping at the store", None);
assert!(!predictions.is_empty());
assert_eq!(predictions[0].0, "Expenses:Groceries");
}
#[test]
fn predict_dining() {
let data = training_data();
let model = CategorizationModel::train(&data).unwrap();
let predictions = model.predict("Coffee", Some("Starbucks"));
assert!(!predictions.is_empty());
assert_eq!(predictions[0].0, "Expenses:Dining");
}
#[test]
fn predict_transport() {
let data = training_data();
let model = CategorizationModel::train(&data).unwrap();
let predictions = model.predict("Fuel for car", Some("Shell"));
assert!(!predictions.is_empty());
assert_eq!(predictions[0].0, "Expenses:Transport");
}
#[test]
fn insufficient_data() {
let data = vec![make_txn(
Some("Store"),
"Stuff",
"Assets:Bank",
"Expenses:Misc",
)];
let result = CategorizationModel::train(&data);
assert!(result.is_err());
}
#[test]
fn insufficient_categories() {
let data = vec![
make_txn(Some("Store"), "Stuff", "Assets:Bank", "Expenses:Misc"),
make_txn(Some("Shop"), "Things", "Assets:Bank", "Expenses:Misc"),
];
let result = CategorizationModel::train(&data);
assert!(result.is_err());
}
#[test]
fn tokenize_basic() {
let tokens = tokenize("WHOLE FOODS MARKET #1234");
assert!(tokens.contains(&"whole".to_string()));
assert!(tokens.contains(&"foods".to_string()));
assert!(tokens.contains(&"market".to_string()));
assert!(tokens.contains(&"1234".to_string()));
}
#[test]
fn naive_bayes_known_values() {
let nb = MultinomialNB::fit(&[vec![(0, 2.0)], vec![(1, 2.0)]], &[0, 1], 2, 2);
let p = nb.predict_proba(&[(0, 1.0)]);
assert!((p[0] - 0.75).abs() < 1e-9, "p[0] = {}", p[0]);
assert!((p[1] - 0.25).abs() < 1e-9, "p[1] = {}", p[1]);
assert!(
(p.iter().sum::<f64>() - 1.0).abs() < 1e-12,
"posteriors must sum to 1.0"
);
let q = nb.predict_proba(&[(1, 1.0)]);
assert!((q[0] - 0.25).abs() < 1e-9 && (q[1] - 0.75).abs() < 1e-9);
}
}