Skip to main content

dapr_durabletask/worker/
registry.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use crate::api::Result;
7use crate::task::{ActivityContext, OrchestrationContext};
8
9/// Type alias for orchestrator function return type.
10pub type OrchestratorResult = Result<Option<String>>;
11
12/// Type alias for boxed orchestrator functions.
13///
14/// Orchestrator functions are async functions that take an `OrchestrationContext`
15/// and return an optional JSON-serialised result string.
16pub type OrchestratorFn = Arc<
17    dyn Fn(OrchestrationContext) -> Pin<Box<dyn Future<Output = OrchestratorResult> + Send>>
18        + Send
19        + Sync,
20>;
21
22/// Type alias for activity function return type.
23pub type ActivityResult = Result<Option<String>>;
24
25/// Type alias for boxed activity functions.
26///
27/// Activity functions are async functions that take an `ActivityContext` and
28/// optional JSON-serialised input string.
29pub type ActivityFn = Arc<
30    dyn Fn(ActivityContext, Option<String>) -> Pin<Box<dyn Future<Output = ActivityResult> + Send>>
31        + Send
32        + Sync,
33>;
34
35/// An entry in the orchestrator registry.
36///
37/// Holds the function pointer and an optional version label. The sidecar
38/// dispatches work items by name; the registry resolves the correct function
39/// using the version selection rules (exact match → latest → unversioned).
40#[derive(Clone)]
41struct OrchestratorEntry {
42    f: OrchestratorFn,
43    /// `None` means the entry is unversioned (registered with
44    /// [`Registry::add_named_orchestrator`]).
45    version: Option<String>,
46    is_latest: bool,
47}
48
49/// Registry for orchestrator and activity functions.
50///
51/// Functions must be registered before the worker is started. Versioned
52/// orchestrators can be registered with [`add_versioned_orchestrator`] and a
53/// specific version string, or with [`add_latest_orchestrator`] to mark a
54/// version as the default when no exact version match is found.
55///
56/// [`add_versioned_orchestrator`]: Registry::add_versioned_orchestrator
57/// [`add_latest_orchestrator`]: Registry::add_latest_orchestrator
58pub struct Registry {
59    /// Map of orchestrator name → list of registered entries.
60    orchestrators: HashMap<String, Vec<OrchestratorEntry>>,
61    activities: HashMap<String, ActivityFn>,
62}
63
64impl std::fmt::Debug for Registry {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        struct FnDebug;
67        impl std::fmt::Debug for FnDebug {
68            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69                f.write_str("<fn>")
70            }
71        }
72        type EntryView<'a> = (Option<&'a String>, bool, FnDebug);
73        let orchestrators: HashMap<&String, Vec<EntryView<'_>>> = self
74            .orchestrators
75            .iter()
76            .map(|(name, entries)| {
77                let rendered = entries
78                    .iter()
79                    .map(|e| (e.version.as_ref(), e.is_latest, FnDebug))
80                    .collect();
81                (name, rendered)
82            })
83            .collect();
84        let activities: Vec<&String> = self.activities.keys().collect();
85        f.debug_struct("Registry")
86            .field("orchestrators", &orchestrators)
87            .field("activities", &activities)
88            .finish()
89    }
90}
91
92impl Registry {
93    /// Create an empty registry.
94    pub fn new() -> Self {
95        Self {
96            orchestrators: HashMap::new(),
97            activities: HashMap::new(),
98        }
99    }
100
101    // ─── Internal helpers ─────────────────────────────────────────────────
102
103    fn push_orchestrator_entry(&mut self, name: &str, entry: OrchestratorEntry) {
104        self.orchestrators
105            .entry(name.to_string())
106            .or_default()
107            .push(entry);
108    }
109
110    // ─── Registration ─────────────────────────────────────────────────────
111
112    /// Register an unversioned orchestrator function with the given name.
113    ///
114    /// This is the simplest registration path. When versioning is not
115    /// required, prefer this method.
116    pub fn add_named_orchestrator<F, Fut>(&mut self, name: &str, f: F)
117    where
118        F: Fn(OrchestrationContext) -> Fut + Send + Sync + 'static,
119        Fut: Future<Output = OrchestratorResult> + Send + 'static,
120    {
121        tracing::info!(orchestrator = %name, "Registering orchestrator");
122        let f: OrchestratorFn = Arc::new(move |ctx| {
123            Box::pin(f(ctx)) as Pin<Box<dyn Future<Output = OrchestratorResult> + Send>>
124        });
125        self.push_orchestrator_entry(
126            name,
127            OrchestratorEntry {
128                f,
129                version: None,
130                is_latest: false,
131            },
132        );
133    }
134
135    /// Register a versioned orchestrator function.
136    ///
137    /// The `version` is matched against the `version` field of incoming
138    /// `ExecutionStarted` history events. Use this when you need to run
139    /// multiple versions of the same orchestrator simultaneously.
140    ///
141    /// # Examples
142    ///
143    /// ```rust,no_run
144    /// use dapr_durabletask::worker::Registry;
145    ///
146    /// let mut reg = Registry::new();
147    /// reg.add_versioned_orchestrator("my_orch", "v1", |ctx| async move { Ok(None) });
148    /// reg.add_versioned_orchestrator("my_orch", "v2", |ctx| async move { Ok(None) });
149    /// ```
150    pub fn add_versioned_orchestrator<F, Fut>(&mut self, name: &str, version: &str, f: F)
151    where
152        F: Fn(OrchestrationContext) -> Fut + Send + Sync + 'static,
153        Fut: Future<Output = OrchestratorResult> + Send + 'static,
154    {
155        tracing::info!(orchestrator = %name, version = %version, "Registering versioned orchestrator");
156        let f: OrchestratorFn = Arc::new(move |ctx| {
157            Box::pin(f(ctx)) as Pin<Box<dyn Future<Output = OrchestratorResult> + Send>>
158        });
159        self.push_orchestrator_entry(
160            name,
161            OrchestratorEntry {
162                f,
163                version: Some(version.to_string()),
164                is_latest: false,
165            },
166        );
167    }
168
169    /// Register a versioned orchestrator and mark it as the *latest*.
170    ///
171    /// The latest entry is selected when no exact version match is found. If
172    /// multiple entries are marked as latest for the same name, the last one
173    /// registered wins.
174    ///
175    /// # Examples
176    ///
177    /// ```rust,no_run
178    /// use dapr_durabletask::worker::Registry;
179    ///
180    /// let mut reg = Registry::new();
181    /// reg.add_versioned_orchestrator("my_orch", "v1", |ctx| async move { Ok(None) });
182    /// reg.add_latest_orchestrator("my_orch", "v2", |ctx| async move { Ok(None) });
183    /// // Requests for "v1" → v1 handler; any other version → v2 handler.
184    /// ```
185    pub fn add_latest_orchestrator<F, Fut>(&mut self, name: &str, version: &str, f: F)
186    where
187        F: Fn(OrchestrationContext) -> Fut + Send + Sync + 'static,
188        Fut: Future<Output = OrchestratorResult> + Send + 'static,
189    {
190        tracing::info!(orchestrator = %name, version = %version, "Registering latest orchestrator");
191        let f: OrchestratorFn = Arc::new(move |ctx| {
192            Box::pin(f(ctx)) as Pin<Box<dyn Future<Output = OrchestratorResult> + Send>>
193        });
194        self.push_orchestrator_entry(
195            name,
196            OrchestratorEntry {
197                f,
198                version: Some(version.to_string()),
199                is_latest: true,
200            },
201        );
202    }
203
204    /// Register an activity function with the given name.
205    pub fn add_named_activity<F, Fut>(&mut self, name: &str, f: F)
206    where
207        F: Fn(ActivityContext, Option<String>) -> Fut + Send + Sync + 'static,
208        Fut: Future<Output = ActivityResult> + Send + 'static,
209    {
210        tracing::info!(activity = %name, "Registering activity");
211        let f: ActivityFn = Arc::new(move |ctx, input| {
212            Box::pin(f(ctx, input)) as Pin<Box<dyn Future<Output = ActivityResult> + Send>>
213        });
214        self.activities.insert(name.to_string(), f);
215    }
216
217    // ─── Lookup ───────────────────────────────────────────────────────────
218
219    /// Look up a registered orchestrator by name and optional version.
220    ///
221    /// Resolution order:
222    /// 1. Exact version match (when `version` is `Some`).
223    /// 2. The entry marked as `is_latest` for this name.
224    /// 3. The single unversioned entry (registered via
225    ///    [`add_named_orchestrator`]).
226    ///
227    /// Returns `None` if no suitable entry is found.
228    ///
229    /// [`add_named_orchestrator`]: Registry::add_named_orchestrator
230    pub fn get_orchestrator(&self, name: &str) -> Option<&OrchestratorFn> {
231        self.get_orchestrator_version(name, None)
232    }
233
234    /// Look up a registered orchestrator by name, optionally constraining the
235    /// version. See [`get_orchestrator`](Self::get_orchestrator) for the
236    /// resolution rules.
237    pub fn get_orchestrator_version(
238        &self,
239        name: &str,
240        version: Option<&str>,
241    ) -> Option<&OrchestratorFn> {
242        let entries = self.orchestrators.get(name)?;
243
244        // 1. Exact version match.
245        if let Some(v) = version
246            && let Some(entry) = entries.iter().find(|e| e.version.as_deref() == Some(v))
247        {
248            return Some(&entry.f);
249        }
250
251        // 2. Latest-flagged entry (last registered wins).
252        if let Some(entry) = entries.iter().rev().find(|e| e.is_latest) {
253            return Some(&entry.f);
254        }
255
256        // 3. Unversioned entry.
257        if let Some(entry) = entries.iter().find(|e| e.version.is_none()) {
258            return Some(&entry.f);
259        }
260
261        None
262    }
263
264    /// Look up a registered activity by name.
265    pub fn get_activity(&self, name: &str) -> Option<&ActivityFn> {
266        self.activities.get(name)
267    }
268}
269
270impl Default for Registry {
271    fn default() -> Self {
272        Self::new()
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    async fn dummy_orchestrator(_ctx: OrchestrationContext) -> OrchestratorResult {
281        Ok(Some("\"done\"".to_string()))
282    }
283
284    async fn dummy_activity(_ctx: ActivityContext, _input: Option<String>) -> ActivityResult {
285        Ok(Some("\"result\"".to_string()))
286    }
287
288    #[test]
289    fn test_register_and_lookup_orchestrator() {
290        let mut reg = Registry::new();
291        reg.add_named_orchestrator("my_orch", dummy_orchestrator);
292        assert!(reg.get_orchestrator("my_orch").is_some());
293        assert!(reg.get_orchestrator("missing").is_none());
294    }
295
296    #[test]
297    fn test_register_and_lookup_activity() {
298        let mut reg = Registry::new();
299        reg.add_named_activity("my_act", dummy_activity);
300        assert!(reg.get_activity("my_act").is_some());
301        assert!(reg.get_activity("missing").is_none());
302    }
303
304    #[tokio::test]
305    async fn test_invoke_orchestrator() {
306        let mut reg = Registry::new();
307        reg.add_named_orchestrator("orch", dummy_orchestrator);
308
309        let f = reg.get_orchestrator("orch").unwrap();
310        let ctx = OrchestrationContext::new(
311            "test".to_string(),
312            "orch".to_string(),
313            None,
314            chrono::Utc::now(),
315            false,
316            &crate::worker::WorkerOptions::default(),
317            0,
318        );
319        let result = (f)(ctx).await;
320        assert_eq!(result.unwrap(), Some("\"done\"".to_string()));
321    }
322
323    #[tokio::test]
324    async fn test_invoke_activity() {
325        let mut reg = Registry::new();
326        reg.add_named_activity("act", dummy_activity);
327
328        let f = reg.get_activity("act").unwrap();
329        let ctx = ActivityContext::new("test".to_string(), 0, String::new());
330        let result = (f)(ctx, None).await;
331        assert_eq!(result.unwrap(), Some("\"result\"".to_string()));
332    }
333
334    // ─── Versioned orchestrator tests ──────────────────────────────────────
335
336    #[test]
337    fn test_versioned_exact_match() {
338        let mut reg = Registry::new();
339        reg.add_versioned_orchestrator("orch", "v1", |_| async move { Ok(Some("v1".to_string())) });
340        reg.add_versioned_orchestrator("orch", "v2", |_| async move { Ok(Some("v2".to_string())) });
341
342        // Exact version match should resolve to the registered function.
343        assert!(reg.get_orchestrator_version("orch", Some("v1")).is_some());
344        assert!(reg.get_orchestrator_version("orch", Some("v2")).is_some());
345        assert!(reg.get_orchestrator_version("orch", Some("v3")).is_none());
346    }
347
348    #[test]
349    fn test_latest_is_fallback() {
350        let mut reg = Registry::new();
351        reg.add_versioned_orchestrator("orch", "v1", |_| async move { Ok(Some("v1".to_string())) });
352        reg.add_latest_orchestrator("orch", "v2", |_| async move { Ok(Some("v2".to_string())) });
353
354        // Unknown version should fall back to the latest-flagged entry.
355        assert!(reg.get_orchestrator_version("orch", Some("v99")).is_some());
356        // No version should also resolve to latest.
357        assert!(reg.get_orchestrator("orch").is_some());
358    }
359
360    #[test]
361    fn test_unversioned_fallback() {
362        let mut reg = Registry::new();
363        reg.add_named_orchestrator("orch", dummy_orchestrator);
364
365        // Requesting a specific version when only unversioned exists falls
366        // back to the unversioned entry.
367        assert!(reg.get_orchestrator_version("orch", Some("any")).is_some());
368    }
369}