ggen_cli_lib/conventions/
planner.rs1use ggen_core::utils::error::Result;
4use std::collections::{HashMap, HashSet};
5use std::fs;
6use std::path::{Path, PathBuf};
7
8use super::ProjectConventions;
9
10#[allow(clippy::derive_partial_eq_without_eq)]
13#[derive(Debug, Clone, PartialEq)]
14pub struct TemplateMetadata {
15 pub output: String,
16 pub when: Vec<String>,
17 pub query: String,
18 pub foreach: Option<String>,
19}
20
21impl TemplateMetadata {
22 pub fn parse(content: &str) -> Result<Self> {
24 let mut output = None;
25 let mut when = Vec::new();
26 let mut query = None;
27 let mut foreach = None;
28
29 for line in content.lines() {
31 let line = line.trim();
32
33 if line.starts_with("{#") && line.ends_with("#}") {
34 let inner = line[2..line.len() - 2].trim();
35
36 if let Some(value) = inner.strip_prefix("output:") {
37 output = Some(value.trim().to_string());
38 } else if let Some(value) = inner.strip_prefix("when:") {
39 when.push(value.trim().to_string());
40 } else if let Some(value) = inner.strip_prefix("query:") {
41 query = Some(value.trim().to_string());
42 } else if let Some(value) = inner.strip_prefix("foreach:") {
43 foreach = Some(value.trim().to_string());
44 }
45 }
46 }
47
48 Ok(TemplateMetadata {
49 output: output
50 .ok_or_else(|| ggen_core::utils::error::Error::new("Missing 'output' directive"))?,
51 when,
52 query: query
53 .ok_or_else(|| ggen_core::utils::error::Error::new("Missing 'query' directive"))?,
54 foreach,
55 })
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct GenerationTask {
62 pub template: String,
63 pub output_pattern: String,
64 pub trigger_files: Vec<PathBuf>,
65 pub query: Option<String>,
66 pub foreach: Option<String>,
67}
68
69#[derive(Debug)]
71pub struct GenerationPlan {
72 pub tasks: Vec<GenerationTask>,
73}
74
75pub struct GenerationPlanner {
77 conventions: ProjectConventions,
78}
79
80impl GenerationPlanner {
81 pub fn new(conventions: ProjectConventions) -> Self {
83 Self { conventions }
84 }
85
86 pub fn plan(&self) -> Result<GenerationPlan> {
88 let mut tasks = Vec::new();
89 let mut task_graph: HashMap<String, Vec<String>> = HashMap::new();
90
91 for (template_name, template_path) in &self.conventions.templates {
93 let metadata = self.parse_template_metadata(template_path)?;
94 let trigger_files = self.resolve_dependencies(&metadata);
95
96 let deps: Vec<String> = trigger_files
98 .iter()
99 .filter_map(|p| {
100 p.file_stem()
101 .and_then(|s| s.to_str())
102 .map(|s| s.to_string())
103 })
104 .collect();
105
106 task_graph.insert(template_name.clone(), deps);
107
108 tasks.push(GenerationTask {
109 template: template_name.clone(),
110 output_pattern: metadata.output,
111 trigger_files,
112 query: Some(metadata.query),
113 foreach: metadata.foreach,
114 });
115 }
116
117 self.check_circular_dependencies(&task_graph)?;
119
120 tasks = self.topological_sort(tasks)?;
122
123 Ok(GenerationPlan { tasks })
124 }
125
126 fn parse_template_metadata(&self, path: &Path) -> Result<TemplateMetadata> {
128 let content = fs::read_to_string(path)?;
129 TemplateMetadata::parse(&content)
130 }
131
132 fn resolve_dependencies(&self, metadata: &TemplateMetadata) -> Vec<PathBuf> {
134 metadata
135 .when
136 .iter()
137 .map(|pattern| {
138 PathBuf::from(pattern)
141 })
142 .collect()
143 }
144
145 fn check_circular_dependencies(&self, graph: &HashMap<String, Vec<String>>) -> Result<()> {
147 for task in graph.keys() {
148 let mut visited = HashSet::new();
149 let mut rec_stack = HashSet::new();
150
151 if self.has_cycle(task, graph, &mut visited, &mut rec_stack) {
152 return Err(ggen_core::utils::error::Error::new(&format!(
153 "Circular dependency detected involving task: {}",
154 task
155 )));
156 }
157 }
158
159 Ok(())
160 }
161
162 #[allow(clippy::self_only_used_in_recursion)] fn has_cycle(
165 &self, task: &str, graph: &HashMap<String, Vec<String>>, visited: &mut HashSet<String>,
166 rec_stack: &mut HashSet<String>,
167 ) -> bool {
168 if rec_stack.contains(task) {
169 return true;
170 }
171
172 if visited.contains(task) {
173 return false;
174 }
175
176 visited.insert(task.to_string());
177 rec_stack.insert(task.to_string());
178
179 if let Some(deps) = graph.get(task) {
180 for dep in deps {
181 if self.has_cycle(dep, graph, visited, rec_stack) {
182 return true;
183 }
184 }
185 }
186
187 rec_stack.remove(task);
188 false
189 }
190
191 fn topological_sort(&self, mut tasks: Vec<GenerationTask>) -> Result<Vec<GenerationTask>> {
193 let mut dep_count: HashMap<String, usize> = HashMap::new();
195 let mut graph: HashMap<String, Vec<String>> = HashMap::new();
196
197 for task in &tasks {
198 dep_count.entry(task.template.clone()).or_insert(0);
199
200 for trigger in &task.trigger_files {
201 if let Some(dep) = trigger.file_stem().and_then(|s| s.to_str()) {
202 graph
203 .entry(dep.to_string())
204 .or_default()
205 .push(task.template.clone());
206 *dep_count.entry(task.template.clone()).or_insert(0) += 1;
207 }
208 }
209 }
210
211 let mut ready: Vec<String> = dep_count
213 .iter()
214 .filter(|(_, &count)| count == 0)
215 .map(|(name, _)| name.clone())
216 .collect();
217
218 let mut sorted = Vec::new();
219
220 while let Some(task_name) = ready.pop() {
221 if let Some(pos) = tasks.iter().position(|t| t.template == task_name) {
223 let task = tasks.remove(pos);
224 sorted.push(task);
225 }
226
227 if let Some(dependents) = graph.get(&task_name) {
229 for dependent in dependents {
230 if let Some(count) = dep_count.get_mut(dependent) {
231 *count = count.saturating_sub(1);
232 if *count == 0 {
233 ready.push(dependent.clone());
234 }
235 }
236 }
237 }
238 }
239
240 if !tasks.is_empty() {
242 return Err(ggen_core::utils::error::Error::new(
243 "Circular dependency detected during topological sort",
244 ));
245 }
246
247 Ok(sorted)
248 }
249}
250
251#[cfg(test)]
252#[allow(clippy::unwrap_used, clippy::needless_raw_string_hashes)]
253mod tests {
254 use super::*;
255 use std::fs;
256 use tempfile::TempDir;
257
258 fn create_test_template(dir: &Path, name: &str, content: &str) -> PathBuf {
259 let path = dir.join(name);
260 fs::write(&path, content).unwrap();
261 path
262 }
263
264 fn create_test_conventions(
265 temp_dir: &Path, templates: Vec<(&str, &str)>,
266 ) -> ProjectConventions {
267 let template_dir = temp_dir.join("templates");
268 fs::create_dir_all(&template_dir).unwrap();
269
270 let mut template_map = HashMap::new();
271 for (name, content) in templates {
272 let full_name = if name.ends_with(".tmpl") {
273 name.to_string()
274 } else {
275 format!("{}.tmpl", name)
276 };
277 let path = create_test_template(&template_dir, &full_name, content);
278 let key = name.strip_suffix(".tmpl").unwrap_or(name).to_string();
279 template_map.insert(key, path);
280 }
281
282 ProjectConventions {
283 rdf_files: vec![],
284 rdf_dir: temp_dir.join("domain"),
285 templates: template_map,
286 templates_dir: template_dir,
287 queries: HashMap::new(),
288 output_dir: temp_dir.to_path_buf(),
289 preset: "test".to_string(),
290 }
291 }
292
293 #[test]
294 fn test_parse_template_metadata_basic() {
295 let content = r#"
296{# output: src/{{name}}.rs #}
297{# when: src/models/*.rs #}
298{# query: SELECT * FROM models #}
299
300template content here
301"#;
302
303 let metadata = TemplateMetadata::parse(content).unwrap();
304 assert_eq!(metadata.output, "src/{{name}}.rs");
305 assert_eq!(metadata.when, vec!["src/models/*.rs"]);
306 assert_eq!(metadata.query, "SELECT * FROM models");
307 assert_eq!(metadata.foreach, None);
308 }
309
310 #[test]
311 fn test_parse_template_metadata_with_foreach() {
312 let content = r#"
313{# output: tests/{{item}}_test.rs #}
314{# when: src/{{item}}.rs #}
315{# query: SELECT name FROM entities #}
316{# foreach: entity #}
317
318test template
319"#;
320
321 let metadata = TemplateMetadata::parse(content).unwrap();
322 assert_eq!(metadata.output, "tests/{{item}}_test.rs");
323 assert_eq!(metadata.foreach, Some("entity".to_string()));
324 }
325
326 #[test]
327 fn test_parse_template_metadata_multiple_when() {
328 let content = r#"
329{# output: generated.rs #}
330{# when: file1.rs #}
331{# when: file2.rs #}
332{# query: SELECT * #}
333"#;
334
335 let metadata = TemplateMetadata::parse(content).unwrap();
336 assert_eq!(metadata.when.len(), 2);
337 assert!(metadata.when.contains(&"file1.rs".to_string()));
338 assert!(metadata.when.contains(&"file2.rs".to_string()));
339 }
340
341 #[test]
342 fn test_parse_template_metadata_missing_output() {
343 let content = r#"
344{# query: SELECT * #}
345"#;
346
347 let result = TemplateMetadata::parse(content);
348 assert!(result.is_err());
349 assert!(result.unwrap_err().to_string().contains("Missing 'output'"));
350 }
351
352 #[test]
353 fn test_parse_template_metadata_missing_query() {
354 let content = r#"
355{# output: file.rs #}
356"#;
357
358 let result = TemplateMetadata::parse(content);
359 assert!(result.is_err());
360 assert!(result.unwrap_err().to_string().contains("Missing 'query'"));
361 }
362
363 #[test]
364 fn test_generation_planner_empty() {
365 let temp_dir = TempDir::new().unwrap();
366 let conventions = create_test_conventions(temp_dir.path(), vec![]);
367
368 let planner = GenerationPlanner::new(conventions);
369 let plan = planner.plan().unwrap();
370
371 assert_eq!(plan.tasks.len(), 0);
372 }
373
374 #[test]
375 fn test_generation_planner_single_task() {
376 let temp_dir = TempDir::new().unwrap();
377 let conventions = create_test_conventions(
378 temp_dir.path(),
379 vec![(
380 "test",
381 r#"
382{# output: generated.rs #}
383{# query: SELECT * #}
384
385content
386"#,
387 )],
388 );
389
390 let planner = GenerationPlanner::new(conventions);
391 let plan = planner.plan().unwrap();
392
393 assert_eq!(plan.tasks.len(), 1);
394 assert_eq!(plan.tasks[0].template, "test");
395 assert_eq!(plan.tasks[0].output_pattern, "generated.rs");
396 }
397
398 #[test]
399 fn test_generation_planner_multiple_tasks() {
400 let temp_dir = TempDir::new().unwrap();
401 let conventions = create_test_conventions(
402 temp_dir.path(),
403 vec![
404 (
405 "task1",
406 r#"
407{# output: out1.rs #}
408{# query: SELECT * FROM table1 #}
409"#,
410 ),
411 (
412 "task2",
413 r#"
414{# output: out2.rs #}
415{# query: SELECT * FROM table2 #}
416"#,
417 ),
418 ],
419 );
420
421 let planner = GenerationPlanner::new(conventions);
422 let plan = planner.plan().unwrap();
423
424 assert_eq!(plan.tasks.len(), 2);
425 }
426
427 #[test]
428 fn test_generation_planner_with_dependencies() {
429 let temp_dir = TempDir::new().unwrap();
430 let conventions = create_test_conventions(
431 temp_dir.path(),
432 vec![
433 (
434 "base",
435 r#"
436{# output: base.rs #}
437{# query: SELECT * FROM base #}
438"#,
439 ),
440 (
441 "derived",
442 r#"
443{# output: derived.rs #}
444{# when: base.rs #}
445{# query: SELECT * FROM derived #}
446"#,
447 ),
448 ],
449 );
450
451 let planner = GenerationPlanner::new(conventions);
452 let plan = planner.plan().unwrap();
453
454 assert_eq!(plan.tasks.len(), 2);
455
456 let base_idx = plan
458 .tasks
459 .iter()
460 .position(|t| t.template == "base")
461 .unwrap();
462 let derived_idx = plan
463 .tasks
464 .iter()
465 .position(|t| t.template == "derived")
466 .unwrap();
467 assert!(base_idx < derived_idx);
468 }
469
470 #[test]
471 fn test_circular_dependency_detection() {
472 let temp_dir = TempDir::new().unwrap();
473 let conventions = create_test_conventions(
474 temp_dir.path(),
475 vec![
476 (
477 "task1",
478 r#"
479{# output: task1.rs #}
480{# when: task2.rs #}
481{# query: SELECT * #}
482"#,
483 ),
484 (
485 "task2",
486 r#"
487{# output: task2.rs #}
488{# when: task1.rs #}
489{# query: SELECT * #}
490"#,
491 ),
492 ],
493 );
494
495 let planner = GenerationPlanner::new(conventions);
496 let result = planner.plan();
497
498 assert!(result.is_err());
499 let err_msg = result.unwrap_err().to_string();
500 assert!(
501 err_msg.contains("Circular dependency") || err_msg.contains("cycle"),
502 "Expected circular dependency error, got: {}",
503 err_msg
504 );
505 }
506
507 #[test]
508 fn test_foreach_pattern() {
509 let temp_dir = TempDir::new().unwrap();
510 let conventions = create_test_conventions(
511 temp_dir.path(),
512 vec![(
513 "test",
514 r#"
515{# output: tests/{{entity}}_test.rs #}
516{# query: SELECT name FROM entities #}
517{# foreach: entity #}
518"#,
519 )],
520 );
521
522 let planner = GenerationPlanner::new(conventions);
523 let plan = planner.plan().unwrap();
524
525 assert_eq!(plan.tasks.len(), 1);
526 assert_eq!(plan.tasks[0].foreach, Some("entity".to_string()));
527 assert_eq!(plan.tasks[0].output_pattern, "tests/{{entity}}_test.rs");
528 }
529
530 #[test]
531 fn test_once_pattern() {
532 let temp_dir = TempDir::new().unwrap();
533 let conventions = create_test_conventions(
534 temp_dir.path(),
535 vec![(
536 "once",
537 r#"
538{# output: single_file.rs #}
539{# query: SELECT COUNT(*) #}
540"#,
541 )],
542 );
543
544 let planner = GenerationPlanner::new(conventions);
545 let plan = planner.plan().unwrap();
546
547 assert_eq!(plan.tasks.len(), 1);
548 assert_eq!(plan.tasks[0].foreach, None);
549 assert_eq!(plan.tasks[0].output_pattern, "single_file.rs");
550 }
551
552 #[test]
553 fn test_complex_dependency_graph() {
554 let temp_dir = TempDir::new().unwrap();
555 let conventions = create_test_conventions(
556 temp_dir.path(),
557 vec![
558 (
559 "models",
560 r#"
561{# output: models.rs #}
562{# query: SELECT * FROM schema #}
563"#,
564 ),
565 (
566 "services",
567 r#"
568{# output: services.rs #}
569{# when: models.rs #}
570{# query: SELECT * FROM services #}
571"#,
572 ),
573 (
574 "controllers",
575 r#"
576{# output: controllers.rs #}
577{# when: services.rs #}
578{# query: SELECT * FROM controllers #}
579"#,
580 ),
581 (
582 "routes",
583 r#"
584{# output: routes.rs #}
585{# when: controllers.rs #}
586{# query: SELECT * FROM routes #}
587"#,
588 ),
589 ],
590 );
591
592 let planner = GenerationPlanner::new(conventions);
593 let plan = planner.plan().unwrap();
594
595 assert_eq!(plan.tasks.len(), 4);
596
597 let names: Vec<_> = plan.tasks.iter().map(|t| t.template.as_str()).collect();
599 let models_idx = names.iter().position(|&n| n == "models").unwrap();
600 let services_idx = names.iter().position(|&n| n == "services").unwrap();
601 let controllers_idx = names.iter().position(|&n| n == "controllers").unwrap();
602 let routes_idx = names.iter().position(|&n| n == "routes").unwrap();
603
604 assert!(models_idx < services_idx);
605 assert!(services_idx < controllers_idx);
606 assert!(controllers_idx < routes_idx);
607 }
608}