Skip to main content

rustledger_loader/
lib.rs

1//! Beancount file loader with include resolution.
2//!
3//! This crate handles loading beancount files, resolving includes,
4//! and collecting options. It builds on the parser to provide a
5//! complete loading pipeline.
6//!
7//! # Features
8//!
9//! - Recursive include resolution with cycle detection
10//! - Options collection and parsing
11//! - Plugin directive collection
12//! - Source map for error reporting
13//! - Push/pop tag and metadata handling
14//! - Automatic GPG decryption for encrypted files (`.gpg`, `.asc`)
15//!
16//! # Example
17//!
18//! ```ignore
19//! use rustledger_loader::Loader;
20//! use std::path::Path;
21//!
22//! let result = Loader::new().load(Path::new("ledger.beancount"))?;
23//! for directive in result.directives {
24//!     println!("{:?}", directive);
25//! }
26//! ```
27
28#![forbid(unsafe_code)]
29#![warn(missing_docs)]
30
31#[cfg(feature = "cache")]
32pub mod cache;
33mod options;
34mod source_map;
35
36#[cfg(feature = "cache")]
37pub use cache::{
38    CacheEntry, CachedOptions, CachedPlugin, invalidate_cache, load_cache_entry,
39    reintern_directives, save_cache_entry,
40};
41pub use options::Options;
42pub use source_map::{SourceFile, SourceMap};
43
44use rustledger_core::{Directive, DisplayContext};
45use rustledger_parser::{ParseError, Span, Spanned};
46use std::collections::HashSet;
47use std::fs;
48use std::path::{Path, PathBuf};
49use std::process::Command;
50use thiserror::Error;
51
52/// Try to canonicalize a path, falling back to making it absolute if canonicalize
53/// is not supported (e.g., on WASI).
54///
55/// This function:
56/// 1. First tries `fs::canonicalize()` which resolves symlinks and returns absolute path
57/// 2. If that fails (e.g., WASI doesn't support it), tries to make an absolute path manually
58/// 3. As a last resort, returns the original path
59fn normalize_path(path: &Path) -> PathBuf {
60    // Try canonicalize first (works on most platforms, resolves symlinks)
61    if let Ok(canonical) = path.canonicalize() {
62        return canonical;
63    }
64
65    // Fallback: make absolute without resolving symlinks (WASI-compatible)
66    if path.is_absolute() {
67        path.to_path_buf()
68    } else if let Ok(cwd) = std::env::current_dir() {
69        // Join with current directory and clean up the path
70        let mut result = cwd;
71        for component in path.components() {
72            match component {
73                std::path::Component::ParentDir => {
74                    result.pop();
75                }
76                std::path::Component::Normal(s) => {
77                    result.push(s);
78                }
79                std::path::Component::CurDir => {}
80                std::path::Component::RootDir => {
81                    result = PathBuf::from("/");
82                }
83                std::path::Component::Prefix(p) => {
84                    result = PathBuf::from(p.as_os_str());
85                }
86            }
87        }
88        result
89    } else {
90        // Last resort: just return the path as-is
91        path.to_path_buf()
92    }
93}
94
95/// Errors that can occur during loading.
96#[derive(Debug, Error)]
97pub enum LoadError {
98    /// IO error reading a file.
99    #[error("failed to read file {path}: {source}")]
100    Io {
101        /// The path that failed to read.
102        path: PathBuf,
103        /// The underlying IO error.
104        #[source]
105        source: std::io::Error,
106    },
107
108    /// Include cycle detected.
109    #[error("include cycle detected: {}", .cycle.join(" -> "))]
110    IncludeCycle {
111        /// The cycle of file paths.
112        cycle: Vec<String>,
113    },
114
115    /// Parse errors occurred.
116    #[error("parse errors in {path}")]
117    ParseErrors {
118        /// The file with parse errors.
119        path: PathBuf,
120        /// The parse errors.
121        errors: Vec<ParseError>,
122    },
123
124    /// Path traversal attempt detected.
125    #[error("path traversal not allowed: {include_path} escapes base directory {base_dir}")]
126    PathTraversal {
127        /// The include path that attempted traversal.
128        include_path: String,
129        /// The base directory.
130        base_dir: PathBuf,
131    },
132
133    /// GPG decryption failed.
134    #[error("failed to decrypt {path}: {message}")]
135    Decryption {
136        /// The encrypted file path.
137        path: PathBuf,
138        /// Error message from GPG.
139        message: String,
140    },
141}
142
143/// Result of loading a beancount file.
144#[derive(Debug)]
145pub struct LoadResult {
146    /// All directives from all files, in order.
147    pub directives: Vec<Spanned<Directive>>,
148    /// Parsed options.
149    pub options: Options,
150    /// Plugins to load.
151    pub plugins: Vec<Plugin>,
152    /// Source map for error reporting.
153    pub source_map: SourceMap,
154    /// All errors encountered during loading.
155    pub errors: Vec<LoadError>,
156    /// Display context for formatting numbers (tracks precision per currency).
157    pub display_context: DisplayContext,
158}
159
160/// A plugin directive.
161#[derive(Debug, Clone)]
162pub struct Plugin {
163    /// Plugin module name.
164    pub name: String,
165    /// Optional configuration string.
166    pub config: Option<String>,
167    /// Source location.
168    pub span: Span,
169    /// File this plugin was declared in.
170    pub file_id: usize,
171}
172
173/// Check if a file is GPG-encrypted based on extension or content.
174///
175/// Returns `true` for:
176/// - Files with `.gpg` extension
177/// - Files with `.asc` extension containing a PGP message header
178fn is_encrypted_file(path: &Path) -> bool {
179    match path.extension().and_then(|e| e.to_str()) {
180        Some("gpg") => true,
181        Some("asc") => {
182            // Check for PGP header in first 1024 bytes
183            if let Ok(content) = fs::read_to_string(path) {
184                let check_len = 1024.min(content.len());
185                content[..check_len].contains("-----BEGIN PGP MESSAGE-----")
186            } else {
187                false
188            }
189        }
190        _ => false,
191    }
192}
193
194/// Decrypt a GPG-encrypted file using the system `gpg` command.
195///
196/// This uses `gpg --batch --decrypt` which will use the user's
197/// GPG keyring and gpg-agent for passphrase handling.
198fn decrypt_gpg_file(path: &Path) -> Result<String, LoadError> {
199    let output = Command::new("gpg")
200        .args(["--batch", "--decrypt"])
201        .arg(path)
202        .output()
203        .map_err(|e| LoadError::Decryption {
204            path: path.to_path_buf(),
205            message: format!("failed to run gpg: {e}"),
206        })?;
207
208    if !output.status.success() {
209        return Err(LoadError::Decryption {
210            path: path.to_path_buf(),
211            message: String::from_utf8_lossy(&output.stderr).trim().to_string(),
212        });
213    }
214
215    String::from_utf8(output.stdout).map_err(|e| LoadError::Decryption {
216        path: path.to_path_buf(),
217        message: format!("decrypted content is not valid UTF-8: {e}"),
218    })
219}
220
221/// Beancount file loader.
222#[derive(Debug, Default)]
223pub struct Loader {
224    /// Files that have been loaded (for cycle detection).
225    loaded_files: HashSet<PathBuf>,
226    /// Stack for cycle detection during loading (maintains order for error messages).
227    include_stack: Vec<PathBuf>,
228    /// Set for O(1) cycle detection (mirrors `include_stack`).
229    include_stack_set: HashSet<PathBuf>,
230    /// Root directory for path traversal protection.
231    /// If set, includes must resolve to paths within this directory.
232    root_dir: Option<PathBuf>,
233    /// Whether to enforce path traversal protection.
234    enforce_path_security: bool,
235}
236
237impl Loader {
238    /// Create a new loader.
239    #[must_use]
240    pub fn new() -> Self {
241        Self::default()
242    }
243
244    /// Enable path traversal protection.
245    ///
246    /// When enabled, include directives cannot escape the root directory
247    /// of the main beancount file. This prevents malicious ledger files
248    /// from accessing sensitive files outside the ledger directory.
249    ///
250    /// # Example
251    ///
252    /// ```ignore
253    /// let result = Loader::new()
254    ///     .with_path_security(true)
255    ///     .load(Path::new("ledger.beancount"))?;
256    /// ```
257    #[must_use]
258    pub const fn with_path_security(mut self, enabled: bool) -> Self {
259        self.enforce_path_security = enabled;
260        self
261    }
262
263    /// Set a custom root directory for path security.
264    ///
265    /// By default, the root directory is the parent directory of the main file.
266    /// This method allows overriding that to a custom directory.
267    #[must_use]
268    pub fn with_root_dir(mut self, root: PathBuf) -> Self {
269        self.root_dir = Some(root);
270        self.enforce_path_security = true;
271        self
272    }
273
274    /// Load a beancount file and all its includes.
275    ///
276    /// Parses the file, processes options and plugin directives, and recursively
277    /// loads any included files.
278    ///
279    /// # Errors
280    ///
281    /// Returns [`LoadError`] in the following cases:
282    ///
283    /// - [`LoadError::Io`] - Failed to read the file or an included file
284    /// - [`LoadError::IncludeCycle`] - Circular include detected
285    ///
286    /// Note: Parse errors and path traversal errors are collected in
287    /// [`LoadResult::errors`] rather than returned directly, allowing
288    /// partial results to be returned.
289    pub fn load(&mut self, path: &Path) -> Result<LoadResult, LoadError> {
290        let mut directives = Vec::new();
291        let mut options = Options::default();
292        let mut plugins = Vec::new();
293        let mut source_map = SourceMap::new();
294        let mut errors = Vec::new();
295
296        // Get normalized absolute path (WASI-compatible, doesn't require canonicalize)
297        let canonical = normalize_path(path);
298
299        // Set root directory for path security if enabled but not explicitly set
300        if self.enforce_path_security && self.root_dir.is_none() {
301            self.root_dir = canonical.parent().map(Path::to_path_buf);
302        }
303
304        self.load_recursive(
305            &canonical,
306            &mut directives,
307            &mut options,
308            &mut plugins,
309            &mut source_map,
310            &mut errors,
311        )?;
312
313        // Build display context from directives and options
314        let display_context = build_display_context(&directives, &options);
315
316        Ok(LoadResult {
317            directives,
318            options,
319            plugins,
320            source_map,
321            errors,
322            display_context,
323        })
324    }
325
326    fn load_recursive(
327        &mut self,
328        path: &Path,
329        directives: &mut Vec<Spanned<Directive>>,
330        options: &mut Options,
331        plugins: &mut Vec<Plugin>,
332        source_map: &mut SourceMap,
333        errors: &mut Vec<LoadError>,
334    ) -> Result<(), LoadError> {
335        // Allocate path once for reuse
336        let path_buf = path.to_path_buf();
337
338        // Check for cycles using O(1) HashSet lookup
339        if self.include_stack_set.contains(&path_buf) {
340            let mut cycle: Vec<String> = self
341                .include_stack
342                .iter()
343                .map(|p| p.display().to_string())
344                .collect();
345            cycle.push(path.display().to_string());
346            return Err(LoadError::IncludeCycle { cycle });
347        }
348
349        // Check if already loaded
350        if self.loaded_files.contains(&path_buf) {
351            return Ok(());
352        }
353
354        // Read file (decrypting if necessary)
355        // Try fast UTF-8 conversion first, fall back to lossy for non-UTF-8 files
356        let source: std::sync::Arc<str> = if is_encrypted_file(path) {
357            decrypt_gpg_file(path)?.into()
358        } else {
359            let bytes = fs::read(path).map_err(|e| LoadError::Io {
360                path: path_buf.clone(),
361                source: e,
362            })?;
363            // Try zero-copy conversion first (common case), fall back to lossy
364            match String::from_utf8(bytes) {
365                Ok(s) => s.into(),
366                Err(e) => String::from_utf8_lossy(e.as_bytes()).into_owned().into(),
367            }
368        };
369
370        // Add to source map (Arc::clone is cheap - just increments refcount)
371        let file_id = source_map.add_file(path_buf.clone(), std::sync::Arc::clone(&source));
372
373        // Mark as loading (update both stack and set)
374        self.include_stack_set.insert(path_buf.clone());
375        self.include_stack.push(path_buf.clone());
376        self.loaded_files.insert(path_buf);
377
378        // Parse (borrows from Arc, no allocation)
379        let result = rustledger_parser::parse(&source);
380
381        // Collect parse errors
382        if !result.errors.is_empty() {
383            errors.push(LoadError::ParseErrors {
384                path: path.to_path_buf(),
385                errors: result.errors,
386            });
387        }
388
389        // Process options
390        for (key, value, _span) in result.options {
391            options.set(&key, &value);
392        }
393
394        // Process plugins
395        for (name, config, span) in result.plugins {
396            plugins.push(Plugin {
397                name,
398                config,
399                span,
400                file_id,
401            });
402        }
403
404        // Process includes
405        let base_dir = path.parent().unwrap_or(Path::new("."));
406        for (include_path, _span) in &result.includes {
407            let full_path = base_dir.join(include_path);
408            // Use normalize_path for WASI compatibility (canonicalize not supported)
409            let canonical = normalize_path(&full_path);
410
411            // Path traversal protection: ensure include stays within root directory
412            if self.enforce_path_security {
413                if let Some(ref root) = self.root_dir {
414                    if !canonical.starts_with(root) {
415                        errors.push(LoadError::PathTraversal {
416                            include_path: include_path.clone(),
417                            base_dir: root.clone(),
418                        });
419                        continue;
420                    }
421                }
422            }
423
424            if let Err(e) =
425                self.load_recursive(&canonical, directives, options, plugins, source_map, errors)
426            {
427                errors.push(e);
428            }
429        }
430
431        // Add directives from this file, setting the file_id
432        directives.extend(
433            result
434                .directives
435                .into_iter()
436                .map(|d| d.with_file_id(file_id)),
437        );
438
439        // Pop from stack and set
440        if let Some(popped) = self.include_stack.pop() {
441            self.include_stack_set.remove(&popped);
442        }
443
444        Ok(())
445    }
446}
447
448/// Build a display context from loaded directives and options.
449///
450/// This scans all directives for amounts and tracks the maximum precision seen
451/// for each currency. Fixed precisions from `option "display_precision"` override
452/// the inferred values.
453fn build_display_context(directives: &[Spanned<Directive>], options: &Options) -> DisplayContext {
454    let mut ctx = DisplayContext::new();
455
456    // Set render_commas from options
457    ctx.set_render_commas(options.render_commas);
458
459    // Scan directives for amounts to infer precision
460    for spanned in directives {
461        match &spanned.value {
462            Directive::Transaction(txn) => {
463                for posting in &txn.postings {
464                    // Units (IncompleteAmount)
465                    if let Some(ref units) = posting.units {
466                        if let (Some(number), Some(currency)) = (units.number(), units.currency()) {
467                            ctx.update(number, currency);
468                        }
469                    }
470                    // Cost (CostSpec)
471                    if let Some(ref cost) = posting.cost {
472                        if let (Some(number), Some(currency)) =
473                            (cost.number_per.or(cost.number_total), &cost.currency)
474                        {
475                            ctx.update(number, currency.as_str());
476                        }
477                    }
478                    // Price (PriceAnnotation)
479                    if let Some(ref price) = posting.price {
480                        if let Some(amount) = price.amount() {
481                            ctx.update(amount.number, amount.currency.as_str());
482                        }
483                    }
484                }
485            }
486            Directive::Balance(bal) => {
487                ctx.update(bal.amount.number, bal.amount.currency.as_str());
488                if let Some(tol) = bal.tolerance {
489                    ctx.update(tol, bal.amount.currency.as_str());
490                }
491            }
492            Directive::Price(price) => {
493                ctx.update(price.amount.number, price.amount.currency.as_str());
494            }
495            Directive::Pad(_)
496            | Directive::Open(_)
497            | Directive::Close(_)
498            | Directive::Commodity(_)
499            | Directive::Event(_)
500            | Directive::Query(_)
501            | Directive::Note(_)
502            | Directive::Document(_)
503            | Directive::Custom(_) => {}
504        }
505    }
506
507    // Apply fixed precisions from options (these override inferred values)
508    for (currency, precision) in &options.display_precision {
509        ctx.set_fixed_precision(currency, *precision);
510    }
511
512    ctx
513}
514
515/// Load a beancount file.
516///
517/// This is a convenience function that creates a loader and loads a single file.
518pub fn load(path: &Path) -> Result<LoadResult, LoadError> {
519    Loader::new().load(path)
520}
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525    use std::io::Write;
526    use tempfile::NamedTempFile;
527
528    #[test]
529    fn test_is_encrypted_file_gpg_extension() {
530        let path = Path::new("test.beancount.gpg");
531        assert!(is_encrypted_file(path));
532    }
533
534    #[test]
535    fn test_is_encrypted_file_plain_beancount() {
536        let path = Path::new("test.beancount");
537        assert!(!is_encrypted_file(path));
538    }
539
540    #[test]
541    fn test_is_encrypted_file_asc_with_pgp_header() {
542        let mut file = NamedTempFile::with_suffix(".asc").unwrap();
543        writeln!(file, "-----BEGIN PGP MESSAGE-----").unwrap();
544        writeln!(file, "some encrypted content").unwrap();
545        writeln!(file, "-----END PGP MESSAGE-----").unwrap();
546        file.flush().unwrap();
547
548        assert!(is_encrypted_file(file.path()));
549    }
550
551    #[test]
552    fn test_is_encrypted_file_asc_without_pgp_header() {
553        let mut file = NamedTempFile::with_suffix(".asc").unwrap();
554        writeln!(file, "This is just a plain text file").unwrap();
555        writeln!(file, "with .asc extension but no PGP content").unwrap();
556        file.flush().unwrap();
557
558        assert!(!is_encrypted_file(file.path()));
559    }
560
561    #[test]
562    fn test_decrypt_gpg_file_missing_gpg() {
563        // Create a fake .gpg file
564        let mut file = NamedTempFile::with_suffix(".gpg").unwrap();
565        writeln!(file, "fake encrypted content").unwrap();
566        file.flush().unwrap();
567
568        // This will fail because the content isn't actually GPG-encrypted
569        // (or gpg isn't installed, or there's no matching key)
570        let result = decrypt_gpg_file(file.path());
571        assert!(result.is_err());
572
573        if let Err(LoadError::Decryption { path, message }) = result {
574            assert_eq!(path, file.path().to_path_buf());
575            assert!(!message.is_empty());
576        } else {
577            panic!("Expected Decryption error");
578        }
579    }
580}