1use super::{ActivityHandler, FnActivity, FnOrchestration, OrchestrationHandler};
9use crate::_typed_codec::Codec;
10use crate::OrchestrationContext;
11use semver::Version;
12use std::collections::HashMap;
13use std::sync::{Arc, Mutex};
14
15const DEFAULT_VERSION: Version = Version::new(1, 0, 0);
17
18#[derive(Clone, Debug)]
19pub enum VersionPolicy {
20 Latest,
21 Exact(Version),
22}
23
24pub struct Registry<H: ?Sized> {
29 pub(crate) inner: Arc<HashMap<String, std::collections::BTreeMap<Version, Arc<H>>>>,
30 pub(crate) policy: Arc<Mutex<HashMap<String, VersionPolicy>>>,
31}
32
33impl<H: ?Sized> Clone for Registry<H> {
35 fn clone(&self) -> Self {
36 Self {
37 inner: Arc::clone(&self.inner),
38 policy: Arc::clone(&self.policy),
39 }
40 }
41}
42
43impl<H: ?Sized> Default for Registry<H> {
44 fn default() -> Self {
45 Self {
46 inner: Arc::new(HashMap::new()),
47 policy: Arc::new(Mutex::new(HashMap::new())),
48 }
49 }
50}
51
52pub struct RegistryBuilder<H: ?Sized> {
54 map: HashMap<String, std::collections::BTreeMap<Version, Arc<H>>>,
55 policy: HashMap<String, VersionPolicy>,
56 errors: Vec<String>,
57}
58
59pub type OrchestrationRegistry = Registry<dyn OrchestrationHandler>;
61pub type ActivityRegistry = Registry<dyn ActivityHandler>;
62pub type OrchestrationRegistryBuilder = RegistryBuilder<dyn OrchestrationHandler>;
63pub type ActivityRegistryBuilder = RegistryBuilder<dyn ActivityHandler>;
64
65impl<H: ?Sized> Registry<H> {
70 pub fn builder() -> RegistryBuilder<H> {
71 RegistryBuilder {
72 map: HashMap::new(),
73 policy: HashMap::new(),
74 errors: Vec::new(),
75 }
76 }
77
78 pub fn builder_from(reg: &Registry<H>) -> RegistryBuilder<H> {
79 RegistryBuilder {
80 map: reg.inner.as_ref().clone(),
81 policy: reg.policy.lock().expect("Mutex should not be poisoned").clone(),
83 errors: Vec::new(),
84 }
85 }
86
87 pub fn resolve_handler(&self, name: &str) -> Option<(Version, Arc<H>)> {
89 let pol = self
90 .policy
91 .lock()
92 .expect("Mutex should not be poisoned")
94 .get(name)
95 .cloned()
96 .unwrap_or(VersionPolicy::Latest);
97
98 let result = match &pol {
99 VersionPolicy::Latest => {
100 if let Some(m) = self.inner.get(name) {
101 if let Some((v, h)) = m.iter().next_back() {
102 Some((v.clone(), h.clone()))
103 } else {
104 None
105 }
106 } else {
107 None
108 }
109 }
110 VersionPolicy::Exact(v) => self
111 .inner
112 .get(name)
113 .and_then(|versions| versions.get(v))
114 .map(|h| (v.clone(), Arc::clone(h))),
115 };
116
117 if result.is_none() {
118 self.log_registry_miss(name, None, Some(&pol));
119 }
120
121 result
122 }
123
124 pub fn resolve_version(&self, name: &str) -> Option<Version> {
127 self.resolve_handler(name).map(|(v, _h)| v)
128 }
129
130 pub fn resolve_handler_exact(&self, name: &str, v: &Version) -> Option<Arc<H>> {
132 let result = if let Some(versions) = self.inner.get(name) {
133 versions.get(v).cloned()
134 } else {
135 None
136 };
137
138 if result.is_none() {
139 self.log_registry_miss(name, Some(v), None);
140 }
141
142 result
143 }
144
145 pub fn set_version_policy(&self, name: &str, policy: VersionPolicy) {
147 self.policy
149 .lock()
150 .expect("Mutex should not be poisoned")
151 .insert(name.to_string(), policy);
152 }
153
154 pub fn list_names(&self) -> Vec<String> {
156 self.inner.keys().cloned().collect()
157 }
158
159 pub fn list_versions(&self, name: &str) -> Vec<Version> {
161 self.inner
162 .get(name)
163 .map(|m| m.keys().cloned().collect())
164 .unwrap_or_default()
165 }
166
167 pub fn has(&self, name: &str) -> bool {
169 self.inner.contains_key(name)
170 }
171
172 pub fn count(&self) -> usize {
174 self.inner.len()
175 }
176
177 fn debug_dump(&self) -> HashMap<String, Vec<String>> {
179 self.inner
180 .iter()
181 .map(|(name, versions)| (name.clone(), versions.keys().map(|v| v.to_string()).collect()))
182 .collect()
183 }
184
185 fn log_registry_miss(
186 &self,
187 name: &str,
188 requested_version: Option<&Version>,
189 requested_policy: Option<&VersionPolicy>,
190 ) {
191 let all_names = self.list_names();
192 let contents = self.debug_dump();
193 let policy_map = self.policy.lock().expect("Mutex should not be poisoned").clone();
195 let available_versions = self.list_versions(name);
196
197 tracing::debug!(
198 target: "duroxide::runtime::registry",
199 requested_name = %name,
200 requested_version = ?requested_version,
201 requested_policy = ?requested_policy,
202 available_versions_for_name = ?available_versions,
203 registered_count = all_names.len(),
204 registered_names = ?all_names,
205 full_registry_contents = ?contents,
206 current_policies = ?policy_map,
207 "Registry lookup miss - dumping full registry state"
208 );
209 }
210}
211
212impl<H: ?Sized> RegistryBuilder<H> {
217 pub fn build(self) -> Registry<H> {
218 Registry {
219 inner: Arc::new(self.map),
220 policy: Arc::new(Mutex::new(self.policy)),
221 }
222 }
223
224 pub fn build_result(self) -> Result<Registry<H>, String> {
230 if self.errors.is_empty() {
231 Ok(self.build())
232 } else {
233 Err(self.errors.join("; "))
234 }
235 }
236
237 pub fn merge_registry(mut self, other: Registry<H>, error_prefix: &str) -> Self {
239 for (name, versions) in other.inner.iter() {
240 let entry = self.map.entry(name.clone()).or_default();
241 for (version, handler) in versions.iter() {
242 if entry.contains_key(version) {
243 self.errors
244 .push(format!("duplicate {error_prefix} in merge: {name}@{version}"));
245 } else {
246 entry.insert(version.clone(), handler.clone());
247 }
248 }
249 }
250 self
251 }
252
253 pub fn register_all_handlers<F>(self, items: Vec<(&str, F)>, register_fn: impl Fn(Self, &str, F) -> Self) -> Self
255 where
256 F: Clone,
257 {
258 items
259 .into_iter()
260 .fold(self, |builder, (name, f)| register_fn(builder, name, f))
261 }
262
263 fn check_duplicate(&mut self, name: &str, version: &Version, error_prefix: &str) -> bool {
265 let entry = self.map.entry(name.to_string()).or_default();
266 if entry.contains_key(version) {
267 self.errors
268 .push(format!("duplicate {error_prefix} registration: {name}@{version}"));
269 true
270 } else {
271 false
272 }
273 }
274}
275
276impl OrchestrationRegistryBuilder {
281 pub fn register<F, Fut>(mut self, name: impl Into<String>, f: F) -> Self
282 where
283 F: Fn(OrchestrationContext, String) -> Fut + Send + Sync + 'static,
284 Fut: std::future::Future<Output = Result<String, String>> + Send + 'static,
285 {
286 let name = name.into();
287 if self.check_duplicate(&name, &DEFAULT_VERSION, "orchestration") {
288 return self;
289 }
290 self.map
291 .entry(name)
292 .or_default()
293 .insert(DEFAULT_VERSION, Arc::new(FnOrchestration(f)));
294 self
295 }
296
297 pub fn register_typed<In, Out, F, Fut>(mut self, name: impl Into<String>, f: F) -> Self
298 where
299 In: serde::de::DeserializeOwned + Send + 'static,
300 Out: serde::Serialize + Send + 'static,
301 F: Fn(OrchestrationContext, In) -> Fut + Send + Sync + Clone + 'static,
302 Fut: std::future::Future<Output = Result<Out, String>> + Send + 'static,
303 {
304 use super::FnOrchestration;
305 let f_clone = f.clone();
306 let wrapper = move |ctx: OrchestrationContext, input_s: String| {
307 let f_inner = f_clone.clone();
308 async move {
309 let input: In = crate::_typed_codec::Json::decode(&input_s)?;
310 let out: Out = f_inner(ctx, input).await?;
311 crate::_typed_codec::Json::encode(&out)
312 }
313 };
314 let name = name.into();
315 self.map
316 .entry(name)
317 .or_default()
318 .insert(DEFAULT_VERSION, Arc::new(FnOrchestration(wrapper)));
319 self
320 }
321
322 pub fn register_versioned<F, Fut>(mut self, name: impl Into<String>, version: impl AsRef<str>, f: F) -> Self
323 where
324 F: Fn(OrchestrationContext, String) -> Fut + Send + Sync + 'static,
325 Fut: std::future::Future<Output = Result<String, String>> + Send + 'static,
326 {
327 let name = name.into();
328 let v = Version::parse(version.as_ref()).expect("Version should be valid semver");
330 if self.check_duplicate(&name, &v, "orchestration") {
331 return self;
332 }
333 let entry = self.map.entry(name.clone()).or_default();
334 if let Some((latest, _)) = entry.iter().next_back()
335 && &v <= latest
336 {
337 panic!("non-monotonic orchestration version for {name}: {v} is not later than existing latest {latest}");
338 }
339 entry.insert(v, Arc::new(FnOrchestration(f)));
340 self
341 }
342
343 pub fn register_versioned_typed<In, Out, F, Fut>(
344 mut self,
345 name: impl Into<String>,
346 version: impl AsRef<str>,
347 f: F,
348 ) -> Self
349 where
350 In: serde::de::DeserializeOwned + Send + 'static,
351 Out: serde::Serialize + Send + 'static,
352 F: Fn(OrchestrationContext, In) -> Fut + Send + Sync + Clone + 'static,
353 Fut: std::future::Future<Output = Result<Out, String>> + Send + 'static,
354 {
355 use super::FnOrchestration;
356 let name = name.into();
357 let v = Version::parse(version.as_ref()).expect("Version should be valid semver");
359 if self.check_duplicate(&name, &v, "orchestration") {
360 return self;
361 }
362 let entry = self.map.entry(name.clone()).or_default();
363 if let Some((latest, _)) = entry.iter().next_back()
364 && &v <= latest
365 {
366 panic!("non-monotonic orchestration version for {name}: {v} is not later than existing latest {latest}");
367 }
368 let f_clone = f.clone();
369 let wrapper = move |ctx: OrchestrationContext, input_s: String| {
370 let f_inner = f_clone.clone();
371 async move {
372 let input: In = crate::_typed_codec::Json::decode(&input_s)?;
373 let out: Out = f_inner(ctx, input).await?;
374 crate::_typed_codec::Json::encode(&out)
375 }
376 };
377 self.map
378 .entry(name)
379 .or_default()
380 .insert(v, Arc::new(FnOrchestration(wrapper)));
381 self
382 }
383
384 pub fn merge(self, other: OrchestrationRegistry) -> Self {
385 self.merge_registry(other, "orchestration")
386 }
387
388 pub fn register_all<F, Fut>(self, items: Vec<(&str, F)>) -> Self
389 where
390 F: Fn(OrchestrationContext, String) -> Fut + Send + Sync + 'static + Clone,
391 Fut: std::future::Future<Output = Result<String, String>> + Send + 'static,
392 {
393 self.register_all_handlers(items, |builder, name, f| builder.register(name, f))
394 }
395
396 pub fn set_policy(mut self, name: impl Into<String>, policy: VersionPolicy) -> Self {
397 self.policy.insert(name.into(), policy);
398 self
399 }
400}
401
402impl ActivityRegistryBuilder {
407 pub fn from_registry(reg: &ActivityRegistry) -> Self {
409 ActivityRegistry::builder_from(reg)
410 }
411
412 pub fn register<F, Fut>(mut self, name: impl Into<String>, f: F) -> Self
413 where
414 F: Fn(crate::ActivityContext, String) -> Fut + Send + Sync + 'static,
415 Fut: std::future::Future<Output = Result<String, String>> + Send + 'static,
416 {
417 let name = name.into();
418 if self.check_duplicate(&name, &DEFAULT_VERSION, "activity") {
419 return self;
420 }
421 self.map
422 .entry(name.clone())
423 .or_default()
424 .insert(DEFAULT_VERSION, Arc::new(FnActivity(f)));
425 self.policy.insert(name, VersionPolicy::Latest);
427 self
428 }
429
430 pub fn register_typed<In, Out, F, Fut>(mut self, name: impl Into<String>, f: F) -> Self
431 where
432 In: serde::de::DeserializeOwned + Send + 'static,
433 Out: serde::Serialize + Send + 'static,
434 F: Fn(crate::ActivityContext, In) -> Fut + Send + Sync + 'static,
435 Fut: std::future::Future<Output = Result<Out, String>> + Send + 'static,
436 {
437 let f_clone = std::sync::Arc::new(f);
438 let wrapper = move |ctx: crate::ActivityContext, input_s: String| {
439 let f_inner = f_clone.clone();
440 async move {
441 let input: In = crate::_typed_codec::Json::decode(&input_s)?;
442 let out: Out = (f_inner)(ctx, input).await?;
443 crate::_typed_codec::Json::encode(&out)
444 }
445 };
446 let name = name.into();
447 if self.check_duplicate(&name, &DEFAULT_VERSION, "activity") {
448 return self;
449 }
450 self.map
451 .entry(name.clone())
452 .or_default()
453 .insert(DEFAULT_VERSION, Arc::new(FnActivity(wrapper)));
454 self.policy.insert(name, VersionPolicy::Latest);
456 self
457 }
458
459 pub fn merge(self, other: ActivityRegistry) -> Self {
460 self.merge_registry(other, "activity")
461 }
462
463 pub fn register_all<F, Fut>(self, items: Vec<(&str, F)>) -> Self
464 where
465 F: Fn(crate::ActivityContext, String) -> Fut + Send + Sync + 'static + Clone,
466 Fut: std::future::Future<Output = Result<String, String>> + Send + 'static,
467 {
468 self.register_all_handlers(items, |builder, name, f| builder.register(name, f))
469 }
470}