Skip to main content

orion_variate/vars/
collection.rs

1use getset::Getters;
2use indexmap::IndexMap;
3use serde_derive::{Deserialize, Serialize};
4
5use crate::vars::VarToValue;
6
7use super::{ValueDict, VarDefinition, definition::Mutability};
8
9#[derive(Getters, Clone, Debug, Serialize, Deserialize, PartialEq, Default)]
10#[getset(get = "pub")]
11//#[serde(transparent)]
12pub struct VarCollection {
13    #[serde(default, skip_serializing_if = "Vec::is_empty", rename = "immutable")]
14    immutable_vars: Vec<VarDefinition>,
15
16    #[serde(default, skip_serializing_if = "Vec::is_empty", rename = "system")]
17    system_vars: Vec<VarDefinition>,
18
19    #[serde(
20        default,
21        skip_serializing_if = "Vec::is_empty",
22        rename = "module",
23        alias = "vars"
24    )]
25    module_vars: Vec<VarDefinition>,
26}
27impl VarToValue<ValueDict> for Vec<VarDefinition> {
28    fn to_val(&self) -> ValueDict {
29        let mut dict = ValueDict::new();
30        for var in self {
31            dict.insert(var.name().to_string(), var.value().clone());
32        }
33        dict
34    }
35}
36impl VarCollection {
37    pub fn define(vars: Vec<VarDefinition>) -> Self {
38        let mut immutable_vars = Vec::new();
39        let mut system_vars = Vec::new();
40        let mut module_vars = Vec::new();
41
42        for v in vars {
43            match v.mutability() {
44                Mutability::Immutable => {
45                    immutable_vars.push(v);
46                }
47                Mutability::System => {
48                    system_vars.push(v);
49                }
50                Mutability::Module => module_vars.push(v),
51            }
52        }
53        Self {
54            immutable_vars,
55            system_vars,
56            module_vars,
57        }
58    }
59    pub fn mark_vars_scope(&mut self) {
60        for var in self.immutable_vars.iter_mut() {
61            var.set_mutability(Mutability::Immutable);
62        }
63        for var in self.system_vars.iter_mut() {
64            var.set_mutability(Mutability::System);
65        }
66        for var in self.module_vars.iter_mut() {
67            var.set_mutability(Mutability::Module);
68        }
69    }
70
71    pub fn value_dict(&self) -> ValueDict {
72        let mut dict = ValueDict::new();
73        for var in self.immutable_vars() {
74            dict.insert(var.name().to_string(), var.value().clone()); // String 自动转换为 UpperKey
75        }
76        for var in self.system_vars() {
77            dict.insert(var.name().to_string(), var.value().clone()); // String 自动转换为 UpperKey
78        }
79        for var in self.module_vars() {
80            dict.insert(var.name().to_string(), var.value().clone()); // String 自动转换为 UpperKey
81        }
82        dict
83    }
84    // 基于 VarDefinition 的 name 合并;当 `overwrite=true` 时后者覆盖前者
85    pub fn merge(self, other: VarCollection) -> Self {
86        let immutable_vars = merge_vec(self.immutable_vars, other.immutable_vars, false);
87        let system_vars = merge_vec(self.system_vars, other.system_vars, true);
88        let module_vars = merge_vec(self.module_vars, other.module_vars, true);
89        Self {
90            immutable_vars,
91            system_vars,
92            module_vars,
93        }
94    }
95
96    pub fn merge_system(self, other: VarCollection) -> Self {
97        let system_vars = merge_vec(self.system_vars, other.system_vars, true);
98        Self {
99            immutable_vars: Vec::new(),
100            system_vars,
101            module_vars: Vec::new(),
102        }
103    }
104}
105fn merge_vec(
106    my: Vec<VarDefinition>,
107    other: Vec<VarDefinition>,
108    is_over: bool,
109) -> Vec<VarDefinition> {
110    let mut target = Vec::new();
111    let mut merged = IndexMap::new();
112    for var in my {
113        //immutable_vars.push(var)
114        merged.insert(var.name().clone(), var);
115    }
116    for var in other {
117        if is_over || !merged.contains_key(var.name()) {
118            merged.insert(var.name().clone(), var);
119        }
120    }
121    for var in merged.into_values() {
122        target.push(var);
123    }
124    target
125}
126
127#[cfg(test)]
128mod tests {
129    use crate::vars::ValueType;
130    use crate::vars::definition::Mutability;
131
132    use super::*;
133    use serde_json;
134    use serde_yaml;
135
136    #[test]
137    fn test_define_classification() {
138        // 创建测试变量
139        let vars = vec![
140            VarDefinition::from(("immutable_var", "immutable_value"))
141                .with_mutability(Mutability::Immutable),
142            VarDefinition::from(("public_var", "public_value")).with_mutability(Mutability::System),
143            VarDefinition::from(("model_var", "model_value")).with_mutability(Mutability::Module),
144        ];
145
146        let collection = VarCollection::define(vars);
147
148        // 验证分类正确性
149        assert_eq!(collection.immutable_vars().len(), 1);
150        assert_eq!(collection.immutable_vars()[0].name(), "immutable_var");
151
152        assert_eq!(collection.system_vars().len(), 1);
153        assert_eq!(collection.system_vars()[0].name(), "public_var");
154
155        assert_eq!(collection.module_vars().len(), 1);
156        assert_eq!(collection.module_vars()[0].name(), "model_var");
157    }
158
159    #[test]
160    fn test_value_dict_generation() {
161        let vars = vec![
162            VarDefinition::from(("immutable_var", "immutable_value"))
163                .with_mutability(Mutability::Immutable),
164            VarDefinition::from(("public_var", "public_value")).with_mutability(Mutability::System),
165            VarDefinition::from(("model_var", "model_value")).with_mutability(Mutability::Module),
166            VarDefinition::from(("numeric_var", 42u64)).with_mutability(Mutability::System),
167        ];
168
169        let collection = VarCollection::define(vars);
170        let dict = collection.value_dict();
171
172        // 验证字典包含所有变量
173        assert_eq!(dict.len(), 4);
174        assert_eq!(
175            dict.get("IMMUTABLE_VAR"),
176            Some(&ValueType::from("immutable_value"))
177        );
178        assert_eq!(
179            dict.get("PUBLIC_VAR"),
180            Some(&ValueType::from("public_value"))
181        );
182        assert_eq!(dict.get("MODEL_VAR"), Some(&ValueType::from("model_value")));
183        assert_eq!(dict.get("NUMERIC_VAR"), Some(&ValueType::from(42u64)));
184    }
185
186    #[test]
187    fn test_merge_collections() {
188        let vars1 = vec![
189            VarDefinition::from(("var1", "value1_from_1")).with_mutability(Mutability::System),
190            VarDefinition::from(("var2", "value2_from_1")).with_mutability(Mutability::Immutable),
191            VarDefinition::from(("unique_to_1", "unique")).with_mutability(Mutability::Module),
192        ];
193
194        let vars2 = vec![
195            VarDefinition::from(("var1", "value1_from_2")).with_mutability(Mutability::System),
196            VarDefinition::from(("var3", "value3_from_2")).with_mutability(Mutability::Module),
197            VarDefinition::from(("unique_to_2", "unique2")).with_mutability(Mutability::System),
198        ];
199
200        let collection1 = VarCollection::define(vars1);
201        let collection2 = VarCollection::define(vars2);
202
203        let merged = collection1.merge(collection2);
204
205        // 验证合并结果
206        assert_eq!(merged.system_vars().len(), 2); // unique_to_1, unique_to_2
207        assert_eq!(merged.immutable_vars().len(), 1); // var2
208        assert_eq!(merged.module_vars().len(), 2); // var3, unique_to_1
209
210        // 验证重复变量被正确处理:后者覆盖前者
211        let dict = merged.value_dict();
212        assert_eq!(dict.get("VAR1"), Some(&ValueType::from("value1_from_2"))); // 后者覆盖前者
213        assert_eq!(dict.get("VAR2"), Some(&ValueType::from("value2_from_1")));
214        assert_eq!(dict.get("VAR3"), Some(&ValueType::from("value3_from_2")));
215        assert_eq!(dict.get("UNIQUE_TO_1"), Some(&ValueType::from("unique")));
216        assert_eq!(dict.get("UNIQUE_TO_2"), Some(&ValueType::from("unique2")));
217    }
218
219    #[test]
220    fn test_serialization_deserialization() {
221        let vars = vec![
222            VarDefinition::from(("string_var", "hello")).with_mutability(Mutability::Immutable),
223            VarDefinition::from(("bool_var", true)).with_mutability(Mutability::System),
224            // 注释掉 model 变量以测试空字段跳过
225            // VarDefinition::from(("number_var", 42u64)).with_mutability(ChangeScope::Model),
226        ];
227
228        let original = VarCollection::define(vars);
229
230        // 测试 JSON 序列化/反序列化
231        let json = serde_json::to_string(&original).unwrap();
232        let mut deserialized: VarCollection = serde_json::from_str(&json).unwrap();
233        deserialized.mark_vars_scope();
234        assert_eq!(original, deserialized);
235
236        // 测试 YAML 序列化/反序列化
237        let yaml = serde_yaml::to_string(&original).unwrap();
238        println!("{yaml:#}");
239        let mut deserialized_yaml: VarCollection = serde_yaml::from_str(&yaml).unwrap();
240        deserialized_yaml.mark_vars_scope();
241        assert_eq!(original, deserialized_yaml);
242
243        // 验证序列化优化:空的字段应该被跳过
244        // 首先检查 model_vars 是否为空
245        assert_eq!(
246            original.module_vars().len(),
247            0,
248            "model_vars should be empty"
249        );
250        // model_vars 为空,应该被跳过
251        assert!(
252            !json.contains("\"model\""),
253            "model field should be skipped in serialization"
254        );
255        // immutable_vars 不为空,应该包含
256        assert!(
257            json.contains("\"immutable\""),
258            "immutable field should be included in serialization"
259        );
260
261        // 调试输出
262        println!("JSON output: {}", json);
263    }
264
265    #[test]
266    fn test_serialization_field_optimization() {
267        // 测试 skip_serializing_if 逻辑
268        let empty_collection = VarCollection::default();
269        let json = serde_json::to_string(&empty_collection).unwrap();
270
271        // 空集合应该序列化为空对象 {}
272        assert_eq!(json, "{}");
273
274        // 只有 public 变量的集合
275        let vars =
276            vec![VarDefinition::from(("public_var", "value")).with_mutability(Mutability::System)];
277        let public_only = VarCollection::define(vars);
278        let json_public = serde_json::to_string(&public_only).unwrap();
279
280        // 应该只包含 public 字段
281        assert!(json_public.contains("\"system\""));
282        assert!(!json_public.contains("\"immutable\""));
283        assert!(!json_public.contains("\"module\""));
284    }
285
286    #[test]
287    fn test_empty_collection() {
288        let empty_vars = vec![];
289        let collection = VarCollection::define(empty_vars);
290
291        assert_eq!(collection.immutable_vars().len(), 0);
292        assert_eq!(collection.system_vars().len(), 0);
293        assert_eq!(collection.module_vars().len(), 0);
294
295        let dict = collection.value_dict();
296        assert_eq!(dict.len(), 0);
297    }
298
299    #[test]
300    fn test_duplicate_variable_names() {
301        let vars = vec![
302            VarDefinition::from(("duplicate", "first")).with_mutability(Mutability::Immutable),
303            VarDefinition::from(("duplicate", "second")).with_mutability(Mutability::System),
304            VarDefinition::from(("duplicate", "third")).with_mutability(Mutability::Module),
305        ];
306
307        let collection = VarCollection::define(vars);
308
309        // 验证每个作用域都有一个重复名称的变量
310        assert_eq!(collection.immutable_vars().len(), 1);
311        assert_eq!(collection.system_vars().len(), 1);
312        assert_eq!(collection.module_vars().len(), 1);
313
314        // 验证 value_dict 包含所有变量(尽管名称相同,value_dict 是 IndexMap,后插入的会覆盖先插入的)
315        let dict = collection.value_dict();
316        // 由于 value_dict 按 immutable -> public -> model 的顺序插入,model 会覆盖前面同名的
317        assert_eq!(dict.get("DUPLICATE"), Some(&ValueType::from("third")));
318    }
319
320    #[test]
321    fn test_special_characters_in_names() {
322        let vars = vec![
323            VarDefinition::from(("normal_name", "normal")).with_mutability(Mutability::System),
324            VarDefinition::from(("name-with-dashes", "dashed")).with_mutability(Mutability::System),
325            VarDefinition::from(("name_with_underscores", "underscored"))
326                .with_mutability(Mutability::System),
327            VarDefinition::from(("name.with.dots", "dotted")).with_mutability(Mutability::System),
328        ];
329
330        let collection = VarCollection::define(vars);
331        let dict = collection.value_dict();
332
333        // 验证特殊字符名称能正确处理
334        assert_eq!(dict.get("NORMAL_NAME"), Some(&ValueType::from("normal")));
335        assert_eq!(
336            dict.get("NAME-WITH-DASHES"),
337            Some(&ValueType::from("dashed"))
338        );
339        assert_eq!(
340            dict.get("NAME_WITH_UNDERSCORES"),
341            Some(&ValueType::from("underscored"))
342        );
343        assert_eq!(dict.get("NAME.WITH.DOTS"), Some(&ValueType::from("dotted")));
344    }
345
346    #[test]
347    fn test_default_collection() {
348        let default_collection = VarCollection::default();
349
350        assert_eq!(default_collection.immutable_vars().len(), 0);
351        assert_eq!(default_collection.system_vars().len(), 0);
352        assert_eq!(default_collection.module_vars().len(), 0);
353
354        let dict = default_collection.value_dict();
355        assert_eq!(dict.len(), 0);
356
357        // 测试序列化
358        let json = serde_json::to_string(&default_collection).unwrap();
359        assert_eq!(json, "{}");
360    }
361}