rustledger_loader/
cache.rs

1//! Binary cache for parsed ledgers.
2//!
3//! This module provides a caching layer that can dramatically speed up
4//! subsequent loads of unchanged beancount files by serializing the parsed
5//! directives to a binary format using rkyv.
6//!
7//! # How it works
8//!
9//! 1. When loading a file, compute a hash of all source files
10//! 2. Check if a cache file exists with a matching hash
11//! 3. If yes, deserialize and return immediately (typically <1ms)
12//! 4. If no, parse normally, serialize to cache, and return
13//!
14//! # Cache location
15//!
16//! Cache files are stored alongside the main ledger file with a `.cache` extension.
17//! For example, `ledger.beancount` would have cache at `ledger.beancount.cache`.
18
19use crate::Options;
20use rust_decimal::Decimal;
21use rustledger_core::Directive;
22use rustledger_core::intern::StringInterner;
23use rustledger_parser::Spanned;
24use sha2::{Digest, Sha256};
25use std::fs;
26use std::io::{Read, Write};
27use std::path::{Path, PathBuf};
28use std::str::FromStr;
29
30/// Cached plugin information.
31#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
32pub struct CachedPlugin {
33    /// Plugin module name.
34    pub name: String,
35    /// Optional configuration string.
36    pub config: Option<String>,
37}
38
39/// Cached options - a serializable subset of Options.
40///
41/// Excludes parsing-time fields like `set_options` and `warnings`.
42/// These fields mirror the Options struct and inherit their meaning.
43#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
44#[allow(missing_docs)]
45pub struct CachedOptions {
46    pub title: Option<String>,
47    pub filename: Option<String>,
48    pub operating_currency: Vec<String>,
49    pub name_assets: String,
50    pub name_liabilities: String,
51    pub name_equity: String,
52    pub name_income: String,
53    pub name_expenses: String,
54    pub account_rounding: Option<String>,
55    pub account_previous_balances: String,
56    pub account_previous_earnings: String,
57    pub account_previous_conversions: String,
58    pub account_current_earnings: String,
59    pub account_current_conversions: Option<String>,
60    pub account_unrealized_gains: Option<String>,
61    pub conversion_currency: Option<String>,
62    /// Stored as (currency, `tolerance_string`) pairs since Decimal needs special handling
63    pub inferred_tolerance_default: Vec<(String, String)>,
64    pub inferred_tolerance_multiplier: String,
65    pub infer_tolerance_from_cost: bool,
66    pub use_legacy_fixed_tolerances: bool,
67    pub experiment_explicit_tolerances: bool,
68    pub booking_method: String,
69    pub render_commas: bool,
70    pub allow_pipe_separator: bool,
71    pub long_string_maxlines: u32,
72    pub documents: Vec<String>,
73    pub custom: Vec<(String, String)>,
74}
75
76impl From<&Options> for CachedOptions {
77    fn from(opts: &Options) -> Self {
78        Self {
79            title: opts.title.clone(),
80            filename: opts.filename.clone(),
81            operating_currency: opts.operating_currency.clone(),
82            name_assets: opts.name_assets.clone(),
83            name_liabilities: opts.name_liabilities.clone(),
84            name_equity: opts.name_equity.clone(),
85            name_income: opts.name_income.clone(),
86            name_expenses: opts.name_expenses.clone(),
87            account_rounding: opts.account_rounding.clone(),
88            account_previous_balances: opts.account_previous_balances.clone(),
89            account_previous_earnings: opts.account_previous_earnings.clone(),
90            account_previous_conversions: opts.account_previous_conversions.clone(),
91            account_current_earnings: opts.account_current_earnings.clone(),
92            account_current_conversions: opts.account_current_conversions.clone(),
93            account_unrealized_gains: opts.account_unrealized_gains.clone(),
94            conversion_currency: opts.conversion_currency.clone(),
95            inferred_tolerance_default: opts
96                .inferred_tolerance_default
97                .iter()
98                .map(|(k, v)| (k.clone(), v.to_string()))
99                .collect(),
100            inferred_tolerance_multiplier: opts.inferred_tolerance_multiplier.to_string(),
101            infer_tolerance_from_cost: opts.infer_tolerance_from_cost,
102            use_legacy_fixed_tolerances: opts.use_legacy_fixed_tolerances,
103            experiment_explicit_tolerances: opts.experiment_explicit_tolerances,
104            booking_method: opts.booking_method.clone(),
105            render_commas: opts.render_commas,
106            allow_pipe_separator: opts.allow_pipe_separator,
107            long_string_maxlines: opts.long_string_maxlines,
108            documents: opts.documents.clone(),
109            custom: opts
110                .custom
111                .iter()
112                .map(|(k, v)| (k.clone(), v.clone()))
113                .collect(),
114        }
115    }
116}
117
118impl From<CachedOptions> for Options {
119    fn from(cached: CachedOptions) -> Self {
120        let mut opts = Self::new();
121        opts.title = cached.title;
122        opts.filename = cached.filename;
123        opts.operating_currency = cached.operating_currency;
124        opts.name_assets = cached.name_assets;
125        opts.name_liabilities = cached.name_liabilities;
126        opts.name_equity = cached.name_equity;
127        opts.name_income = cached.name_income;
128        opts.name_expenses = cached.name_expenses;
129        opts.account_rounding = cached.account_rounding;
130        opts.account_previous_balances = cached.account_previous_balances;
131        opts.account_previous_earnings = cached.account_previous_earnings;
132        opts.account_previous_conversions = cached.account_previous_conversions;
133        opts.account_current_earnings = cached.account_current_earnings;
134        opts.account_current_conversions = cached.account_current_conversions;
135        opts.account_unrealized_gains = cached.account_unrealized_gains;
136        opts.conversion_currency = cached.conversion_currency;
137        opts.inferred_tolerance_default = cached
138            .inferred_tolerance_default
139            .into_iter()
140            .filter_map(|(k, v)| Decimal::from_str(&v).ok().map(|d| (k, d)))
141            .collect();
142        opts.inferred_tolerance_multiplier =
143            Decimal::from_str(&cached.inferred_tolerance_multiplier)
144                .unwrap_or_else(|_| Decimal::new(5, 1));
145        opts.infer_tolerance_from_cost = cached.infer_tolerance_from_cost;
146        opts.use_legacy_fixed_tolerances = cached.use_legacy_fixed_tolerances;
147        opts.experiment_explicit_tolerances = cached.experiment_explicit_tolerances;
148        opts.booking_method = cached.booking_method;
149        opts.render_commas = cached.render_commas;
150        opts.allow_pipe_separator = cached.allow_pipe_separator;
151        opts.long_string_maxlines = cached.long_string_maxlines;
152        opts.documents = cached.documents;
153        opts.custom = cached.custom.into_iter().collect();
154        opts
155    }
156}
157
158/// Complete cache entry containing all data needed to restore a `LoadResult`.
159#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
160pub struct CacheEntry {
161    /// All parsed directives.
162    pub directives: Vec<Spanned<Directive>>,
163    /// Parsed options.
164    pub options: CachedOptions,
165    /// Plugin declarations.
166    pub plugins: Vec<CachedPlugin>,
167    /// All files that were loaded (as strings, for serialization).
168    pub files: Vec<String>,
169}
170
171impl CacheEntry {
172    /// Get files as `PathBuf` references.
173    pub fn file_paths(&self) -> Vec<PathBuf> {
174        self.files.iter().map(PathBuf::from).collect()
175    }
176}
177
178/// Magic bytes to identify cache files.
179const CACHE_MAGIC: &[u8; 8] = b"RLEDGER\0";
180
181/// Cache version - increment when format changes.
182/// v1: Initial release with string-based Decimal/NaiveDate
183/// v2: Binary Decimal (16 bytes) and `NaiveDate` (i32 days)
184/// v3: Fixed account type defaults in `CachedOptions`
185const CACHE_VERSION: u32 = 3;
186
187/// Cache header stored at the start of cache files.
188#[derive(Debug, Clone)]
189struct CacheHeader {
190    /// Magic bytes for identification.
191    magic: [u8; 8],
192    /// Cache format version.
193    version: u32,
194    /// SHA-256 hash of source files.
195    hash: [u8; 32],
196    /// Length of the serialized data.
197    data_len: u64,
198}
199
200impl CacheHeader {
201    const SIZE: usize = 8 + 4 + 32 + 8;
202
203    fn to_bytes(&self) -> [u8; Self::SIZE] {
204        let mut buf = [0u8; Self::SIZE];
205        buf[0..8].copy_from_slice(&self.magic);
206        buf[8..12].copy_from_slice(&self.version.to_le_bytes());
207        buf[12..44].copy_from_slice(&self.hash);
208        buf[44..52].copy_from_slice(&self.data_len.to_le_bytes());
209        buf
210    }
211
212    fn from_bytes(bytes: &[u8]) -> Option<Self> {
213        if bytes.len() < Self::SIZE {
214            return None;
215        }
216
217        let mut magic = [0u8; 8];
218        magic.copy_from_slice(&bytes[0..8]);
219
220        let version = u32::from_le_bytes(bytes[8..12].try_into().ok()?);
221
222        let mut hash = [0u8; 32];
223        hash.copy_from_slice(&bytes[12..44]);
224
225        let data_len = u64::from_le_bytes(bytes[44..52].try_into().ok()?);
226
227        Some(Self {
228            magic,
229            version,
230            hash,
231            data_len,
232        })
233    }
234}
235
236/// Compute a hash of the given files and their modification times.
237fn compute_hash(files: &[&Path]) -> [u8; 32] {
238    let mut hasher = Sha256::new();
239
240    for file in files {
241        // Hash the file path
242        hasher.update(file.to_string_lossy().as_bytes());
243
244        // Hash the modification time
245        if let Ok(metadata) = fs::metadata(file) {
246            if let Ok(mtime) = metadata.modified() {
247                if let Ok(duration) = mtime.duration_since(std::time::UNIX_EPOCH) {
248                    hasher.update(duration.as_secs().to_le_bytes());
249                    hasher.update(duration.subsec_nanos().to_le_bytes());
250                }
251            }
252            // Hash the file size
253            hasher.update(metadata.len().to_le_bytes());
254        }
255    }
256
257    hasher.finalize().into()
258}
259
260/// Get the cache file path for a given source file.
261fn cache_path(source: &Path) -> std::path::PathBuf {
262    let mut path = source.to_path_buf();
263    let name = path.file_name().map_or_else(
264        || "ledger.cache".to_string(),
265        |n| format!("{}.cache", n.to_string_lossy()),
266    );
267    path.set_file_name(name);
268    path
269}
270
271/// Try to load a cache entry from disk.
272///
273/// Returns `Some(CacheEntry)` if cache is valid and file hashes match,
274/// `None` if cache is missing, invalid, or outdated.
275pub fn load_cache_entry(main_file: &Path) -> Option<CacheEntry> {
276    let cache_file = cache_path(main_file);
277    let mut file = fs::File::open(&cache_file).ok()?;
278
279    // Read header
280    let mut header_bytes = [0u8; CacheHeader::SIZE];
281    file.read_exact(&mut header_bytes).ok()?;
282    let header = CacheHeader::from_bytes(&header_bytes)?;
283
284    // Validate magic and version
285    if header.magic != *CACHE_MAGIC {
286        return None;
287    }
288    if header.version != CACHE_VERSION {
289        return None;
290    }
291
292    // Read data
293    let mut data = vec![0u8; header.data_len as usize];
294    file.read_exact(&mut data).ok()?;
295
296    // Deserialize
297    let entry: CacheEntry = rkyv::from_bytes::<CacheEntry, rkyv::rancor::Error>(&data).ok()?;
298
299    // Validate hash against the files stored in the cache
300    let file_paths = entry.file_paths();
301    let file_refs: Vec<&Path> = file_paths.iter().map(PathBuf::as_path).collect();
302    let expected_hash = compute_hash(&file_refs);
303    if header.hash != expected_hash {
304        return None;
305    }
306
307    Some(entry)
308}
309
310/// Save a cache entry to disk.
311pub fn save_cache_entry(main_file: &Path, entry: &CacheEntry) -> Result<(), std::io::Error> {
312    let cache_file = cache_path(main_file);
313
314    // Compute hash from the files in the entry
315    let file_paths = entry.file_paths();
316    let file_refs: Vec<&Path> = file_paths.iter().map(PathBuf::as_path).collect();
317    let hash = compute_hash(&file_refs);
318
319    // Serialize
320    let data = rkyv::to_bytes::<rkyv::rancor::Error>(entry)
321        .map(|v| v.to_vec())
322        .map_err(|e| std::io::Error::other(e.to_string()))?;
323
324    // Write header + data
325    let header = CacheHeader {
326        magic: *CACHE_MAGIC,
327        version: CACHE_VERSION,
328        hash,
329        data_len: data.len() as u64,
330    };
331
332    let mut file = fs::File::create(&cache_file)?;
333    file.write_all(&header.to_bytes())?;
334    file.write_all(&data)?;
335
336    Ok(())
337}
338
339/// Serialize directives to bytes using rkyv (for benchmarking).
340#[cfg(test)]
341fn serialize_directives(directives: &Vec<Spanned<Directive>>) -> Result<Vec<u8>, std::io::Error> {
342    rkyv::to_bytes::<rkyv::rancor::Error>(directives)
343        .map(|v| v.to_vec())
344        .map_err(|e| std::io::Error::other(e.to_string()))
345}
346
347/// Deserialize directives from bytes using rkyv (for benchmarking).
348#[cfg(test)]
349fn deserialize_directives(data: &[u8]) -> Option<Vec<Spanned<Directive>>> {
350    rkyv::from_bytes::<Vec<Spanned<Directive>>, rkyv::rancor::Error>(data).ok()
351}
352
353/// Invalidate the cache for a file.
354pub fn invalidate_cache(main_file: &Path) {
355    let cache_file = cache_path(main_file);
356    let _ = fs::remove_file(cache_file);
357}
358
359/// Re-intern all strings in directives to deduplicate memory.
360///
361/// After deserializing from cache, strings are not interned (each is a separate
362/// allocation). This function walks through all directives and re-interns account
363/// names and currencies using a shared `StringInterner`, deduplicating identical
364/// strings to save memory.
365///
366/// Returns the number of strings that were deduplicated (i.e., strings that
367/// were found to already exist in the interner).
368pub fn reintern_directives(directives: &mut [Spanned<Directive>]) -> usize {
369    use rustledger_core::intern::InternedStr;
370    use rustledger_core::{IncompleteAmount, PriceAnnotation};
371
372    // Intern a single string (defined before use to satisfy clippy)
373    fn do_intern(s: &mut InternedStr, interner: &mut StringInterner) -> bool {
374        let already_exists = interner.contains(s.as_str());
375        *s = interner.intern(s.as_str());
376        already_exists
377    }
378
379    let mut interner = StringInterner::with_capacity(1024);
380    let mut dedup_count = 0;
381
382    for spanned in directives.iter_mut() {
383        match &mut spanned.value {
384            Directive::Transaction(txn) => {
385                for posting in &mut txn.postings {
386                    if do_intern(&mut posting.account, &mut interner) {
387                        dedup_count += 1;
388                    }
389                    // Units
390                    if let Some(ref mut units) = posting.units {
391                        match units {
392                            IncompleteAmount::Complete(amt) => {
393                                if do_intern(&mut amt.currency, &mut interner) {
394                                    dedup_count += 1;
395                                }
396                            }
397                            IncompleteAmount::CurrencyOnly(cur) => {
398                                if do_intern(cur, &mut interner) {
399                                    dedup_count += 1;
400                                }
401                            }
402                            IncompleteAmount::NumberOnly(_) => {}
403                        }
404                    }
405                    // Cost spec
406                    if let Some(ref mut cost) = posting.cost {
407                        if let Some(ref mut cur) = cost.currency {
408                            if do_intern(cur, &mut interner) {
409                                dedup_count += 1;
410                            }
411                        }
412                    }
413                    // Price annotation
414                    if let Some(ref mut price) = posting.price {
415                        match price {
416                            PriceAnnotation::Unit(amt) | PriceAnnotation::Total(amt) => {
417                                if do_intern(&mut amt.currency, &mut interner) {
418                                    dedup_count += 1;
419                                }
420                            }
421                            PriceAnnotation::UnitIncomplete(inc)
422                            | PriceAnnotation::TotalIncomplete(inc) => match inc {
423                                IncompleteAmount::Complete(amt) => {
424                                    if do_intern(&mut amt.currency, &mut interner) {
425                                        dedup_count += 1;
426                                    }
427                                }
428                                IncompleteAmount::CurrencyOnly(cur) => {
429                                    if do_intern(cur, &mut interner) {
430                                        dedup_count += 1;
431                                    }
432                                }
433                                IncompleteAmount::NumberOnly(_) => {}
434                            },
435                            PriceAnnotation::UnitEmpty | PriceAnnotation::TotalEmpty => {}
436                        }
437                    }
438                }
439            }
440            Directive::Balance(bal) => {
441                if do_intern(&mut bal.account, &mut interner) {
442                    dedup_count += 1;
443                }
444                if do_intern(&mut bal.amount.currency, &mut interner) {
445                    dedup_count += 1;
446                }
447            }
448            Directive::Open(open) => {
449                if do_intern(&mut open.account, &mut interner) {
450                    dedup_count += 1;
451                }
452                for cur in &mut open.currencies {
453                    if do_intern(cur, &mut interner) {
454                        dedup_count += 1;
455                    }
456                }
457            }
458            Directive::Close(close) => {
459                if do_intern(&mut close.account, &mut interner) {
460                    dedup_count += 1;
461                }
462            }
463            Directive::Commodity(comm) => {
464                if do_intern(&mut comm.currency, &mut interner) {
465                    dedup_count += 1;
466                }
467            }
468            Directive::Pad(pad) => {
469                if do_intern(&mut pad.account, &mut interner) {
470                    dedup_count += 1;
471                }
472                if do_intern(&mut pad.source_account, &mut interner) {
473                    dedup_count += 1;
474                }
475            }
476            Directive::Note(note) => {
477                if do_intern(&mut note.account, &mut interner) {
478                    dedup_count += 1;
479                }
480            }
481            Directive::Document(doc) => {
482                if do_intern(&mut doc.account, &mut interner) {
483                    dedup_count += 1;
484                }
485            }
486            Directive::Price(price) => {
487                if do_intern(&mut price.currency, &mut interner) {
488                    dedup_count += 1;
489                }
490                if do_intern(&mut price.amount.currency, &mut interner) {
491                    dedup_count += 1;
492                }
493            }
494            Directive::Event(_) | Directive::Query(_) | Directive::Custom(_) => {
495                // These don't contain InternedStr fields
496            }
497        }
498    }
499
500    dedup_count
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506    use chrono::NaiveDate;
507    use rust_decimal_macros::dec;
508    use rustledger_core::{Amount, Posting, Transaction};
509    use rustledger_parser::Span;
510
511    #[test]
512    fn test_cache_header_roundtrip() {
513        let header = CacheHeader {
514            magic: *CACHE_MAGIC,
515            version: CACHE_VERSION,
516            hash: [42u8; 32],
517            data_len: 12345,
518        };
519
520        let bytes = header.to_bytes();
521        let parsed = CacheHeader::from_bytes(&bytes).unwrap();
522
523        assert_eq!(parsed.magic, header.magic);
524        assert_eq!(parsed.version, header.version);
525        assert_eq!(parsed.hash, header.hash);
526        assert_eq!(parsed.data_len, header.data_len);
527    }
528
529    #[test]
530    fn test_compute_hash_deterministic() {
531        let files: Vec<&Path> = vec![];
532        let hash1 = compute_hash(&files);
533        let hash2 = compute_hash(&files);
534        assert_eq!(hash1, hash2);
535    }
536
537    #[test]
538    fn test_serialize_deserialize_roundtrip() {
539        let date = NaiveDate::from_ymd_opt(2024, 1, 15).unwrap();
540
541        let txn = Transaction::new(date, "Test transaction")
542            .with_payee("Test Payee")
543            .with_posting(Posting::new(
544                "Expenses:Test",
545                Amount::new(dec!(100.00), "USD"),
546            ))
547            .with_posting(Posting::auto("Assets:Checking"));
548
549        let directives = vec![Spanned::new(Directive::Transaction(txn), Span::new(0, 100))];
550
551        // Serialize
552        let serialized = serialize_directives(&directives).expect("serialization failed");
553
554        // Deserialize
555        let deserialized = deserialize_directives(&serialized).expect("deserialization failed");
556
557        // Verify roundtrip
558        assert_eq!(directives.len(), deserialized.len());
559        let orig_txn = directives[0].value.as_transaction().unwrap();
560        let deser_txn = deserialized[0].value.as_transaction().unwrap();
561
562        assert_eq!(orig_txn.date, deser_txn.date);
563        assert_eq!(orig_txn.payee, deser_txn.payee);
564        assert_eq!(orig_txn.narration, deser_txn.narration);
565        assert_eq!(orig_txn.postings.len(), deser_txn.postings.len());
566
567        // Check first posting
568        assert_eq!(orig_txn.postings[0].account, deser_txn.postings[0].account);
569        assert_eq!(orig_txn.postings[0].units, deser_txn.postings[0].units);
570    }
571
572    #[test]
573    #[ignore = "manual benchmark - run with: cargo test -p rustledger-loader --release -- --ignored --nocapture"]
574    fn bench_cache_performance() {
575        // Generate test directives
576        let date = NaiveDate::from_ymd_opt(2024, 1, 15).unwrap();
577        let mut directives = Vec::with_capacity(10000);
578
579        for i in 0..10000 {
580            let txn = Transaction::new(date, format!("Transaction {i}"))
581                .with_payee("Store")
582                .with_posting(Posting::new(
583                    "Expenses:Food",
584                    Amount::new(dec!(25.00), "USD"),
585                ))
586                .with_posting(Posting::auto("Assets:Checking"));
587
588            directives.push(Spanned::new(Directive::Transaction(txn), Span::new(0, 100)));
589        }
590
591        println!("\n=== Cache Benchmark (10,000 directives) ===");
592
593        // Benchmark serialization
594        let start = std::time::Instant::now();
595        let serialized = serialize_directives(&directives).unwrap();
596        let serialize_time = start.elapsed();
597        println!(
598            "Serialize: {:?} ({:.2} MB)",
599            serialize_time,
600            serialized.len() as f64 / 1_000_000.0
601        );
602
603        // Benchmark deserialization
604        let start = std::time::Instant::now();
605        let deserialized = deserialize_directives(&serialized).unwrap();
606        let deserialize_time = start.elapsed();
607        println!("Deserialize: {deserialize_time:?}");
608
609        assert_eq!(directives.len(), deserialized.len());
610
611        println!(
612            "\nSpeedup potential: If parsing takes 100ms, cache load would be {:.1}x faster",
613            100.0 / deserialize_time.as_millis() as f64
614        );
615    }
616
617    #[test]
618    fn test_cache_path() {
619        let source = Path::new("/tmp/ledger.beancount");
620        let cache = cache_path(source);
621        assert_eq!(cache, Path::new("/tmp/ledger.beancount.cache"));
622
623        let source2 = Path::new("relative/path/my.beancount");
624        let cache2 = cache_path(source2);
625        assert_eq!(cache2, Path::new("relative/path/my.beancount.cache"));
626    }
627
628    #[test]
629    fn test_save_load_cache_entry_roundtrip() {
630        use std::io::Write;
631
632        // Create a temp directory
633        let temp_dir = std::env::temp_dir().join("rustledger_cache_test");
634        let _ = fs::create_dir_all(&temp_dir);
635
636        // Create a temp beancount file
637        let beancount_file = temp_dir.join("test.beancount");
638        let mut f = fs::File::create(&beancount_file).unwrap();
639        writeln!(f, "2024-01-01 open Assets:Test").unwrap();
640        drop(f);
641
642        // Create a cache entry
643        let date = NaiveDate::from_ymd_opt(2024, 1, 15).unwrap();
644        let txn = Transaction::new(date, "Test").with_posting(Posting::auto("Assets:Test"));
645        let directives = vec![Spanned::new(Directive::Transaction(txn), Span::new(0, 50))];
646
647        let entry = CacheEntry {
648            directives,
649            options: CachedOptions::from(&Options::new()),
650            plugins: vec![CachedPlugin {
651                name: "test_plugin".to_string(),
652                config: Some("config".to_string()),
653            }],
654            files: vec![beancount_file.to_string_lossy().to_string()],
655        };
656
657        // Save cache
658        save_cache_entry(&beancount_file, &entry).expect("save failed");
659
660        // Load cache
661        let loaded = load_cache_entry(&beancount_file).expect("load failed");
662
663        // Verify
664        assert_eq!(loaded.directives.len(), entry.directives.len());
665        assert_eq!(loaded.plugins.len(), 1);
666        assert_eq!(loaded.plugins[0].name, "test_plugin");
667        assert_eq!(loaded.plugins[0].config, Some("config".to_string()));
668        assert_eq!(loaded.files.len(), 1);
669
670        // Cleanup
671        let _ = fs::remove_file(&beancount_file);
672        let _ = fs::remove_file(cache_path(&beancount_file));
673        let _ = fs::remove_dir(&temp_dir);
674    }
675
676    #[test]
677    fn test_invalidate_cache() {
678        use std::io::Write;
679
680        let temp_dir = std::env::temp_dir().join("rustledger_invalidate_test");
681        let _ = fs::create_dir_all(&temp_dir);
682
683        let beancount_file = temp_dir.join("test.beancount");
684        let mut f = fs::File::create(&beancount_file).unwrap();
685        writeln!(f, "2024-01-01 open Assets:Test").unwrap();
686        drop(f);
687
688        // Create and save a cache
689        let entry = CacheEntry {
690            directives: vec![],
691            options: CachedOptions::from(&Options::new()),
692            plugins: vec![],
693            files: vec![beancount_file.to_string_lossy().to_string()],
694        };
695        save_cache_entry(&beancount_file, &entry).unwrap();
696
697        // Verify cache exists
698        assert!(cache_path(&beancount_file).exists());
699
700        // Invalidate
701        invalidate_cache(&beancount_file);
702
703        // Verify cache is gone
704        assert!(!cache_path(&beancount_file).exists());
705
706        // Cleanup
707        let _ = fs::remove_file(&beancount_file);
708        let _ = fs::remove_dir(&temp_dir);
709    }
710
711    #[test]
712    fn test_load_cache_missing_file() {
713        let missing = Path::new("/nonexistent/path/to/file.beancount");
714        assert!(load_cache_entry(missing).is_none());
715    }
716
717    #[test]
718    fn test_load_cache_invalid_magic() {
719        use std::io::Write;
720
721        let temp_dir = std::env::temp_dir().join("rustledger_magic_test");
722        let _ = fs::create_dir_all(&temp_dir);
723
724        let cache_file = temp_dir.join("test.beancount.cache");
725        let mut f = fs::File::create(&cache_file).unwrap();
726        // Write invalid magic
727        f.write_all(b"INVALID\0").unwrap();
728        f.write_all(&[0u8; CacheHeader::SIZE - 8]).unwrap();
729        drop(f);
730
731        let beancount_file = temp_dir.join("test.beancount");
732        assert!(load_cache_entry(&beancount_file).is_none());
733
734        // Cleanup
735        let _ = fs::remove_file(&cache_file);
736        let _ = fs::remove_dir(&temp_dir);
737    }
738
739    #[test]
740    fn test_reintern_directives_deduplication() {
741        let date = NaiveDate::from_ymd_opt(2024, 1, 15).unwrap();
742
743        // Create multiple transactions with the same account
744        let mut directives = vec![];
745        for i in 0..5 {
746            let txn = Transaction::new(date, format!("Txn {i}"))
747                .with_posting(Posting::new(
748                    "Expenses:Food",
749                    Amount::new(dec!(10.00), "USD"),
750                ))
751                .with_posting(Posting::auto("Assets:Checking"));
752            directives.push(Spanned::new(Directive::Transaction(txn), Span::new(0, 50)));
753        }
754
755        // Re-intern should deduplicate the repeated account names and currencies
756        let dedup_count = reintern_directives(&mut directives);
757
758        // We should have deduplicated:
759        // - "Expenses:Food" appears 5 times but only first is new (4 dedup)
760        // - "USD" appears 5 times but only first is new (4 dedup)
761        // - "Assets:Checking" appears 5 times but only first is new (4 dedup)
762        // Total: 12 deduplications
763        assert_eq!(dedup_count, 12);
764    }
765
766    #[test]
767    fn test_cached_options_roundtrip() {
768        let mut opts = Options::new();
769        opts.title = Some("Test Ledger".to_string());
770        opts.operating_currency = vec!["USD".to_string(), "EUR".to_string()];
771        opts.render_commas = true;
772
773        let cached = CachedOptions::from(&opts);
774        let restored: Options = cached.into();
775
776        assert_eq!(restored.title, Some("Test Ledger".to_string()));
777        assert_eq!(restored.operating_currency, vec!["USD", "EUR"]);
778        assert!(restored.render_commas);
779    }
780
781    #[test]
782    fn test_cache_entry_file_paths() {
783        let entry = CacheEntry {
784            directives: vec![],
785            options: CachedOptions::from(&Options::new()),
786            plugins: vec![],
787            files: vec![
788                "/path/to/ledger.beancount".to_string(),
789                "/path/to/include.beancount".to_string(),
790            ],
791        };
792
793        let paths = entry.file_paths();
794        assert_eq!(paths.len(), 2);
795        assert_eq!(paths[0], PathBuf::from("/path/to/ledger.beancount"));
796        assert_eq!(paths[1], PathBuf::from("/path/to/include.beancount"));
797    }
798
799    #[test]
800    fn test_reintern_balance_directive() {
801        use rustledger_core::Balance;
802
803        let date = NaiveDate::from_ymd_opt(2024, 1, 15).unwrap();
804        let balance = Balance::new(date, "Assets:Checking", Amount::new(dec!(1000.00), "USD"));
805
806        let mut directives = vec![
807            Spanned::new(Directive::Balance(balance.clone()), Span::new(0, 50)),
808            Spanned::new(Directive::Balance(balance), Span::new(51, 100)),
809        ];
810
811        let dedup_count = reintern_directives(&mut directives);
812        // Second occurrence of "Assets:Checking" and "USD" should be deduplicated
813        assert_eq!(dedup_count, 2);
814    }
815
816    #[test]
817    fn test_reintern_open_close_directives() {
818        use rustledger_core::{Close, Open};
819
820        let date = NaiveDate::from_ymd_opt(2024, 1, 15).unwrap();
821        let open = Open::new(date, "Assets:Checking");
822        let close = Close::new(date, "Assets:Checking");
823
824        let mut directives = vec![
825            Spanned::new(Directive::Open(open), Span::new(0, 50)),
826            Spanned::new(Directive::Close(close), Span::new(51, 100)),
827        ];
828
829        let dedup_count = reintern_directives(&mut directives);
830        // Second "Assets:Checking" should be deduplicated
831        assert_eq!(dedup_count, 1);
832    }
833}