1use std::collections::HashMap;
7use std::fmt;
8use std::sync::Arc;
9
10use parking_lot::RwLock;
11
12use khive_fold::objective::{
13 Objective, ObjectiveContext, ObjectiveError, ObjectiveResult, Selection,
14};
15
16pub struct RegisteredObjective<T: Send + Sync> {
18 pub name: String,
19 pub description: Option<String>,
20 objective: Box<dyn Objective<T>>,
21}
22
23impl<T: Send + Sync> fmt::Debug for RegisteredObjective<T> {
24 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25 f.debug_struct("RegisteredObjective")
26 .field("name", &self.name)
27 .field("description", &self.description)
28 .finish_non_exhaustive()
29 }
30}
31
32impl<T: Send + Sync> RegisteredObjective<T> {
33 pub fn new(name: impl Into<String>, objective: Box<dyn Objective<T>>) -> Self {
35 Self {
36 name: name.into(),
37 description: None,
38 objective,
39 }
40 }
41
42 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
44 self.description = Some(desc.into());
45 self
46 }
47
48 pub fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
50 self.objective.score(candidate, context)
51 }
52
53 pub fn select<'a>(
57 &self,
58 candidates: &'a [T],
59 context: &ObjectiveContext,
60 ) -> ObjectiveResult<Selection<&'a T>> {
61 self.objective
62 .select(candidates, context)
63 .into_iter()
64 .next()
65 .ok_or_else(|| ObjectiveError::NoMatch("No candidate selected".into()))
66 }
67}
68
69struct RegistryInner<T: Send + Sync> {
70 objectives: HashMap<String, Arc<RegisteredObjective<T>>>,
71 default: Option<String>,
72}
73
74pub struct ObjectiveRegistry<T: Send + Sync> {
78 inner: RwLock<RegistryInner<T>>,
79}
80
81impl<T: Send + Sync> fmt::Debug for ObjectiveRegistry<T> {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 let inner = self.inner.read();
84 f.debug_struct("ObjectiveRegistry")
85 .field("count", &inner.objectives.len())
86 .field("default", &inner.default)
87 .finish()
88 }
89}
90
91impl<T: Send + Sync> Default for ObjectiveRegistry<T> {
92 fn default() -> Self {
93 Self::new()
94 }
95}
96
97impl<T: Send + Sync> ObjectiveRegistry<T> {
98 pub fn new() -> Self {
100 Self {
101 inner: RwLock::new(RegistryInner {
102 objectives: HashMap::new(),
103 default: None,
104 }),
105 }
106 }
107
108 pub fn register(
110 &self,
111 name: impl Into<String>,
112 objective: Box<dyn Objective<T>>,
113 ) -> Option<Arc<RegisteredObjective<T>>> {
114 let name = name.into();
115 let registered = Arc::new(RegisteredObjective::new(name.clone(), objective));
116 self.inner.write().objectives.insert(name, registered)
117 }
118
119 pub fn register_with_desc(
121 &self,
122 name: impl Into<String>,
123 description: impl Into<String>,
124 objective: Box<dyn Objective<T>>,
125 ) -> Option<Arc<RegisteredObjective<T>>> {
126 let name = name.into();
127 let registered = Arc::new(
128 RegisteredObjective::new(name.clone(), objective).with_description(description),
129 );
130 self.inner.write().objectives.insert(name, registered)
131 }
132
133 pub fn set_default(&self, name: impl Into<String>) -> ObjectiveResult<()> {
137 let name = name.into();
138 let mut inner = self.inner.write();
139 if !inner.objectives.contains_key(&name) {
140 return Err(ObjectiveError::NotFound(name));
141 }
142 inner.default = Some(name);
143 Ok(())
144 }
145
146 pub fn get(&self, name: &str) -> ObjectiveResult<Arc<RegisteredObjective<T>>> {
150 self.inner
151 .read()
152 .objectives
153 .get(name)
154 .cloned()
155 .ok_or_else(|| ObjectiveError::NotFound(name.to_string()))
156 }
157
158 pub fn get_default(&self) -> ObjectiveResult<Arc<RegisteredObjective<T>>> {
162 let inner = self.inner.read();
163 match inner.default.as_ref() {
164 Some(name) => inner
165 .objectives
166 .get(name)
167 .cloned()
168 .ok_or_else(|| ObjectiveError::NotFound(name.clone())),
169 None => Err(ObjectiveError::NotFound("No default set".to_string())),
170 }
171 }
172
173 pub fn list(&self) -> Vec<String> {
175 let inner = self.inner.read();
176 let mut names: Vec<String> = inner.objectives.keys().cloned().collect();
177 names.sort();
178 names
179 }
180
181 pub fn contains(&self, name: &str) -> bool {
183 self.inner.read().objectives.contains_key(name)
184 }
185
186 pub fn score(
188 &self,
189 name: &str,
190 candidate: &T,
191 context: &ObjectiveContext,
192 ) -> ObjectiveResult<f64> {
193 let objective = self.get(name)?;
194 Ok(objective.score(candidate, context))
195 }
196
197 pub fn select<'a>(
201 &self,
202 name: &str,
203 candidates: &'a [T],
204 context: &ObjectiveContext,
205 ) -> ObjectiveResult<Selection<&'a T>> {
206 let objective = self.get(name)?;
207 objective
208 .select(candidates, context)
209 .into_iter()
210 .next()
211 .ok_or_else(|| ObjectiveError::NoMatch("No candidate selected".into()))
212 }
213
214 pub fn select_default<'a>(
218 &self,
219 candidates: &'a [T],
220 context: &ObjectiveContext,
221 ) -> ObjectiveResult<Selection<&'a T>> {
222 let objective = self.get_default()?;
223 objective
224 .select(candidates, context)
225 .into_iter()
226 .next()
227 .ok_or_else(|| ObjectiveError::NoMatch("No candidate selected".into()))
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use khive_fold::objective::objective_fn;
235
236 #[test]
237 fn register_and_get() {
238 let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
239 let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
240 let old = registry.register("max", Box::new(obj));
241 assert!(old.is_none());
242 assert!(registry.contains("max"));
243 assert!(!registry.contains("min"));
244 }
245
246 #[test]
247 fn register_overwrites() {
248 let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
249 let obj1 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
250 let obj2 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| -(*n as f64));
251 assert!(registry.register("test", Box::new(obj1)).is_none());
252 assert!(registry.register("test", Box::new(obj2)).is_some());
253
254 let candidates = vec![1, 5, 3];
255 let selection = registry
256 .select("test", &candidates, &ObjectiveContext::new())
257 .unwrap();
258 assert_eq!(*selection.item, 1);
259 }
260
261 #[test]
262 fn select_by_name() {
263 let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
264 let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
265 registry.register("max", Box::new(obj));
266
267 let candidates = vec![1, 5, 3];
268 let selection = registry
269 .select("max", &candidates, &ObjectiveContext::new())
270 .unwrap();
271 assert_eq!(*selection.item, 5);
272 }
273
274 #[test]
275 fn default_objective() {
276 let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
277 let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
278 registry.register("max", Box::new(obj));
279 registry.set_default("max").unwrap();
280
281 let candidates = vec![1, 5, 3];
282 let selection = registry
283 .select_default(&candidates, &ObjectiveContext::new())
284 .unwrap();
285 assert_eq!(*selection.item, 5);
286 }
287
288 #[test]
289 fn list_objectives_sorted() {
290 let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
291 let obj1 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
292 let obj2 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| -(*n as f64));
293 let obj3 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| (*n as f64).abs());
294 registry.register("zebra", Box::new(obj1));
295 registry.register("alpha", Box::new(obj2));
296 registry.register("middle", Box::new(obj3));
297
298 let names = registry.list();
299 assert_eq!(names, vec!["alpha", "middle", "zebra"]);
300 }
301
302 #[test]
303 fn get_nonexistent_returns_error() {
304 let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
305 let result = registry.get("nope");
306 assert!(matches!(result, Err(ObjectiveError::NotFound(ref s)) if s == "nope"));
307 }
308
309 #[test]
310 fn get_default_without_setting_returns_error() {
311 let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
312 let result = registry.get_default();
313 assert!(matches!(result, Err(ObjectiveError::NotFound(_))));
314 }
315
316 #[test]
317 fn set_default_nonexistent_returns_error() {
318 let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
319 let result = registry.set_default("ghost");
320 assert!(matches!(result, Err(ObjectiveError::NotFound(ref s)) if s == "ghost"));
321 }
322
323 #[test]
324 fn score_via_registry() {
325 let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
326 let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64 * 2.0);
327 registry.register("double", Box::new(obj));
328
329 let score = registry
330 .score("double", &5, &ObjectiveContext::new())
331 .unwrap();
332 assert!((score - 10.0).abs() < 1e-12);
333 }
334
335 #[test]
336 fn select_default_via_registry() {
337 let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
338 let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| -(*n as f64));
339 registry.register("min", Box::new(obj));
340 registry.set_default("min").unwrap();
341
342 let candidates = vec![1, 5, 3];
343 let selection = registry
344 .select_default(&candidates, &ObjectiveContext::new())
345 .unwrap();
346 assert_eq!(*selection.item, 1);
347 }
348
349 #[test]
350 fn debug_impls() {
351 let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
352 let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
353 registry.register("test", Box::new(obj));
354 let debug = format!("{:?}", registry);
355 assert!(debug.contains("ObjectiveRegistry"));
356 assert!(debug.contains("count: 1"));
357
358 let registered = registry.get("test").unwrap();
359 let debug = format!("{:?}", registered);
360 assert!(debug.contains("RegisteredObjective"));
361 assert!(debug.contains("test"));
362 }
363
364 #[test]
365 fn concurrent_read_write() {
366 let registry = Arc::new(ObjectiveRegistry::<i32>::new());
367
368 std::thread::scope(|s| {
369 for i in 0..8 {
370 let reg = Arc::clone(®istry);
371 s.spawn(move || {
372 let name = format!("obj_{i}");
373 let obj =
374 objective_fn(move |n: &i32, _ctx: &ObjectiveContext| *n as f64 + i as f64);
375 reg.register(name.clone(), Box::new(obj));
376
377 assert!(reg.contains(&name));
378
379 let candidates = vec![1, 2, 3];
380 let _ = reg.select(&name, &candidates, &ObjectiveContext::new());
381 });
382 }
383 });
384
385 assert_eq!(registry.list().len(), 8);
386 }
387}