ggen_cli_lib/conventions/
planner.rs

1//! Generation planner for creating task execution plans
2
3use ggen_utils::error::Result;
4use std::collections::{HashMap, HashSet};
5use std::fs;
6use std::path::{Path, PathBuf};
7
8use super::ProjectConventions;
9
10/// Metadata extracted from template comments
11#[derive(Debug, Clone, PartialEq)]
12pub struct TemplateMetadata {
13    pub output: String,
14    pub when: Vec<String>,
15    pub query: String,
16    pub foreach: Option<String>,
17}
18
19impl TemplateMetadata {
20    /// Parse template metadata from file content
21    pub fn parse(content: &str) -> Result<Self> {
22        let mut output = None;
23        let mut when = Vec::new();
24        let mut query = None;
25        let mut foreach = None;
26
27        // Parse {# ... #} style comments
28        for line in content.lines() {
29            let line = line.trim();
30
31            if line.starts_with("{#") && line.ends_with("#}") {
32                let inner = line[2..line.len() - 2].trim();
33
34                if let Some(value) = inner.strip_prefix("output:") {
35                    output = Some(value.trim().to_string());
36                } else if let Some(value) = inner.strip_prefix("when:") {
37                    when.push(value.trim().to_string());
38                } else if let Some(value) = inner.strip_prefix("query:") {
39                    query = Some(value.trim().to_string());
40                } else if let Some(value) = inner.strip_prefix("foreach:") {
41                    foreach = Some(value.trim().to_string());
42                }
43            }
44        }
45
46        Ok(TemplateMetadata {
47            output: output
48                .ok_or_else(|| ggen_utils::error::Error::new("Missing 'output' directive"))?,
49            when,
50            query: query
51                .ok_or_else(|| ggen_utils::error::Error::new("Missing 'query' directive"))?,
52            foreach,
53        })
54    }
55}
56
57/// A single generation task
58#[derive(Debug, Clone)]
59pub struct GenerationTask {
60    pub template: String,
61    pub output_pattern: String,
62    pub trigger_files: Vec<PathBuf>,
63    pub query: Option<String>,
64    pub foreach: Option<String>,
65}
66
67/// Complete generation plan with all tasks
68#[derive(Debug)]
69pub struct GenerationPlan {
70    pub tasks: Vec<GenerationTask>,
71}
72
73/// Plans code generation based on conventions and templates
74pub struct GenerationPlanner {
75    conventions: ProjectConventions,
76}
77
78impl GenerationPlanner {
79    /// Create a new generation planner
80    pub fn new(conventions: ProjectConventions) -> Self {
81        Self { conventions }
82    }
83
84    /// Create a generation plan by analyzing all templates
85    pub fn plan(&self) -> Result<GenerationPlan> {
86        let mut tasks = Vec::new();
87        let mut task_graph: HashMap<String, Vec<String>> = HashMap::new();
88
89        // Iterate through all discovered templates
90        for (template_name, template_path) in &self.conventions.templates {
91            let metadata = self.parse_template_metadata(template_path)?;
92            let trigger_files = self.resolve_dependencies(&metadata);
93
94            // Track dependencies for circular detection
95            let deps: Vec<String> = trigger_files
96                .iter()
97                .filter_map(|p| {
98                    p.file_stem()
99                        .and_then(|s| s.to_str())
100                        .map(|s| s.to_string())
101                })
102                .collect();
103
104            task_graph.insert(template_name.clone(), deps);
105
106            tasks.push(GenerationTask {
107                template: template_name.clone(),
108                output_pattern: metadata.output,
109                trigger_files,
110                query: Some(metadata.query),
111                foreach: metadata.foreach,
112            });
113        }
114
115        // Check for circular dependencies
116        self.check_circular_dependencies(&task_graph)?;
117
118        // Sort tasks by dependencies (topological sort)
119        tasks = self.topological_sort(tasks)?;
120
121        Ok(GenerationPlan { tasks })
122    }
123
124    /// Parse template metadata from a template file
125    fn parse_template_metadata(&self, path: &Path) -> Result<TemplateMetadata> {
126        let content = fs::read_to_string(path)?;
127        TemplateMetadata::parse(&content)
128    }
129
130    /// Resolve file dependencies from template metadata
131    fn resolve_dependencies(&self, metadata: &TemplateMetadata) -> Vec<PathBuf> {
132        metadata
133            .when
134            .iter()
135            .map(|pattern| {
136                // Simple glob pattern resolution
137                // In a real implementation, this would use proper glob matching
138                PathBuf::from(pattern)
139            })
140            .collect()
141    }
142
143    /// Check for circular dependencies in the task graph
144    fn check_circular_dependencies(&self, graph: &HashMap<String, Vec<String>>) -> Result<()> {
145        for task in graph.keys() {
146            let mut visited = HashSet::new();
147            let mut rec_stack = HashSet::new();
148
149            if self.has_cycle(task, graph, &mut visited, &mut rec_stack) {
150                return Err(ggen_utils::error::Error::new(&format!(
151                    "Circular dependency detected involving task: {}",
152                    task
153                )));
154            }
155        }
156
157        Ok(())
158    }
159
160    /// DFS-based cycle detection
161    #[allow(clippy::only_used_in_recursion)] // Parameter used in recursive calls
162    fn has_cycle(
163        &self, task: &str, graph: &HashMap<String, Vec<String>>, visited: &mut HashSet<String>,
164        rec_stack: &mut HashSet<String>,
165    ) -> bool {
166        if rec_stack.contains(task) {
167            return true;
168        }
169
170        if visited.contains(task) {
171            return false;
172        }
173
174        visited.insert(task.to_string());
175        rec_stack.insert(task.to_string());
176
177        if let Some(deps) = graph.get(task) {
178            for dep in deps {
179                if self.has_cycle(dep, graph, visited, rec_stack) {
180                    return true;
181                }
182            }
183        }
184
185        rec_stack.remove(task);
186        false
187    }
188
189    /// Topologically sort tasks by dependencies
190    fn topological_sort(&self, mut tasks: Vec<GenerationTask>) -> Result<Vec<GenerationTask>> {
191        // Build dependency map
192        let mut dep_count: HashMap<String, usize> = HashMap::new();
193        let mut graph: HashMap<String, Vec<String>> = HashMap::new();
194
195        for task in &tasks {
196            dep_count.entry(task.template.clone()).or_insert(0);
197
198            for trigger in &task.trigger_files {
199                if let Some(dep) = trigger.file_stem().and_then(|s| s.to_str()) {
200                    graph
201                        .entry(dep.to_string())
202                        .or_default()
203                        .push(task.template.clone());
204                    *dep_count.entry(task.template.clone()).or_insert(0) += 1;
205                }
206            }
207        }
208
209        // Find tasks with no dependencies
210        let mut ready: Vec<String> = dep_count
211            .iter()
212            .filter(|(_, &count)| count == 0)
213            .map(|(name, _)| name.clone())
214            .collect();
215
216        let mut sorted = Vec::new();
217
218        while let Some(task_name) = ready.pop() {
219            // Find and add the task
220            if let Some(pos) = tasks.iter().position(|t| t.template == task_name) {
221                let task = tasks.remove(pos);
222                sorted.push(task);
223            }
224
225            // Update dependent tasks
226            if let Some(dependents) = graph.get(&task_name) {
227                for dependent in dependents {
228                    if let Some(count) = dep_count.get_mut(dependent) {
229                        *count = count.saturating_sub(1);
230                        if *count == 0 {
231                            ready.push(dependent.clone());
232                        }
233                    }
234                }
235            }
236        }
237
238        // If there are remaining tasks, there's a cycle
239        if !tasks.is_empty() {
240            return Err(ggen_utils::error::Error::new(
241                "Circular dependency detected during topological sort",
242            ));
243        }
244
245        Ok(sorted)
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use std::fs;
253    use tempfile::TempDir;
254
255    fn create_test_template(dir: &Path, name: &str, content: &str) -> PathBuf {
256        let path = dir.join(name);
257        fs::write(&path, content).unwrap();
258        path
259    }
260
261    fn create_test_conventions(
262        temp_dir: &Path, templates: Vec<(&str, &str)>,
263    ) -> ProjectConventions {
264        let template_dir = temp_dir.join("templates");
265        fs::create_dir_all(&template_dir).unwrap();
266
267        let mut template_map = HashMap::new();
268        for (name, content) in templates {
269            let full_name = if name.ends_with(".tmpl") {
270                name.to_string()
271            } else {
272                format!("{}.tmpl", name)
273            };
274            let path = create_test_template(&template_dir, &full_name, content);
275            let key = name.strip_suffix(".tmpl").unwrap_or(name).to_string();
276            template_map.insert(key, path);
277        }
278
279        ProjectConventions {
280            rdf_files: vec![],
281            rdf_dir: temp_dir.join("domain"),
282            templates: template_map,
283            templates_dir: template_dir,
284            queries: HashMap::new(),
285            output_dir: temp_dir.join("generated"),
286            preset: "test".to_string(),
287        }
288    }
289
290    #[test]
291    fn test_parse_template_metadata_basic() {
292        let content = r#"
293{# output: src/{{name}}.rs #}
294{# when: src/models/*.rs #}
295{# query: SELECT * FROM models #}
296
297template content here
298"#;
299
300        let metadata = TemplateMetadata::parse(content).unwrap();
301        assert_eq!(metadata.output, "src/{{name}}.rs");
302        assert_eq!(metadata.when, vec!["src/models/*.rs"]);
303        assert_eq!(metadata.query, "SELECT * FROM models");
304        assert_eq!(metadata.foreach, None);
305    }
306
307    #[test]
308    fn test_parse_template_metadata_with_foreach() {
309        let content = r#"
310{# output: tests/{{item}}_test.rs #}
311{# when: src/{{item}}.rs #}
312{# query: SELECT name FROM entities #}
313{# foreach: entity #}
314
315test template
316"#;
317
318        let metadata = TemplateMetadata::parse(content).unwrap();
319        assert_eq!(metadata.output, "tests/{{item}}_test.rs");
320        assert_eq!(metadata.foreach, Some("entity".to_string()));
321    }
322
323    #[test]
324    fn test_parse_template_metadata_multiple_when() {
325        let content = r#"
326{# output: generated.rs #}
327{# when: file1.rs #}
328{# when: file2.rs #}
329{# query: SELECT * #}
330"#;
331
332        let metadata = TemplateMetadata::parse(content).unwrap();
333        assert_eq!(metadata.when.len(), 2);
334        assert!(metadata.when.contains(&"file1.rs".to_string()));
335        assert!(metadata.when.contains(&"file2.rs".to_string()));
336    }
337
338    #[test]
339    fn test_parse_template_metadata_missing_output() {
340        let content = r#"
341{# query: SELECT * #}
342"#;
343
344        let result = TemplateMetadata::parse(content);
345        assert!(result.is_err());
346        assert!(result.unwrap_err().to_string().contains("Missing 'output'"));
347    }
348
349    #[test]
350    fn test_parse_template_metadata_missing_query() {
351        let content = r#"
352{# output: file.rs #}
353"#;
354
355        let result = TemplateMetadata::parse(content);
356        assert!(result.is_err());
357        assert!(result.unwrap_err().to_string().contains("Missing 'query'"));
358    }
359
360    #[test]
361    fn test_generation_planner_empty() {
362        let temp_dir = TempDir::new().unwrap();
363        let conventions = create_test_conventions(temp_dir.path(), vec![]);
364
365        let planner = GenerationPlanner::new(conventions);
366        let plan = planner.plan().unwrap();
367
368        assert_eq!(plan.tasks.len(), 0);
369    }
370
371    #[test]
372    fn test_generation_planner_single_task() {
373        let temp_dir = TempDir::new().unwrap();
374        let conventions = create_test_conventions(
375            temp_dir.path(),
376            vec![(
377                "test",
378                r#"
379{# output: generated.rs #}
380{# query: SELECT * #}
381
382content
383"#,
384            )],
385        );
386
387        let planner = GenerationPlanner::new(conventions);
388        let plan = planner.plan().unwrap();
389
390        assert_eq!(plan.tasks.len(), 1);
391        assert_eq!(plan.tasks[0].template, "test");
392        assert_eq!(plan.tasks[0].output_pattern, "generated.rs");
393    }
394
395    #[test]
396    fn test_generation_planner_multiple_tasks() {
397        let temp_dir = TempDir::new().unwrap();
398        let conventions = create_test_conventions(
399            temp_dir.path(),
400            vec![
401                (
402                    "task1",
403                    r#"
404{# output: out1.rs #}
405{# query: SELECT * FROM table1 #}
406"#,
407                ),
408                (
409                    "task2",
410                    r#"
411{# output: out2.rs #}
412{# query: SELECT * FROM table2 #}
413"#,
414                ),
415            ],
416        );
417
418        let planner = GenerationPlanner::new(conventions);
419        let plan = planner.plan().unwrap();
420
421        assert_eq!(plan.tasks.len(), 2);
422    }
423
424    #[test]
425    fn test_generation_planner_with_dependencies() {
426        let temp_dir = TempDir::new().unwrap();
427        let conventions = create_test_conventions(
428            temp_dir.path(),
429            vec![
430                (
431                    "base",
432                    r#"
433{# output: base.rs #}
434{# query: SELECT * FROM base #}
435"#,
436                ),
437                (
438                    "derived",
439                    r#"
440{# output: derived.rs #}
441{# when: base.rs #}
442{# query: SELECT * FROM derived #}
443"#,
444                ),
445            ],
446        );
447
448        let planner = GenerationPlanner::new(conventions);
449        let plan = planner.plan().unwrap();
450
451        assert_eq!(plan.tasks.len(), 2);
452
453        // Base task should come before derived
454        let base_idx = plan
455            .tasks
456            .iter()
457            .position(|t| t.template == "base")
458            .unwrap();
459        let derived_idx = plan
460            .tasks
461            .iter()
462            .position(|t| t.template == "derived")
463            .unwrap();
464        assert!(base_idx < derived_idx);
465    }
466
467    #[test]
468    fn test_circular_dependency_detection() {
469        let temp_dir = TempDir::new().unwrap();
470        let conventions = create_test_conventions(
471            temp_dir.path(),
472            vec![
473                (
474                    "task1",
475                    r#"
476{# output: task1.rs #}
477{# when: task2.rs #}
478{# query: SELECT * #}
479"#,
480                ),
481                (
482                    "task2",
483                    r#"
484{# output: task2.rs #}
485{# when: task1.rs #}
486{# query: SELECT * #}
487"#,
488                ),
489            ],
490        );
491
492        let planner = GenerationPlanner::new(conventions);
493        let result = planner.plan();
494
495        assert!(result.is_err());
496        let err_msg = result.unwrap_err().to_string();
497        assert!(
498            err_msg.contains("Circular dependency") || err_msg.contains("cycle"),
499            "Expected circular dependency error, got: {}",
500            err_msg
501        );
502    }
503
504    #[test]
505    fn test_foreach_pattern() {
506        let temp_dir = TempDir::new().unwrap();
507        let conventions = create_test_conventions(
508            temp_dir.path(),
509            vec![(
510                "test",
511                r#"
512{# output: tests/{{entity}}_test.rs #}
513{# query: SELECT name FROM entities #}
514{# foreach: entity #}
515"#,
516            )],
517        );
518
519        let planner = GenerationPlanner::new(conventions);
520        let plan = planner.plan().unwrap();
521
522        assert_eq!(plan.tasks.len(), 1);
523        assert_eq!(plan.tasks[0].foreach, Some("entity".to_string()));
524        assert_eq!(plan.tasks[0].output_pattern, "tests/{{entity}}_test.rs");
525    }
526
527    #[test]
528    fn test_once_pattern() {
529        let temp_dir = TempDir::new().unwrap();
530        let conventions = create_test_conventions(
531            temp_dir.path(),
532            vec![(
533                "once",
534                r#"
535{# output: single_file.rs #}
536{# query: SELECT COUNT(*) #}
537"#,
538            )],
539        );
540
541        let planner = GenerationPlanner::new(conventions);
542        let plan = planner.plan().unwrap();
543
544        assert_eq!(plan.tasks.len(), 1);
545        assert_eq!(plan.tasks[0].foreach, None);
546        assert_eq!(plan.tasks[0].output_pattern, "single_file.rs");
547    }
548
549    #[test]
550    fn test_complex_dependency_graph() {
551        let temp_dir = TempDir::new().unwrap();
552        let conventions = create_test_conventions(
553            temp_dir.path(),
554            vec![
555                (
556                    "models",
557                    r#"
558{# output: models.rs #}
559{# query: SELECT * FROM schema #}
560"#,
561                ),
562                (
563                    "services",
564                    r#"
565{# output: services.rs #}
566{# when: models.rs #}
567{# query: SELECT * FROM services #}
568"#,
569                ),
570                (
571                    "controllers",
572                    r#"
573{# output: controllers.rs #}
574{# when: services.rs #}
575{# query: SELECT * FROM controllers #}
576"#,
577                ),
578                (
579                    "routes",
580                    r#"
581{# output: routes.rs #}
582{# when: controllers.rs #}
583{# query: SELECT * FROM routes #}
584"#,
585                ),
586            ],
587        );
588
589        let planner = GenerationPlanner::new(conventions);
590        let plan = planner.plan().unwrap();
591
592        assert_eq!(plan.tasks.len(), 4);
593
594        // Verify ordering
595        let names: Vec<_> = plan.tasks.iter().map(|t| t.template.as_str()).collect();
596        let models_idx = names.iter().position(|&n| n == "models").unwrap();
597        let services_idx = names.iter().position(|&n| n == "services").unwrap();
598        let controllers_idx = names.iter().position(|&n| n == "controllers").unwrap();
599        let routes_idx = names.iter().position(|&n| n == "routes").unwrap();
600
601        assert!(models_idx < services_idx);
602        assert!(services_idx < controllers_idx);
603        assert!(controllers_idx < routes_idx);
604    }
605}