ggen_cli_lib/conventions/
planner.rs

1//! Generation planner for creating task execution plans
2
3use anyhow::{anyhow, Result};
4use std::path::{Path, PathBuf};
5use std::collections::{HashMap, HashSet};
6use std::fs;
7
8use super::ProjectConventions;
9
10/// Metadata extracted from template comments
11#[derive(Debug, Clone, PartialEq)]
12pub struct TemplateMetadata {
13    pub output: String,
14    pub when: Vec<String>,
15    pub query: String,
16    pub foreach: Option<String>,
17}
18
19impl TemplateMetadata {
20    /// Parse template metadata from file content
21    pub fn parse(content: &str) -> Result<Self> {
22        let mut output = None;
23        let mut when = Vec::new();
24        let mut query = None;
25        let mut foreach = None;
26
27        // Parse {# ... #} style comments
28        for line in content.lines() {
29            let line = line.trim();
30
31            if line.starts_with("{#") && line.ends_with("#}") {
32                let inner = line[2..line.len() - 2].trim();
33
34                if let Some(value) = inner.strip_prefix("output:") {
35                    output = Some(value.trim().to_string());
36                } else if let Some(value) = inner.strip_prefix("when:") {
37                    when.push(value.trim().to_string());
38                } else if let Some(value) = inner.strip_prefix("query:") {
39                    query = Some(value.trim().to_string());
40                } else if let Some(value) = inner.strip_prefix("foreach:") {
41                    foreach = Some(value.trim().to_string());
42                }
43            }
44        }
45
46        Ok(TemplateMetadata {
47            output: output.ok_or_else(|| anyhow!("Missing 'output' directive"))?,
48            when,
49            query: query.ok_or_else(|| anyhow!("Missing 'query' directive"))?,
50            foreach,
51        })
52    }
53}
54
55/// A single generation task
56#[derive(Debug, Clone)]
57pub struct GenerationTask {
58    pub template: String,
59    pub output_pattern: String,
60    pub trigger_files: Vec<PathBuf>,
61    pub query: Option<String>,
62    pub foreach: Option<String>,
63}
64
65/// Complete generation plan with all tasks
66#[derive(Debug)]
67pub struct GenerationPlan {
68    pub tasks: Vec<GenerationTask>,
69}
70
71/// Plans code generation based on conventions and templates
72pub struct GenerationPlanner {
73    conventions: ProjectConventions,
74}
75
76impl GenerationPlanner {
77    /// Create a new generation planner
78    pub fn new(conventions: ProjectConventions) -> Self {
79        Self { conventions }
80    }
81
82    /// Create a generation plan by analyzing all templates
83    pub fn plan(&self) -> Result<GenerationPlan> {
84        let mut tasks = Vec::new();
85        let mut task_graph: HashMap<String, Vec<String>> = HashMap::new();
86
87        // Iterate through all discovered templates
88        for (template_name, template_path) in &self.conventions.templates {
89            let metadata = self.parse_template_metadata(template_path)?;
90            let trigger_files = self.resolve_dependencies(&metadata);
91
92            // Track dependencies for circular detection
93            let deps: Vec<String> = trigger_files
94                .iter()
95                .filter_map(|p| {
96                    p.file_stem()
97                        .and_then(|s| s.to_str())
98                        .map(|s| s.to_string())
99                })
100                .collect();
101
102            task_graph.insert(template_name.clone(), deps);
103
104            tasks.push(GenerationTask {
105                template: template_name.clone(),
106                output_pattern: metadata.output,
107                trigger_files,
108                query: Some(metadata.query),
109                foreach: metadata.foreach,
110            });
111        }
112
113        // Check for circular dependencies
114        self.check_circular_dependencies(&task_graph)?;
115
116        // Sort tasks by dependencies (topological sort)
117        tasks = self.topological_sort(tasks)?;
118
119        Ok(GenerationPlan { tasks })
120    }
121
122    /// Parse template metadata from a template file
123    fn parse_template_metadata(&self, path: &Path) -> Result<TemplateMetadata> {
124        let content = fs::read_to_string(path)?;
125        TemplateMetadata::parse(&content)
126    }
127
128    /// Resolve file dependencies from template metadata
129    fn resolve_dependencies(&self, metadata: &TemplateMetadata) -> Vec<PathBuf> {
130        metadata
131            .when
132            .iter()
133            .map(|pattern| {
134                // Simple glob pattern resolution
135                // In a real implementation, this would use proper glob matching
136                PathBuf::from(pattern)
137            })
138            .collect()
139    }
140
141    /// Check for circular dependencies in the task graph
142    fn check_circular_dependencies(&self, graph: &HashMap<String, Vec<String>>) -> Result<()> {
143        for (task, _) in graph {
144            let mut visited = HashSet::new();
145            let mut rec_stack = HashSet::new();
146
147            if self.has_cycle(task, graph, &mut visited, &mut rec_stack) {
148                return Err(anyhow!("Circular dependency detected involving task: {}", task));
149            }
150        }
151
152        Ok(())
153    }
154
155    /// DFS-based cycle detection
156    fn has_cycle(
157        &self,
158        task: &str,
159        graph: &HashMap<String, Vec<String>>,
160        visited: &mut HashSet<String>,
161        rec_stack: &mut HashSet<String>,
162    ) -> bool {
163        if rec_stack.contains(task) {
164            return true;
165        }
166
167        if visited.contains(task) {
168            return false;
169        }
170
171        visited.insert(task.to_string());
172        rec_stack.insert(task.to_string());
173
174        if let Some(deps) = graph.get(task) {
175            for dep in deps {
176                if self.has_cycle(dep, graph, visited, rec_stack) {
177                    return true;
178                }
179            }
180        }
181
182        rec_stack.remove(task);
183        false
184    }
185
186    /// Topologically sort tasks by dependencies
187    fn topological_sort(&self, mut tasks: Vec<GenerationTask>) -> Result<Vec<GenerationTask>> {
188        // Build dependency map
189        let mut dep_count: HashMap<String, usize> = HashMap::new();
190        let mut graph: HashMap<String, Vec<String>> = HashMap::new();
191
192        for task in &tasks {
193            dep_count.entry(task.template.clone()).or_insert(0);
194
195            for trigger in &task.trigger_files {
196                if let Some(dep) = trigger.file_stem().and_then(|s| s.to_str()) {
197                    graph
198                        .entry(dep.to_string())
199                        .or_default()
200                        .push(task.template.clone());
201                    *dep_count.entry(task.template.clone()).or_insert(0) += 1;
202                }
203            }
204        }
205
206        // Find tasks with no dependencies
207        let mut ready: Vec<String> = dep_count
208            .iter()
209            .filter(|(_, &count)| count == 0)
210            .map(|(name, _)| name.clone())
211            .collect();
212
213        let mut sorted = Vec::new();
214
215        while let Some(task_name) = ready.pop() {
216            // Find and add the task
217            if let Some(pos) = tasks.iter().position(|t| t.template == task_name) {
218                let task = tasks.remove(pos);
219                sorted.push(task);
220            }
221
222            // Update dependent tasks
223            if let Some(dependents) = graph.get(&task_name) {
224                for dependent in dependents {
225                    if let Some(count) = dep_count.get_mut(dependent) {
226                        *count = count.saturating_sub(1);
227                        if *count == 0 {
228                            ready.push(dependent.clone());
229                        }
230                    }
231                }
232            }
233        }
234
235        // If there are remaining tasks, there's a cycle
236        if !tasks.is_empty() {
237            return Err(anyhow!("Circular dependency detected during topological sort"));
238        }
239
240        Ok(sorted)
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use std::fs;
248    use tempfile::TempDir;
249
250    fn create_test_template(dir: &Path, name: &str, content: &str) -> PathBuf {
251        let path = dir.join(name);
252        fs::write(&path, content).unwrap();
253        path
254    }
255
256    fn create_test_conventions(temp_dir: &Path, templates: Vec<(&str, &str)>) -> ProjectConventions {
257        let template_dir = temp_dir.join("templates");
258        fs::create_dir_all(&template_dir).unwrap();
259
260        let mut template_map = HashMap::new();
261        for (name, content) in templates {
262            let full_name = if name.ends_with(".tmpl") {
263                name.to_string()
264            } else {
265                format!("{}.tmpl", name)
266            };
267            let path = create_test_template(&template_dir, &full_name, content);
268            let key = name.strip_suffix(".tmpl").unwrap_or(name).to_string();
269            template_map.insert(key, path);
270        }
271
272        ProjectConventions {
273            rdf_files: vec![],
274            rdf_dir: temp_dir.join("domain"),
275            templates: template_map,
276            templates_dir: template_dir,
277            queries: HashMap::new(),
278            output_dir: temp_dir.join("generated"),
279            preset: "test".to_string(),
280        }
281    }
282
283    #[test]
284    fn test_parse_template_metadata_basic() {
285        let content = r#"
286{# output: src/{{name}}.rs #}
287{# when: src/models/*.rs #}
288{# query: SELECT * FROM models #}
289
290template content here
291"#;
292
293        let metadata = TemplateMetadata::parse(content).unwrap();
294        assert_eq!(metadata.output, "src/{{name}}.rs");
295        assert_eq!(metadata.when, vec!["src/models/*.rs"]);
296        assert_eq!(metadata.query, "SELECT * FROM models");
297        assert_eq!(metadata.foreach, None);
298    }
299
300    #[test]
301    fn test_parse_template_metadata_with_foreach() {
302        let content = r#"
303{# output: tests/{{item}}_test.rs #}
304{# when: src/{{item}}.rs #}
305{# query: SELECT name FROM entities #}
306{# foreach: entity #}
307
308test template
309"#;
310
311        let metadata = TemplateMetadata::parse(content).unwrap();
312        assert_eq!(metadata.output, "tests/{{item}}_test.rs");
313        assert_eq!(metadata.foreach, Some("entity".to_string()));
314    }
315
316    #[test]
317    fn test_parse_template_metadata_multiple_when() {
318        let content = r#"
319{# output: generated.rs #}
320{# when: file1.rs #}
321{# when: file2.rs #}
322{# query: SELECT * #}
323"#;
324
325        let metadata = TemplateMetadata::parse(content).unwrap();
326        assert_eq!(metadata.when.len(), 2);
327        assert!(metadata.when.contains(&"file1.rs".to_string()));
328        assert!(metadata.when.contains(&"file2.rs".to_string()));
329    }
330
331    #[test]
332    fn test_parse_template_metadata_missing_output() {
333        let content = r#"
334{# query: SELECT * #}
335"#;
336
337        let result = TemplateMetadata::parse(content);
338        assert!(result.is_err());
339        assert!(result.unwrap_err().to_string().contains("Missing 'output'"));
340    }
341
342    #[test]
343    fn test_parse_template_metadata_missing_query() {
344        let content = r#"
345{# output: file.rs #}
346"#;
347
348        let result = TemplateMetadata::parse(content);
349        assert!(result.is_err());
350        assert!(result.unwrap_err().to_string().contains("Missing 'query'"));
351    }
352
353    #[test]
354    fn test_generation_planner_empty() {
355        let temp_dir = TempDir::new().unwrap();
356        let conventions = create_test_conventions(temp_dir.path(), vec![]);
357
358        let planner = GenerationPlanner::new(conventions);
359        let plan = planner.plan().unwrap();
360
361        assert_eq!(plan.tasks.len(), 0);
362    }
363
364    #[test]
365    fn test_generation_planner_single_task() {
366        let temp_dir = TempDir::new().unwrap();
367        let conventions = create_test_conventions(
368            temp_dir.path(),
369            vec![(
370                "test",
371                r#"
372{# output: generated.rs #}
373{# query: SELECT * #}
374
375content
376"#,
377            )],
378        );
379
380        let planner = GenerationPlanner::new(conventions);
381        let plan = planner.plan().unwrap();
382
383        assert_eq!(plan.tasks.len(), 1);
384        assert_eq!(plan.tasks[0].template, "test");
385        assert_eq!(plan.tasks[0].output_pattern, "generated.rs");
386    }
387
388    #[test]
389    fn test_generation_planner_multiple_tasks() {
390        let temp_dir = TempDir::new().unwrap();
391        let conventions = create_test_conventions(
392            temp_dir.path(),
393            vec![
394                (
395                    "task1",
396                    r#"
397{# output: out1.rs #}
398{# query: SELECT * FROM table1 #}
399"#,
400                ),
401                (
402                    "task2",
403                    r#"
404{# output: out2.rs #}
405{# query: SELECT * FROM table2 #}
406"#,
407                ),
408            ],
409        );
410
411        let planner = GenerationPlanner::new(conventions);
412        let plan = planner.plan().unwrap();
413
414        assert_eq!(plan.tasks.len(), 2);
415    }
416
417    #[test]
418    fn test_generation_planner_with_dependencies() {
419        let temp_dir = TempDir::new().unwrap();
420        let conventions = create_test_conventions(
421            temp_dir.path(),
422            vec![
423                (
424                    "base",
425                    r#"
426{# output: base.rs #}
427{# query: SELECT * FROM base #}
428"#,
429                ),
430                (
431                    "derived",
432                    r#"
433{# output: derived.rs #}
434{# when: base.rs #}
435{# query: SELECT * FROM derived #}
436"#,
437                ),
438            ],
439        );
440
441        let planner = GenerationPlanner::new(conventions);
442        let plan = planner.plan().unwrap();
443
444        assert_eq!(plan.tasks.len(), 2);
445
446        // Base task should come before derived
447        let base_idx = plan.tasks.iter().position(|t| t.template == "base").unwrap();
448        let derived_idx = plan.tasks.iter().position(|t| t.template == "derived").unwrap();
449        assert!(base_idx < derived_idx);
450    }
451
452    #[test]
453    fn test_circular_dependency_detection() {
454        let temp_dir = TempDir::new().unwrap();
455        let conventions = create_test_conventions(
456            temp_dir.path(),
457            vec![
458                (
459                    "task1",
460                    r#"
461{# output: task1.rs #}
462{# when: task2.rs #}
463{# query: SELECT * #}
464"#,
465                ),
466                (
467                    "task2",
468                    r#"
469{# output: task2.rs #}
470{# when: task1.rs #}
471{# query: SELECT * #}
472"#,
473                ),
474            ],
475        );
476
477        let planner = GenerationPlanner::new(conventions);
478        let result = planner.plan();
479
480        assert!(result.is_err());
481        let err_msg = result.unwrap_err().to_string();
482        assert!(
483            err_msg.contains("Circular dependency") || err_msg.contains("cycle"),
484            "Expected circular dependency error, got: {}",
485            err_msg
486        );
487    }
488
489    #[test]
490    fn test_foreach_pattern() {
491        let temp_dir = TempDir::new().unwrap();
492        let conventions = create_test_conventions(
493            temp_dir.path(),
494            vec![(
495                "test",
496                r#"
497{# output: tests/{{entity}}_test.rs #}
498{# query: SELECT name FROM entities #}
499{# foreach: entity #}
500"#,
501            )],
502        );
503
504        let planner = GenerationPlanner::new(conventions);
505        let plan = planner.plan().unwrap();
506
507        assert_eq!(plan.tasks.len(), 1);
508        assert_eq!(plan.tasks[0].foreach, Some("entity".to_string()));
509        assert_eq!(plan.tasks[0].output_pattern, "tests/{{entity}}_test.rs");
510    }
511
512    #[test]
513    fn test_once_pattern() {
514        let temp_dir = TempDir::new().unwrap();
515        let conventions = create_test_conventions(
516            temp_dir.path(),
517            vec![(
518                "once",
519                r#"
520{# output: single_file.rs #}
521{# query: SELECT COUNT(*) #}
522"#,
523            )],
524        );
525
526        let planner = GenerationPlanner::new(conventions);
527        let plan = planner.plan().unwrap();
528
529        assert_eq!(plan.tasks.len(), 1);
530        assert_eq!(plan.tasks[0].foreach, None);
531        assert_eq!(plan.tasks[0].output_pattern, "single_file.rs");
532    }
533
534    #[test]
535    fn test_complex_dependency_graph() {
536        let temp_dir = TempDir::new().unwrap();
537        let conventions = create_test_conventions(
538            temp_dir.path(),
539            vec![
540                (
541                    "models",
542                    r#"
543{# output: models.rs #}
544{# query: SELECT * FROM schema #}
545"#,
546                ),
547                (
548                    "services",
549                    r#"
550{# output: services.rs #}
551{# when: models.rs #}
552{# query: SELECT * FROM services #}
553"#,
554                ),
555                (
556                    "controllers",
557                    r#"
558{# output: controllers.rs #}
559{# when: services.rs #}
560{# query: SELECT * FROM controllers #}
561"#,
562                ),
563                (
564                    "routes",
565                    r#"
566{# output: routes.rs #}
567{# when: controllers.rs #}
568{# query: SELECT * FROM routes #}
569"#,
570                ),
571            ],
572        );
573
574        let planner = GenerationPlanner::new(conventions);
575        let plan = planner.plan().unwrap();
576
577        assert_eq!(plan.tasks.len(), 4);
578
579        // Verify ordering
580        let names: Vec<_> = plan.tasks.iter().map(|t| t.template.as_str()).collect();
581        let models_idx = names.iter().position(|&n| n == "models").unwrap();
582        let services_idx = names.iter().position(|&n| n == "services").unwrap();
583        let controllers_idx = names.iter().position(|&n| n == "controllers").unwrap();
584        let routes_idx = names.iter().position(|&n| n == "routes").unwrap();
585
586        assert!(models_idx < services_idx);
587        assert!(services_idx < controllers_idx);
588        assert!(controllers_idx < routes_idx);
589    }
590}