cognis_core/wrappers/
configurable.rs1use std::collections::HashMap;
24use std::marker::PhantomData;
25use std::sync::Arc;
26
27use async_trait::async_trait;
28
29use crate::runnable::{Runnable, RunnableConfig};
30use crate::Result;
31
32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
34pub struct ConfigKey {
35 pub name: String,
37}
38
39impl ConfigKey {
40 pub fn new(name: impl Into<String>) -> Self {
42 Self { name: name.into() }
43 }
44}
45
46pub struct Configurable<I, O> {
48 default: Arc<dyn Runnable<I, O>>,
49 alternatives: HashMap<String, Arc<dyn Runnable<I, O>>>,
50 _phantom: PhantomData<fn(I) -> O>,
51}
52
53impl<I, O> Configurable<I, O>
54where
55 I: Send + 'static,
56 O: Send + 'static,
57{
58 pub fn new(default: Arc<dyn Runnable<I, O>>) -> Self {
60 Self {
61 default,
62 alternatives: HashMap::new(),
63 _phantom: PhantomData,
64 }
65 }
66
67 pub fn alternative(mut self, name: impl Into<String>, r: Arc<dyn Runnable<I, O>>) -> Self {
70 self.alternatives.insert(name.into(), r);
71 self
72 }
73}
74
75#[async_trait]
76impl<I, O> Runnable<I, O> for Configurable<I, O>
77where
78 I: Send + 'static,
79 O: Send + 'static,
80{
81 async fn invoke(&self, input: I, config: RunnableConfig) -> Result<O> {
82 let chosen = config
83 .extras
84 .get::<ConfigKey>()
85 .and_then(|k| self.alternatives.get(&k.name))
86 .cloned()
87 .unwrap_or_else(|| self.default.clone());
88 chosen.invoke(input, config).await
89 }
90 fn name(&self) -> &str {
91 "Configurable"
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98
99 struct Const(u32);
100
101 #[async_trait]
102 impl Runnable<u32, u32> for Const {
103 async fn invoke(&self, _: u32, _: RunnableConfig) -> Result<u32> {
104 Ok(self.0)
105 }
106 }
107
108 #[tokio::test]
109 async fn picks_default_when_no_key() {
110 let c: Configurable<u32, u32> =
111 Configurable::new(Arc::new(Const(1))).alternative("alt", Arc::new(Const(2)));
112 assert_eq!(c.invoke(0, RunnableConfig::default()).await.unwrap(), 1);
113 }
114
115 #[tokio::test]
116 async fn picks_alternative_when_keyed() {
117 let c: Configurable<u32, u32> =
118 Configurable::new(Arc::new(Const(1))).alternative("alt", Arc::new(Const(2)));
119 let mut cfg = RunnableConfig::default();
120 cfg.extras.insert(ConfigKey::new("alt"));
121 assert_eq!(c.invoke(0, cfg).await.unwrap(), 2);
122 }
123
124 #[tokio::test]
125 async fn unknown_key_falls_back_to_default() {
126 let c: Configurable<u32, u32> = Configurable::new(Arc::new(Const(1)));
127 let mut cfg = RunnableConfig::default();
128 cfg.extras.insert(ConfigKey::new("missing"));
129 assert_eq!(c.invoke(0, cfg).await.unwrap(), 1);
130 }
131}