1use linfa::prelude::*;
18use linfa_bayes::MultinomialNb;
19use ndarray::{Array1, Array2};
20use rustledger_plugin_types::{DirectiveData, DirectiveWrapper};
21use std::collections::HashMap;
22
23pub struct CategorizationModel {
28 model: MultinomialNb<f64, usize>,
30 vocabulary: HashMap<String, usize>,
32 idf: Vec<f64>,
34 labels: Vec<String>,
36}
37
38#[derive(Debug)]
40pub enum MlError {
41 InsufficientData(String),
43 TrainingFailed(String),
45}
46
47impl std::fmt::Display for MlError {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 match self {
50 Self::InsufficientData(msg) => write!(f, "insufficient training data: {msg}"),
51 Self::TrainingFailed(msg) => write!(f, "training failed: {msg}"),
52 }
53 }
54}
55
56impl std::error::Error for MlError {}
57
58impl CategorizationModel {
59 pub fn train(directives: &[DirectiveWrapper]) -> Result<Self, MlError> {
70 let mut samples: Vec<(String, String)> = Vec::new();
72
73 for d in directives {
74 if let DirectiveData::Transaction(txn) = &d.data {
75 if txn.postings.len() < 2 {
77 continue;
78 }
79
80 let account = &txn.postings[1].account;
82
83 let mut text = String::new();
85 if let Some(ref payee) = txn.payee {
86 text.push_str(payee);
87 text.push(' ');
88 }
89 text.push_str(&txn.narration);
90
91 if !text.trim().is_empty() {
92 samples.push((text.to_lowercase(), account.clone()));
93 }
94 }
95 }
96
97 if samples.len() < 2 {
98 return Err(MlError::InsufficientData(format!(
99 "need at least 2 transactions, got {}",
100 samples.len()
101 )));
102 }
103
104 let mut label_set: Vec<String> = samples.iter().map(|(_, a)| a.clone()).collect();
106 label_set.sort();
107 label_set.dedup();
108
109 if label_set.len() < 2 {
110 return Err(MlError::InsufficientData(
111 "need at least 2 distinct categories".to_string(),
112 ));
113 }
114
115 let label_to_idx: HashMap<&str, usize> = label_set
116 .iter()
117 .enumerate()
118 .map(|(i, s)| (s.as_str(), i))
119 .collect();
120
121 let mut vocab: HashMap<String, usize> = HashMap::new();
123 let tokenized: Vec<Vec<String>> = samples.iter().map(|(text, _)| tokenize(text)).collect();
124
125 for tokens in &tokenized {
126 for token in tokens {
127 let len = vocab.len();
128 vocab.entry(token.clone()).or_insert(len);
129 }
130 }
131
132 if vocab.is_empty() {
133 return Err(MlError::InsufficientData(
134 "no tokens found in training data".to_string(),
135 ));
136 }
137
138 let n_docs = samples.len() as f64;
140 let mut doc_freq = vec![0u32; vocab.len()];
141 for tokens in &tokenized {
142 let mut seen = std::collections::HashSet::new();
143 for token in tokens {
144 if let Some(&idx) = vocab.get(token)
145 && seen.insert(idx)
146 {
147 doc_freq[idx] += 1;
148 }
149 }
150 }
151 let idf: Vec<f64> = doc_freq
152 .iter()
153 .map(|&df| (n_docs / (1.0 + f64::from(df))).ln() + 1.0)
154 .collect();
155
156 let n_samples = samples.len();
158 let n_features = vocab.len();
159 let mut features = Array2::<f64>::zeros((n_samples, n_features));
160 let mut targets = Array1::<usize>::zeros(n_samples);
161
162 for (i, (tokens, (_, account))) in tokenized.iter().zip(samples.iter()).enumerate() {
163 let mut tf = vec![0u32; n_features];
165 for token in tokens {
166 if let Some(&idx) = vocab.get(token) {
167 tf[idx] += 1;
168 }
169 }
170 for (j, &count) in tf.iter().enumerate() {
172 features[[i, j]] = f64::from(count) * idf[j];
173 }
174 targets[i] = label_to_idx[account.as_str()];
175 }
176
177 let dataset = DatasetBase::new(features, targets);
179 let model = MultinomialNb::params()
180 .fit(&dataset)
181 .map_err(|e| MlError::TrainingFailed(format!("{e}")))?;
182
183 Ok(Self {
184 model,
185 vocabulary: vocab,
186 idf,
187 labels: label_set,
188 })
189 }
190
191 #[must_use]
201 pub fn predict(&self, narration: &str, payee: Option<&str>) -> Vec<(String, f64)> {
202 let mut text = String::new();
203 if let Some(p) = payee {
204 text.push_str(p);
205 text.push(' ');
206 }
207 text.push_str(narration);
208
209 let features = self.vectorize(&text.to_lowercase());
210 let features_2d = features.insert_axis(ndarray::Axis(0));
211
212 let prediction = self.model.predict(&features_2d);
214 let predicted_idx = prediction[0];
215
216 let mut results: Vec<(String, f64)> = self
220 .labels
221 .iter()
222 .enumerate()
223 .map(|(i, label)| {
224 let conf = if i == predicted_idx { 0.8 } else { 0.0 };
225 (label.clone(), conf)
226 })
227 .filter(|(_, conf)| *conf > 0.0)
228 .collect();
229
230 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
231 results
232 }
233
234 fn vectorize(&self, text: &str) -> Array1<f64> {
236 let tokens = tokenize(text);
237 let n_features = self.vocabulary.len();
238 let mut tf = vec![0u32; n_features];
239
240 for token in &tokens {
241 if let Some(&idx) = self.vocabulary.get(token) {
242 tf[idx] += 1;
243 }
244 }
245
246 let mut features = Array1::<f64>::zeros(n_features);
247 for (j, &count) in tf.iter().enumerate() {
248 features[j] = f64::from(count) * self.idf[j];
249 }
250 features
251 }
252
253 #[must_use]
255 pub const fn num_categories(&self) -> usize {
256 self.labels.len()
257 }
258
259 #[must_use]
261 pub fn vocab_size(&self) -> usize {
262 self.vocabulary.len()
263 }
264}
265
266fn tokenize(text: &str) -> Vec<String> {
268 text.split(|c: char| !c.is_alphanumeric())
269 .filter(|s| s.len() >= 2)
270 .map(str::to_lowercase)
271 .collect()
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use rustledger_plugin_types::{AmountData, PostingData, TransactionData};
278
279 fn make_txn(
280 payee: Option<&str>,
281 narration: &str,
282 from_account: &str,
283 to_account: &str,
284 ) -> DirectiveWrapper {
285 DirectiveWrapper {
286 directive_type: "transaction".to_string(),
287 date: "2024-01-15".to_string(),
288 filename: None,
289 lineno: None,
290 data: DirectiveData::Transaction(TransactionData {
291 flag: "*".to_string(),
292 payee: payee.map(String::from),
293 narration: narration.to_string(),
294 tags: vec![],
295 links: vec![],
296 metadata: vec![],
297 postings: vec![
298 PostingData {
299 account: from_account.to_string(),
300 units: Some(AmountData {
301 number: "-50.00".to_string(),
302 currency: "USD".to_string(),
303 }),
304 cost: None,
305 price: None,
306 flag: None,
307 metadata: vec![],
308 },
309 PostingData {
310 account: to_account.to_string(),
311 units: None,
312 cost: None,
313 price: None,
314 flag: None,
315 metadata: vec![],
316 },
317 ],
318 }),
319 }
320 }
321
322 fn training_data() -> Vec<DirectiveWrapper> {
323 vec![
324 make_txn(
325 Some("Whole Foods"),
326 "Groceries",
327 "Assets:Bank",
328 "Expenses:Groceries",
329 ),
330 make_txn(
331 Some("Trader Joe's"),
332 "Weekly groceries",
333 "Assets:Bank",
334 "Expenses:Groceries",
335 ),
336 make_txn(
337 Some("Safeway"),
338 "Food shopping",
339 "Assets:Bank",
340 "Expenses:Groceries",
341 ),
342 make_txn(
343 Some("Kroger"),
344 "Groceries",
345 "Assets:Bank",
346 "Expenses:Groceries",
347 ),
348 make_txn(
349 Some("Starbucks"),
350 "Coffee",
351 "Assets:Bank",
352 "Expenses:Dining",
353 ),
354 make_txn(
355 Some("McDonald's"),
356 "Lunch",
357 "Assets:Bank",
358 "Expenses:Dining",
359 ),
360 make_txn(Some("Chipotle"), "Dinner", "Assets:Bank", "Expenses:Dining"),
361 make_txn(
362 Some("Panera"),
363 "Coffee and sandwich",
364 "Assets:Bank",
365 "Expenses:Dining",
366 ),
367 make_txn(Some("Shell"), "Gas", "Assets:Bank", "Expenses:Transport"),
368 make_txn(Some("Chevron"), "Fuel", "Assets:Bank", "Expenses:Transport"),
369 make_txn(
370 Some("Uber"),
371 "Ride to airport",
372 "Assets:Bank",
373 "Expenses:Transport",
374 ),
375 ]
376 }
377
378 #[test]
379 fn train_and_predict() {
380 let data = training_data();
381 let model = CategorizationModel::train(&data).unwrap();
382
383 assert_eq!(model.num_categories(), 3);
384 assert!(model.vocab_size() > 5);
385
386 let predictions = model.predict("Weekly food shopping at the store", None);
387 assert!(!predictions.is_empty());
388 assert_eq!(predictions[0].0, "Expenses:Groceries");
390 }
391
392 #[test]
393 fn predict_dining() {
394 let data = training_data();
395 let model = CategorizationModel::train(&data).unwrap();
396
397 let predictions = model.predict("Coffee", Some("Starbucks"));
398 assert!(!predictions.is_empty());
399 assert_eq!(predictions[0].0, "Expenses:Dining");
400 }
401
402 #[test]
403 fn predict_transport() {
404 let data = training_data();
405 let model = CategorizationModel::train(&data).unwrap();
406
407 let predictions = model.predict("Fuel for car", Some("Shell"));
408 assert!(!predictions.is_empty());
409 assert_eq!(predictions[0].0, "Expenses:Transport");
410 }
411
412 #[test]
413 fn insufficient_data() {
414 let data = vec![make_txn(
415 Some("Store"),
416 "Stuff",
417 "Assets:Bank",
418 "Expenses:Misc",
419 )];
420 let result = CategorizationModel::train(&data);
421 assert!(result.is_err());
422 }
423
424 #[test]
425 fn insufficient_categories() {
426 let data = vec![
427 make_txn(Some("Store"), "Stuff", "Assets:Bank", "Expenses:Misc"),
428 make_txn(Some("Shop"), "Things", "Assets:Bank", "Expenses:Misc"),
429 ];
430 let result = CategorizationModel::train(&data);
431 assert!(result.is_err());
432 }
433
434 #[test]
435 fn tokenize_basic() {
436 let tokens = tokenize("WHOLE FOODS MARKET #1234");
437 assert!(tokens.contains(&"whole".to_string()));
438 assert!(tokens.contains(&"foods".to_string()));
439 assert!(tokens.contains(&"market".to_string()));
440 assert!(tokens.contains(&"1234".to_string()));
441 }
442}