Skip to main content

ggen_cli_lib/conventions/
planner.rs

1//! Generation planner for creating task execution plans
2
3use ggen_core::utils::error::Result;
4use std::collections::{HashMap, HashSet};
5use std::fs;
6use std::path::{Path, PathBuf};
7
8use super::ProjectConventions;
9
10/// Metadata extracted from template comments
11/// PartialEq without Eq: All fields (`String`, `Vec<String>`) implement Eq
12#[allow(clippy::derive_partial_eq_without_eq)]
13#[derive(Debug, Clone, PartialEq)]
14pub struct TemplateMetadata {
15    pub output: String,
16    pub when: Vec<String>,
17    pub query: String,
18    pub foreach: Option<String>,
19}
20
21impl TemplateMetadata {
22    /// Parse template metadata from file content
23    pub fn parse(content: &str) -> Result<Self> {
24        let mut output = None;
25        let mut when = Vec::new();
26        let mut query = None;
27        let mut foreach = None;
28
29        // Parse {# ... #} style comments
30        for line in content.lines() {
31            let line = line.trim();
32
33            if line.starts_with("{#") && line.ends_with("#}") {
34                let inner = line[2..line.len() - 2].trim();
35
36                if let Some(value) = inner.strip_prefix("output:") {
37                    output = Some(value.trim().to_string());
38                } else if let Some(value) = inner.strip_prefix("when:") {
39                    when.push(value.trim().to_string());
40                } else if let Some(value) = inner.strip_prefix("query:") {
41                    query = Some(value.trim().to_string());
42                } else if let Some(value) = inner.strip_prefix("foreach:") {
43                    foreach = Some(value.trim().to_string());
44                }
45            }
46        }
47
48        Ok(TemplateMetadata {
49            output: output
50                .ok_or_else(|| ggen_core::utils::error::Error::new("Missing 'output' directive"))?,
51            when,
52            query: query
53                .ok_or_else(|| ggen_core::utils::error::Error::new("Missing 'query' directive"))?,
54            foreach,
55        })
56    }
57}
58
59/// A single generation task
60#[derive(Debug, Clone)]
61pub struct GenerationTask {
62    pub template: String,
63    pub output_pattern: String,
64    pub trigger_files: Vec<PathBuf>,
65    pub query: Option<String>,
66    pub foreach: Option<String>,
67}
68
69/// Complete generation plan with all tasks
70#[derive(Debug)]
71pub struct GenerationPlan {
72    pub tasks: Vec<GenerationTask>,
73}
74
75/// Plans code generation based on conventions and templates
76pub struct GenerationPlanner {
77    conventions: ProjectConventions,
78}
79
80impl GenerationPlanner {
81    /// Create a new generation planner
82    pub fn new(conventions: ProjectConventions) -> Self {
83        Self { conventions }
84    }
85
86    /// Create a generation plan by analyzing all templates
87    pub fn plan(&self) -> Result<GenerationPlan> {
88        let mut tasks = Vec::new();
89        let mut task_graph: HashMap<String, Vec<String>> = HashMap::new();
90
91        // Iterate through all discovered templates
92        for (template_name, template_path) in &self.conventions.templates {
93            let metadata = self.parse_template_metadata(template_path)?;
94            let trigger_files = self.resolve_dependencies(&metadata);
95
96            // Track dependencies for circular detection
97            let deps: Vec<String> = trigger_files
98                .iter()
99                .filter_map(|p| {
100                    p.file_stem()
101                        .and_then(|s| s.to_str())
102                        .map(|s| s.to_string())
103                })
104                .collect();
105
106            task_graph.insert(template_name.clone(), deps);
107
108            tasks.push(GenerationTask {
109                template: template_name.clone(),
110                output_pattern: metadata.output,
111                trigger_files,
112                query: Some(metadata.query),
113                foreach: metadata.foreach,
114            });
115        }
116
117        // Check for circular dependencies
118        self.check_circular_dependencies(&task_graph)?;
119
120        // Sort tasks by dependencies (topological sort)
121        tasks = self.topological_sort(tasks)?;
122
123        Ok(GenerationPlan { tasks })
124    }
125
126    /// Parse template metadata from a template file
127    fn parse_template_metadata(&self, path: &Path) -> Result<TemplateMetadata> {
128        let content = fs::read_to_string(path)?;
129        TemplateMetadata::parse(&content)
130    }
131
132    /// Resolve file dependencies from template metadata
133    fn resolve_dependencies(&self, metadata: &TemplateMetadata) -> Vec<PathBuf> {
134        metadata
135            .when
136            .iter()
137            .map(|pattern| {
138                // Simple glob pattern resolution
139                // In a real implementation, this would use proper glob matching
140                PathBuf::from(pattern)
141            })
142            .collect()
143    }
144
145    /// Check for circular dependencies in the task graph
146    fn check_circular_dependencies(&self, graph: &HashMap<String, Vec<String>>) -> Result<()> {
147        for task in graph.keys() {
148            let mut visited = HashSet::new();
149            let mut rec_stack = HashSet::new();
150
151            if self.has_cycle(task, graph, &mut visited, &mut rec_stack) {
152                return Err(ggen_core::utils::error::Error::new(&format!(
153                    "Circular dependency detected involving task: {}",
154                    task
155                )));
156            }
157        }
158
159        Ok(())
160    }
161
162    /// DFS-based cycle detection
163    #[allow(clippy::self_only_used_in_recursion)] // self needed for recursive dispatch
164    fn has_cycle(
165        &self, task: &str, graph: &HashMap<String, Vec<String>>, visited: &mut HashSet<String>,
166        rec_stack: &mut HashSet<String>,
167    ) -> bool {
168        if rec_stack.contains(task) {
169            return true;
170        }
171
172        if visited.contains(task) {
173            return false;
174        }
175
176        visited.insert(task.to_string());
177        rec_stack.insert(task.to_string());
178
179        if let Some(deps) = graph.get(task) {
180            for dep in deps {
181                if self.has_cycle(dep, graph, visited, rec_stack) {
182                    return true;
183                }
184            }
185        }
186
187        rec_stack.remove(task);
188        false
189    }
190
191    /// Topologically sort tasks by dependencies
192    fn topological_sort(&self, mut tasks: Vec<GenerationTask>) -> Result<Vec<GenerationTask>> {
193        // Build dependency map
194        let mut dep_count: HashMap<String, usize> = HashMap::new();
195        let mut graph: HashMap<String, Vec<String>> = HashMap::new();
196
197        for task in &tasks {
198            dep_count.entry(task.template.clone()).or_insert(0);
199
200            for trigger in &task.trigger_files {
201                if let Some(dep) = trigger.file_stem().and_then(|s| s.to_str()) {
202                    graph
203                        .entry(dep.to_string())
204                        .or_default()
205                        .push(task.template.clone());
206                    *dep_count.entry(task.template.clone()).or_insert(0) += 1;
207                }
208            }
209        }
210
211        // Find tasks with no dependencies
212        let mut ready: Vec<String> = dep_count
213            .iter()
214            .filter(|(_, &count)| count == 0)
215            .map(|(name, _)| name.clone())
216            .collect();
217
218        let mut sorted = Vec::new();
219
220        while let Some(task_name) = ready.pop() {
221            // Find and add the task
222            if let Some(pos) = tasks.iter().position(|t| t.template == task_name) {
223                let task = tasks.remove(pos);
224                sorted.push(task);
225            }
226
227            // Update dependent tasks
228            if let Some(dependents) = graph.get(&task_name) {
229                for dependent in dependents {
230                    if let Some(count) = dep_count.get_mut(dependent) {
231                        *count = count.saturating_sub(1);
232                        if *count == 0 {
233                            ready.push(dependent.clone());
234                        }
235                    }
236                }
237            }
238        }
239
240        // If there are remaining tasks, there's a cycle
241        if !tasks.is_empty() {
242            return Err(ggen_core::utils::error::Error::new(
243                "Circular dependency detected during topological sort",
244            ));
245        }
246
247        Ok(sorted)
248    }
249}
250
251#[cfg(test)]
252#[allow(clippy::unwrap_used, clippy::needless_raw_string_hashes)]
253mod tests {
254    use super::*;
255    use std::fs;
256    use tempfile::TempDir;
257
258    fn create_test_template(dir: &Path, name: &str, content: &str) -> PathBuf {
259        let path = dir.join(name);
260        fs::write(&path, content).unwrap();
261        path
262    }
263
264    fn create_test_conventions(
265        temp_dir: &Path, templates: Vec<(&str, &str)>,
266    ) -> ProjectConventions {
267        let template_dir = temp_dir.join("templates");
268        fs::create_dir_all(&template_dir).unwrap();
269
270        let mut template_map = HashMap::new();
271        for (name, content) in templates {
272            let full_name = if name.ends_with(".tmpl") {
273                name.to_string()
274            } else {
275                format!("{}.tmpl", name)
276            };
277            let path = create_test_template(&template_dir, &full_name, content);
278            let key = name.strip_suffix(".tmpl").unwrap_or(name).to_string();
279            template_map.insert(key, path);
280        }
281
282        ProjectConventions {
283            rdf_files: vec![],
284            rdf_dir: temp_dir.join("domain"),
285            templates: template_map,
286            templates_dir: template_dir,
287            queries: HashMap::new(),
288            output_dir: temp_dir.to_path_buf(),
289            preset: "test".to_string(),
290        }
291    }
292
293    #[test]
294    fn test_parse_template_metadata_basic() {
295        let content = r#"
296{# output: src/{{name}}.rs #}
297{# when: src/models/*.rs #}
298{# query: SELECT * FROM models #}
299
300template content here
301"#;
302
303        let metadata = TemplateMetadata::parse(content).unwrap();
304        assert_eq!(metadata.output, "src/{{name}}.rs");
305        assert_eq!(metadata.when, vec!["src/models/*.rs"]);
306        assert_eq!(metadata.query, "SELECT * FROM models");
307        assert_eq!(metadata.foreach, None);
308    }
309
310    #[test]
311    fn test_parse_template_metadata_with_foreach() {
312        let content = r#"
313{# output: tests/{{item}}_test.rs #}
314{# when: src/{{item}}.rs #}
315{# query: SELECT name FROM entities #}
316{# foreach: entity #}
317
318test template
319"#;
320
321        let metadata = TemplateMetadata::parse(content).unwrap();
322        assert_eq!(metadata.output, "tests/{{item}}_test.rs");
323        assert_eq!(metadata.foreach, Some("entity".to_string()));
324    }
325
326    #[test]
327    fn test_parse_template_metadata_multiple_when() {
328        let content = r#"
329{# output: generated.rs #}
330{# when: file1.rs #}
331{# when: file2.rs #}
332{# query: SELECT * #}
333"#;
334
335        let metadata = TemplateMetadata::parse(content).unwrap();
336        assert_eq!(metadata.when.len(), 2);
337        assert!(metadata.when.contains(&"file1.rs".to_string()));
338        assert!(metadata.when.contains(&"file2.rs".to_string()));
339    }
340
341    #[test]
342    fn test_parse_template_metadata_missing_output() {
343        let content = r#"
344{# query: SELECT * #}
345"#;
346
347        let result = TemplateMetadata::parse(content);
348        assert!(result.is_err());
349        assert!(result.unwrap_err().to_string().contains("Missing 'output'"));
350    }
351
352    #[test]
353    fn test_parse_template_metadata_missing_query() {
354        let content = r#"
355{# output: file.rs #}
356"#;
357
358        let result = TemplateMetadata::parse(content);
359        assert!(result.is_err());
360        assert!(result.unwrap_err().to_string().contains("Missing 'query'"));
361    }
362
363    #[test]
364    fn test_generation_planner_empty() {
365        let temp_dir = TempDir::new().unwrap();
366        let conventions = create_test_conventions(temp_dir.path(), vec![]);
367
368        let planner = GenerationPlanner::new(conventions);
369        let plan = planner.plan().unwrap();
370
371        assert_eq!(plan.tasks.len(), 0);
372    }
373
374    #[test]
375    fn test_generation_planner_single_task() {
376        let temp_dir = TempDir::new().unwrap();
377        let conventions = create_test_conventions(
378            temp_dir.path(),
379            vec![(
380                "test",
381                r#"
382{# output: generated.rs #}
383{# query: SELECT * #}
384
385content
386"#,
387            )],
388        );
389
390        let planner = GenerationPlanner::new(conventions);
391        let plan = planner.plan().unwrap();
392
393        assert_eq!(plan.tasks.len(), 1);
394        assert_eq!(plan.tasks[0].template, "test");
395        assert_eq!(plan.tasks[0].output_pattern, "generated.rs");
396    }
397
398    #[test]
399    fn test_generation_planner_multiple_tasks() {
400        let temp_dir = TempDir::new().unwrap();
401        let conventions = create_test_conventions(
402            temp_dir.path(),
403            vec![
404                (
405                    "task1",
406                    r#"
407{# output: out1.rs #}
408{# query: SELECT * FROM table1 #}
409"#,
410                ),
411                (
412                    "task2",
413                    r#"
414{# output: out2.rs #}
415{# query: SELECT * FROM table2 #}
416"#,
417                ),
418            ],
419        );
420
421        let planner = GenerationPlanner::new(conventions);
422        let plan = planner.plan().unwrap();
423
424        assert_eq!(plan.tasks.len(), 2);
425    }
426
427    #[test]
428    fn test_generation_planner_with_dependencies() {
429        let temp_dir = TempDir::new().unwrap();
430        let conventions = create_test_conventions(
431            temp_dir.path(),
432            vec![
433                (
434                    "base",
435                    r#"
436{# output: base.rs #}
437{# query: SELECT * FROM base #}
438"#,
439                ),
440                (
441                    "derived",
442                    r#"
443{# output: derived.rs #}
444{# when: base.rs #}
445{# query: SELECT * FROM derived #}
446"#,
447                ),
448            ],
449        );
450
451        let planner = GenerationPlanner::new(conventions);
452        let plan = planner.plan().unwrap();
453
454        assert_eq!(plan.tasks.len(), 2);
455
456        // Base task should come before derived
457        let base_idx = plan
458            .tasks
459            .iter()
460            .position(|t| t.template == "base")
461            .unwrap();
462        let derived_idx = plan
463            .tasks
464            .iter()
465            .position(|t| t.template == "derived")
466            .unwrap();
467        assert!(base_idx < derived_idx);
468    }
469
470    #[test]
471    fn test_circular_dependency_detection() {
472        let temp_dir = TempDir::new().unwrap();
473        let conventions = create_test_conventions(
474            temp_dir.path(),
475            vec![
476                (
477                    "task1",
478                    r#"
479{# output: task1.rs #}
480{# when: task2.rs #}
481{# query: SELECT * #}
482"#,
483                ),
484                (
485                    "task2",
486                    r#"
487{# output: task2.rs #}
488{# when: task1.rs #}
489{# query: SELECT * #}
490"#,
491                ),
492            ],
493        );
494
495        let planner = GenerationPlanner::new(conventions);
496        let result = planner.plan();
497
498        assert!(result.is_err());
499        let err_msg = result.unwrap_err().to_string();
500        assert!(
501            err_msg.contains("Circular dependency") || err_msg.contains("cycle"),
502            "Expected circular dependency error, got: {}",
503            err_msg
504        );
505    }
506
507    #[test]
508    fn test_foreach_pattern() {
509        let temp_dir = TempDir::new().unwrap();
510        let conventions = create_test_conventions(
511            temp_dir.path(),
512            vec![(
513                "test",
514                r#"
515{# output: tests/{{entity}}_test.rs #}
516{# query: SELECT name FROM entities #}
517{# foreach: entity #}
518"#,
519            )],
520        );
521
522        let planner = GenerationPlanner::new(conventions);
523        let plan = planner.plan().unwrap();
524
525        assert_eq!(plan.tasks.len(), 1);
526        assert_eq!(plan.tasks[0].foreach, Some("entity".to_string()));
527        assert_eq!(plan.tasks[0].output_pattern, "tests/{{entity}}_test.rs");
528    }
529
530    #[test]
531    fn test_once_pattern() {
532        let temp_dir = TempDir::new().unwrap();
533        let conventions = create_test_conventions(
534            temp_dir.path(),
535            vec![(
536                "once",
537                r#"
538{# output: single_file.rs #}
539{# query: SELECT COUNT(*) #}
540"#,
541            )],
542        );
543
544        let planner = GenerationPlanner::new(conventions);
545        let plan = planner.plan().unwrap();
546
547        assert_eq!(plan.tasks.len(), 1);
548        assert_eq!(plan.tasks[0].foreach, None);
549        assert_eq!(plan.tasks[0].output_pattern, "single_file.rs");
550    }
551
552    #[test]
553    fn test_complex_dependency_graph() {
554        let temp_dir = TempDir::new().unwrap();
555        let conventions = create_test_conventions(
556            temp_dir.path(),
557            vec![
558                (
559                    "models",
560                    r#"
561{# output: models.rs #}
562{# query: SELECT * FROM schema #}
563"#,
564                ),
565                (
566                    "services",
567                    r#"
568{# output: services.rs #}
569{# when: models.rs #}
570{# query: SELECT * FROM services #}
571"#,
572                ),
573                (
574                    "controllers",
575                    r#"
576{# output: controllers.rs #}
577{# when: services.rs #}
578{# query: SELECT * FROM controllers #}
579"#,
580                ),
581                (
582                    "routes",
583                    r#"
584{# output: routes.rs #}
585{# when: controllers.rs #}
586{# query: SELECT * FROM routes #}
587"#,
588                ),
589            ],
590        );
591
592        let planner = GenerationPlanner::new(conventions);
593        let plan = planner.plan().unwrap();
594
595        assert_eq!(plan.tasks.len(), 4);
596
597        // Verify ordering
598        let names: Vec<_> = plan.tasks.iter().map(|t| t.template.as_str()).collect();
599        let models_idx = names.iter().position(|&n| n == "models").unwrap();
600        let services_idx = names.iter().position(|&n| n == "services").unwrap();
601        let controllers_idx = names.iter().position(|&n| n == "controllers").unwrap();
602        let routes_idx = names.iter().position(|&n| n == "routes").unwrap();
603
604        assert!(models_idx < services_idx);
605        assert!(services_idx < controllers_idx);
606        assert!(controllers_idx < routes_idx);
607    }
608}