ggen_cli_lib/conventions/
planner.rs1use anyhow::{anyhow, Result};
4use std::path::{Path, PathBuf};
5use std::collections::{HashMap, HashSet};
6use std::fs;
7
8use super::ProjectConventions;
9
10#[derive(Debug, Clone, PartialEq)]
12pub struct TemplateMetadata {
13 pub output: String,
14 pub when: Vec<String>,
15 pub query: String,
16 pub foreach: Option<String>,
17}
18
19impl TemplateMetadata {
20 pub fn parse(content: &str) -> Result<Self> {
22 let mut output = None;
23 let mut when = Vec::new();
24 let mut query = None;
25 let mut foreach = None;
26
27 for line in content.lines() {
29 let line = line.trim();
30
31 if line.starts_with("{#") && line.ends_with("#}") {
32 let inner = line[2..line.len() - 2].trim();
33
34 if let Some(value) = inner.strip_prefix("output:") {
35 output = Some(value.trim().to_string());
36 } else if let Some(value) = inner.strip_prefix("when:") {
37 when.push(value.trim().to_string());
38 } else if let Some(value) = inner.strip_prefix("query:") {
39 query = Some(value.trim().to_string());
40 } else if let Some(value) = inner.strip_prefix("foreach:") {
41 foreach = Some(value.trim().to_string());
42 }
43 }
44 }
45
46 Ok(TemplateMetadata {
47 output: output.ok_or_else(|| anyhow!("Missing 'output' directive"))?,
48 when,
49 query: query.ok_or_else(|| anyhow!("Missing 'query' directive"))?,
50 foreach,
51 })
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct GenerationTask {
58 pub template: String,
59 pub output_pattern: String,
60 pub trigger_files: Vec<PathBuf>,
61 pub query: Option<String>,
62 pub foreach: Option<String>,
63}
64
65#[derive(Debug)]
67pub struct GenerationPlan {
68 pub tasks: Vec<GenerationTask>,
69}
70
71pub struct GenerationPlanner {
73 conventions: ProjectConventions,
74}
75
76impl GenerationPlanner {
77 pub fn new(conventions: ProjectConventions) -> Self {
79 Self { conventions }
80 }
81
82 pub fn plan(&self) -> Result<GenerationPlan> {
84 let mut tasks = Vec::new();
85 let mut task_graph: HashMap<String, Vec<String>> = HashMap::new();
86
87 for (template_name, template_path) in &self.conventions.templates {
89 let metadata = self.parse_template_metadata(template_path)?;
90 let trigger_files = self.resolve_dependencies(&metadata);
91
92 let deps: Vec<String> = trigger_files
94 .iter()
95 .filter_map(|p| {
96 p.file_stem()
97 .and_then(|s| s.to_str())
98 .map(|s| s.to_string())
99 })
100 .collect();
101
102 task_graph.insert(template_name.clone(), deps);
103
104 tasks.push(GenerationTask {
105 template: template_name.clone(),
106 output_pattern: metadata.output,
107 trigger_files,
108 query: Some(metadata.query),
109 foreach: metadata.foreach,
110 });
111 }
112
113 self.check_circular_dependencies(&task_graph)?;
115
116 tasks = self.topological_sort(tasks)?;
118
119 Ok(GenerationPlan { tasks })
120 }
121
122 fn parse_template_metadata(&self, path: &Path) -> Result<TemplateMetadata> {
124 let content = fs::read_to_string(path)?;
125 TemplateMetadata::parse(&content)
126 }
127
128 fn resolve_dependencies(&self, metadata: &TemplateMetadata) -> Vec<PathBuf> {
130 metadata
131 .when
132 .iter()
133 .map(|pattern| {
134 PathBuf::from(pattern)
137 })
138 .collect()
139 }
140
141 fn check_circular_dependencies(&self, graph: &HashMap<String, Vec<String>>) -> Result<()> {
143 for (task, _) in graph {
144 let mut visited = HashSet::new();
145 let mut rec_stack = HashSet::new();
146
147 if self.has_cycle(task, graph, &mut visited, &mut rec_stack) {
148 return Err(anyhow!("Circular dependency detected involving task: {}", task));
149 }
150 }
151
152 Ok(())
153 }
154
155 fn has_cycle(
157 &self,
158 task: &str,
159 graph: &HashMap<String, Vec<String>>,
160 visited: &mut HashSet<String>,
161 rec_stack: &mut HashSet<String>,
162 ) -> bool {
163 if rec_stack.contains(task) {
164 return true;
165 }
166
167 if visited.contains(task) {
168 return false;
169 }
170
171 visited.insert(task.to_string());
172 rec_stack.insert(task.to_string());
173
174 if let Some(deps) = graph.get(task) {
175 for dep in deps {
176 if self.has_cycle(dep, graph, visited, rec_stack) {
177 return true;
178 }
179 }
180 }
181
182 rec_stack.remove(task);
183 false
184 }
185
186 fn topological_sort(&self, mut tasks: Vec<GenerationTask>) -> Result<Vec<GenerationTask>> {
188 let mut dep_count: HashMap<String, usize> = HashMap::new();
190 let mut graph: HashMap<String, Vec<String>> = HashMap::new();
191
192 for task in &tasks {
193 dep_count.entry(task.template.clone()).or_insert(0);
194
195 for trigger in &task.trigger_files {
196 if let Some(dep) = trigger.file_stem().and_then(|s| s.to_str()) {
197 graph
198 .entry(dep.to_string())
199 .or_default()
200 .push(task.template.clone());
201 *dep_count.entry(task.template.clone()).or_insert(0) += 1;
202 }
203 }
204 }
205
206 let mut ready: Vec<String> = dep_count
208 .iter()
209 .filter(|(_, &count)| count == 0)
210 .map(|(name, _)| name.clone())
211 .collect();
212
213 let mut sorted = Vec::new();
214
215 while let Some(task_name) = ready.pop() {
216 if let Some(pos) = tasks.iter().position(|t| t.template == task_name) {
218 let task = tasks.remove(pos);
219 sorted.push(task);
220 }
221
222 if let Some(dependents) = graph.get(&task_name) {
224 for dependent in dependents {
225 if let Some(count) = dep_count.get_mut(dependent) {
226 *count = count.saturating_sub(1);
227 if *count == 0 {
228 ready.push(dependent.clone());
229 }
230 }
231 }
232 }
233 }
234
235 if !tasks.is_empty() {
237 return Err(anyhow!("Circular dependency detected during topological sort"));
238 }
239
240 Ok(sorted)
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use std::fs;
248 use tempfile::TempDir;
249
250 fn create_test_template(dir: &Path, name: &str, content: &str) -> PathBuf {
251 let path = dir.join(name);
252 fs::write(&path, content).unwrap();
253 path
254 }
255
256 fn create_test_conventions(temp_dir: &Path, templates: Vec<(&str, &str)>) -> ProjectConventions {
257 let template_dir = temp_dir.join("templates");
258 fs::create_dir_all(&template_dir).unwrap();
259
260 let mut template_map = HashMap::new();
261 for (name, content) in templates {
262 let full_name = if name.ends_with(".tmpl") {
263 name.to_string()
264 } else {
265 format!("{}.tmpl", name)
266 };
267 let path = create_test_template(&template_dir, &full_name, content);
268 let key = name.strip_suffix(".tmpl").unwrap_or(name).to_string();
269 template_map.insert(key, path);
270 }
271
272 ProjectConventions {
273 rdf_files: vec![],
274 rdf_dir: temp_dir.join("domain"),
275 templates: template_map,
276 templates_dir: template_dir,
277 queries: HashMap::new(),
278 output_dir: temp_dir.join("generated"),
279 preset: "test".to_string(),
280 }
281 }
282
283 #[test]
284 fn test_parse_template_metadata_basic() {
285 let content = r#"
286{# output: src/{{name}}.rs #}
287{# when: src/models/*.rs #}
288{# query: SELECT * FROM models #}
289
290template content here
291"#;
292
293 let metadata = TemplateMetadata::parse(content).unwrap();
294 assert_eq!(metadata.output, "src/{{name}}.rs");
295 assert_eq!(metadata.when, vec!["src/models/*.rs"]);
296 assert_eq!(metadata.query, "SELECT * FROM models");
297 assert_eq!(metadata.foreach, None);
298 }
299
300 #[test]
301 fn test_parse_template_metadata_with_foreach() {
302 let content = r#"
303{# output: tests/{{item}}_test.rs #}
304{# when: src/{{item}}.rs #}
305{# query: SELECT name FROM entities #}
306{# foreach: entity #}
307
308test template
309"#;
310
311 let metadata = TemplateMetadata::parse(content).unwrap();
312 assert_eq!(metadata.output, "tests/{{item}}_test.rs");
313 assert_eq!(metadata.foreach, Some("entity".to_string()));
314 }
315
316 #[test]
317 fn test_parse_template_metadata_multiple_when() {
318 let content = r#"
319{# output: generated.rs #}
320{# when: file1.rs #}
321{# when: file2.rs #}
322{# query: SELECT * #}
323"#;
324
325 let metadata = TemplateMetadata::parse(content).unwrap();
326 assert_eq!(metadata.when.len(), 2);
327 assert!(metadata.when.contains(&"file1.rs".to_string()));
328 assert!(metadata.when.contains(&"file2.rs".to_string()));
329 }
330
331 #[test]
332 fn test_parse_template_metadata_missing_output() {
333 let content = r#"
334{# query: SELECT * #}
335"#;
336
337 let result = TemplateMetadata::parse(content);
338 assert!(result.is_err());
339 assert!(result.unwrap_err().to_string().contains("Missing 'output'"));
340 }
341
342 #[test]
343 fn test_parse_template_metadata_missing_query() {
344 let content = r#"
345{# output: file.rs #}
346"#;
347
348 let result = TemplateMetadata::parse(content);
349 assert!(result.is_err());
350 assert!(result.unwrap_err().to_string().contains("Missing 'query'"));
351 }
352
353 #[test]
354 fn test_generation_planner_empty() {
355 let temp_dir = TempDir::new().unwrap();
356 let conventions = create_test_conventions(temp_dir.path(), vec![]);
357
358 let planner = GenerationPlanner::new(conventions);
359 let plan = planner.plan().unwrap();
360
361 assert_eq!(plan.tasks.len(), 0);
362 }
363
364 #[test]
365 fn test_generation_planner_single_task() {
366 let temp_dir = TempDir::new().unwrap();
367 let conventions = create_test_conventions(
368 temp_dir.path(),
369 vec![(
370 "test",
371 r#"
372{# output: generated.rs #}
373{# query: SELECT * #}
374
375content
376"#,
377 )],
378 );
379
380 let planner = GenerationPlanner::new(conventions);
381 let plan = planner.plan().unwrap();
382
383 assert_eq!(plan.tasks.len(), 1);
384 assert_eq!(plan.tasks[0].template, "test");
385 assert_eq!(plan.tasks[0].output_pattern, "generated.rs");
386 }
387
388 #[test]
389 fn test_generation_planner_multiple_tasks() {
390 let temp_dir = TempDir::new().unwrap();
391 let conventions = create_test_conventions(
392 temp_dir.path(),
393 vec![
394 (
395 "task1",
396 r#"
397{# output: out1.rs #}
398{# query: SELECT * FROM table1 #}
399"#,
400 ),
401 (
402 "task2",
403 r#"
404{# output: out2.rs #}
405{# query: SELECT * FROM table2 #}
406"#,
407 ),
408 ],
409 );
410
411 let planner = GenerationPlanner::new(conventions);
412 let plan = planner.plan().unwrap();
413
414 assert_eq!(plan.tasks.len(), 2);
415 }
416
417 #[test]
418 fn test_generation_planner_with_dependencies() {
419 let temp_dir = TempDir::new().unwrap();
420 let conventions = create_test_conventions(
421 temp_dir.path(),
422 vec![
423 (
424 "base",
425 r#"
426{# output: base.rs #}
427{# query: SELECT * FROM base #}
428"#,
429 ),
430 (
431 "derived",
432 r#"
433{# output: derived.rs #}
434{# when: base.rs #}
435{# query: SELECT * FROM derived #}
436"#,
437 ),
438 ],
439 );
440
441 let planner = GenerationPlanner::new(conventions);
442 let plan = planner.plan().unwrap();
443
444 assert_eq!(plan.tasks.len(), 2);
445
446 let base_idx = plan.tasks.iter().position(|t| t.template == "base").unwrap();
448 let derived_idx = plan.tasks.iter().position(|t| t.template == "derived").unwrap();
449 assert!(base_idx < derived_idx);
450 }
451
452 #[test]
453 fn test_circular_dependency_detection() {
454 let temp_dir = TempDir::new().unwrap();
455 let conventions = create_test_conventions(
456 temp_dir.path(),
457 vec![
458 (
459 "task1",
460 r#"
461{# output: task1.rs #}
462{# when: task2.rs #}
463{# query: SELECT * #}
464"#,
465 ),
466 (
467 "task2",
468 r#"
469{# output: task2.rs #}
470{# when: task1.rs #}
471{# query: SELECT * #}
472"#,
473 ),
474 ],
475 );
476
477 let planner = GenerationPlanner::new(conventions);
478 let result = planner.plan();
479
480 assert!(result.is_err());
481 let err_msg = result.unwrap_err().to_string();
482 assert!(
483 err_msg.contains("Circular dependency") || err_msg.contains("cycle"),
484 "Expected circular dependency error, got: {}",
485 err_msg
486 );
487 }
488
489 #[test]
490 fn test_foreach_pattern() {
491 let temp_dir = TempDir::new().unwrap();
492 let conventions = create_test_conventions(
493 temp_dir.path(),
494 vec![(
495 "test",
496 r#"
497{# output: tests/{{entity}}_test.rs #}
498{# query: SELECT name FROM entities #}
499{# foreach: entity #}
500"#,
501 )],
502 );
503
504 let planner = GenerationPlanner::new(conventions);
505 let plan = planner.plan().unwrap();
506
507 assert_eq!(plan.tasks.len(), 1);
508 assert_eq!(plan.tasks[0].foreach, Some("entity".to_string()));
509 assert_eq!(plan.tasks[0].output_pattern, "tests/{{entity}}_test.rs");
510 }
511
512 #[test]
513 fn test_once_pattern() {
514 let temp_dir = TempDir::new().unwrap();
515 let conventions = create_test_conventions(
516 temp_dir.path(),
517 vec![(
518 "once",
519 r#"
520{# output: single_file.rs #}
521{# query: SELECT COUNT(*) #}
522"#,
523 )],
524 );
525
526 let planner = GenerationPlanner::new(conventions);
527 let plan = planner.plan().unwrap();
528
529 assert_eq!(plan.tasks.len(), 1);
530 assert_eq!(plan.tasks[0].foreach, None);
531 assert_eq!(plan.tasks[0].output_pattern, "single_file.rs");
532 }
533
534 #[test]
535 fn test_complex_dependency_graph() {
536 let temp_dir = TempDir::new().unwrap();
537 let conventions = create_test_conventions(
538 temp_dir.path(),
539 vec![
540 (
541 "models",
542 r#"
543{# output: models.rs #}
544{# query: SELECT * FROM schema #}
545"#,
546 ),
547 (
548 "services",
549 r#"
550{# output: services.rs #}
551{# when: models.rs #}
552{# query: SELECT * FROM services #}
553"#,
554 ),
555 (
556 "controllers",
557 r#"
558{# output: controllers.rs #}
559{# when: services.rs #}
560{# query: SELECT * FROM controllers #}
561"#,
562 ),
563 (
564 "routes",
565 r#"
566{# output: routes.rs #}
567{# when: controllers.rs #}
568{# query: SELECT * FROM routes #}
569"#,
570 ),
571 ],
572 );
573
574 let planner = GenerationPlanner::new(conventions);
575 let plan = planner.plan().unwrap();
576
577 assert_eq!(plan.tasks.len(), 4);
578
579 let names: Vec<_> = plan.tasks.iter().map(|t| t.template.as_str()).collect();
581 let models_idx = names.iter().position(|&n| n == "models").unwrap();
582 let services_idx = names.iter().position(|&n| n == "services").unwrap();
583 let controllers_idx = names.iter().position(|&n| n == "controllers").unwrap();
584 let routes_idx = names.iter().position(|&n| n == "routes").unwrap();
585
586 assert!(models_idx < services_idx);
587 assert!(services_idx < controllers_idx);
588 assert!(controllers_idx < routes_idx);
589 }
590}