lmn_core/request_template/
definition.rs1use std::collections::HashMap;
2
3use serde::Deserialize;
4use tracing::instrument;
5
6use crate::request_template::error::TemplateError;
7use crate::request_template::validators;
8
9pub use crate::request_template::validators::{
12 float::{FloatDef, FloatStrategy, RawFloatDetails},
13 object::ObjectDef,
14 string::{LengthSpec, RawStringDetails, StringDef, StringGenConfig, StringStrategy},
15};
16
17#[derive(Deserialize)]
20#[serde(tag = "type", rename_all = "snake_case")]
21pub enum RawTemplateDef {
22 String {
23 exact: Option<f64>,
24 min: Option<f64>,
25 max: Option<f64>,
26 details: Option<RawStringDetails>,
27 },
28 Float {
29 exact: Option<f64>,
30 min: Option<f64>,
31 max: Option<f64>,
32 details: Option<RawFloatDetails>,
33 },
34 Object {
35 composition: HashMap<String, String>,
36 },
37}
38
39pub enum TemplateDef {
42 String(StringDef),
43 Float(FloatDef),
44 Object(ObjectDef),
45}
46
47pub fn validate_all(
50 raw: HashMap<String, RawTemplateDef>,
51) -> Result<HashMap<String, TemplateDef>, TemplateError> {
52 raw.into_iter()
53 .map(|(name, raw_def)| validators::validate(raw_def, &name).map(|def| (name, def)))
54 .collect()
55}
56
57#[instrument(name = "lmn.template.check_circular_refs", skip(defs), fields(def_count = defs.len()))]
60pub fn check_circular_refs(defs: &HashMap<String, TemplateDef>) -> Result<(), TemplateError> {
61 for def in defs.values() {
62 if let TemplateDef::Object(obj) = def {
63 for ref_name in obj.composition.values() {
64 if !defs.contains_key(ref_name.as_str()) {
65 return Err(TemplateError::MissingDefinition(ref_name.clone()));
66 }
67 }
68 }
69 }
70
71 for name in defs.keys() {
72 let mut visiting: Vec<&str> = Vec::new();
73 detect_cycle(name, defs, &mut visiting)?;
74 }
75
76 Ok(())
77}
78
79fn detect_cycle<'a>(
80 name: &'a str,
81 defs: &'a HashMap<String, TemplateDef>,
82 visiting: &mut Vec<&'a str>,
83) -> Result<(), TemplateError> {
84 if visiting.contains(&name) {
85 let mut cycle: Vec<String> = visiting.iter().map(|s| s.to_string()).collect();
86 cycle.push(name.to_string());
87 return Err(TemplateError::CircularReference(cycle));
88 }
89
90 if let Some(TemplateDef::Object(obj)) = defs.get(name) {
91 visiting.push(name);
92 for ref_name in obj.composition.values() {
93 detect_cycle(ref_name, defs, visiting)?;
94 }
95 visiting.pop();
96 }
97
98 Ok(())
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104
105 fn float_def() -> TemplateDef {
106 TemplateDef::Float(FloatDef {
107 strategy: FloatStrategy::Exact(1.0),
108 decimals: 0,
109 })
110 }
111
112 fn object_def(refs: &[(&str, &str)]) -> TemplateDef {
113 TemplateDef::Object(ObjectDef {
114 composition: refs
115 .iter()
116 .map(|(k, v)| (k.to_string(), v.to_string()))
117 .collect(),
118 })
119 }
120
121 #[test]
122 fn detect_cycle_finds_direct_cycle() {
123 let mut defs = HashMap::new();
124 defs.insert("a".to_string(), object_def(&[("x", "b")]));
125 defs.insert("b".to_string(), object_def(&[("y", "a")]));
126 let mut visiting = Vec::new();
127 assert!(detect_cycle("a", &defs, &mut visiting).is_err());
128 }
129
130 #[test]
131 fn detect_cycle_ok_for_acyclic_graph() {
132 let mut defs = HashMap::new();
133 defs.insert("a".to_string(), object_def(&[("x", "b")]));
134 defs.insert("b".to_string(), float_def());
135 let mut visiting = Vec::new();
136 assert!(detect_cycle("a", &defs, &mut visiting).is_ok());
137 }
138
139 #[test]
140 fn detect_cycle_ok_for_non_object() {
141 let mut defs = HashMap::new();
142 defs.insert("x".to_string(), float_def());
143 let mut visiting = Vec::new();
144 assert!(detect_cycle("x", &defs, &mut visiting).is_ok());
145 }
146}