commit_wizard/engine/config/
rules.rs1use std::collections::BTreeSet;
2
3use serde::{Deserialize, Serialize};
4use toml::Value;
5
6use crate::engine::error::{ErrorCode, Result};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
11#[serde(transparent)]
12pub struct RulesConfig(pub Value);
13
14impl Default for RulesConfig {
15 fn default() -> Self {
16 Self(Value::Table(Default::default()))
17 }
18}
19
20impl RulesConfig {
21 pub fn get(&self, path: &str) -> Option<&Value> {
22 let mut current = &self.0;
23
24 for segment in path.split('.') {
25 current = current.get(segment)?;
26 }
27
28 Some(current)
29 }
30 pub fn from_toml_str(input: &str) -> Result<Self> {
31 toml::from_str(input).map_err(|err| {
32 ErrorCode::ConfigInvalid
33 .error()
34 .with_context("error", err.to_string())
35 })
36 }
37 pub fn resolve_ref(&self, reference: &str) -> Option<&Value> {
38 let path = reference.strip_prefix("@rules.")?;
39 self.get(path)
40 }
41
42 pub fn resolve<T>(&self, reference: &str) -> Result<T>
43 where
44 T: serde::de::DeserializeOwned,
45 {
46 let value = self.resolve_ref(reference).cloned().ok_or_else(|| {
47 ErrorCode::ConfigReferenceInvalid
48 .error()
49 .with_context("reference", reference)
50 })?;
51
52 value.try_into().map_err(|err| {
53 ErrorCode::ConfigReferenceInvalid
54 .error()
55 .with_context("reference", reference)
56 .with_context("error", err.to_string())
57 })
58 }
59
60 pub fn is_reference(value: &str) -> bool {
61 value.starts_with("@rules.")
62 }
63
64 pub fn resolve_string(&self, value: &str) -> Result<String> {
68 if Self::is_reference(value) {
69 let resolved: String = self.resolve(value)?;
70 Ok(resolved)
71 } else {
72 Ok(value.to_string())
73 }
74 }
75
76 pub fn resolve_value_refs(&self, value: &mut Value) -> Result<()> {
77 let mut visiting = BTreeSet::new();
78 self.resolve_value_refs_inner(value, "$", &mut visiting)
79 }
80
81 fn resolve_value_refs_inner(
82 &self,
83 value: &mut Value,
84 path: &str,
85 visiting: &mut BTreeSet<String>,
86 ) -> Result<()> {
87 match value {
88 Value::String(s) if Self::is_reference(s) => {
89 let reference = s.clone();
90 let resolved = self.resolve_ref(&reference).cloned().ok_or_else(|| {
91 ErrorCode::ConfigReferenceInvalid
92 .error()
93 .with_context("reference", reference.clone())
94 .with_context("path", path)
95 })?;
96
97 if !visiting.insert(reference.clone()) {
98 return Err(ErrorCode::ConfigReferenceInvalid
99 .error()
100 .with_context("reference", reference)
101 .with_context("path", path)
102 .with_context("reason", "cyclic reference detected"));
103 }
104
105 let mut resolved = resolved;
106 self.resolve_value_refs_inner(&mut resolved, path, visiting)?;
107 visiting.remove(&reference);
108
109 *value = resolved;
110 Ok(())
111 }
112 Value::Array(items) => {
113 for (index, item) in items.iter_mut().enumerate() {
114 let child_path = format!("{path}[{index}]");
115 self.resolve_value_refs_inner(item, &child_path, visiting)?;
116 }
117 Ok(())
118 }
119 Value::Table(table) => {
120 for (key, item) in table.iter_mut() {
121 let child_path = format!("{path}.{key}");
122 self.resolve_value_refs_inner(item, &child_path, visiting)?;
123 }
124 Ok(())
125 }
126 _ => Ok(()),
127 }
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 #[test]
136 fn resolves_nested_rule_references() {
137 let rules: RulesConfig = toml::from_str(
138 r#"
139 a = "@rules.b"
140 b = "@rules.c"
141 c = "final"
142 "#,
143 )
144 .unwrap();
145
146 let mut value = Value::String("@rules.a".into());
147 rules.resolve_value_refs(&mut value).unwrap();
148
149 assert_eq!(value, Value::String("final".into()));
150 }
151
152 #[test]
153 fn rejects_cyclic_rule_references() {
154 let rules: RulesConfig = toml::from_str(
155 r#"
156 a = "@rules.b"
157 b = "@rules.a"
158 "#,
159 )
160 .unwrap();
161
162 let mut value = Value::String("@rules.a".into());
163 let err = rules.resolve_value_refs(&mut value).unwrap_err();
164
165 assert_eq!(err.code, ErrorCode::ConfigReferenceInvalid);
166 }
167}