1use crate::constraint::ConstraintChannel;
9use serde::{Deserialize, Serialize};
10use std::collections::BTreeMap;
11
12pub struct ConstraintRegistry {
14 channels: BTreeMap<String, RegisteredChannel>,
15}
16
17pub struct RegisteredChannel {
19 pub channel: Box<dyn ConstraintChannel>,
21 pub enabled: bool,
23 pub config: ChannelConfig,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ChannelConfig {
30 pub weight: f64,
32 pub safe_threshold: Option<f64>,
34 pub caution_threshold: Option<f64>,
35 pub block_threshold: Option<f64>,
36 pub description: String,
38 pub domain_tag: String,
40}
41
42impl Default for ChannelConfig {
43 fn default() -> Self {
44 Self {
45 weight: 1.0,
46 safe_threshold: None,
47 caution_threshold: None,
48 block_threshold: None,
49 description: String::new(),
50 domain_tag: "custom".to_string(),
51 }
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct ChannelSummary {
58 pub id: String,
59 pub name: String,
60 pub enabled: bool,
61 pub config: ChannelConfig,
62 pub dimensions: Vec<String>,
63}
64
65impl ConstraintRegistry {
66 pub fn new() -> Self {
67 Self {
68 channels: BTreeMap::new(),
69 }
70 }
71
72 pub fn register(
74 &mut self,
75 id: &str,
76 channel: Box<dyn ConstraintChannel>,
77 config: ChannelConfig,
78 ) {
79 self.channels.insert(
80 id.to_string(),
81 RegisteredChannel {
82 channel,
83 enabled: true,
84 config,
85 },
86 );
87 }
88
89 pub fn unregister(&mut self, id: &str) -> bool {
91 self.channels.remove(id).is_some()
92 }
93
94 pub fn set_enabled(&mut self, id: &str, enabled: bool) -> bool {
96 if let Some(ch) = self.channels.get_mut(id) {
97 ch.enabled = enabled;
98 true
99 } else {
100 false
101 }
102 }
103
104 pub fn update_config(&mut self, id: &str, config: ChannelConfig) -> bool {
106 if let Some(ch) = self.channels.get_mut(id) {
107 ch.config = config;
108 true
109 } else {
110 false
111 }
112 }
113
114 pub fn active_channels(&self) -> Vec<(&str, &RegisteredChannel)> {
116 self.channels
117 .iter()
118 .filter(|(_, c)| c.enabled)
119 .map(|(id, c)| (id.as_str(), c))
120 .collect()
121 }
122
123 pub fn len(&self) -> usize {
125 self.channels.len()
126 }
127
128 pub fn is_empty(&self) -> bool {
130 self.channels.is_empty()
131 }
132
133 pub fn active_count(&self) -> usize {
135 self.channels.values().filter(|c| c.enabled).count()
136 }
137
138 pub fn list_all(&self) -> Vec<ChannelSummary> {
140 self.channels
141 .iter()
142 .map(|(id, rc)| ChannelSummary {
143 id: id.clone(),
144 name: rc.channel.name().to_string(),
145 enabled: rc.enabled,
146 config: rc.config.clone(),
147 dimensions: rc.channel.dimension_names(),
148 })
149 .collect()
150 }
151
152 pub fn export(&self) -> Vec<ChannelSummary> {
154 self.list_all()
155 }
156
157 pub fn contains(&self, id: &str) -> bool {
159 self.channels.contains_key(id)
160 }
161}
162
163impl Default for ConstraintRegistry {
164 fn default() -> Self {
165 Self::new()
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use crate::KernelResult;
173
174 struct FixedChannel {
175 name: String,
176 margin: f64,
177 }
178
179 impl ConstraintChannel for FixedChannel {
180 fn name(&self) -> &str {
181 &self.name
182 }
183 fn evaluate(&self, _state: &[f64]) -> KernelResult<f64> {
184 Ok(self.margin)
185 }
186 fn dimension_names(&self) -> Vec<String> {
187 vec!["x".into()]
188 }
189 }
190
191 fn make_channel(name: &str, margin: f64) -> Box<dyn ConstraintChannel> {
192 Box::new(FixedChannel {
193 name: name.to_string(),
194 margin,
195 })
196 }
197
198 fn make_config(domain: &str) -> ChannelConfig {
199 ChannelConfig {
200 domain_tag: domain.to_string(),
201 description: format!("{} channel", domain),
202 ..Default::default()
203 }
204 }
205
206 #[test]
207 fn register_and_list() {
208 let mut reg = ConstraintRegistry::new();
209 reg.register("ch1", make_channel("ch1", 0.8), make_config("test"));
210 reg.register("ch2", make_channel("ch2", 0.5), make_config("test"));
211
212 assert_eq!(reg.len(), 2);
213 assert_eq!(reg.active_count(), 2);
214
215 let summaries = reg.list_all();
216 assert_eq!(summaries.len(), 2);
217 assert_eq!(summaries[0].id, "ch1");
218 assert_eq!(summaries[1].id, "ch2");
219 }
220
221 #[test]
222 fn unregister_channel() {
223 let mut reg = ConstraintRegistry::new();
224 reg.register("ch1", make_channel("ch1", 0.8), make_config("test"));
225
226 assert!(reg.unregister("ch1"));
227 assert!(!reg.unregister("ch1")); assert_eq!(reg.len(), 0);
229 }
230
231 #[test]
232 fn enable_disable_channel() {
233 let mut reg = ConstraintRegistry::new();
234 reg.register("ch1", make_channel("ch1", 0.8), make_config("test"));
235
236 assert_eq!(reg.active_count(), 1);
237
238 reg.set_enabled("ch1", false);
239 assert_eq!(reg.active_count(), 0);
240 assert_eq!(reg.len(), 1); reg.set_enabled("ch1", true);
243 assert_eq!(reg.active_count(), 1);
244 }
245
246 #[test]
247 fn update_config() {
248 let mut reg = ConstraintRegistry::new();
249 reg.register("ch1", make_channel("ch1", 0.8), make_config("test"));
250
251 let new_config = ChannelConfig {
252 weight: 2.0,
253 safe_threshold: Some(0.7),
254 ..make_config("updated")
255 };
256 assert!(reg.update_config("ch1", new_config));
257
258 let summaries = reg.list_all();
259 assert!((summaries[0].config.weight - 2.0).abs() < f64::EPSILON);
260 assert_eq!(summaries[0].config.domain_tag, "updated");
261 }
262
263 #[test]
264 fn active_channels_excludes_disabled() {
265 let mut reg = ConstraintRegistry::new();
266 reg.register("ch1", make_channel("ch1", 0.8), make_config("test"));
267 reg.register("ch2", make_channel("ch2", 0.5), make_config("test"));
268
269 reg.set_enabled("ch1", false);
270 let active = reg.active_channels();
271 assert_eq!(active.len(), 1);
272 assert_eq!(active[0].0, "ch2");
273 }
274
275 #[test]
276 fn contains_check() {
277 let mut reg = ConstraintRegistry::new();
278 reg.register("ch1", make_channel("ch1", 0.8), make_config("test"));
279
280 assert!(reg.contains("ch1"));
281 assert!(!reg.contains("ch2"));
282 }
283
284 #[test]
285 fn deterministic_ordering() {
286 let mut reg = ConstraintRegistry::new();
287 reg.register("z_channel", make_channel("z", 0.1), make_config("test"));
288 reg.register("a_channel", make_channel("a", 0.9), make_config("test"));
289 reg.register("m_channel", make_channel("m", 0.5), make_config("test"));
290
291 let summaries = reg.list_all();
292 assert_eq!(summaries[0].id, "a_channel");
293 assert_eq!(summaries[1].id, "m_channel");
294 assert_eq!(summaries[2].id, "z_channel");
295 }
296}