1use std::collections::HashMap;
7use std::fmt;
8use std::sync::Arc;
9
10use parking_lot::RwLock;
11
12use khive_fold::objective::{
13 Objective, ObjectiveContext, ObjectiveError, ObjectiveResult, Selection,
14};
15
16pub struct RegisteredObjective<T: Send + Sync> {
18 pub name: String,
19 pub description: Option<String>,
20 objective: Box<dyn Objective<T>>,
21}
22
23impl<T: Send + Sync> fmt::Debug for RegisteredObjective<T> {
24 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25 f.debug_struct("RegisteredObjective")
26 .field("name", &self.name)
27 .field("description", &self.description)
28 .finish_non_exhaustive()
29 }
30}
31
32impl<T: Send + Sync> RegisteredObjective<T> {
33 pub fn new(name: impl Into<String>, objective: Box<dyn Objective<T>>) -> Self {
34 Self {
35 name: name.into(),
36 description: None,
37 objective,
38 }
39 }
40
41 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
42 self.description = Some(desc.into());
43 self
44 }
45
46 pub fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
49 self.objective.score(candidate, context)
50 }
51
52 pub fn select<'a>(
53 &self,
54 candidates: &'a [T],
55 context: &ObjectiveContext,
56 ) -> ObjectiveResult<Selection<&'a T>> {
57 self.objective
58 .select(candidates, context)
59 .into_iter()
60 .next()
61 .ok_or_else(|| ObjectiveError::NoMatch("No candidate selected".into()))
62 }
63}
64
65struct RegistryInner<T: Send + Sync> {
66 objectives: HashMap<String, Arc<RegisteredObjective<T>>>,
67 default: Option<String>,
68}
69
70pub struct ObjectiveRegistry<T: Send + Sync> {
74 inner: RwLock<RegistryInner<T>>,
75}
76
77impl<T: Send + Sync> fmt::Debug for ObjectiveRegistry<T> {
78 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79 let inner = self.inner.read();
80 f.debug_struct("ObjectiveRegistry")
81 .field("count", &inner.objectives.len())
82 .field("default", &inner.default)
83 .finish()
84 }
85}
86
87impl<T: Send + Sync> Default for ObjectiveRegistry<T> {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93impl<T: Send + Sync> ObjectiveRegistry<T> {
94 pub fn new() -> Self {
95 Self {
96 inner: RwLock::new(RegistryInner {
97 objectives: HashMap::new(),
98 default: None,
99 }),
100 }
101 }
102
103 pub fn register(
104 &self,
105 name: impl Into<String>,
106 objective: Box<dyn Objective<T>>,
107 ) -> Option<Arc<RegisteredObjective<T>>> {
108 let name = name.into();
109 let registered = Arc::new(RegisteredObjective::new(name.clone(), objective));
110 self.inner.write().objectives.insert(name, registered)
111 }
112
113 pub fn register_with_desc(
114 &self,
115 name: impl Into<String>,
116 description: impl Into<String>,
117 objective: Box<dyn Objective<T>>,
118 ) -> Option<Arc<RegisteredObjective<T>>> {
119 let name = name.into();
120 let registered = Arc::new(
121 RegisteredObjective::new(name.clone(), objective).with_description(description),
122 );
123 self.inner.write().objectives.insert(name, registered)
124 }
125
126 pub fn set_default(&self, name: impl Into<String>) -> ObjectiveResult<()> {
127 let name = name.into();
128 let mut inner = self.inner.write();
129 if !inner.objectives.contains_key(&name) {
130 return Err(ObjectiveError::NotFound(name));
131 }
132 inner.default = Some(name);
133 Ok(())
134 }
135
136 pub fn get(&self, name: &str) -> ObjectiveResult<Arc<RegisteredObjective<T>>> {
137 self.inner
138 .read()
139 .objectives
140 .get(name)
141 .cloned()
142 .ok_or_else(|| ObjectiveError::NotFound(name.to_string()))
143 }
144
145 pub fn get_default(&self) -> ObjectiveResult<Arc<RegisteredObjective<T>>> {
146 let inner = self.inner.read();
147 match inner.default.as_ref() {
148 Some(name) => inner
149 .objectives
150 .get(name)
151 .cloned()
152 .ok_or_else(|| ObjectiveError::NotFound(name.clone())),
153 None => Err(ObjectiveError::NotFound("No default set".to_string())),
154 }
155 }
156
157 pub fn list(&self) -> Vec<String> {
158 let inner = self.inner.read();
159 let mut names: Vec<String> = inner.objectives.keys().cloned().collect();
160 names.sort();
161 names
162 }
163
164 pub fn contains(&self, name: &str) -> bool {
165 self.inner.read().objectives.contains_key(name)
166 }
167
168 pub fn score(
170 &self,
171 name: &str,
172 candidate: &T,
173 context: &ObjectiveContext,
174 ) -> ObjectiveResult<f64> {
175 let objective = self.get(name)?;
176 Ok(objective.score(candidate, context))
177 }
178
179 pub fn select<'a>(
180 &self,
181 name: &str,
182 candidates: &'a [T],
183 context: &ObjectiveContext,
184 ) -> ObjectiveResult<Selection<&'a T>> {
185 let objective = self.get(name)?;
186 objective
187 .select(candidates, context)
188 .into_iter()
189 .next()
190 .ok_or_else(|| ObjectiveError::NoMatch("No candidate selected".into()))
191 }
192
193 pub fn select_default<'a>(
194 &self,
195 candidates: &'a [T],
196 context: &ObjectiveContext,
197 ) -> ObjectiveResult<Selection<&'a T>> {
198 let objective = self.get_default()?;
199 objective
200 .select(candidates, context)
201 .into_iter()
202 .next()
203 .ok_or_else(|| ObjectiveError::NoMatch("No candidate selected".into()))
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use khive_fold::objective::objective_fn;
211
212 #[test]
213 fn register_and_get() {
214 let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
215 let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
216 let old = registry.register("max", Box::new(obj));
217 assert!(old.is_none());
218 assert!(registry.contains("max"));
219 assert!(!registry.contains("min"));
220 }
221
222 #[test]
223 fn register_overwrites() {
224 let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
225 let obj1 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
226 let obj2 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| -(*n as f64));
227 assert!(registry.register("test", Box::new(obj1)).is_none());
228 assert!(registry.register("test", Box::new(obj2)).is_some());
229
230 let candidates = vec![1, 5, 3];
231 let selection = registry
232 .select("test", &candidates, &ObjectiveContext::new())
233 .unwrap();
234 assert_eq!(*selection.item, 1);
235 }
236
237 #[test]
238 fn select_by_name() {
239 let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
240 let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
241 registry.register("max", Box::new(obj));
242
243 let candidates = vec![1, 5, 3];
244 let selection = registry
245 .select("max", &candidates, &ObjectiveContext::new())
246 .unwrap();
247 assert_eq!(*selection.item, 5);
248 }
249
250 #[test]
251 fn default_objective() {
252 let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
253 let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
254 registry.register("max", Box::new(obj));
255 registry.set_default("max").unwrap();
256
257 let candidates = vec![1, 5, 3];
258 let selection = registry
259 .select_default(&candidates, &ObjectiveContext::new())
260 .unwrap();
261 assert_eq!(*selection.item, 5);
262 }
263
264 #[test]
265 fn list_objectives_sorted() {
266 let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
267 let obj1 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
268 let obj2 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| -(*n as f64));
269 let obj3 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| (*n as f64).abs());
270 registry.register("zebra", Box::new(obj1));
271 registry.register("alpha", Box::new(obj2));
272 registry.register("middle", Box::new(obj3));
273
274 let names = registry.list();
275 assert_eq!(names, vec!["alpha", "middle", "zebra"]);
276 }
277
278 #[test]
279 fn get_nonexistent_returns_error() {
280 let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
281 let result = registry.get("nope");
282 assert!(matches!(result, Err(ObjectiveError::NotFound(ref s)) if s == "nope"));
283 }
284
285 #[test]
286 fn get_default_without_setting_returns_error() {
287 let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
288 let result = registry.get_default();
289 assert!(matches!(result, Err(ObjectiveError::NotFound(_))));
290 }
291
292 #[test]
293 fn set_default_nonexistent_returns_error() {
294 let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
295 let result = registry.set_default("ghost");
296 assert!(matches!(result, Err(ObjectiveError::NotFound(ref s)) if s == "ghost"));
297 }
298
299 #[test]
300 fn score_via_registry() {
301 let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
302 let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64 * 2.0);
303 registry.register("double", Box::new(obj));
304
305 let score = registry
306 .score("double", &5, &ObjectiveContext::new())
307 .unwrap();
308 assert!((score - 10.0).abs() < 1e-12);
309 }
310
311 #[test]
312 fn select_default_via_registry() {
313 let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
314 let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| -(*n as f64));
315 registry.register("min", Box::new(obj));
316 registry.set_default("min").unwrap();
317
318 let candidates = vec![1, 5, 3];
319 let selection = registry
320 .select_default(&candidates, &ObjectiveContext::new())
321 .unwrap();
322 assert_eq!(*selection.item, 1);
323 }
324
325 #[test]
326 fn debug_impls() {
327 let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
328 let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
329 registry.register("test", Box::new(obj));
330 let debug = format!("{:?}", registry);
331 assert!(debug.contains("ObjectiveRegistry"));
332 assert!(debug.contains("count: 1"));
333
334 let registered = registry.get("test").unwrap();
335 let debug = format!("{:?}", registered);
336 assert!(debug.contains("RegisteredObjective"));
337 assert!(debug.contains("test"));
338 }
339
340 #[test]
341 fn concurrent_read_write() {
342 let registry = Arc::new(ObjectiveRegistry::<i32>::new());
343
344 std::thread::scope(|s| {
345 for i in 0..8 {
346 let reg = Arc::clone(®istry);
347 s.spawn(move || {
348 let name = format!("obj_{i}");
349 let obj =
350 objective_fn(move |n: &i32, _ctx: &ObjectiveContext| *n as f64 + i as f64);
351 reg.register(name.clone(), Box::new(obj));
352
353 assert!(reg.contains(&name));
354
355 let candidates = vec![1, 2, 3];
356 let _ = reg.select(&name, &candidates, &ObjectiveContext::new());
357 });
358 }
359 });
360
361 assert_eq!(registry.list().len(), 8);
362 }
363}