Skip to main content

khive_runtime/
registry.rs

1//! Objective registry for dynamic dispatch.
2//!
3//! Runtime infrastructure: named registration, lookup, defaults.
4//! Lives in khive-runtime (not khive-fold) per ADR-058.
5
6use std::collections::HashMap;
7use std::fmt;
8use std::sync::Arc;
9
10use parking_lot::RwLock;
11
12use khive_fold::objective::{
13    Objective, ObjectiveContext, ObjectiveError, ObjectiveResult, Selection,
14};
15
16/// A type-erased objective wrapper.
17pub struct RegisteredObjective<T: Send + Sync> {
18    pub name: String,
19    pub description: Option<String>,
20    objective: Box<dyn Objective<T>>,
21}
22
23impl<T: Send + Sync> fmt::Debug for RegisteredObjective<T> {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        f.debug_struct("RegisteredObjective")
26            .field("name", &self.name)
27            .field("description", &self.description)
28            .finish_non_exhaustive()
29    }
30}
31
32impl<T: Send + Sync> RegisteredObjective<T> {
33    pub fn new(name: impl Into<String>, objective: Box<dyn Objective<T>>) -> Self {
34        Self {
35            name: name.into(),
36            description: None,
37            objective,
38        }
39    }
40
41    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
42        self.description = Some(desc.into());
43        self
44    }
45
46    /// Raw score (no precision weighting). Use `select()` for ranked selection
47    /// that applies `score * precision` per ADR-059.
48    pub fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
49        self.objective.score(candidate, context)
50    }
51
52    pub fn select<'a>(
53        &self,
54        candidates: &'a [T],
55        context: &ObjectiveContext,
56    ) -> ObjectiveResult<Selection<&'a T>> {
57        self.objective
58            .select(candidates, context)
59            .into_iter()
60            .next()
61            .ok_or_else(|| ObjectiveError::NoMatch("No candidate selected".into()))
62    }
63}
64
65struct RegistryInner<T: Send + Sync> {
66    objectives: HashMap<String, Arc<RegisteredObjective<T>>>,
67    default: Option<String>,
68}
69
70/// Registry of named objectives.
71///
72/// Thread-safe: all operations are behind a single `RwLock`.
73pub struct ObjectiveRegistry<T: Send + Sync> {
74    inner: RwLock<RegistryInner<T>>,
75}
76
77impl<T: Send + Sync> fmt::Debug for ObjectiveRegistry<T> {
78    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79        let inner = self.inner.read();
80        f.debug_struct("ObjectiveRegistry")
81            .field("count", &inner.objectives.len())
82            .field("default", &inner.default)
83            .finish()
84    }
85}
86
87impl<T: Send + Sync> Default for ObjectiveRegistry<T> {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93impl<T: Send + Sync> ObjectiveRegistry<T> {
94    pub fn new() -> Self {
95        Self {
96            inner: RwLock::new(RegistryInner {
97                objectives: HashMap::new(),
98                default: None,
99            }),
100        }
101    }
102
103    pub fn register(
104        &self,
105        name: impl Into<String>,
106        objective: Box<dyn Objective<T>>,
107    ) -> Option<Arc<RegisteredObjective<T>>> {
108        let name = name.into();
109        let registered = Arc::new(RegisteredObjective::new(name.clone(), objective));
110        self.inner.write().objectives.insert(name, registered)
111    }
112
113    pub fn register_with_desc(
114        &self,
115        name: impl Into<String>,
116        description: impl Into<String>,
117        objective: Box<dyn Objective<T>>,
118    ) -> Option<Arc<RegisteredObjective<T>>> {
119        let name = name.into();
120        let registered = Arc::new(
121            RegisteredObjective::new(name.clone(), objective).with_description(description),
122        );
123        self.inner.write().objectives.insert(name, registered)
124    }
125
126    pub fn set_default(&self, name: impl Into<String>) -> ObjectiveResult<()> {
127        let name = name.into();
128        let mut inner = self.inner.write();
129        if !inner.objectives.contains_key(&name) {
130            return Err(ObjectiveError::NotFound(name));
131        }
132        inner.default = Some(name);
133        Ok(())
134    }
135
136    pub fn get(&self, name: &str) -> ObjectiveResult<Arc<RegisteredObjective<T>>> {
137        self.inner
138            .read()
139            .objectives
140            .get(name)
141            .cloned()
142            .ok_or_else(|| ObjectiveError::NotFound(name.to_string()))
143    }
144
145    pub fn get_default(&self) -> ObjectiveResult<Arc<RegisteredObjective<T>>> {
146        let inner = self.inner.read();
147        match inner.default.as_ref() {
148            Some(name) => inner
149                .objectives
150                .get(name)
151                .cloned()
152                .ok_or_else(|| ObjectiveError::NotFound(name.clone())),
153            None => Err(ObjectiveError::NotFound("No default set".to_string())),
154        }
155    }
156
157    pub fn list(&self) -> Vec<String> {
158        let inner = self.inner.read();
159        let mut names: Vec<String> = inner.objectives.keys().cloned().collect();
160        names.sort();
161        names
162    }
163
164    pub fn contains(&self, name: &str) -> bool {
165        self.inner.read().objectives.contains_key(name)
166    }
167
168    /// Raw score via a named objective (no precision weighting).
169    pub fn score(
170        &self,
171        name: &str,
172        candidate: &T,
173        context: &ObjectiveContext,
174    ) -> ObjectiveResult<f64> {
175        let objective = self.get(name)?;
176        Ok(objective.score(candidate, context))
177    }
178
179    pub fn select<'a>(
180        &self,
181        name: &str,
182        candidates: &'a [T],
183        context: &ObjectiveContext,
184    ) -> ObjectiveResult<Selection<&'a T>> {
185        let objective = self.get(name)?;
186        objective
187            .select(candidates, context)
188            .into_iter()
189            .next()
190            .ok_or_else(|| ObjectiveError::NoMatch("No candidate selected".into()))
191    }
192
193    pub fn select_default<'a>(
194        &self,
195        candidates: &'a [T],
196        context: &ObjectiveContext,
197    ) -> ObjectiveResult<Selection<&'a T>> {
198        let objective = self.get_default()?;
199        objective
200            .select(candidates, context)
201            .into_iter()
202            .next()
203            .ok_or_else(|| ObjectiveError::NoMatch("No candidate selected".into()))
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210    use khive_fold::objective::objective_fn;
211
212    #[test]
213    fn register_and_get() {
214        let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
215        let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
216        let old = registry.register("max", Box::new(obj));
217        assert!(old.is_none());
218        assert!(registry.contains("max"));
219        assert!(!registry.contains("min"));
220    }
221
222    #[test]
223    fn register_overwrites() {
224        let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
225        let obj1 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
226        let obj2 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| -(*n as f64));
227        assert!(registry.register("test", Box::new(obj1)).is_none());
228        assert!(registry.register("test", Box::new(obj2)).is_some());
229
230        let candidates = vec![1, 5, 3];
231        let selection = registry
232            .select("test", &candidates, &ObjectiveContext::new())
233            .unwrap();
234        assert_eq!(*selection.item, 1);
235    }
236
237    #[test]
238    fn select_by_name() {
239        let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
240        let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
241        registry.register("max", Box::new(obj));
242
243        let candidates = vec![1, 5, 3];
244        let selection = registry
245            .select("max", &candidates, &ObjectiveContext::new())
246            .unwrap();
247        assert_eq!(*selection.item, 5);
248    }
249
250    #[test]
251    fn default_objective() {
252        let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
253        let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
254        registry.register("max", Box::new(obj));
255        registry.set_default("max").unwrap();
256
257        let candidates = vec![1, 5, 3];
258        let selection = registry
259            .select_default(&candidates, &ObjectiveContext::new())
260            .unwrap();
261        assert_eq!(*selection.item, 5);
262    }
263
264    #[test]
265    fn list_objectives_sorted() {
266        let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
267        let obj1 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
268        let obj2 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| -(*n as f64));
269        let obj3 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| (*n as f64).abs());
270        registry.register("zebra", Box::new(obj1));
271        registry.register("alpha", Box::new(obj2));
272        registry.register("middle", Box::new(obj3));
273
274        let names = registry.list();
275        assert_eq!(names, vec!["alpha", "middle", "zebra"]);
276    }
277
278    #[test]
279    fn get_nonexistent_returns_error() {
280        let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
281        let result = registry.get("nope");
282        assert!(matches!(result, Err(ObjectiveError::NotFound(ref s)) if s == "nope"));
283    }
284
285    #[test]
286    fn get_default_without_setting_returns_error() {
287        let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
288        let result = registry.get_default();
289        assert!(matches!(result, Err(ObjectiveError::NotFound(_))));
290    }
291
292    #[test]
293    fn set_default_nonexistent_returns_error() {
294        let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
295        let result = registry.set_default("ghost");
296        assert!(matches!(result, Err(ObjectiveError::NotFound(ref s)) if s == "ghost"));
297    }
298
299    #[test]
300    fn score_via_registry() {
301        let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
302        let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64 * 2.0);
303        registry.register("double", Box::new(obj));
304
305        let score = registry
306            .score("double", &5, &ObjectiveContext::new())
307            .unwrap();
308        assert!((score - 10.0).abs() < 1e-12);
309    }
310
311    #[test]
312    fn select_default_via_registry() {
313        let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
314        let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| -(*n as f64));
315        registry.register("min", Box::new(obj));
316        registry.set_default("min").unwrap();
317
318        let candidates = vec![1, 5, 3];
319        let selection = registry
320            .select_default(&candidates, &ObjectiveContext::new())
321            .unwrap();
322        assert_eq!(*selection.item, 1);
323    }
324
325    #[test]
326    fn debug_impls() {
327        let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
328        let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
329        registry.register("test", Box::new(obj));
330        let debug = format!("{:?}", registry);
331        assert!(debug.contains("ObjectiveRegistry"));
332        assert!(debug.contains("count: 1"));
333
334        let registered = registry.get("test").unwrap();
335        let debug = format!("{:?}", registered);
336        assert!(debug.contains("RegisteredObjective"));
337        assert!(debug.contains("test"));
338    }
339
340    #[test]
341    fn concurrent_read_write() {
342        let registry = Arc::new(ObjectiveRegistry::<i32>::new());
343
344        std::thread::scope(|s| {
345            for i in 0..8 {
346                let reg = Arc::clone(&registry);
347                s.spawn(move || {
348                    let name = format!("obj_{i}");
349                    let obj =
350                        objective_fn(move |n: &i32, _ctx: &ObjectiveContext| *n as f64 + i as f64);
351                    reg.register(name.clone(), Box::new(obj));
352
353                    assert!(reg.contains(&name));
354
355                    let candidates = vec![1, 2, 3];
356                    let _ = reg.select(&name, &candidates, &ObjectiveContext::new());
357                });
358            }
359        });
360
361        assert_eq!(registry.list().len(), 8);
362    }
363}