Skip to main content

rustledger_loader/
lib.rs

1//! Beancount file loader with include resolution.
2//!
3//! This crate handles loading beancount files, resolving includes,
4//! and collecting options. It builds on the parser to provide a
5//! complete loading pipeline.
6//!
7//! # Features
8//!
9//! - Recursive include resolution with cycle detection
10//! - Options collection and parsing
11//! - Plugin directive collection
12//! - Source map for error reporting
13//! - Push/pop tag and metadata handling
14//! - Automatic GPG decryption for encrypted files (`.gpg`, `.asc`)
15//!
16//! # Example
17//!
18//! ```ignore
19//! use rustledger_loader::Loader;
20//! use std::path::Path;
21//!
22//! let result = Loader::new().load(Path::new("ledger.beancount"))?;
23//! for directive in result.directives {
24//!     println!("{:?}", directive);
25//! }
26//! ```
27
28#![forbid(unsafe_code)]
29#![warn(missing_docs)]
30
31#[cfg(feature = "cache")]
32pub mod cache;
33mod options;
34mod source_map;
35
36#[cfg(feature = "cache")]
37pub use cache::{
38    CacheEntry, CachedOptions, CachedPlugin, invalidate_cache, load_cache_entry,
39    reintern_directives, save_cache_entry,
40};
41pub use options::Options;
42pub use source_map::{SourceFile, SourceMap};
43
44use rustledger_core::{Directive, DisplayContext};
45use rustledger_parser::{ParseError, Span, Spanned};
46use std::collections::HashSet;
47use std::fs;
48use std::path::{Path, PathBuf};
49use std::process::Command;
50use thiserror::Error;
51
52/// Try to canonicalize a path, falling back to making it absolute if canonicalize
53/// is not supported (e.g., on WASI).
54///
55/// This function:
56/// 1. First tries `fs::canonicalize()` which resolves symlinks and returns absolute path
57/// 2. If that fails (e.g., WASI doesn't support it), tries to make an absolute path manually
58/// 3. As a last resort, returns the original path
59fn normalize_path(path: &Path) -> PathBuf {
60    // Try canonicalize first (works on most platforms, resolves symlinks)
61    if let Ok(canonical) = path.canonicalize() {
62        return canonical;
63    }
64
65    // Fallback: make absolute without resolving symlinks (WASI-compatible)
66    if path.is_absolute() {
67        path.to_path_buf()
68    } else if let Ok(cwd) = std::env::current_dir() {
69        // Join with current directory and clean up the path
70        let mut result = cwd;
71        for component in path.components() {
72            match component {
73                std::path::Component::ParentDir => {
74                    result.pop();
75                }
76                std::path::Component::Normal(s) => {
77                    result.push(s);
78                }
79                std::path::Component::CurDir => {}
80                std::path::Component::RootDir => {
81                    result = PathBuf::from("/");
82                }
83                std::path::Component::Prefix(p) => {
84                    result = PathBuf::from(p.as_os_str());
85                }
86            }
87        }
88        result
89    } else {
90        // Last resort: just return the path as-is
91        path.to_path_buf()
92    }
93}
94
95/// Errors that can occur during loading.
96#[derive(Debug, Error)]
97pub enum LoadError {
98    /// IO error reading a file.
99    #[error("failed to read file {path}: {source}")]
100    Io {
101        /// The path that failed to read.
102        path: PathBuf,
103        /// The underlying IO error.
104        #[source]
105        source: std::io::Error,
106    },
107
108    /// Include cycle detected.
109    #[error("include cycle detected: {}", .cycle.join(" -> "))]
110    IncludeCycle {
111        /// The cycle of file paths.
112        cycle: Vec<String>,
113    },
114
115    /// Parse errors occurred.
116    #[error("parse errors in {path}")]
117    ParseErrors {
118        /// The file with parse errors.
119        path: PathBuf,
120        /// The parse errors.
121        errors: Vec<ParseError>,
122    },
123
124    /// Path traversal attempt detected.
125    #[error("path traversal not allowed: {include_path} escapes base directory {base_dir}")]
126    PathTraversal {
127        /// The include path that attempted traversal.
128        include_path: String,
129        /// The base directory.
130        base_dir: PathBuf,
131    },
132
133    /// GPG decryption failed.
134    #[error("failed to decrypt {path}: {message}")]
135    Decryption {
136        /// The encrypted file path.
137        path: PathBuf,
138        /// Error message from GPG.
139        message: String,
140    },
141}
142
143/// Result of loading a beancount file.
144#[derive(Debug)]
145pub struct LoadResult {
146    /// All directives from all files, in order.
147    pub directives: Vec<Spanned<Directive>>,
148    /// Parsed options.
149    pub options: Options,
150    /// Plugins to load.
151    pub plugins: Vec<Plugin>,
152    /// Source map for error reporting.
153    pub source_map: SourceMap,
154    /// All errors encountered during loading.
155    pub errors: Vec<LoadError>,
156    /// Display context for formatting numbers (tracks precision per currency).
157    pub display_context: DisplayContext,
158}
159
160/// A plugin directive.
161#[derive(Debug, Clone)]
162pub struct Plugin {
163    /// Plugin module name.
164    pub name: String,
165    /// Optional configuration string.
166    pub config: Option<String>,
167    /// Source location.
168    pub span: Span,
169    /// File this plugin was declared in.
170    pub file_id: usize,
171}
172
173/// Check if a file is GPG-encrypted based on extension or content.
174///
175/// Returns `true` for:
176/// - Files with `.gpg` extension
177/// - Files with `.asc` extension containing a PGP message header
178fn is_encrypted_file(path: &Path) -> bool {
179    match path.extension().and_then(|e| e.to_str()) {
180        Some("gpg") => true,
181        Some("asc") => {
182            // Check for PGP header in first 1024 bytes
183            if let Ok(content) = fs::read_to_string(path) {
184                let check_len = 1024.min(content.len());
185                content[..check_len].contains("-----BEGIN PGP MESSAGE-----")
186            } else {
187                false
188            }
189        }
190        _ => false,
191    }
192}
193
194/// Decrypt a GPG-encrypted file using the system `gpg` command.
195///
196/// This uses `gpg --batch --decrypt` which will use the user's
197/// GPG keyring and gpg-agent for passphrase handling.
198fn decrypt_gpg_file(path: &Path) -> Result<String, LoadError> {
199    let output = Command::new("gpg")
200        .args(["--batch", "--decrypt"])
201        .arg(path)
202        .output()
203        .map_err(|e| LoadError::Decryption {
204            path: path.to_path_buf(),
205            message: format!("failed to run gpg: {e}"),
206        })?;
207
208    if !output.status.success() {
209        return Err(LoadError::Decryption {
210            path: path.to_path_buf(),
211            message: String::from_utf8_lossy(&output.stderr).trim().to_string(),
212        });
213    }
214
215    String::from_utf8(output.stdout).map_err(|e| LoadError::Decryption {
216        path: path.to_path_buf(),
217        message: format!("decrypted content is not valid UTF-8: {e}"),
218    })
219}
220
221/// Beancount file loader.
222#[derive(Debug, Default)]
223pub struct Loader {
224    /// Files that have been loaded (for cycle detection).
225    loaded_files: HashSet<PathBuf>,
226    /// Stack for cycle detection during loading.
227    include_stack: Vec<PathBuf>,
228    /// Root directory for path traversal protection.
229    /// If set, includes must resolve to paths within this directory.
230    root_dir: Option<PathBuf>,
231    /// Whether to enforce path traversal protection.
232    enforce_path_security: bool,
233}
234
235impl Loader {
236    /// Create a new loader.
237    #[must_use]
238    pub fn new() -> Self {
239        Self::default()
240    }
241
242    /// Enable path traversal protection.
243    ///
244    /// When enabled, include directives cannot escape the root directory
245    /// of the main beancount file. This prevents malicious ledger files
246    /// from accessing sensitive files outside the ledger directory.
247    ///
248    /// # Example
249    ///
250    /// ```ignore
251    /// let result = Loader::new()
252    ///     .with_path_security(true)
253    ///     .load(Path::new("ledger.beancount"))?;
254    /// ```
255    #[must_use]
256    pub const fn with_path_security(mut self, enabled: bool) -> Self {
257        self.enforce_path_security = enabled;
258        self
259    }
260
261    /// Set a custom root directory for path security.
262    ///
263    /// By default, the root directory is the parent directory of the main file.
264    /// This method allows overriding that to a custom directory.
265    #[must_use]
266    pub fn with_root_dir(mut self, root: PathBuf) -> Self {
267        self.root_dir = Some(root);
268        self.enforce_path_security = true;
269        self
270    }
271
272    /// Load a beancount file and all its includes.
273    ///
274    /// Parses the file, processes options and plugin directives, and recursively
275    /// loads any included files.
276    ///
277    /// # Errors
278    ///
279    /// Returns [`LoadError`] in the following cases:
280    ///
281    /// - [`LoadError::Io`] - Failed to read the file or an included file
282    /// - [`LoadError::IncludeCycle`] - Circular include detected
283    ///
284    /// Note: Parse errors and path traversal errors are collected in
285    /// [`LoadResult::errors`] rather than returned directly, allowing
286    /// partial results to be returned.
287    pub fn load(&mut self, path: &Path) -> Result<LoadResult, LoadError> {
288        let mut directives = Vec::new();
289        let mut options = Options::default();
290        let mut plugins = Vec::new();
291        let mut source_map = SourceMap::new();
292        let mut errors = Vec::new();
293
294        // Get normalized absolute path (WASI-compatible, doesn't require canonicalize)
295        let canonical = normalize_path(path);
296
297        // Set root directory for path security if enabled but not explicitly set
298        if self.enforce_path_security && self.root_dir.is_none() {
299            self.root_dir = canonical.parent().map(Path::to_path_buf);
300        }
301
302        self.load_recursive(
303            &canonical,
304            &mut directives,
305            &mut options,
306            &mut plugins,
307            &mut source_map,
308            &mut errors,
309        )?;
310
311        // Build display context from directives and options
312        let display_context = build_display_context(&directives, &options);
313
314        Ok(LoadResult {
315            directives,
316            options,
317            plugins,
318            source_map,
319            errors,
320            display_context,
321        })
322    }
323
324    fn load_recursive(
325        &mut self,
326        path: &Path,
327        directives: &mut Vec<Spanned<Directive>>,
328        options: &mut Options,
329        plugins: &mut Vec<Plugin>,
330        source_map: &mut SourceMap,
331        errors: &mut Vec<LoadError>,
332    ) -> Result<(), LoadError> {
333        // Check for cycles
334        let path_buf = path.to_path_buf();
335        if self.include_stack.contains(&path_buf) {
336            let mut cycle: Vec<String> = self
337                .include_stack
338                .iter()
339                .map(|p| p.display().to_string())
340                .collect();
341            cycle.push(path.display().to_string());
342            return Err(LoadError::IncludeCycle { cycle });
343        }
344
345        // Check if already loaded
346        if self.loaded_files.contains(path) {
347            return Ok(());
348        }
349
350        // Read file (decrypting if necessary)
351        // Use lossy UTF-8 decoding to handle non-UTF-8 files gracefully (like Python beancount)
352        let source: std::sync::Arc<str> = if is_encrypted_file(path) {
353            decrypt_gpg_file(path)?.into()
354        } else {
355            let bytes = fs::read(path).map_err(|e| LoadError::Io {
356                path: path.to_path_buf(),
357                source: e,
358            })?;
359            String::from_utf8_lossy(&bytes).into_owned().into()
360        };
361
362        // Add to source map (Arc::clone is cheap - just increments refcount)
363        let file_id = source_map.add_file(path.to_path_buf(), std::sync::Arc::clone(&source));
364
365        // Mark as loading
366        self.include_stack.push(path.to_path_buf());
367        self.loaded_files.insert(path.to_path_buf());
368
369        // Parse (borrows from Arc, no allocation)
370        let result = rustledger_parser::parse(&source);
371
372        // Collect parse errors
373        if !result.errors.is_empty() {
374            errors.push(LoadError::ParseErrors {
375                path: path.to_path_buf(),
376                errors: result.errors,
377            });
378        }
379
380        // Process options
381        for (key, value, _span) in result.options {
382            options.set(&key, &value);
383        }
384
385        // Process plugins
386        for (name, config, span) in result.plugins {
387            plugins.push(Plugin {
388                name,
389                config,
390                span,
391                file_id,
392            });
393        }
394
395        // Process includes
396        let base_dir = path.parent().unwrap_or(Path::new("."));
397        for (include_path, _span) in &result.includes {
398            let full_path = base_dir.join(include_path);
399            // Use normalize_path for WASI compatibility (canonicalize not supported)
400            let canonical = normalize_path(&full_path);
401
402            // Path traversal protection: ensure include stays within root directory
403            if self.enforce_path_security {
404                if let Some(ref root) = self.root_dir {
405                    if !canonical.starts_with(root) {
406                        errors.push(LoadError::PathTraversal {
407                            include_path: include_path.clone(),
408                            base_dir: root.clone(),
409                        });
410                        continue;
411                    }
412                }
413            }
414
415            if let Err(e) =
416                self.load_recursive(&canonical, directives, options, plugins, source_map, errors)
417            {
418                errors.push(e);
419            }
420        }
421
422        // Add directives from this file, setting the file_id
423        directives.extend(
424            result
425                .directives
426                .into_iter()
427                .map(|d| d.with_file_id(file_id)),
428        );
429
430        // Pop from stack
431        self.include_stack.pop();
432
433        Ok(())
434    }
435}
436
437/// Build a display context from loaded directives and options.
438///
439/// This scans all directives for amounts and tracks the maximum precision seen
440/// for each currency. Fixed precisions from `option "display_precision"` override
441/// the inferred values.
442fn build_display_context(directives: &[Spanned<Directive>], options: &Options) -> DisplayContext {
443    let mut ctx = DisplayContext::new();
444
445    // Set render_commas from options
446    ctx.set_render_commas(options.render_commas);
447
448    // Scan directives for amounts to infer precision
449    for spanned in directives {
450        match &spanned.value {
451            Directive::Transaction(txn) => {
452                for posting in &txn.postings {
453                    // Units (IncompleteAmount)
454                    if let Some(ref units) = posting.units {
455                        if let (Some(number), Some(currency)) = (units.number(), units.currency()) {
456                            ctx.update(number, currency);
457                        }
458                    }
459                    // Cost (CostSpec)
460                    if let Some(ref cost) = posting.cost {
461                        if let (Some(number), Some(currency)) =
462                            (cost.number_per.or(cost.number_total), &cost.currency)
463                        {
464                            ctx.update(number, currency.as_str());
465                        }
466                    }
467                    // Price (PriceAnnotation)
468                    if let Some(ref price) = posting.price {
469                        if let Some(amount) = price.amount() {
470                            ctx.update(amount.number, amount.currency.as_str());
471                        }
472                    }
473                }
474            }
475            Directive::Balance(bal) => {
476                ctx.update(bal.amount.number, bal.amount.currency.as_str());
477                if let Some(tol) = bal.tolerance {
478                    ctx.update(tol, bal.amount.currency.as_str());
479                }
480            }
481            Directive::Price(price) => {
482                ctx.update(price.amount.number, price.amount.currency.as_str());
483            }
484            Directive::Pad(_)
485            | Directive::Open(_)
486            | Directive::Close(_)
487            | Directive::Commodity(_)
488            | Directive::Event(_)
489            | Directive::Query(_)
490            | Directive::Note(_)
491            | Directive::Document(_)
492            | Directive::Custom(_) => {}
493        }
494    }
495
496    // Apply fixed precisions from options (these override inferred values)
497    for (currency, precision) in &options.display_precision {
498        ctx.set_fixed_precision(currency, *precision);
499    }
500
501    ctx
502}
503
504/// Load a beancount file.
505///
506/// This is a convenience function that creates a loader and loads a single file.
507pub fn load(path: &Path) -> Result<LoadResult, LoadError> {
508    Loader::new().load(path)
509}
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514    use std::io::Write;
515    use tempfile::NamedTempFile;
516
517    #[test]
518    fn test_is_encrypted_file_gpg_extension() {
519        let path = Path::new("test.beancount.gpg");
520        assert!(is_encrypted_file(path));
521    }
522
523    #[test]
524    fn test_is_encrypted_file_plain_beancount() {
525        let path = Path::new("test.beancount");
526        assert!(!is_encrypted_file(path));
527    }
528
529    #[test]
530    fn test_is_encrypted_file_asc_with_pgp_header() {
531        let mut file = NamedTempFile::with_suffix(".asc").unwrap();
532        writeln!(file, "-----BEGIN PGP MESSAGE-----").unwrap();
533        writeln!(file, "some encrypted content").unwrap();
534        writeln!(file, "-----END PGP MESSAGE-----").unwrap();
535        file.flush().unwrap();
536
537        assert!(is_encrypted_file(file.path()));
538    }
539
540    #[test]
541    fn test_is_encrypted_file_asc_without_pgp_header() {
542        let mut file = NamedTempFile::with_suffix(".asc").unwrap();
543        writeln!(file, "This is just a plain text file").unwrap();
544        writeln!(file, "with .asc extension but no PGP content").unwrap();
545        file.flush().unwrap();
546
547        assert!(!is_encrypted_file(file.path()));
548    }
549
550    #[test]
551    fn test_decrypt_gpg_file_missing_gpg() {
552        // Create a fake .gpg file
553        let mut file = NamedTempFile::with_suffix(".gpg").unwrap();
554        writeln!(file, "fake encrypted content").unwrap();
555        file.flush().unwrap();
556
557        // This will fail because the content isn't actually GPG-encrypted
558        // (or gpg isn't installed, or there's no matching key)
559        let result = decrypt_gpg_file(file.path());
560        assert!(result.is_err());
561
562        if let Err(LoadError::Decryption { path, message }) = result {
563            assert_eq!(path, file.path().to_path_buf());
564            assert!(!message.is_empty());
565        } else {
566            panic!("Expected Decryption error");
567        }
568    }
569}