Skip to main content

khive_runtime/
registry.rs

1//! Objective registry for dynamic dispatch.
2//!
3//! Runtime infrastructure: named registration, lookup, defaults.
4//! Lives in khive-runtime (not khive-fold) because it depends on runtime types.
5
6use std::collections::HashMap;
7use std::fmt;
8use std::sync::Arc;
9
10use parking_lot::RwLock;
11
12use khive_fold::objective::{
13    Objective, ObjectiveContext, ObjectiveError, ObjectiveResult, Selection,
14};
15
16/// A type-erased objective wrapper.
17pub struct RegisteredObjective<T: Send + Sync> {
18    pub name: String,
19    pub description: Option<String>,
20    objective: Box<dyn Objective<T>>,
21}
22
23impl<T: Send + Sync> fmt::Debug for RegisteredObjective<T> {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        f.debug_struct("RegisteredObjective")
26            .field("name", &self.name)
27            .field("description", &self.description)
28            .finish_non_exhaustive()
29    }
30}
31
32impl<T: Send + Sync> RegisteredObjective<T> {
33    /// Create a new registered objective with the given name and no description.
34    pub fn new(name: impl Into<String>, objective: Box<dyn Objective<T>>) -> Self {
35        Self {
36            name: name.into(),
37            description: None,
38            objective,
39        }
40    }
41
42    /// Attach a human-readable description to this registered objective.
43    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
44        self.description = Some(desc.into());
45        self
46    }
47
48    /// Raw score (no precision weighting). Use `select()` for ranked selection with salience boost.
49    pub fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
50        self.objective.score(candidate, context)
51    }
52
53    /// Select the best candidate according to this objective's ranking.
54    ///
55    /// Returns `ObjectiveError::NoMatch` when the candidate slice is empty.
56    pub fn select<'a>(
57        &self,
58        candidates: &'a [T],
59        context: &ObjectiveContext,
60    ) -> ObjectiveResult<Selection<&'a T>> {
61        self.objective
62            .select(candidates, context)
63            .into_iter()
64            .next()
65            .ok_or_else(|| ObjectiveError::NoMatch("No candidate selected".into()))
66    }
67}
68
69struct RegistryInner<T: Send + Sync> {
70    objectives: HashMap<String, Arc<RegisteredObjective<T>>>,
71    default: Option<String>,
72}
73
74/// Registry of named objectives.
75///
76/// Thread-safe: all operations are behind a single `RwLock`.
77pub struct ObjectiveRegistry<T: Send + Sync> {
78    inner: RwLock<RegistryInner<T>>,
79}
80
81impl<T: Send + Sync> fmt::Debug for ObjectiveRegistry<T> {
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        let inner = self.inner.read();
84        f.debug_struct("ObjectiveRegistry")
85            .field("count", &inner.objectives.len())
86            .field("default", &inner.default)
87            .finish()
88    }
89}
90
91impl<T: Send + Sync> Default for ObjectiveRegistry<T> {
92    fn default() -> Self {
93        Self::new()
94    }
95}
96
97impl<T: Send + Sync> ObjectiveRegistry<T> {
98    /// Create an empty registry with no objectives and no default set.
99    pub fn new() -> Self {
100        Self {
101            inner: RwLock::new(RegistryInner {
102                objectives: HashMap::new(),
103                default: None,
104            }),
105        }
106    }
107
108    /// Register a named objective; returns the previous entry if the name was already taken.
109    pub fn register(
110        &self,
111        name: impl Into<String>,
112        objective: Box<dyn Objective<T>>,
113    ) -> Option<Arc<RegisteredObjective<T>>> {
114        let name = name.into();
115        let registered = Arc::new(RegisteredObjective::new(name.clone(), objective));
116        self.inner.write().objectives.insert(name, registered)
117    }
118
119    /// Register a named objective with a human-readable description.
120    pub fn register_with_desc(
121        &self,
122        name: impl Into<String>,
123        description: impl Into<String>,
124        objective: Box<dyn Objective<T>>,
125    ) -> Option<Arc<RegisteredObjective<T>>> {
126        let name = name.into();
127        let registered = Arc::new(
128            RegisteredObjective::new(name.clone(), objective).with_description(description),
129        );
130        self.inner.write().objectives.insert(name, registered)
131    }
132
133    /// Set the named objective as the registry default.
134    ///
135    /// Returns `ObjectiveError::NotFound` when no objective with that name is registered.
136    pub fn set_default(&self, name: impl Into<String>) -> ObjectiveResult<()> {
137        let name = name.into();
138        let mut inner = self.inner.write();
139        if !inner.objectives.contains_key(&name) {
140            return Err(ObjectiveError::NotFound(name));
141        }
142        inner.default = Some(name);
143        Ok(())
144    }
145
146    /// Retrieve a registered objective by name.
147    ///
148    /// Returns `ObjectiveError::NotFound` when the name is not registered.
149    pub fn get(&self, name: &str) -> ObjectiveResult<Arc<RegisteredObjective<T>>> {
150        self.inner
151            .read()
152            .objectives
153            .get(name)
154            .cloned()
155            .ok_or_else(|| ObjectiveError::NotFound(name.to_string()))
156    }
157
158    /// Retrieve the current default objective.
159    ///
160    /// Returns `ObjectiveError::NotFound` when no default has been set.
161    pub fn get_default(&self) -> ObjectiveResult<Arc<RegisteredObjective<T>>> {
162        let inner = self.inner.read();
163        match inner.default.as_ref() {
164            Some(name) => inner
165                .objectives
166                .get(name)
167                .cloned()
168                .ok_or_else(|| ObjectiveError::NotFound(name.clone())),
169            None => Err(ObjectiveError::NotFound("No default set".to_string())),
170        }
171    }
172
173    /// List all registered objective names in sorted order.
174    pub fn list(&self) -> Vec<String> {
175        let inner = self.inner.read();
176        let mut names: Vec<String> = inner.objectives.keys().cloned().collect();
177        names.sort();
178        names
179    }
180
181    /// Return `true` if an objective with the given name is registered.
182    pub fn contains(&self, name: &str) -> bool {
183        self.inner.read().objectives.contains_key(name)
184    }
185
186    /// Raw score via a named objective (no precision weighting).
187    pub fn score(
188        &self,
189        name: &str,
190        candidate: &T,
191        context: &ObjectiveContext,
192    ) -> ObjectiveResult<f64> {
193        let objective = self.get(name)?;
194        Ok(objective.score(candidate, context))
195    }
196
197    /// Select the best candidate using the named objective.
198    ///
199    /// Returns `ObjectiveError::NoMatch` when the candidate slice is empty.
200    pub fn select<'a>(
201        &self,
202        name: &str,
203        candidates: &'a [T],
204        context: &ObjectiveContext,
205    ) -> ObjectiveResult<Selection<&'a T>> {
206        let objective = self.get(name)?;
207        objective
208            .select(candidates, context)
209            .into_iter()
210            .next()
211            .ok_or_else(|| ObjectiveError::NoMatch("No candidate selected".into()))
212    }
213
214    /// Select the best candidate using the current default objective.
215    ///
216    /// Returns `ObjectiveError::NotFound` when no default is set.
217    pub fn select_default<'a>(
218        &self,
219        candidates: &'a [T],
220        context: &ObjectiveContext,
221    ) -> ObjectiveResult<Selection<&'a T>> {
222        let objective = self.get_default()?;
223        objective
224            .select(candidates, context)
225            .into_iter()
226            .next()
227            .ok_or_else(|| ObjectiveError::NoMatch("No candidate selected".into()))
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use khive_fold::objective::objective_fn;
235
236    #[test]
237    fn register_and_get() {
238        let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
239        let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
240        let old = registry.register("max", Box::new(obj));
241        assert!(old.is_none());
242        assert!(registry.contains("max"));
243        assert!(!registry.contains("min"));
244    }
245
246    #[test]
247    fn register_overwrites() {
248        let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
249        let obj1 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
250        let obj2 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| -(*n as f64));
251        assert!(registry.register("test", Box::new(obj1)).is_none());
252        assert!(registry.register("test", Box::new(obj2)).is_some());
253
254        let candidates = vec![1, 5, 3];
255        let selection = registry
256            .select("test", &candidates, &ObjectiveContext::new())
257            .unwrap();
258        assert_eq!(*selection.item, 1);
259    }
260
261    #[test]
262    fn select_by_name() {
263        let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
264        let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
265        registry.register("max", Box::new(obj));
266
267        let candidates = vec![1, 5, 3];
268        let selection = registry
269            .select("max", &candidates, &ObjectiveContext::new())
270            .unwrap();
271        assert_eq!(*selection.item, 5);
272    }
273
274    #[test]
275    fn default_objective() {
276        let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
277        let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
278        registry.register("max", Box::new(obj));
279        registry.set_default("max").unwrap();
280
281        let candidates = vec![1, 5, 3];
282        let selection = registry
283            .select_default(&candidates, &ObjectiveContext::new())
284            .unwrap();
285        assert_eq!(*selection.item, 5);
286    }
287
288    #[test]
289    fn list_objectives_sorted() {
290        let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
291        let obj1 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
292        let obj2 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| -(*n as f64));
293        let obj3 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| (*n as f64).abs());
294        registry.register("zebra", Box::new(obj1));
295        registry.register("alpha", Box::new(obj2));
296        registry.register("middle", Box::new(obj3));
297
298        let names = registry.list();
299        assert_eq!(names, vec!["alpha", "middle", "zebra"]);
300    }
301
302    #[test]
303    fn get_nonexistent_returns_error() {
304        let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
305        let result = registry.get("nope");
306        assert!(matches!(result, Err(ObjectiveError::NotFound(ref s)) if s == "nope"));
307    }
308
309    #[test]
310    fn get_default_without_setting_returns_error() {
311        let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
312        let result = registry.get_default();
313        assert!(matches!(result, Err(ObjectiveError::NotFound(_))));
314    }
315
316    #[test]
317    fn set_default_nonexistent_returns_error() {
318        let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
319        let result = registry.set_default("ghost");
320        assert!(matches!(result, Err(ObjectiveError::NotFound(ref s)) if s == "ghost"));
321    }
322
323    #[test]
324    fn score_via_registry() {
325        let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
326        let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64 * 2.0);
327        registry.register("double", Box::new(obj));
328
329        let score = registry
330            .score("double", &5, &ObjectiveContext::new())
331            .unwrap();
332        assert!((score - 10.0).abs() < 1e-12);
333    }
334
335    #[test]
336    fn select_default_via_registry() {
337        let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
338        let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| -(*n as f64));
339        registry.register("min", Box::new(obj));
340        registry.set_default("min").unwrap();
341
342        let candidates = vec![1, 5, 3];
343        let selection = registry
344            .select_default(&candidates, &ObjectiveContext::new())
345            .unwrap();
346        assert_eq!(*selection.item, 1);
347    }
348
349    #[test]
350    fn debug_impls() {
351        let registry: ObjectiveRegistry<i32> = ObjectiveRegistry::new();
352        let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
353        registry.register("test", Box::new(obj));
354        let debug = format!("{:?}", registry);
355        assert!(debug.contains("ObjectiveRegistry"));
356        assert!(debug.contains("count: 1"));
357
358        let registered = registry.get("test").unwrap();
359        let debug = format!("{:?}", registered);
360        assert!(debug.contains("RegisteredObjective"));
361        assert!(debug.contains("test"));
362    }
363
364    #[test]
365    fn concurrent_read_write() {
366        let registry = Arc::new(ObjectiveRegistry::<i32>::new());
367
368        std::thread::scope(|s| {
369            for i in 0..8 {
370                let reg = Arc::clone(&registry);
371                s.spawn(move || {
372                    let name = format!("obj_{i}");
373                    let obj =
374                        objective_fn(move |n: &i32, _ctx: &ObjectiveContext| *n as f64 + i as f64);
375                    reg.register(name.clone(), Box::new(obj));
376
377                    assert!(reg.contains(&name));
378
379                    let candidates = vec![1, 2, 3];
380                    let _ = reg.select(&name, &candidates, &ObjectiveContext::new());
381                });
382            }
383        });
384
385        assert_eq!(registry.list().len(), 8);
386    }
387}