1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use crate::api::Result;
7use crate::task::{ActivityContext, OrchestrationContext};
8
9pub type OrchestratorResult = Result<Option<String>>;
11
12pub type OrchestratorFn = Arc<
17 dyn Fn(OrchestrationContext) -> Pin<Box<dyn Future<Output = OrchestratorResult> + Send>>
18 + Send
19 + Sync,
20>;
21
22pub type ActivityResult = Result<Option<String>>;
24
25pub type ActivityFn = Arc<
30 dyn Fn(ActivityContext, Option<String>) -> Pin<Box<dyn Future<Output = ActivityResult> + Send>>
31 + Send
32 + Sync,
33>;
34
35#[derive(Clone)]
41struct OrchestratorEntry {
42 f: OrchestratorFn,
43 version: Option<String>,
46 is_latest: bool,
47}
48
49pub struct Registry {
59 orchestrators: HashMap<String, Vec<OrchestratorEntry>>,
61 activities: HashMap<String, ActivityFn>,
62}
63
64impl std::fmt::Debug for Registry {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 struct FnDebug;
67 impl std::fmt::Debug for FnDebug {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 f.write_str("<fn>")
70 }
71 }
72 type EntryView<'a> = (Option<&'a String>, bool, FnDebug);
73 let orchestrators: HashMap<&String, Vec<EntryView<'_>>> = self
74 .orchestrators
75 .iter()
76 .map(|(name, entries)| {
77 let rendered = entries
78 .iter()
79 .map(|e| (e.version.as_ref(), e.is_latest, FnDebug))
80 .collect();
81 (name, rendered)
82 })
83 .collect();
84 let activities: Vec<&String> = self.activities.keys().collect();
85 f.debug_struct("Registry")
86 .field("orchestrators", &orchestrators)
87 .field("activities", &activities)
88 .finish()
89 }
90}
91
92impl Registry {
93 pub fn new() -> Self {
95 Self {
96 orchestrators: HashMap::new(),
97 activities: HashMap::new(),
98 }
99 }
100
101 fn push_orchestrator_entry(&mut self, name: &str, entry: OrchestratorEntry) {
104 self.orchestrators
105 .entry(name.to_string())
106 .or_default()
107 .push(entry);
108 }
109
110 pub fn add_named_orchestrator<F, Fut>(&mut self, name: &str, f: F)
117 where
118 F: Fn(OrchestrationContext) -> Fut + Send + Sync + 'static,
119 Fut: Future<Output = OrchestratorResult> + Send + 'static,
120 {
121 tracing::info!(orchestrator = %name, "Registering orchestrator");
122 let f: OrchestratorFn = Arc::new(move |ctx| {
123 Box::pin(f(ctx)) as Pin<Box<dyn Future<Output = OrchestratorResult> + Send>>
124 });
125 self.push_orchestrator_entry(
126 name,
127 OrchestratorEntry {
128 f,
129 version: None,
130 is_latest: false,
131 },
132 );
133 }
134
135 pub fn add_versioned_orchestrator<F, Fut>(&mut self, name: &str, version: &str, f: F)
151 where
152 F: Fn(OrchestrationContext) -> Fut + Send + Sync + 'static,
153 Fut: Future<Output = OrchestratorResult> + Send + 'static,
154 {
155 tracing::info!(orchestrator = %name, version = %version, "Registering versioned orchestrator");
156 let f: OrchestratorFn = Arc::new(move |ctx| {
157 Box::pin(f(ctx)) as Pin<Box<dyn Future<Output = OrchestratorResult> + Send>>
158 });
159 self.push_orchestrator_entry(
160 name,
161 OrchestratorEntry {
162 f,
163 version: Some(version.to_string()),
164 is_latest: false,
165 },
166 );
167 }
168
169 pub fn add_latest_orchestrator<F, Fut>(&mut self, name: &str, version: &str, f: F)
186 where
187 F: Fn(OrchestrationContext) -> Fut + Send + Sync + 'static,
188 Fut: Future<Output = OrchestratorResult> + Send + 'static,
189 {
190 tracing::info!(orchestrator = %name, version = %version, "Registering latest orchestrator");
191 let f: OrchestratorFn = Arc::new(move |ctx| {
192 Box::pin(f(ctx)) as Pin<Box<dyn Future<Output = OrchestratorResult> + Send>>
193 });
194 self.push_orchestrator_entry(
195 name,
196 OrchestratorEntry {
197 f,
198 version: Some(version.to_string()),
199 is_latest: true,
200 },
201 );
202 }
203
204 pub fn add_named_activity<F, Fut>(&mut self, name: &str, f: F)
206 where
207 F: Fn(ActivityContext, Option<String>) -> Fut + Send + Sync + 'static,
208 Fut: Future<Output = ActivityResult> + Send + 'static,
209 {
210 tracing::info!(activity = %name, "Registering activity");
211 let f: ActivityFn = Arc::new(move |ctx, input| {
212 Box::pin(f(ctx, input)) as Pin<Box<dyn Future<Output = ActivityResult> + Send>>
213 });
214 self.activities.insert(name.to_string(), f);
215 }
216
217 pub fn get_orchestrator(&self, name: &str) -> Option<&OrchestratorFn> {
231 self.get_orchestrator_version(name, None)
232 }
233
234 pub fn get_orchestrator_version(
238 &self,
239 name: &str,
240 version: Option<&str>,
241 ) -> Option<&OrchestratorFn> {
242 let entries = self.orchestrators.get(name)?;
243
244 if let Some(v) = version
246 && let Some(entry) = entries.iter().find(|e| e.version.as_deref() == Some(v))
247 {
248 return Some(&entry.f);
249 }
250
251 if let Some(entry) = entries.iter().rev().find(|e| e.is_latest) {
253 return Some(&entry.f);
254 }
255
256 if let Some(entry) = entries.iter().find(|e| e.version.is_none()) {
258 return Some(&entry.f);
259 }
260
261 None
262 }
263
264 pub fn get_activity(&self, name: &str) -> Option<&ActivityFn> {
266 self.activities.get(name)
267 }
268}
269
270impl Default for Registry {
271 fn default() -> Self {
272 Self::new()
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 async fn dummy_orchestrator(_ctx: OrchestrationContext) -> OrchestratorResult {
281 Ok(Some("\"done\"".to_string()))
282 }
283
284 async fn dummy_activity(_ctx: ActivityContext, _input: Option<String>) -> ActivityResult {
285 Ok(Some("\"result\"".to_string()))
286 }
287
288 #[test]
289 fn test_register_and_lookup_orchestrator() {
290 let mut reg = Registry::new();
291 reg.add_named_orchestrator("my_orch", dummy_orchestrator);
292 assert!(reg.get_orchestrator("my_orch").is_some());
293 assert!(reg.get_orchestrator("missing").is_none());
294 }
295
296 #[test]
297 fn test_register_and_lookup_activity() {
298 let mut reg = Registry::new();
299 reg.add_named_activity("my_act", dummy_activity);
300 assert!(reg.get_activity("my_act").is_some());
301 assert!(reg.get_activity("missing").is_none());
302 }
303
304 #[tokio::test]
305 async fn test_invoke_orchestrator() {
306 let mut reg = Registry::new();
307 reg.add_named_orchestrator("orch", dummy_orchestrator);
308
309 let f = reg.get_orchestrator("orch").unwrap();
310 let ctx = OrchestrationContext::new(
311 "test".to_string(),
312 "orch".to_string(),
313 None,
314 chrono::Utc::now(),
315 false,
316 &crate::worker::WorkerOptions::default(),
317 0,
318 );
319 let result = (f)(ctx).await;
320 assert_eq!(result.unwrap(), Some("\"done\"".to_string()));
321 }
322
323 #[tokio::test]
324 async fn test_invoke_activity() {
325 let mut reg = Registry::new();
326 reg.add_named_activity("act", dummy_activity);
327
328 let f = reg.get_activity("act").unwrap();
329 let ctx = ActivityContext::new("test".to_string(), 0, String::new());
330 let result = (f)(ctx, None).await;
331 assert_eq!(result.unwrap(), Some("\"result\"".to_string()));
332 }
333
334 #[test]
337 fn test_versioned_exact_match() {
338 let mut reg = Registry::new();
339 reg.add_versioned_orchestrator("orch", "v1", |_| async move { Ok(Some("v1".to_string())) });
340 reg.add_versioned_orchestrator("orch", "v2", |_| async move { Ok(Some("v2".to_string())) });
341
342 assert!(reg.get_orchestrator_version("orch", Some("v1")).is_some());
344 assert!(reg.get_orchestrator_version("orch", Some("v2")).is_some());
345 assert!(reg.get_orchestrator_version("orch", Some("v3")).is_none());
346 }
347
348 #[test]
349 fn test_latest_is_fallback() {
350 let mut reg = Registry::new();
351 reg.add_versioned_orchestrator("orch", "v1", |_| async move { Ok(Some("v1".to_string())) });
352 reg.add_latest_orchestrator("orch", "v2", |_| async move { Ok(Some("v2".to_string())) });
353
354 assert!(reg.get_orchestrator_version("orch", Some("v99")).is_some());
356 assert!(reg.get_orchestrator("orch").is_some());
358 }
359
360 #[test]
361 fn test_unversioned_fallback() {
362 let mut reg = Registry::new();
363 reg.add_named_orchestrator("orch", dummy_orchestrator);
364
365 assert!(reg.get_orchestrator_version("orch", Some("any")).is_some());
368 }
369}