Skip to main content

rustledger_ops/
llm.rs

1//! LLM prompt building for transaction categorization.
2//!
3//! Provides utilities for building prompts that ask an LLM to categorize
4//! transactions. Designed to be used via the MCP server, where the LLM
5//! acts as the third tier of categorization (after rules and ML).
6
7use rustledger_plugin_types::{DirectiveData, DirectiveWrapper};
8
9/// A request to categorize a transaction via LLM.
10#[derive(Debug, Clone)]
11pub struct CategorizationRequest {
12    /// The payee (if available).
13    pub payee: Option<String>,
14    /// The narration/description.
15    pub narration: String,
16    /// The amount (as a string).
17    pub amount: Option<String>,
18    /// The currency.
19    pub currency: Option<String>,
20    /// The date.
21    pub date: String,
22    /// List of known accounts in the ledger (for constrained prediction).
23    pub known_accounts: Vec<String>,
24}
25
26/// A parsed categorization response from the LLM.
27#[derive(Debug, Clone)]
28pub struct CategorizationResponse {
29    /// The predicted account.
30    pub account: String,
31    /// Brief reasoning for the prediction.
32    pub reasoning: String,
33}
34
35/// Build a prompt for transaction categorization.
36///
37/// The prompt includes the transaction details and a list of known accounts,
38/// asking the LLM to select the most appropriate account and explain why.
39#[must_use]
40pub fn build_categorization_prompt(request: &CategorizationRequest) -> String {
41    let mut prompt = String::new();
42
43    prompt.push_str("Categorize this financial transaction into the most appropriate account.\n\n");
44    prompt.push_str("Transaction:\n");
45    prompt.push_str(&format!("  Date: {}\n", request.date));
46    if let Some(ref payee) = request.payee {
47        prompt.push_str(&format!("  Payee: {payee}\n"));
48    }
49    prompt.push_str(&format!("  Description: {}\n", request.narration));
50    if let Some(ref amount) = request.amount {
51        let currency = request.currency.as_deref().unwrap_or("USD");
52        prompt.push_str(&format!("  Amount: {amount} {currency}\n"));
53    }
54
55    prompt.push_str("\nAvailable accounts:\n");
56    for account in &request.known_accounts {
57        prompt.push_str(&format!("  - {account}\n"));
58    }
59
60    prompt.push_str("\nRespond with ONLY the account name on the first line, ");
61    prompt.push_str("followed by a brief reason on the second line.\n");
62    prompt.push_str("Example:\n");
63    prompt.push_str("Expenses:Groceries\n");
64    prompt.push_str("Whole Foods is a grocery store\n");
65
66    prompt
67}
68
69/// Parse an LLM response into a structured categorization.
70///
71/// Expects the account name on the first line and reasoning on the second.
72/// Returns `None` if the response can't be parsed.
73#[must_use]
74pub fn parse_categorization_response(response: &str) -> Option<CategorizationResponse> {
75    let mut lines = response.trim().lines();
76    let account = lines.next()?.trim().to_string();
77
78    // Validate it looks like an account (contains ':')
79    if !account.contains(':') {
80        return None;
81    }
82
83    let reasoning = lines.next().unwrap_or("").trim().to_string();
84
85    Some(CategorizationResponse { account, reasoning })
86}
87
88/// Extract known expense/income accounts from directives for prompt building.
89#[must_use]
90pub fn extract_known_accounts(directives: &[DirectiveWrapper]) -> Vec<String> {
91    let mut accounts = std::collections::BTreeSet::new();
92
93    for d in directives {
94        match &d.data {
95            DirectiveData::Transaction(txn) => {
96                for posting in &txn.postings {
97                    if posting.account.starts_with("Expenses:")
98                        || posting.account.starts_with("Income:")
99                    {
100                        accounts.insert(posting.account.clone());
101                    }
102                }
103            }
104            DirectiveData::Open(open)
105                if (open.account.starts_with("Expenses:")
106                    || open.account.starts_with("Income:")) =>
107            {
108                accounts.insert(open.account.clone());
109            }
110            _ => {}
111        }
112    }
113
114    accounts.into_iter().collect()
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn build_prompt_basic() {
123        let request = CategorizationRequest {
124            payee: Some("WHOLE FOODS MARKET".to_string()),
125            narration: "Groceries".to_string(),
126            amount: Some("-85.23".to_string()),
127            currency: Some("USD".to_string()),
128            date: "2024-01-15".to_string(),
129            known_accounts: vec![
130                "Expenses:Groceries".to_string(),
131                "Expenses:Dining".to_string(),
132                "Expenses:Transport".to_string(),
133            ],
134        };
135        let prompt = build_categorization_prompt(&request);
136        assert!(prompt.contains("WHOLE FOODS MARKET"));
137        assert!(prompt.contains("-85.23 USD"));
138        assert!(prompt.contains("Expenses:Groceries"));
139        assert!(prompt.contains("Expenses:Dining"));
140    }
141
142    #[test]
143    fn parse_response_valid() {
144        let response = "Expenses:Groceries\nWhole Foods is a grocery store";
145        let parsed = parse_categorization_response(response).unwrap();
146        assert_eq!(parsed.account, "Expenses:Groceries");
147        assert_eq!(parsed.reasoning, "Whole Foods is a grocery store");
148    }
149
150    #[test]
151    fn parse_response_no_reasoning() {
152        let response = "Expenses:Dining\n";
153        let parsed = parse_categorization_response(response).unwrap();
154        assert_eq!(parsed.account, "Expenses:Dining");
155        assert_eq!(parsed.reasoning, "");
156    }
157
158    #[test]
159    fn parse_response_invalid() {
160        let response = "This is not an account";
161        assert!(parse_categorization_response(response).is_none());
162    }
163
164    #[test]
165    fn extract_accounts() {
166        use rustledger_plugin_types::OpenData;
167
168        let directives = vec![
169            DirectiveWrapper {
170                directive_type: "open".to_string(),
171                date: "2024-01-01".to_string(),
172                filename: None,
173                lineno: None,
174                data: DirectiveData::Open(OpenData {
175                    account: "Expenses:Groceries".to_string(),
176                    currencies: vec![],
177                    booking: None,
178                    metadata: vec![],
179                }),
180            },
181            DirectiveWrapper {
182                directive_type: "open".to_string(),
183                date: "2024-01-01".to_string(),
184                filename: None,
185                lineno: None,
186                data: DirectiveData::Open(OpenData {
187                    account: "Assets:Bank".to_string(),
188                    currencies: vec![],
189                    booking: None,
190                    metadata: vec![],
191                }),
192            },
193        ];
194        let accounts = extract_known_accounts(&directives);
195        assert_eq!(accounts, vec!["Expenses:Groceries"]);
196        // Assets:Bank is excluded (not Expenses: or Income:)
197    }
198}