ricecoder_generation/templates/
loader.rs

1//! Template loading from files and directories
2//!
3//! Loads templates from `.tmpl` files in global and project-specific locations.
4//! Supports template inheritance and includes.
5
6use crate::models::{Template, TemplateMetadata};
7use crate::templates::error::TemplateError;
8use crate::templates::parser::TemplateParser;
9use std::collections::HashMap;
10use std::fs;
11use std::path::{Path, PathBuf};
12
13/// Loads templates from files and directories
14pub struct TemplateLoader {
15    /// Cache of loaded templates
16    cache: HashMap<String, Template>,
17}
18
19impl TemplateLoader {
20    /// Create a new template loader
21    pub fn new() -> Self {
22        Self {
23            cache: HashMap::new(),
24        }
25    }
26
27    /// Load a template from a file
28    ///
29    /// # Arguments
30    /// * `path` - Path to the template file (.tmpl extension)
31    ///
32    /// # Returns
33    /// Loaded template or error
34    pub fn load_from_file(&mut self, path: &Path) -> Result<Template, TemplateError> {
35        // Check cache first
36        if let Some(cached) = self.cache.get(path.to_string_lossy().as_ref()) {
37            return Ok(cached.clone());
38        }
39
40        // Read file
41        let content = fs::read_to_string(path).map_err(TemplateError::IoError)?;
42
43        // Parse template to validate syntax
44        let parsed = TemplateParser::parse(&content)?;
45
46        // Extract template ID from filename (without .tmpl and language extension)
47        let id = path
48            .file_name()
49            .and_then(|name| name.to_str())
50            .map(|name| {
51                // Remove .tmpl extension
52                let name = name.strip_suffix(".tmpl").unwrap_or(name);
53                // Remove language extension (e.g., .rs, .ts, .py)
54                if let Some(dot_pos) = name.rfind('.') {
55                    &name[..dot_pos]
56                } else {
57                    name
58                }
59            })
60            .unwrap_or("unknown")
61            .to_string();
62
63        // Create template
64        let template = Template {
65            id: id.clone(),
66            name: id,
67            language: self.detect_language(path),
68            content,
69            placeholders: parsed.placeholders,
70            metadata: TemplateMetadata {
71                description: None,
72                version: None,
73                author: None,
74            },
75        };
76
77        // Cache the template
78        self.cache
79            .insert(path.to_string_lossy().to_string(), template.clone());
80
81        Ok(template)
82    }
83
84    /// Load all templates from a directory
85    ///
86    /// # Arguments
87    /// * `dir` - Directory containing .tmpl files
88    ///
89    /// # Returns
90    /// Vector of loaded templates or error
91    pub fn load_from_directory(&mut self, dir: &Path) -> Result<Vec<Template>, TemplateError> {
92        if !dir.exists() {
93            return Ok(Vec::new());
94        }
95
96        let mut templates = Vec::new();
97
98        // Recursively scan directory for .tmpl files
99        self.scan_directory(dir, &mut templates)?;
100
101        Ok(templates)
102    }
103
104    /// Scan directory recursively for .tmpl files
105    fn scan_directory(
106        &mut self,
107        dir: &Path,
108        templates: &mut Vec<Template>,
109    ) -> Result<(), TemplateError> {
110        let entries = fs::read_dir(dir).map_err(TemplateError::IoError)?;
111
112        for entry in entries {
113            let entry = entry.map_err(TemplateError::IoError)?;
114            let path = entry.path();
115
116            if path.is_dir() {
117                // Recursively scan subdirectories
118                self.scan_directory(&path, templates)?;
119            } else if path.extension().and_then(|s| s.to_str()) == Some("tmpl") {
120                // Load template file
121                match self.load_from_file(&path) {
122                    Ok(template) => templates.push(template),
123                    Err(e) => {
124                        // Log error but continue scanning
125                        eprintln!("Failed to load template {}: {}", path.display(), e);
126                    }
127                }
128            }
129        }
130
131        Ok(())
132    }
133
134    /// Load templates from global location (~/.ricecoder/templates/)
135    ///
136    /// # Returns
137    /// Vector of loaded templates or error
138    pub fn load_global_templates(&mut self) -> Result<Vec<Template>, TemplateError> {
139        let global_dir = self.get_global_templates_dir();
140        self.load_from_directory(&global_dir)
141    }
142
143    /// Load templates from project location (./.ricecoder/templates/)
144    ///
145    /// # Arguments
146    /// * `project_root` - Root directory of the project
147    ///
148    /// # Returns
149    /// Vector of loaded templates or error
150    pub fn load_project_templates(
151        &mut self,
152        project_root: &Path,
153    ) -> Result<Vec<Template>, TemplateError> {
154        let project_dir = project_root.join(".ricecoder").join("templates");
155        self.load_from_directory(&project_dir)
156    }
157
158    /// Load templates from both global and project locations
159    ///
160    /// Project templates take precedence over global templates with the same name.
161    ///
162    /// # Arguments
163    /// * `project_root` - Root directory of the project
164    ///
165    /// # Returns
166    /// Vector of loaded templates (project templates override global ones)
167    pub fn load_all_templates(
168        &mut self,
169        project_root: &Path,
170    ) -> Result<Vec<Template>, TemplateError> {
171        // Load global templates first
172        let templates = self.load_global_templates()?;
173
174        // Load project templates
175        let project_templates = self.load_project_templates(project_root)?;
176
177        // Create a map of templates by ID for deduplication
178        let mut template_map: HashMap<String, Template> =
179            templates.into_iter().map(|t| (t.id.clone(), t)).collect();
180
181        // Add/override with project templates
182        for template in project_templates {
183            template_map.insert(template.id.clone(), template);
184        }
185
186        Ok(template_map.into_values().collect())
187    }
188
189    /// Get the global templates directory path
190    fn get_global_templates_dir(&self) -> PathBuf {
191        if let Ok(home) = std::env::var("HOME") {
192            PathBuf::from(home).join(".ricecoder").join("templates")
193        } else if let Ok(home) = std::env::var("USERPROFILE") {
194            // Windows
195            PathBuf::from(home).join(".ricecoder").join("templates")
196        } else {
197            PathBuf::from(".ricecoder/templates")
198        }
199    }
200
201    /// Detect programming language from file extension
202    fn detect_language(&self, path: &Path) -> String {
203        path.extension()
204            .and_then(|_ext| {
205                // Get the extension before .tmpl
206                let parts: Vec<&str> = path.file_name()?.to_str()?.split('.').collect();
207
208                if parts.len() >= 2 {
209                    Some(parts[parts.len() - 2].to_string())
210                } else {
211                    None
212                }
213            })
214            .unwrap_or_else(|| "unknown".to_string())
215    }
216
217    /// Clear the template cache
218    pub fn clear_cache(&mut self) {
219        self.cache.clear();
220    }
221
222    /// Get cache statistics
223    pub fn cache_stats(&self) -> CacheStats {
224        CacheStats {
225            cached_templates: self.cache.len(),
226        }
227    }
228}
229
230impl Default for TemplateLoader {
231    fn default() -> Self {
232        Self::new()
233    }
234}
235
236/// Statistics about the template cache
237#[derive(Debug, Clone)]
238pub struct CacheStats {
239    /// Number of cached templates
240    pub cached_templates: usize,
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use std::fs;
247    use tempfile::TempDir;
248
249    #[test]
250    fn test_load_from_file() {
251        let temp_dir = TempDir::new().unwrap();
252        let template_path = temp_dir.path().join("test.rs.tmpl");
253
254        let content = "pub struct {{Name}} {\n    pub field: String,\n}";
255        fs::write(&template_path, content).unwrap();
256
257        let mut loader = TemplateLoader::new();
258        let template = loader.load_from_file(&template_path).unwrap();
259
260        assert_eq!(template.id, "test");
261        assert_eq!(template.language, "rs");
262        assert_eq!(template.content, content);
263    }
264
265    #[test]
266    fn test_load_from_directory() {
267        let temp_dir = TempDir::new().unwrap();
268
269        // Create multiple template files
270        fs::write(
271            temp_dir.path().join("struct.rs.tmpl"),
272            "pub struct {{Name}} {}",
273        )
274        .unwrap();
275        fs::write(temp_dir.path().join("impl.rs.tmpl"), "impl {{Name}} {}").unwrap();
276
277        let mut loader = TemplateLoader::new();
278        let templates = loader.load_from_directory(temp_dir.path()).unwrap();
279
280        assert_eq!(templates.len(), 2);
281        assert!(templates.iter().any(|t| t.id == "struct"));
282        assert!(templates.iter().any(|t| t.id == "impl"));
283    }
284
285    #[test]
286    fn test_load_nonexistent_directory() {
287        let mut loader = TemplateLoader::new();
288        let templates = loader
289            .load_from_directory(Path::new("/nonexistent/path"))
290            .unwrap();
291
292        assert_eq!(templates.len(), 0);
293    }
294
295    #[test]
296    fn test_detect_language() {
297        let loader = TemplateLoader::new();
298
299        assert_eq!(loader.detect_language(Path::new("test.rs.tmpl")), "rs");
300        assert_eq!(loader.detect_language(Path::new("test.ts.tmpl")), "ts");
301        assert_eq!(loader.detect_language(Path::new("test.py.tmpl")), "py");
302    }
303
304    #[test]
305    fn test_cache_stats() {
306        let temp_dir = TempDir::new().unwrap();
307        let template_path = temp_dir.path().join("test.rs.tmpl");
308        fs::write(&template_path, "pub struct {{Name}} {}").unwrap();
309
310        let mut loader = TemplateLoader::new();
311        loader.load_from_file(&template_path).unwrap();
312
313        let stats = loader.cache_stats();
314        assert_eq!(stats.cached_templates, 1);
315    }
316
317    #[test]
318    fn test_clear_cache() {
319        let temp_dir = TempDir::new().unwrap();
320        let template_path = temp_dir.path().join("test.rs.tmpl");
321        fs::write(&template_path, "pub struct {{Name}} {}").unwrap();
322
323        let mut loader = TemplateLoader::new();
324        loader.load_from_file(&template_path).unwrap();
325        assert_eq!(loader.cache_stats().cached_templates, 1);
326
327        loader.clear_cache();
328        assert_eq!(loader.cache_stats().cached_templates, 0);
329    }
330}