1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use crate::api::Result;
7use crate::task::{ActivityContext, OrchestrationContext};
8
9pub type OrchestratorResult = Result<Option<String>>;
11
12pub type OrchestratorFn = Arc<
17 dyn Fn(OrchestrationContext) -> Pin<Box<dyn Future<Output = OrchestratorResult> + Send>>
18 + Send
19 + Sync,
20>;
21
22pub type ActivityResult = Result<Option<String>>;
24
25pub type ActivityFn = Arc<
30 dyn Fn(ActivityContext, Option<String>) -> Pin<Box<dyn Future<Output = ActivityResult> + Send>>
31 + Send
32 + Sync,
33>;
34
35#[derive(Clone)]
41struct OrchestratorEntry {
42 f: OrchestratorFn,
43 version: Option<String>,
46 is_latest: bool,
47}
48
49pub struct Registry {
59 orchestrators: HashMap<String, Vec<OrchestratorEntry>>,
61 activities: HashMap<String, ActivityFn>,
62}
63
64impl Registry {
65 pub fn new() -> Self {
67 Self {
68 orchestrators: HashMap::new(),
69 activities: HashMap::new(),
70 }
71 }
72
73 fn push_orchestrator_entry(&mut self, name: &str, entry: OrchestratorEntry) {
76 self.orchestrators
77 .entry(name.to_string())
78 .or_default()
79 .push(entry);
80 }
81
82 pub fn add_named_orchestrator<F, Fut>(&mut self, name: &str, f: F)
89 where
90 F: Fn(OrchestrationContext) -> Fut + Send + Sync + 'static,
91 Fut: Future<Output = OrchestratorResult> + Send + 'static,
92 {
93 tracing::info!(orchestrator = %name, "Registering orchestrator");
94 let f: OrchestratorFn = Arc::new(move |ctx| {
95 Box::pin(f(ctx)) as Pin<Box<dyn Future<Output = OrchestratorResult> + Send>>
96 });
97 self.push_orchestrator_entry(
98 name,
99 OrchestratorEntry {
100 f,
101 version: None,
102 is_latest: false,
103 },
104 );
105 }
106
107 pub fn add_versioned_orchestrator<F, Fut>(&mut self, name: &str, version: &str, f: F)
123 where
124 F: Fn(OrchestrationContext) -> Fut + Send + Sync + 'static,
125 Fut: Future<Output = OrchestratorResult> + Send + 'static,
126 {
127 tracing::info!(orchestrator = %name, version = %version, "Registering versioned orchestrator");
128 let f: OrchestratorFn = Arc::new(move |ctx| {
129 Box::pin(f(ctx)) as Pin<Box<dyn Future<Output = OrchestratorResult> + Send>>
130 });
131 self.push_orchestrator_entry(
132 name,
133 OrchestratorEntry {
134 f,
135 version: Some(version.to_string()),
136 is_latest: false,
137 },
138 );
139 }
140
141 pub fn add_latest_orchestrator<F, Fut>(&mut self, name: &str, version: &str, f: F)
158 where
159 F: Fn(OrchestrationContext) -> Fut + Send + Sync + 'static,
160 Fut: Future<Output = OrchestratorResult> + Send + 'static,
161 {
162 tracing::info!(orchestrator = %name, version = %version, "Registering latest orchestrator");
163 let f: OrchestratorFn = Arc::new(move |ctx| {
164 Box::pin(f(ctx)) as Pin<Box<dyn Future<Output = OrchestratorResult> + Send>>
165 });
166 self.push_orchestrator_entry(
167 name,
168 OrchestratorEntry {
169 f,
170 version: Some(version.to_string()),
171 is_latest: true,
172 },
173 );
174 }
175
176 pub fn add_named_activity<F, Fut>(&mut self, name: &str, f: F)
178 where
179 F: Fn(ActivityContext, Option<String>) -> Fut + Send + Sync + 'static,
180 Fut: Future<Output = ActivityResult> + Send + 'static,
181 {
182 tracing::info!(activity = %name, "Registering activity");
183 let f: ActivityFn = Arc::new(move |ctx, input| {
184 Box::pin(f(ctx, input)) as Pin<Box<dyn Future<Output = ActivityResult> + Send>>
185 });
186 self.activities.insert(name.to_string(), f);
187 }
188
189 pub fn get_orchestrator(&self, name: &str) -> Option<&OrchestratorFn> {
203 self.get_orchestrator_version(name, None)
204 }
205
206 pub fn get_orchestrator_version(
210 &self,
211 name: &str,
212 version: Option<&str>,
213 ) -> Option<&OrchestratorFn> {
214 let entries = self.orchestrators.get(name)?;
215
216 if let Some(v) = version {
218 if let Some(entry) = entries.iter().find(|e| e.version.as_deref() == Some(v)) {
219 return Some(&entry.f);
220 }
221 }
222
223 if let Some(entry) = entries.iter().rev().find(|e| e.is_latest) {
225 return Some(&entry.f);
226 }
227
228 if let Some(entry) = entries.iter().find(|e| e.version.is_none()) {
230 return Some(&entry.f);
231 }
232
233 None
234 }
235
236 pub fn get_activity(&self, name: &str) -> Option<&ActivityFn> {
238 self.activities.get(name)
239 }
240}
241
242impl Default for Registry {
243 fn default() -> Self {
244 Self::new()
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 async fn dummy_orchestrator(_ctx: OrchestrationContext) -> OrchestratorResult {
253 Ok(Some("\"done\"".to_string()))
254 }
255
256 async fn dummy_activity(_ctx: ActivityContext, _input: Option<String>) -> ActivityResult {
257 Ok(Some("\"result\"".to_string()))
258 }
259
260 #[test]
261 fn test_register_and_lookup_orchestrator() {
262 let mut reg = Registry::new();
263 reg.add_named_orchestrator("my_orch", dummy_orchestrator);
264 assert!(reg.get_orchestrator("my_orch").is_some());
265 assert!(reg.get_orchestrator("missing").is_none());
266 }
267
268 #[test]
269 fn test_register_and_lookup_activity() {
270 let mut reg = Registry::new();
271 reg.add_named_activity("my_act", dummy_activity);
272 assert!(reg.get_activity("my_act").is_some());
273 assert!(reg.get_activity("missing").is_none());
274 }
275
276 #[tokio::test]
277 async fn test_invoke_orchestrator() {
278 let mut reg = Registry::new();
279 reg.add_named_orchestrator("orch", dummy_orchestrator);
280
281 let f = reg.get_orchestrator("orch").unwrap();
282 let ctx = OrchestrationContext::new(
283 "test".to_string(),
284 "orch".to_string(),
285 None,
286 chrono::Utc::now(),
287 false,
288 &crate::worker::WorkerOptions::default(),
289 0,
290 );
291 let result = (f)(ctx).await;
292 assert_eq!(result.unwrap(), Some("\"done\"".to_string()));
293 }
294
295 #[tokio::test]
296 async fn test_invoke_activity() {
297 let mut reg = Registry::new();
298 reg.add_named_activity("act", dummy_activity);
299
300 let f = reg.get_activity("act").unwrap();
301 let ctx = ActivityContext::new("test".to_string(), 0, String::new());
302 let result = (f)(ctx, None).await;
303 assert_eq!(result.unwrap(), Some("\"result\"".to_string()));
304 }
305
306 #[test]
309 fn test_versioned_exact_match() {
310 let mut reg = Registry::new();
311 reg.add_versioned_orchestrator("orch", "v1", |_| async move { Ok(Some("v1".to_string())) });
312 reg.add_versioned_orchestrator("orch", "v2", |_| async move { Ok(Some("v2".to_string())) });
313
314 assert!(reg.get_orchestrator_version("orch", Some("v1")).is_some());
316 assert!(reg.get_orchestrator_version("orch", Some("v2")).is_some());
317 assert!(reg.get_orchestrator_version("orch", Some("v3")).is_none());
318 }
319
320 #[test]
321 fn test_latest_is_fallback() {
322 let mut reg = Registry::new();
323 reg.add_versioned_orchestrator("orch", "v1", |_| async move { Ok(Some("v1".to_string())) });
324 reg.add_latest_orchestrator("orch", "v2", |_| async move { Ok(Some("v2".to_string())) });
325
326 assert!(reg.get_orchestrator_version("orch", Some("v99")).is_some());
328 assert!(reg.get_orchestrator("orch").is_some());
330 }
331
332 #[test]
333 fn test_unversioned_fallback() {
334 let mut reg = Registry::new();
335 reg.add_named_orchestrator("orch", dummy_orchestrator);
336
337 assert!(reg.get_orchestrator_version("orch", Some("any")).is_some());
340 }
341}