ggen_cli_lib/conventions/
planner.rs1use ggen_utils::error::Result;
4use std::collections::{HashMap, HashSet};
5use std::fs;
6use std::path::{Path, PathBuf};
7
8use super::ProjectConventions;
9
10#[derive(Debug, Clone, PartialEq)]
12pub struct TemplateMetadata {
13 pub output: String,
14 pub when: Vec<String>,
15 pub query: String,
16 pub foreach: Option<String>,
17}
18
19impl TemplateMetadata {
20 pub fn parse(content: &str) -> Result<Self> {
22 let mut output = None;
23 let mut when = Vec::new();
24 let mut query = None;
25 let mut foreach = None;
26
27 for line in content.lines() {
29 let line = line.trim();
30
31 if line.starts_with("{#") && line.ends_with("#}") {
32 let inner = line[2..line.len() - 2].trim();
33
34 if let Some(value) = inner.strip_prefix("output:") {
35 output = Some(value.trim().to_string());
36 } else if let Some(value) = inner.strip_prefix("when:") {
37 when.push(value.trim().to_string());
38 } else if let Some(value) = inner.strip_prefix("query:") {
39 query = Some(value.trim().to_string());
40 } else if let Some(value) = inner.strip_prefix("foreach:") {
41 foreach = Some(value.trim().to_string());
42 }
43 }
44 }
45
46 Ok(TemplateMetadata {
47 output: output
48 .ok_or_else(|| ggen_utils::error::Error::new("Missing 'output' directive"))?,
49 when,
50 query: query
51 .ok_or_else(|| ggen_utils::error::Error::new("Missing 'query' directive"))?,
52 foreach,
53 })
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct GenerationTask {
60 pub template: String,
61 pub output_pattern: String,
62 pub trigger_files: Vec<PathBuf>,
63 pub query: Option<String>,
64 pub foreach: Option<String>,
65}
66
67#[derive(Debug)]
69pub struct GenerationPlan {
70 pub tasks: Vec<GenerationTask>,
71}
72
73pub struct GenerationPlanner {
75 conventions: ProjectConventions,
76}
77
78impl GenerationPlanner {
79 pub fn new(conventions: ProjectConventions) -> Self {
81 Self { conventions }
82 }
83
84 pub fn plan(&self) -> Result<GenerationPlan> {
86 let mut tasks = Vec::new();
87 let mut task_graph: HashMap<String, Vec<String>> = HashMap::new();
88
89 for (template_name, template_path) in &self.conventions.templates {
91 let metadata = self.parse_template_metadata(template_path)?;
92 let trigger_files = self.resolve_dependencies(&metadata);
93
94 let deps: Vec<String> = trigger_files
96 .iter()
97 .filter_map(|p| {
98 p.file_stem()
99 .and_then(|s| s.to_str())
100 .map(|s| s.to_string())
101 })
102 .collect();
103
104 task_graph.insert(template_name.clone(), deps);
105
106 tasks.push(GenerationTask {
107 template: template_name.clone(),
108 output_pattern: metadata.output,
109 trigger_files,
110 query: Some(metadata.query),
111 foreach: metadata.foreach,
112 });
113 }
114
115 self.check_circular_dependencies(&task_graph)?;
117
118 tasks = self.topological_sort(tasks)?;
120
121 Ok(GenerationPlan { tasks })
122 }
123
124 fn parse_template_metadata(&self, path: &Path) -> Result<TemplateMetadata> {
126 let content = fs::read_to_string(path)?;
127 TemplateMetadata::parse(&content)
128 }
129
130 fn resolve_dependencies(&self, metadata: &TemplateMetadata) -> Vec<PathBuf> {
132 metadata
133 .when
134 .iter()
135 .map(|pattern| {
136 PathBuf::from(pattern)
139 })
140 .collect()
141 }
142
143 fn check_circular_dependencies(&self, graph: &HashMap<String, Vec<String>>) -> Result<()> {
145 for task in graph.keys() {
146 let mut visited = HashSet::new();
147 let mut rec_stack = HashSet::new();
148
149 if self.has_cycle(task, graph, &mut visited, &mut rec_stack) {
150 return Err(ggen_utils::error::Error::new(&format!(
151 "Circular dependency detected involving task: {}",
152 task
153 )));
154 }
155 }
156
157 Ok(())
158 }
159
160 #[allow(clippy::only_used_in_recursion)] fn has_cycle(
163 &self, task: &str, graph: &HashMap<String, Vec<String>>, visited: &mut HashSet<String>,
164 rec_stack: &mut HashSet<String>,
165 ) -> bool {
166 if rec_stack.contains(task) {
167 return true;
168 }
169
170 if visited.contains(task) {
171 return false;
172 }
173
174 visited.insert(task.to_string());
175 rec_stack.insert(task.to_string());
176
177 if let Some(deps) = graph.get(task) {
178 for dep in deps {
179 if self.has_cycle(dep, graph, visited, rec_stack) {
180 return true;
181 }
182 }
183 }
184
185 rec_stack.remove(task);
186 false
187 }
188
189 fn topological_sort(&self, mut tasks: Vec<GenerationTask>) -> Result<Vec<GenerationTask>> {
191 let mut dep_count: HashMap<String, usize> = HashMap::new();
193 let mut graph: HashMap<String, Vec<String>> = HashMap::new();
194
195 for task in &tasks {
196 dep_count.entry(task.template.clone()).or_insert(0);
197
198 for trigger in &task.trigger_files {
199 if let Some(dep) = trigger.file_stem().and_then(|s| s.to_str()) {
200 graph
201 .entry(dep.to_string())
202 .or_default()
203 .push(task.template.clone());
204 *dep_count.entry(task.template.clone()).or_insert(0) += 1;
205 }
206 }
207 }
208
209 let mut ready: Vec<String> = dep_count
211 .iter()
212 .filter(|(_, &count)| count == 0)
213 .map(|(name, _)| name.clone())
214 .collect();
215
216 let mut sorted = Vec::new();
217
218 while let Some(task_name) = ready.pop() {
219 if let Some(pos) = tasks.iter().position(|t| t.template == task_name) {
221 let task = tasks.remove(pos);
222 sorted.push(task);
223 }
224
225 if let Some(dependents) = graph.get(&task_name) {
227 for dependent in dependents {
228 if let Some(count) = dep_count.get_mut(dependent) {
229 *count = count.saturating_sub(1);
230 if *count == 0 {
231 ready.push(dependent.clone());
232 }
233 }
234 }
235 }
236 }
237
238 if !tasks.is_empty() {
240 return Err(ggen_utils::error::Error::new(
241 "Circular dependency detected during topological sort",
242 ));
243 }
244
245 Ok(sorted)
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use std::fs;
253 use tempfile::TempDir;
254
255 fn create_test_template(dir: &Path, name: &str, content: &str) -> PathBuf {
256 let path = dir.join(name);
257 fs::write(&path, content).unwrap();
258 path
259 }
260
261 fn create_test_conventions(
262 temp_dir: &Path, templates: Vec<(&str, &str)>,
263 ) -> ProjectConventions {
264 let template_dir = temp_dir.join("templates");
265 fs::create_dir_all(&template_dir).unwrap();
266
267 let mut template_map = HashMap::new();
268 for (name, content) in templates {
269 let full_name = if name.ends_with(".tmpl") {
270 name.to_string()
271 } else {
272 format!("{}.tmpl", name)
273 };
274 let path = create_test_template(&template_dir, &full_name, content);
275 let key = name.strip_suffix(".tmpl").unwrap_or(name).to_string();
276 template_map.insert(key, path);
277 }
278
279 ProjectConventions {
280 rdf_files: vec![],
281 rdf_dir: temp_dir.join("domain"),
282 templates: template_map,
283 templates_dir: template_dir,
284 queries: HashMap::new(),
285 output_dir: temp_dir.join("generated"),
286 preset: "test".to_string(),
287 }
288 }
289
290 #[test]
291 fn test_parse_template_metadata_basic() {
292 let content = r#"
293{# output: src/{{name}}.rs #}
294{# when: src/models/*.rs #}
295{# query: SELECT * FROM models #}
296
297template content here
298"#;
299
300 let metadata = TemplateMetadata::parse(content).unwrap();
301 assert_eq!(metadata.output, "src/{{name}}.rs");
302 assert_eq!(metadata.when, vec!["src/models/*.rs"]);
303 assert_eq!(metadata.query, "SELECT * FROM models");
304 assert_eq!(metadata.foreach, None);
305 }
306
307 #[test]
308 fn test_parse_template_metadata_with_foreach() {
309 let content = r#"
310{# output: tests/{{item}}_test.rs #}
311{# when: src/{{item}}.rs #}
312{# query: SELECT name FROM entities #}
313{# foreach: entity #}
314
315test template
316"#;
317
318 let metadata = TemplateMetadata::parse(content).unwrap();
319 assert_eq!(metadata.output, "tests/{{item}}_test.rs");
320 assert_eq!(metadata.foreach, Some("entity".to_string()));
321 }
322
323 #[test]
324 fn test_parse_template_metadata_multiple_when() {
325 let content = r#"
326{# output: generated.rs #}
327{# when: file1.rs #}
328{# when: file2.rs #}
329{# query: SELECT * #}
330"#;
331
332 let metadata = TemplateMetadata::parse(content).unwrap();
333 assert_eq!(metadata.when.len(), 2);
334 assert!(metadata.when.contains(&"file1.rs".to_string()));
335 assert!(metadata.when.contains(&"file2.rs".to_string()));
336 }
337
338 #[test]
339 fn test_parse_template_metadata_missing_output() {
340 let content = r#"
341{# query: SELECT * #}
342"#;
343
344 let result = TemplateMetadata::parse(content);
345 assert!(result.is_err());
346 assert!(result.unwrap_err().to_string().contains("Missing 'output'"));
347 }
348
349 #[test]
350 fn test_parse_template_metadata_missing_query() {
351 let content = r#"
352{# output: file.rs #}
353"#;
354
355 let result = TemplateMetadata::parse(content);
356 assert!(result.is_err());
357 assert!(result.unwrap_err().to_string().contains("Missing 'query'"));
358 }
359
360 #[test]
361 fn test_generation_planner_empty() {
362 let temp_dir = TempDir::new().unwrap();
363 let conventions = create_test_conventions(temp_dir.path(), vec![]);
364
365 let planner = GenerationPlanner::new(conventions);
366 let plan = planner.plan().unwrap();
367
368 assert_eq!(plan.tasks.len(), 0);
369 }
370
371 #[test]
372 fn test_generation_planner_single_task() {
373 let temp_dir = TempDir::new().unwrap();
374 let conventions = create_test_conventions(
375 temp_dir.path(),
376 vec![(
377 "test",
378 r#"
379{# output: generated.rs #}
380{# query: SELECT * #}
381
382content
383"#,
384 )],
385 );
386
387 let planner = GenerationPlanner::new(conventions);
388 let plan = planner.plan().unwrap();
389
390 assert_eq!(plan.tasks.len(), 1);
391 assert_eq!(plan.tasks[0].template, "test");
392 assert_eq!(plan.tasks[0].output_pattern, "generated.rs");
393 }
394
395 #[test]
396 fn test_generation_planner_multiple_tasks() {
397 let temp_dir = TempDir::new().unwrap();
398 let conventions = create_test_conventions(
399 temp_dir.path(),
400 vec![
401 (
402 "task1",
403 r#"
404{# output: out1.rs #}
405{# query: SELECT * FROM table1 #}
406"#,
407 ),
408 (
409 "task2",
410 r#"
411{# output: out2.rs #}
412{# query: SELECT * FROM table2 #}
413"#,
414 ),
415 ],
416 );
417
418 let planner = GenerationPlanner::new(conventions);
419 let plan = planner.plan().unwrap();
420
421 assert_eq!(plan.tasks.len(), 2);
422 }
423
424 #[test]
425 fn test_generation_planner_with_dependencies() {
426 let temp_dir = TempDir::new().unwrap();
427 let conventions = create_test_conventions(
428 temp_dir.path(),
429 vec![
430 (
431 "base",
432 r#"
433{# output: base.rs #}
434{# query: SELECT * FROM base #}
435"#,
436 ),
437 (
438 "derived",
439 r#"
440{# output: derived.rs #}
441{# when: base.rs #}
442{# query: SELECT * FROM derived #}
443"#,
444 ),
445 ],
446 );
447
448 let planner = GenerationPlanner::new(conventions);
449 let plan = planner.plan().unwrap();
450
451 assert_eq!(plan.tasks.len(), 2);
452
453 let base_idx = plan
455 .tasks
456 .iter()
457 .position(|t| t.template == "base")
458 .unwrap();
459 let derived_idx = plan
460 .tasks
461 .iter()
462 .position(|t| t.template == "derived")
463 .unwrap();
464 assert!(base_idx < derived_idx);
465 }
466
467 #[test]
468 fn test_circular_dependency_detection() {
469 let temp_dir = TempDir::new().unwrap();
470 let conventions = create_test_conventions(
471 temp_dir.path(),
472 vec![
473 (
474 "task1",
475 r#"
476{# output: task1.rs #}
477{# when: task2.rs #}
478{# query: SELECT * #}
479"#,
480 ),
481 (
482 "task2",
483 r#"
484{# output: task2.rs #}
485{# when: task1.rs #}
486{# query: SELECT * #}
487"#,
488 ),
489 ],
490 );
491
492 let planner = GenerationPlanner::new(conventions);
493 let result = planner.plan();
494
495 assert!(result.is_err());
496 let err_msg = result.unwrap_err().to_string();
497 assert!(
498 err_msg.contains("Circular dependency") || err_msg.contains("cycle"),
499 "Expected circular dependency error, got: {}",
500 err_msg
501 );
502 }
503
504 #[test]
505 fn test_foreach_pattern() {
506 let temp_dir = TempDir::new().unwrap();
507 let conventions = create_test_conventions(
508 temp_dir.path(),
509 vec![(
510 "test",
511 r#"
512{# output: tests/{{entity}}_test.rs #}
513{# query: SELECT name FROM entities #}
514{# foreach: entity #}
515"#,
516 )],
517 );
518
519 let planner = GenerationPlanner::new(conventions);
520 let plan = planner.plan().unwrap();
521
522 assert_eq!(plan.tasks.len(), 1);
523 assert_eq!(plan.tasks[0].foreach, Some("entity".to_string()));
524 assert_eq!(plan.tasks[0].output_pattern, "tests/{{entity}}_test.rs");
525 }
526
527 #[test]
528 fn test_once_pattern() {
529 let temp_dir = TempDir::new().unwrap();
530 let conventions = create_test_conventions(
531 temp_dir.path(),
532 vec![(
533 "once",
534 r#"
535{# output: single_file.rs #}
536{# query: SELECT COUNT(*) #}
537"#,
538 )],
539 );
540
541 let planner = GenerationPlanner::new(conventions);
542 let plan = planner.plan().unwrap();
543
544 assert_eq!(plan.tasks.len(), 1);
545 assert_eq!(plan.tasks[0].foreach, None);
546 assert_eq!(plan.tasks[0].output_pattern, "single_file.rs");
547 }
548
549 #[test]
550 fn test_complex_dependency_graph() {
551 let temp_dir = TempDir::new().unwrap();
552 let conventions = create_test_conventions(
553 temp_dir.path(),
554 vec![
555 (
556 "models",
557 r#"
558{# output: models.rs #}
559{# query: SELECT * FROM schema #}
560"#,
561 ),
562 (
563 "services",
564 r#"
565{# output: services.rs #}
566{# when: models.rs #}
567{# query: SELECT * FROM services #}
568"#,
569 ),
570 (
571 "controllers",
572 r#"
573{# output: controllers.rs #}
574{# when: services.rs #}
575{# query: SELECT * FROM controllers #}
576"#,
577 ),
578 (
579 "routes",
580 r#"
581{# output: routes.rs #}
582{# when: controllers.rs #}
583{# query: SELECT * FROM routes #}
584"#,
585 ),
586 ],
587 );
588
589 let planner = GenerationPlanner::new(conventions);
590 let plan = planner.plan().unwrap();
591
592 assert_eq!(plan.tasks.len(), 4);
593
594 let names: Vec<_> = plan.tasks.iter().map(|t| t.template.as_str()).collect();
596 let models_idx = names.iter().position(|&n| n == "models").unwrap();
597 let services_idx = names.iter().position(|&n| n == "services").unwrap();
598 let controllers_idx = names.iter().position(|&n| n == "controllers").unwrap();
599 let routes_idx = names.iter().position(|&n| n == "routes").unwrap();
600
601 assert!(models_idx < services_idx);
602 assert!(services_idx < controllers_idx);
603 assert!(controllers_idx < routes_idx);
604 }
605}