frontmatter_gen/
utils.rs

1// Copyright © 2024 Shokunin Static Site Generator. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! # Utility Module
5//!
6//! Provides common utilities for file system operations, logging, and other shared functionality.
7//!
8//! ## Features
9//!
10//! - Secure file system operations
11//! - Path validation and normalization
12//! - Temporary file management
13//! - Logging utilities
14//!
15//! ## Security
16//!
17//! All file system operations include checks for:
18//! - Path traversal attacks
19//! - Symlink attacks
20//! - Directory structure validation
21//! - Permission validation
22
23#[cfg(feature = "ssg")]
24use std::collections::HashSet;
25#[cfg(feature = "ssg")]
26use std::fs::File;
27
28use std::fs::create_dir_all;
29#[cfg(feature = "ssg")]
30use std::fs::remove_file;
31
32use std::io::{self};
33use std::path::Path;
34
35#[cfg(feature = "ssg")]
36use std::sync::Arc;
37
38use anyhow::{Context, Result};
39use thiserror::Error;
40
41#[cfg(feature = "ssg")]
42use tokio::sync::RwLock;
43
44#[cfg(feature = "ssg")]
45use uuid::Uuid;
46
47/// Errors that can occur during utility operations
48#[derive(Error, Debug)]
49pub enum UtilsError {
50    /// File system operation failed
51    #[error("File system error: {0}")]
52    FileSystem(#[from] io::Error),
53
54    /// Path validation failed
55    #[error("Invalid path '{path}': {details}")]
56    InvalidPath {
57        /// The path that was invalid
58        path: String,
59        /// Details about why the path was invalid
60        details: String,
61    },
62
63    /// Permission error
64    #[error("Permission denied: {0}")]
65    PermissionDenied(String),
66
67    /// Resource not found
68    #[error("Resource not found: {0}")]
69    NotFound(String),
70
71    /// Invalid operation
72    #[error("Invalid operation: {0}")]
73    InvalidOperation(String),
74}
75
76/// File system utilities module
77pub mod fs {
78    use super::*;
79    use std::path::PathBuf;
80
81    /// Tracks temporary files for cleanup
82    #[cfg(feature = "ssg")]
83    #[derive(Debug, Default)]
84    pub struct TempFileTracker {
85        files: Arc<RwLock<HashSet<PathBuf>>>,
86    }
87
88    #[cfg(feature = "ssg")]
89    impl TempFileTracker {
90        /// Creates a new temporary file tracker
91        pub fn new() -> Self {
92            Self {
93                files: Arc::new(RwLock::new(HashSet::new())),
94            }
95        }
96
97        /// Registers a temporary file for tracking
98        pub async fn register(&self, path: PathBuf) -> Result<()> {
99            let mut files = self.files.write().await;
100            let _ = files.insert(path);
101            Ok(())
102        }
103
104        /// Cleans up all tracked temporary files
105        pub async fn cleanup(&self) -> Result<()> {
106            let files = self.files.read().await;
107            for path in files.iter() {
108                if path.exists() {
109                    remove_file(path).with_context(|| {
110                        format!(
111                            "Failed to remove temporary file: {}",
112                            path.display()
113                        )
114                    })?;
115                }
116            }
117            Ok(())
118        }
119    }
120
121    /// Creates a new temporary file with the given prefix
122    #[cfg(feature = "ssg")]
123    pub async fn create_temp_file(
124        prefix: &str,
125    ) -> Result<(PathBuf, File), UtilsError> {
126        let temp_dir = std::env::temp_dir();
127        let file_name = format!("{}-{}", prefix, Uuid::new_v4());
128        let path = temp_dir.join(file_name);
129
130        let file =
131            File::create(&path).map_err(UtilsError::FileSystem)?;
132
133        Ok((path, file))
134    }
135
136    /// Validates that a path is safe to use
137    ///
138    /// # Arguments
139    ///
140    /// * `path` - Path to validate
141    ///
142    /// # Returns
143    ///
144    /// Returns Ok(()) if the path is safe, or an error if validation fails
145    ///
146    /// # Security
147    ///
148    /// Checks for:
149    /// - Path length limits
150    /// - Invalid characters
151    /// - Path traversal attempts
152    /// - Symlinks
153    /// - Reserved names
154    pub fn validate_path_safety(path: &Path) -> Result<()> {
155        let path_str = path.to_string_lossy();
156
157        // 1. Disallow backslashes for POSIX compatibility
158        if path_str.contains('\\') {
159            return Err(UtilsError::InvalidPath {
160                path: path_str.to_string(),
161                details: "Backslashes are not allowed in paths"
162                    .to_string(),
163            }
164            .into());
165        }
166
167        // 2. Check for null bytes and control characters
168        if path_str.contains('\0')
169            || path_str.chars().any(|c| c.is_control())
170        {
171            return Err(UtilsError::InvalidPath {
172                path: path_str.to_string(),
173                details: "Path contains invalid characters".to_string(),
174            }
175            .into());
176        }
177
178        // 3. Disallow path traversal using `..`
179        if path_str.contains("..") {
180            return Err(UtilsError::InvalidPath {
181                path: path_str.to_string(),
182                details: "Path traversal not allowed".to_string(),
183            }
184            .into());
185        }
186
187        // 4. Handle absolute paths
188        if path.is_absolute() {
189            println!(
190                "Debug: Absolute path detected: {}",
191                path.display()
192            );
193
194            // In test mode, allow absolute paths in the temporary directory
195            if cfg!(test) {
196                let temp_dir = std::env::temp_dir();
197                let path_canonicalized = path
198                    .canonicalize()
199                    .or_else(|_| {
200                        Ok::<PathBuf, io::Error>(path.to_path_buf())
201                    }) // Specify the type explicitly
202                    .with_context(|| {
203                        format!(
204                            "Failed to canonicalize path: {}",
205                            path.display()
206                        )
207                    })?;
208                let temp_dir_canonicalized = temp_dir
209                    .canonicalize()
210                    .or_else(|_| {
211                        Ok::<PathBuf, io::Error>(temp_dir.clone())
212                    }) // Specify the type explicitly
213                    .with_context(|| {
214                        format!(
215                            "Failed to canonicalize temp_dir: {}",
216                            temp_dir.display()
217                        )
218                    })?;
219
220                if path_canonicalized
221                    .starts_with(&temp_dir_canonicalized)
222                {
223                    return Ok(());
224                }
225            }
226
227            // Allow all absolute paths in non-test mode
228            return Ok(());
229        }
230
231        // 5. Check for symlinks
232        if path.exists() {
233            let metadata =
234                path.symlink_metadata().with_context(|| {
235                    format!(
236                        "Failed to get metadata for path: {}",
237                        path.display()
238                    )
239                })?;
240
241            if metadata.file_type().is_symlink() {
242                return Err(UtilsError::InvalidPath {
243                    path: path_str.to_string(),
244                    details: "Symlinks are not allowed".to_string(),
245                }
246                .into());
247            }
248        }
249
250        // 6. Prevent the use of reserved names (Windows compatibility)
251        let reserved_names =
252            ["con", "prn", "aux", "nul", "com1", "lpt1"];
253        if let Some(file_name) =
254            path.file_name().and_then(|n| n.to_str())
255        {
256            if reserved_names
257                .contains(&file_name.to_lowercase().as_str())
258            {
259                return Err(UtilsError::InvalidPath {
260                    path: path_str.to_string(),
261                    details: "Reserved file name not allowed"
262                        .to_string(),
263                }
264                .into());
265            }
266        }
267
268        Ok(())
269    }
270
271    /// Creates a directory and all parent directories
272    ///
273    /// # Arguments
274    ///
275    /// * `path` - Path to create
276    ///
277    /// # Returns
278    ///
279    /// Returns Ok(()) on success, or an error if creation fails
280    ///
281    /// # Security
282    ///
283    /// Validates path safety before creation
284    #[cfg(feature = "ssg")]
285    pub async fn create_directory(path: &Path) -> Result<()> {
286        validate_path_safety(path)?;
287
288        create_dir_all(path).with_context(|| {
289            format!("Failed to create directory: {}", path.display())
290        })?;
291
292        Ok(())
293    }
294
295    /// Copies a file from source to destination
296    ///
297    /// # Arguments
298    ///
299    /// * `src` - Source path
300    /// * `dst` - Destination path
301    ///
302    /// # Returns
303    ///
304    /// Returns Ok(()) on success, or an error if copy fails
305    ///
306    /// # Security
307    ///
308    /// Validates both paths and ensures proper permissions
309    pub async fn copy_file(src: &Path, dst: &Path) -> Result<()> {
310        validate_path_safety(src)?;
311        validate_path_safety(dst)?;
312
313        if let Some(parent) = dst.parent() {
314            create_dir_all(parent).with_context(|| {
315                format!(
316                    "Failed to create parent directory: {}",
317                    parent.display()
318                )
319            })?;
320        }
321
322        let _ = std::fs::copy(src, dst).with_context(|| {
323            format!(
324                "Failed to copy {} to {}",
325                src.display(),
326                dst.display()
327            )
328        })?;
329
330        Ok(())
331    }
332}
333
334/// Logging utilities module
335pub mod log {
336    #[cfg(feature = "ssg")]
337    use anyhow::{Context, Result};
338    #[cfg(feature = "ssg")]
339    use dtt::datetime::DateTime;
340    #[cfg(feature = "ssg")]
341    use log::{Level, Record};
342    #[cfg(feature = "ssg")]
343    use std::{
344        fs::{File, OpenOptions},
345        io::Write,
346        path::Path,
347    };
348
349    /// Log entry structure
350    #[cfg(feature = "ssg")]
351    #[derive(Debug)]
352    pub struct LogEntry {
353        /// Timestamp of the log entry
354        pub timestamp: DateTime,
355        /// Log level
356        pub level: Level,
357        /// Log message
358        pub message: String,
359        /// Optional error details
360        pub error: Option<String>,
361    }
362
363    #[cfg(feature = "ssg")]
364    impl LogEntry {
365        /// Creates a new log entry
366        pub fn new(record: &Record<'_>) -> Self {
367            Self {
368                timestamp: DateTime::new(),
369                level: record.level(),
370                message: record.args().to_string(),
371                error: None,
372            }
373        }
374
375        /// Formats the log entry as a string
376        pub fn format(&self) -> String {
377            let error_info = self
378                .error
379                .as_ref()
380                .map(|e| format!(" (Error: {})", e))
381                .unwrap_or_default();
382
383            format!(
384                "[{} {:>5}] {}{}",
385                self.timestamp, self.level, self.message, error_info
386            )
387        }
388    }
389
390    /// Log writer for handling log output
391    #[cfg(feature = "ssg")]
392    #[derive(Debug)]
393    pub struct LogWriter {
394        file: File,
395    }
396
397    #[cfg(feature = "ssg")]
398    impl LogWriter {
399        /// Creates a new log writer
400        pub fn new(path: &Path) -> Result<Self> {
401            let file = OpenOptions::new()
402                .create(true)
403                .append(true)
404                .open(path)
405                .with_context(|| {
406                    format!(
407                        "Failed to open log file: {}",
408                        path.display()
409                    )
410                })?;
411
412            Ok(Self { file })
413        }
414
415        /// Writes a log entry
416        pub fn write(&mut self, entry: &LogEntry) -> Result<()> {
417            writeln!(self.file, "{}", entry.format())
418                .context("Failed to write log entry")?;
419            Ok(())
420        }
421    }
422}
423
424impl From<anyhow::Error> for UtilsError {
425    fn from(err: anyhow::Error) -> Self {
426        UtilsError::InvalidOperation(err.to_string())
427    }
428}
429
430impl From<tokio::task::JoinError> for UtilsError {
431    fn from(err: tokio::task::JoinError) -> Self {
432        UtilsError::InvalidOperation(err.to_string())
433    }
434}
435
436#[cfg(all(test, feature = "ssg"))]
437mod tests {
438    use crate::utils::fs::copy_file;
439    use crate::utils::fs::create_directory;
440    use crate::utils::fs::create_temp_file;
441    use crate::utils::fs::validate_path_safety;
442    use crate::utils::fs::TempFileTracker;
443    use crate::utils::log::LogEntry;
444    use crate::utils::log::LogWriter;
445    use crate::utils::UtilsError;
446    use log::Level;
447    use log::Record;
448    use std::fs::read_to_string;
449    use std::fs::remove_file;
450    use std::path::Path;
451    use std::sync::Arc;
452
453    #[tokio::test]
454    async fn test_temp_file_creation_and_cleanup() -> anyhow::Result<()>
455    {
456        let tracker = TempFileTracker::new();
457        let (path, _file) = create_temp_file("test").await?;
458
459        tracker.register(path.clone()).await?;
460        assert!(path.exists());
461
462        tracker.cleanup().await?;
463        assert!(!path.exists());
464        Ok(())
465    }
466
467    #[tokio::test]
468    async fn test_temp_file_concurrent_access() -> Result<(), UtilsError>
469    {
470        use tokio::task;
471
472        let tracker = Arc::new(TempFileTracker::new());
473        let mut handles = Vec::new();
474
475        for i in 0..5 {
476            let tracker = Arc::clone(&tracker);
477            handles.push(task::spawn(async move {
478                let (path, _) =
479                    create_temp_file(&format!("test{}", i)).await?;
480                tracker.register(path).await
481            }));
482        }
483
484        for handle in handles {
485            handle.await??;
486        }
487
488        tracker.cleanup().await?;
489        Ok(())
490    }
491
492    #[tokio::test]
493    async fn test_create_directory_valid_path() -> anyhow::Result<()> {
494        let temp_dir = std::env::temp_dir().join("test_dir");
495
496        // Ensure the directory does not exist beforehand
497        if temp_dir.exists() {
498            tokio::fs::remove_dir_all(&temp_dir).await?;
499        }
500
501        create_directory(&temp_dir).await?;
502        assert!(temp_dir.exists());
503        tokio::fs::remove_dir_all(temp_dir).await?;
504        Ok(())
505    }
506
507    #[tokio::test]
508    async fn test_copy_file_valid_paths() -> anyhow::Result<()> {
509        let src = std::env::temp_dir().join("src.txt");
510        let dst = std::env::temp_dir().join("dst.txt");
511
512        // Create the source file with content
513        tokio::fs::write(&src, "test content").await?;
514
515        copy_file(&src, &dst).await?;
516        assert_eq!(
517            tokio::fs::read_to_string(&dst).await?,
518            "test content"
519        );
520
521        tokio::fs::remove_file(src).await?;
522        tokio::fs::remove_file(dst).await?;
523        Ok(())
524    }
525
526    #[test]
527    fn test_validate_path_safety_valid_paths() {
528        assert!(
529            validate_path_safety(Path::new("content/file.txt")).is_ok()
530        );
531        assert!(
532            validate_path_safety(Path::new("templates/blog")).is_ok()
533        );
534    }
535
536    #[test]
537    fn test_validate_path_safety_invalid_paths() {
538        assert!(validate_path_safety(Path::new("../outside")).is_err());
539        assert!(
540            validate_path_safety(Path::new("content\0file")).is_err()
541        );
542        assert!(validate_path_safety(Path::new("CON")).is_err());
543    }
544
545    #[test]
546    fn test_validate_path_safety_edge_cases() {
547        // Test Unicode
548        assert!(validate_path_safety(Path::new("content/📚")).is_ok());
549
550        // Long paths
551        let long_name = "a".repeat(255);
552        assert!(validate_path_safety(Path::new(&long_name)).is_ok());
553
554        // Special characters
555        assert!(validate_path_safety(Path::new("content/#$@!")).is_ok());
556    }
557
558    #[test]
559    fn test_log_entry_format() {
560        let record = Record::builder()
561            .args(format_args!("Test log message"))
562            .level(Level::Info)
563            .target("test")
564            .module_path_static(Some("test"))
565            .file_static(Some("test.rs"))
566            .line(Some(42))
567            .build();
568
569        let entry = LogEntry::new(&record);
570        assert!(entry.format().contains("Test log message"));
571        assert!(entry.format().contains("INFO"));
572    }
573
574    #[test]
575    fn test_log_entry_with_error() {
576        let record = Record::builder()
577            .args(format_args!("Test error message"))
578            .level(Level::Error)
579            .target("test")
580            .module_path_static(Some("test"))
581            .file_static(Some("test.rs"))
582            .line(Some(42))
583            .build();
584
585        let mut entry = LogEntry::new(&record);
586        entry.error = Some("Error details".to_string());
587
588        let formatted = entry.format();
589        assert!(formatted.contains("Error details"));
590        assert!(formatted.contains("ERROR"));
591    }
592
593    #[test]
594    fn test_log_writer_creation() {
595        let temp_log_path = std::env::temp_dir().join("test_log.txt");
596        let writer = LogWriter::new(&temp_log_path).unwrap();
597
598        assert!(temp_log_path.exists());
599        drop(writer); // Ensure file is closed
600        remove_file(temp_log_path).unwrap();
601    }
602
603    #[test]
604    fn test_log_writer_write() {
605        let temp_log_path =
606            std::env::temp_dir().join("test_log_write.txt");
607        let mut writer = LogWriter::new(&temp_log_path).unwrap();
608
609        let record = Record::builder()
610            .args(format_args!("Write test message"))
611            .level(Level::Info)
612            .target("test")
613            .build();
614
615        let entry = LogEntry::new(&record);
616        writer.write(&entry).unwrap();
617
618        let content = read_to_string(&temp_log_path).unwrap();
619        assert!(content.contains("Write test message"));
620        remove_file(temp_log_path).unwrap();
621    }
622}