Skip to main content

rustledger_ops/
ml.rs

1//! ML-based transaction categorization.
2//!
3//! Trains a Multinomial Naive Bayes classifier on existing ledger transactions
4//! to predict the expense/income account for new transactions based on their
5//! payee and narration text.
6//!
7//! Uses TF-IDF vectorization with linfa-bayes for classification.
8//!
9//! # Example
10//!
11//! ```rust,ignore
12//! let model = CategorizationModel::train(&existing_directives)?;
13//! let predictions = model.predict("WHOLE FOODS", Some("groceries"));
14//! // → [("Expenses:Groceries", 0.80)]
15//! ```
16
17use linfa::prelude::*;
18use linfa_bayes::MultinomialNb;
19use ndarray::{Array1, Array2};
20use rustledger_plugin_types::{DirectiveData, DirectiveWrapper};
21use std::collections::HashMap;
22
23/// A trained categorization model.
24///
25/// Wraps a Multinomial Naive Bayes classifier trained on TF-IDF features
26/// extracted from transaction payee/narration text.
27pub struct CategorizationModel {
28    /// The trained classifier.
29    model: MultinomialNb<f64, usize>,
30    /// Vocabulary: word → column index in the feature matrix.
31    vocabulary: HashMap<String, usize>,
32    /// IDF weights for each word in the vocabulary.
33    idf: Vec<f64>,
34    /// Label map: index → account name.
35    labels: Vec<String>,
36}
37
38/// Error type for ML operations.
39#[derive(Debug)]
40pub enum MlError {
41    /// Not enough training data.
42    InsufficientData(String),
43    /// Model training failed.
44    TrainingFailed(String),
45}
46
47impl std::fmt::Display for MlError {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        match self {
50            Self::InsufficientData(msg) => write!(f, "insufficient training data: {msg}"),
51            Self::TrainingFailed(msg) => write!(f, "training failed: {msg}"),
52        }
53    }
54}
55
56impl std::error::Error for MlError {}
57
58impl CategorizationModel {
59    /// Train a model from existing ledger directives.
60    ///
61    /// Extracts (text, account) pairs from transactions where the second
62    /// posting's account is the categorization target. Requires at least
63    /// 2 distinct categories with at least 1 transaction each.
64    ///
65    /// # Errors
66    ///
67    /// Returns `MlError::InsufficientData` if there aren't enough transactions
68    /// or distinct categories to train a useful model.
69    pub fn train(directives: &[DirectiveWrapper]) -> Result<Self, MlError> {
70        // Extract training data: (text, account) pairs
71        let mut samples: Vec<(String, String)> = Vec::new();
72
73        for d in directives {
74            if let DirectiveData::Transaction(txn) = &d.data {
75                // Skip transactions with fewer than 2 postings
76                if txn.postings.len() < 2 {
77                    continue;
78                }
79
80                // The target account is the second posting (contra-account)
81                let account = &txn.postings[1].account;
82
83                // Build text from payee + narration
84                let mut text = String::new();
85                if let Some(ref payee) = txn.payee {
86                    text.push_str(payee);
87                    text.push(' ');
88                }
89                text.push_str(&txn.narration);
90
91                if !text.trim().is_empty() {
92                    samples.push((text.to_lowercase(), account.clone()));
93                }
94            }
95        }
96
97        if samples.len() < 2 {
98            return Err(MlError::InsufficientData(format!(
99                "need at least 2 transactions, got {}",
100                samples.len()
101            )));
102        }
103
104        // Build label map
105        let mut label_set: Vec<String> = samples.iter().map(|(_, a)| a.clone()).collect();
106        label_set.sort();
107        label_set.dedup();
108
109        if label_set.len() < 2 {
110            return Err(MlError::InsufficientData(
111                "need at least 2 distinct categories".to_string(),
112            ));
113        }
114
115        let label_to_idx: HashMap<&str, usize> = label_set
116            .iter()
117            .enumerate()
118            .map(|(i, s)| (s.as_str(), i))
119            .collect();
120
121        // Build vocabulary from all tokens
122        let mut vocab: HashMap<String, usize> = HashMap::new();
123        let tokenized: Vec<Vec<String>> = samples.iter().map(|(text, _)| tokenize(text)).collect();
124
125        for tokens in &tokenized {
126            for token in tokens {
127                let len = vocab.len();
128                vocab.entry(token.clone()).or_insert(len);
129            }
130        }
131
132        if vocab.is_empty() {
133            return Err(MlError::InsufficientData(
134                "no tokens found in training data".to_string(),
135            ));
136        }
137
138        // Compute IDF weights
139        let n_docs = samples.len() as f64;
140        let mut doc_freq = vec![0u32; vocab.len()];
141        for tokens in &tokenized {
142            let mut seen = std::collections::HashSet::new();
143            for token in tokens {
144                if let Some(&idx) = vocab.get(token)
145                    && seen.insert(idx)
146                {
147                    doc_freq[idx] += 1;
148                }
149            }
150        }
151        let idf: Vec<f64> = doc_freq
152            .iter()
153            .map(|&df| (n_docs / (1.0 + f64::from(df))).ln() + 1.0)
154            .collect();
155
156        // Build TF-IDF feature matrix
157        let n_samples = samples.len();
158        let n_features = vocab.len();
159        let mut features = Array2::<f64>::zeros((n_samples, n_features));
160        let mut targets = Array1::<usize>::zeros(n_samples);
161
162        for (i, (tokens, (_, account))) in tokenized.iter().zip(samples.iter()).enumerate() {
163            // Term frequency
164            let mut tf = vec![0u32; n_features];
165            for token in tokens {
166                if let Some(&idx) = vocab.get(token) {
167                    tf[idx] += 1;
168                }
169            }
170            // TF-IDF
171            for (j, &count) in tf.iter().enumerate() {
172                features[[i, j]] = f64::from(count) * idf[j];
173            }
174            targets[i] = label_to_idx[account.as_str()];
175        }
176
177        // Train Multinomial Naive Bayes
178        let dataset = DatasetBase::new(features, targets);
179        let model = MultinomialNb::params()
180            .fit(&dataset)
181            .map_err(|e| MlError::TrainingFailed(format!("{e}")))?;
182
183        Ok(Self {
184            model,
185            vocabulary: vocab,
186            idf,
187            labels: label_set,
188        })
189    }
190
191    /// Predict the account for a transaction.
192    ///
193    /// Returns up to `top_n` predictions sorted by confidence (highest first).
194    /// Each prediction is an `(account, probability)` pair.
195    ///
196    /// **Note:** Confidence scores are not calibrated. The predicted class always
197    /// receives a confidence of `0.8` and all other classes receive `0.0`.
198    /// Computing calibrated probabilities requires log-likelihood estimation,
199    /// which is a future enhancement.
200    #[must_use]
201    pub fn predict(&self, narration: &str, payee: Option<&str>) -> Vec<(String, f64)> {
202        let mut text = String::new();
203        if let Some(p) = payee {
204            text.push_str(p);
205            text.push(' ');
206        }
207        text.push_str(narration);
208
209        let features = self.vectorize(&text.to_lowercase());
210        let features_2d = features.insert_axis(ndarray::Axis(0));
211
212        // Get prediction
213        let prediction = self.model.predict(&features_2d);
214        let predicted_idx = prediction[0];
215
216        // For Naive Bayes, we don't get probability scores directly from linfa.
217        // Return the predicted class with confidence 1.0 and others with 0.0.
218        // A more sophisticated approach would compute log-likelihoods.
219        let mut results: Vec<(String, f64)> = self
220            .labels
221            .iter()
222            .enumerate()
223            .map(|(i, label)| {
224                let conf = if i == predicted_idx { 0.8 } else { 0.0 };
225                (label.clone(), conf)
226            })
227            .filter(|(_, conf)| *conf > 0.0)
228            .collect();
229
230        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
231        results
232    }
233
234    /// Vectorize text into a TF-IDF feature array.
235    fn vectorize(&self, text: &str) -> Array1<f64> {
236        let tokens = tokenize(text);
237        let n_features = self.vocabulary.len();
238        let mut tf = vec![0u32; n_features];
239
240        for token in &tokens {
241            if let Some(&idx) = self.vocabulary.get(token) {
242                tf[idx] += 1;
243            }
244        }
245
246        let mut features = Array1::<f64>::zeros(n_features);
247        for (j, &count) in tf.iter().enumerate() {
248            features[j] = f64::from(count) * self.idf[j];
249        }
250        features
251    }
252
253    /// Number of distinct categories the model was trained on.
254    #[must_use]
255    pub const fn num_categories(&self) -> usize {
256        self.labels.len()
257    }
258
259    /// Number of features (vocabulary size).
260    #[must_use]
261    pub fn vocab_size(&self) -> usize {
262        self.vocabulary.len()
263    }
264}
265
266/// Tokenize text into lowercase words, filtering out short tokens.
267fn tokenize(text: &str) -> Vec<String> {
268    text.split(|c: char| !c.is_alphanumeric())
269        .filter(|s| s.len() >= 2)
270        .map(str::to_lowercase)
271        .collect()
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use rustledger_plugin_types::{AmountData, PostingData, TransactionData};
278
279    fn make_txn(
280        payee: Option<&str>,
281        narration: &str,
282        from_account: &str,
283        to_account: &str,
284    ) -> DirectiveWrapper {
285        DirectiveWrapper {
286            directive_type: "transaction".to_string(),
287            date: "2024-01-15".to_string(),
288            filename: None,
289            lineno: None,
290            data: DirectiveData::Transaction(TransactionData {
291                flag: "*".to_string(),
292                payee: payee.map(String::from),
293                narration: narration.to_string(),
294                tags: vec![],
295                links: vec![],
296                metadata: vec![],
297                postings: vec![
298                    PostingData {
299                        account: from_account.to_string(),
300                        units: Some(AmountData {
301                            number: "-50.00".to_string(),
302                            currency: "USD".to_string(),
303                        }),
304                        cost: None,
305                        price: None,
306                        flag: None,
307                        metadata: vec![],
308                    },
309                    PostingData {
310                        account: to_account.to_string(),
311                        units: None,
312                        cost: None,
313                        price: None,
314                        flag: None,
315                        metadata: vec![],
316                    },
317                ],
318            }),
319        }
320    }
321
322    fn training_data() -> Vec<DirectiveWrapper> {
323        vec![
324            make_txn(
325                Some("Whole Foods"),
326                "Groceries",
327                "Assets:Bank",
328                "Expenses:Groceries",
329            ),
330            make_txn(
331                Some("Trader Joe's"),
332                "Weekly groceries",
333                "Assets:Bank",
334                "Expenses:Groceries",
335            ),
336            make_txn(
337                Some("Safeway"),
338                "Food shopping",
339                "Assets:Bank",
340                "Expenses:Groceries",
341            ),
342            make_txn(
343                Some("Kroger"),
344                "Groceries",
345                "Assets:Bank",
346                "Expenses:Groceries",
347            ),
348            make_txn(
349                Some("Starbucks"),
350                "Coffee",
351                "Assets:Bank",
352                "Expenses:Dining",
353            ),
354            make_txn(
355                Some("McDonald's"),
356                "Lunch",
357                "Assets:Bank",
358                "Expenses:Dining",
359            ),
360            make_txn(Some("Chipotle"), "Dinner", "Assets:Bank", "Expenses:Dining"),
361            make_txn(
362                Some("Panera"),
363                "Coffee and sandwich",
364                "Assets:Bank",
365                "Expenses:Dining",
366            ),
367            make_txn(Some("Shell"), "Gas", "Assets:Bank", "Expenses:Transport"),
368            make_txn(Some("Chevron"), "Fuel", "Assets:Bank", "Expenses:Transport"),
369            make_txn(
370                Some("Uber"),
371                "Ride to airport",
372                "Assets:Bank",
373                "Expenses:Transport",
374            ),
375        ]
376    }
377
378    #[test]
379    fn train_and_predict() {
380        let data = training_data();
381        let model = CategorizationModel::train(&data).unwrap();
382
383        assert_eq!(model.num_categories(), 3);
384        assert!(model.vocab_size() > 5);
385
386        let predictions = model.predict("Weekly food shopping at the store", None);
387        assert!(!predictions.is_empty());
388        // Should predict Groceries (most similar to training data)
389        assert_eq!(predictions[0].0, "Expenses:Groceries");
390    }
391
392    #[test]
393    fn predict_dining() {
394        let data = training_data();
395        let model = CategorizationModel::train(&data).unwrap();
396
397        let predictions = model.predict("Coffee", Some("Starbucks"));
398        assert!(!predictions.is_empty());
399        assert_eq!(predictions[0].0, "Expenses:Dining");
400    }
401
402    #[test]
403    fn predict_transport() {
404        let data = training_data();
405        let model = CategorizationModel::train(&data).unwrap();
406
407        let predictions = model.predict("Fuel for car", Some("Shell"));
408        assert!(!predictions.is_empty());
409        assert_eq!(predictions[0].0, "Expenses:Transport");
410    }
411
412    #[test]
413    fn insufficient_data() {
414        let data = vec![make_txn(
415            Some("Store"),
416            "Stuff",
417            "Assets:Bank",
418            "Expenses:Misc",
419        )];
420        let result = CategorizationModel::train(&data);
421        assert!(result.is_err());
422    }
423
424    #[test]
425    fn insufficient_categories() {
426        let data = vec![
427            make_txn(Some("Store"), "Stuff", "Assets:Bank", "Expenses:Misc"),
428            make_txn(Some("Shop"), "Things", "Assets:Bank", "Expenses:Misc"),
429        ];
430        let result = CategorizationModel::train(&data);
431        assert!(result.is_err());
432    }
433
434    #[test]
435    fn tokenize_basic() {
436        let tokens = tokenize("WHOLE FOODS MARKET #1234");
437        assert!(tokens.contains(&"whole".to_string()));
438        assert!(tokens.contains(&"foods".to_string()));
439        assert!(tokens.contains(&"market".to_string()));
440        assert!(tokens.contains(&"1234".to_string()));
441    }
442}