Skip to main content

dapr_durabletask/worker/
registry.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use crate::api::Result;
7use crate::task::{ActivityContext, OrchestrationContext};
8
9/// Type alias for orchestrator function return type.
10pub type OrchestratorResult = Result<Option<String>>;
11
12/// Type alias for boxed orchestrator functions.
13///
14/// Orchestrator functions are async functions that take an `OrchestrationContext`
15/// and return an optional JSON-serialised result string.
16pub type OrchestratorFn = Arc<
17    dyn Fn(OrchestrationContext) -> Pin<Box<dyn Future<Output = OrchestratorResult> + Send>>
18        + Send
19        + Sync,
20>;
21
22/// Type alias for activity function return type.
23pub type ActivityResult = Result<Option<String>>;
24
25/// Type alias for boxed activity functions.
26///
27/// Activity functions are async functions that take an `ActivityContext` and
28/// optional JSON-serialised input string.
29pub type ActivityFn = Arc<
30    dyn Fn(ActivityContext, Option<String>) -> Pin<Box<dyn Future<Output = ActivityResult> + Send>>
31        + Send
32        + Sync,
33>;
34
35/// An entry in the orchestrator registry.
36///
37/// Holds the function pointer and an optional version label. The sidecar
38/// dispatches work items by name; the registry resolves the correct function
39/// using the version selection rules (exact match → latest → unversioned).
40#[derive(Clone)]
41struct OrchestratorEntry {
42    f: OrchestratorFn,
43    /// `None` means the entry is unversioned (registered with
44    /// [`Registry::add_named_orchestrator`]).
45    version: Option<String>,
46    is_latest: bool,
47}
48
49/// Registry for orchestrator and activity functions.
50///
51/// Functions must be registered before the worker is started. Versioned
52/// orchestrators can be registered with [`add_versioned_orchestrator`] and a
53/// specific version string, or with [`add_latest_orchestrator`] to mark a
54/// version as the default when no exact version match is found.
55///
56/// [`add_versioned_orchestrator`]: Registry::add_versioned_orchestrator
57/// [`add_latest_orchestrator`]: Registry::add_latest_orchestrator
58pub struct Registry {
59    /// Map of orchestrator name → list of registered entries.
60    orchestrators: HashMap<String, Vec<OrchestratorEntry>>,
61    activities: HashMap<String, ActivityFn>,
62}
63
64impl Registry {
65    /// Create an empty registry.
66    pub fn new() -> Self {
67        Self {
68            orchestrators: HashMap::new(),
69            activities: HashMap::new(),
70        }
71    }
72
73    // ─── Internal helpers ─────────────────────────────────────────────────
74
75    fn push_orchestrator_entry(&mut self, name: &str, entry: OrchestratorEntry) {
76        self.orchestrators
77            .entry(name.to_string())
78            .or_default()
79            .push(entry);
80    }
81
82    // ─── Registration ─────────────────────────────────────────────────────
83
84    /// Register an unversioned orchestrator function with the given name.
85    ///
86    /// This is the simplest registration path. When versioning is not
87    /// required, prefer this method.
88    pub fn add_named_orchestrator<F, Fut>(&mut self, name: &str, f: F)
89    where
90        F: Fn(OrchestrationContext) -> Fut + Send + Sync + 'static,
91        Fut: Future<Output = OrchestratorResult> + Send + 'static,
92    {
93        tracing::info!(orchestrator = %name, "Registering orchestrator");
94        let f: OrchestratorFn = Arc::new(move |ctx| {
95            Box::pin(f(ctx)) as Pin<Box<dyn Future<Output = OrchestratorResult> + Send>>
96        });
97        self.push_orchestrator_entry(
98            name,
99            OrchestratorEntry {
100                f,
101                version: None,
102                is_latest: false,
103            },
104        );
105    }
106
107    /// Register a versioned orchestrator function.
108    ///
109    /// The `version` is matched against the `version` field of incoming
110    /// `ExecutionStarted` history events. Use this when you need to run
111    /// multiple versions of the same orchestrator simultaneously.
112    ///
113    /// # Examples
114    ///
115    /// ```rust,no_run
116    /// use dapr_durabletask::worker::Registry;
117    ///
118    /// let mut reg = Registry::new();
119    /// reg.add_versioned_orchestrator("my_orch", "v1", |ctx| async move { Ok(None) });
120    /// reg.add_versioned_orchestrator("my_orch", "v2", |ctx| async move { Ok(None) });
121    /// ```
122    pub fn add_versioned_orchestrator<F, Fut>(&mut self, name: &str, version: &str, f: F)
123    where
124        F: Fn(OrchestrationContext) -> Fut + Send + Sync + 'static,
125        Fut: Future<Output = OrchestratorResult> + Send + 'static,
126    {
127        tracing::info!(orchestrator = %name, version = %version, "Registering versioned orchestrator");
128        let f: OrchestratorFn = Arc::new(move |ctx| {
129            Box::pin(f(ctx)) as Pin<Box<dyn Future<Output = OrchestratorResult> + Send>>
130        });
131        self.push_orchestrator_entry(
132            name,
133            OrchestratorEntry {
134                f,
135                version: Some(version.to_string()),
136                is_latest: false,
137            },
138        );
139    }
140
141    /// Register a versioned orchestrator and mark it as the *latest*.
142    ///
143    /// The latest entry is selected when no exact version match is found. If
144    /// multiple entries are marked as latest for the same name, the last one
145    /// registered wins.
146    ///
147    /// # Examples
148    ///
149    /// ```rust,no_run
150    /// use dapr_durabletask::worker::Registry;
151    ///
152    /// let mut reg = Registry::new();
153    /// reg.add_versioned_orchestrator("my_orch", "v1", |ctx| async move { Ok(None) });
154    /// reg.add_latest_orchestrator("my_orch", "v2", |ctx| async move { Ok(None) });
155    /// // Requests for "v1" → v1 handler; any other version → v2 handler.
156    /// ```
157    pub fn add_latest_orchestrator<F, Fut>(&mut self, name: &str, version: &str, f: F)
158    where
159        F: Fn(OrchestrationContext) -> Fut + Send + Sync + 'static,
160        Fut: Future<Output = OrchestratorResult> + Send + 'static,
161    {
162        tracing::info!(orchestrator = %name, version = %version, "Registering latest orchestrator");
163        let f: OrchestratorFn = Arc::new(move |ctx| {
164            Box::pin(f(ctx)) as Pin<Box<dyn Future<Output = OrchestratorResult> + Send>>
165        });
166        self.push_orchestrator_entry(
167            name,
168            OrchestratorEntry {
169                f,
170                version: Some(version.to_string()),
171                is_latest: true,
172            },
173        );
174    }
175
176    /// Register an activity function with the given name.
177    pub fn add_named_activity<F, Fut>(&mut self, name: &str, f: F)
178    where
179        F: Fn(ActivityContext, Option<String>) -> Fut + Send + Sync + 'static,
180        Fut: Future<Output = ActivityResult> + Send + 'static,
181    {
182        tracing::info!(activity = %name, "Registering activity");
183        let f: ActivityFn = Arc::new(move |ctx, input| {
184            Box::pin(f(ctx, input)) as Pin<Box<dyn Future<Output = ActivityResult> + Send>>
185        });
186        self.activities.insert(name.to_string(), f);
187    }
188
189    // ─── Lookup ───────────────────────────────────────────────────────────
190
191    /// Look up a registered orchestrator by name and optional version.
192    ///
193    /// Resolution order:
194    /// 1. Exact version match (when `version` is `Some`).
195    /// 2. The entry marked as `is_latest` for this name.
196    /// 3. The single unversioned entry (registered via
197    ///    [`add_named_orchestrator`]).
198    ///
199    /// Returns `None` if no suitable entry is found.
200    ///
201    /// [`add_named_orchestrator`]: Registry::add_named_orchestrator
202    pub fn get_orchestrator(&self, name: &str) -> Option<&OrchestratorFn> {
203        self.get_orchestrator_version(name, None)
204    }
205
206    /// Look up a registered orchestrator by name, optionally constraining the
207    /// version. See [`get_orchestrator`](Self::get_orchestrator) for the
208    /// resolution rules.
209    pub fn get_orchestrator_version(
210        &self,
211        name: &str,
212        version: Option<&str>,
213    ) -> Option<&OrchestratorFn> {
214        let entries = self.orchestrators.get(name)?;
215
216        // 1. Exact version match.
217        if let Some(v) = version {
218            if let Some(entry) = entries.iter().find(|e| e.version.as_deref() == Some(v)) {
219                return Some(&entry.f);
220            }
221        }
222
223        // 2. Latest-flagged entry (last registered wins).
224        if let Some(entry) = entries.iter().rev().find(|e| e.is_latest) {
225            return Some(&entry.f);
226        }
227
228        // 3. Unversioned entry.
229        if let Some(entry) = entries.iter().find(|e| e.version.is_none()) {
230            return Some(&entry.f);
231        }
232
233        None
234    }
235
236    /// Look up a registered activity by name.
237    pub fn get_activity(&self, name: &str) -> Option<&ActivityFn> {
238        self.activities.get(name)
239    }
240}
241
242impl Default for Registry {
243    fn default() -> Self {
244        Self::new()
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    async fn dummy_orchestrator(_ctx: OrchestrationContext) -> OrchestratorResult {
253        Ok(Some("\"done\"".to_string()))
254    }
255
256    async fn dummy_activity(_ctx: ActivityContext, _input: Option<String>) -> ActivityResult {
257        Ok(Some("\"result\"".to_string()))
258    }
259
260    #[test]
261    fn test_register_and_lookup_orchestrator() {
262        let mut reg = Registry::new();
263        reg.add_named_orchestrator("my_orch", dummy_orchestrator);
264        assert!(reg.get_orchestrator("my_orch").is_some());
265        assert!(reg.get_orchestrator("missing").is_none());
266    }
267
268    #[test]
269    fn test_register_and_lookup_activity() {
270        let mut reg = Registry::new();
271        reg.add_named_activity("my_act", dummy_activity);
272        assert!(reg.get_activity("my_act").is_some());
273        assert!(reg.get_activity("missing").is_none());
274    }
275
276    #[tokio::test]
277    async fn test_invoke_orchestrator() {
278        let mut reg = Registry::new();
279        reg.add_named_orchestrator("orch", dummy_orchestrator);
280
281        let f = reg.get_orchestrator("orch").unwrap();
282        let ctx = OrchestrationContext::new(
283            "test".to_string(),
284            "orch".to_string(),
285            None,
286            chrono::Utc::now(),
287            false,
288            &crate::worker::WorkerOptions::default(),
289            0,
290        );
291        let result = (f)(ctx).await;
292        assert_eq!(result.unwrap(), Some("\"done\"".to_string()));
293    }
294
295    #[tokio::test]
296    async fn test_invoke_activity() {
297        let mut reg = Registry::new();
298        reg.add_named_activity("act", dummy_activity);
299
300        let f = reg.get_activity("act").unwrap();
301        let ctx = ActivityContext::new("test".to_string(), 0, String::new());
302        let result = (f)(ctx, None).await;
303        assert_eq!(result.unwrap(), Some("\"result\"".to_string()));
304    }
305
306    // ─── Versioned orchestrator tests ──────────────────────────────────────
307
308    #[test]
309    fn test_versioned_exact_match() {
310        let mut reg = Registry::new();
311        reg.add_versioned_orchestrator("orch", "v1", |_| async move { Ok(Some("v1".to_string())) });
312        reg.add_versioned_orchestrator("orch", "v2", |_| async move { Ok(Some("v2".to_string())) });
313
314        // Exact version match should resolve to the registered function.
315        assert!(reg.get_orchestrator_version("orch", Some("v1")).is_some());
316        assert!(reg.get_orchestrator_version("orch", Some("v2")).is_some());
317        assert!(reg.get_orchestrator_version("orch", Some("v3")).is_none());
318    }
319
320    #[test]
321    fn test_latest_is_fallback() {
322        let mut reg = Registry::new();
323        reg.add_versioned_orchestrator("orch", "v1", |_| async move { Ok(Some("v1".to_string())) });
324        reg.add_latest_orchestrator("orch", "v2", |_| async move { Ok(Some("v2".to_string())) });
325
326        // Unknown version should fall back to the latest-flagged entry.
327        assert!(reg.get_orchestrator_version("orch", Some("v99")).is_some());
328        // No version should also resolve to latest.
329        assert!(reg.get_orchestrator("orch").is_some());
330    }
331
332    #[test]
333    fn test_unversioned_fallback() {
334        let mut reg = Registry::new();
335        reg.add_named_orchestrator("orch", dummy_orchestrator);
336
337        // Requesting a specific version when only unversioned exists falls
338        // back to the unversioned entry.
339        assert!(reg.get_orchestrator_version("orch", Some("any")).is_some());
340    }
341}