use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use crate::api::Result;
use crate::task::{ActivityContext, OrchestrationContext};
pub type OrchestratorResult = Result<Option<String>>;
pub type OrchestratorFn = Arc<
dyn Fn(OrchestrationContext) -> Pin<Box<dyn Future<Output = OrchestratorResult> + Send>>
+ Send
+ Sync,
>;
pub type ActivityResult = Result<Option<String>>;
pub type ActivityFn = Arc<
dyn Fn(ActivityContext, Option<String>) -> Pin<Box<dyn Future<Output = ActivityResult> + Send>>
+ Send
+ Sync,
>;
#[derive(Clone)]
struct OrchestratorEntry {
f: OrchestratorFn,
version: Option<String>,
is_latest: bool,
}
pub struct Registry {
orchestrators: HashMap<String, Vec<OrchestratorEntry>>,
activities: HashMap<String, ActivityFn>,
}
impl std::fmt::Debug for Registry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
struct FnDebug;
impl std::fmt::Debug for FnDebug {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("<fn>")
}
}
type EntryView<'a> = (Option<&'a String>, bool, FnDebug);
let orchestrators: HashMap<&String, Vec<EntryView<'_>>> = self
.orchestrators
.iter()
.map(|(name, entries)| {
let rendered = entries
.iter()
.map(|e| (e.version.as_ref(), e.is_latest, FnDebug))
.collect();
(name, rendered)
})
.collect();
let activities: Vec<&String> = self.activities.keys().collect();
f.debug_struct("Registry")
.field("orchestrators", &orchestrators)
.field("activities", &activities)
.finish()
}
}
impl Registry {
pub fn new() -> Self {
Self {
orchestrators: HashMap::new(),
activities: HashMap::new(),
}
}
fn push_orchestrator_entry(&mut self, name: &str, entry: OrchestratorEntry) {
self.orchestrators
.entry(name.to_string())
.or_default()
.push(entry);
}
pub fn add_named_orchestrator<F, Fut>(&mut self, name: &str, f: F)
where
F: Fn(OrchestrationContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = OrchestratorResult> + Send + 'static,
{
tracing::info!(orchestrator = %name, "Registering orchestrator");
let f: OrchestratorFn = Arc::new(move |ctx| {
Box::pin(f(ctx)) as Pin<Box<dyn Future<Output = OrchestratorResult> + Send>>
});
self.push_orchestrator_entry(
name,
OrchestratorEntry {
f,
version: None,
is_latest: false,
},
);
}
pub fn add_versioned_orchestrator<F, Fut>(&mut self, name: &str, version: &str, f: F)
where
F: Fn(OrchestrationContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = OrchestratorResult> + Send + 'static,
{
tracing::info!(orchestrator = %name, version = %version, "Registering versioned orchestrator");
let f: OrchestratorFn = Arc::new(move |ctx| {
Box::pin(f(ctx)) as Pin<Box<dyn Future<Output = OrchestratorResult> + Send>>
});
self.push_orchestrator_entry(
name,
OrchestratorEntry {
f,
version: Some(version.to_string()),
is_latest: false,
},
);
}
pub fn add_latest_orchestrator<F, Fut>(&mut self, name: &str, version: &str, f: F)
where
F: Fn(OrchestrationContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = OrchestratorResult> + Send + 'static,
{
tracing::info!(orchestrator = %name, version = %version, "Registering latest orchestrator");
let f: OrchestratorFn = Arc::new(move |ctx| {
Box::pin(f(ctx)) as Pin<Box<dyn Future<Output = OrchestratorResult> + Send>>
});
self.push_orchestrator_entry(
name,
OrchestratorEntry {
f,
version: Some(version.to_string()),
is_latest: true,
},
);
}
pub fn add_named_activity<F, Fut>(&mut self, name: &str, f: F)
where
F: Fn(ActivityContext, Option<String>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ActivityResult> + Send + 'static,
{
tracing::info!(activity = %name, "Registering activity");
let f: ActivityFn = Arc::new(move |ctx, input| {
Box::pin(f(ctx, input)) as Pin<Box<dyn Future<Output = ActivityResult> + Send>>
});
self.activities.insert(name.to_string(), f);
}
pub fn get_orchestrator(&self, name: &str) -> Option<&OrchestratorFn> {
self.get_orchestrator_version(name, None)
}
pub fn get_orchestrator_version(
&self,
name: &str,
version: Option<&str>,
) -> Option<&OrchestratorFn> {
let entries = self.orchestrators.get(name)?;
if let Some(v) = version
&& let Some(entry) = entries.iter().find(|e| e.version.as_deref() == Some(v))
{
return Some(&entry.f);
}
if let Some(entry) = entries.iter().rev().find(|e| e.is_latest) {
return Some(&entry.f);
}
if let Some(entry) = entries.iter().find(|e| e.version.is_none()) {
return Some(&entry.f);
}
None
}
pub fn get_activity(&self, name: &str) -> Option<&ActivityFn> {
self.activities.get(name)
}
}
impl Default for Registry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
async fn dummy_orchestrator(_ctx: OrchestrationContext) -> OrchestratorResult {
Ok(Some("\"done\"".to_string()))
}
async fn dummy_activity(_ctx: ActivityContext, _input: Option<String>) -> ActivityResult {
Ok(Some("\"result\"".to_string()))
}
#[test]
fn test_register_and_lookup_orchestrator() {
let mut reg = Registry::new();
reg.add_named_orchestrator("my_orch", dummy_orchestrator);
assert!(reg.get_orchestrator("my_orch").is_some());
assert!(reg.get_orchestrator("missing").is_none());
}
#[test]
fn test_register_and_lookup_activity() {
let mut reg = Registry::new();
reg.add_named_activity("my_act", dummy_activity);
assert!(reg.get_activity("my_act").is_some());
assert!(reg.get_activity("missing").is_none());
}
#[tokio::test]
async fn test_invoke_orchestrator() {
let mut reg = Registry::new();
reg.add_named_orchestrator("orch", dummy_orchestrator);
let f = reg.get_orchestrator("orch").unwrap();
let ctx = OrchestrationContext::new(
"test".to_string(),
"orch".to_string(),
None,
chrono::Utc::now(),
false,
&crate::worker::WorkerOptions::default(),
0,
);
let result = (f)(ctx).await;
assert_eq!(result.unwrap(), Some("\"done\"".to_string()));
}
#[tokio::test]
async fn test_invoke_activity() {
let mut reg = Registry::new();
reg.add_named_activity("act", dummy_activity);
let f = reg.get_activity("act").unwrap();
let ctx = ActivityContext::new("test".to_string(), 0, String::new());
let result = (f)(ctx, None).await;
assert_eq!(result.unwrap(), Some("\"result\"".to_string()));
}
#[test]
fn test_versioned_exact_match() {
let mut reg = Registry::new();
reg.add_versioned_orchestrator("orch", "v1", |_| async move { Ok(Some("v1".to_string())) });
reg.add_versioned_orchestrator("orch", "v2", |_| async move { Ok(Some("v2".to_string())) });
assert!(reg.get_orchestrator_version("orch", Some("v1")).is_some());
assert!(reg.get_orchestrator_version("orch", Some("v2")).is_some());
assert!(reg.get_orchestrator_version("orch", Some("v3")).is_none());
}
#[test]
fn test_latest_is_fallback() {
let mut reg = Registry::new();
reg.add_versioned_orchestrator("orch", "v1", |_| async move { Ok(Some("v1".to_string())) });
reg.add_latest_orchestrator("orch", "v2", |_| async move { Ok(Some("v2".to_string())) });
assert!(reg.get_orchestrator_version("orch", Some("v99")).is_some());
assert!(reg.get_orchestrator("orch").is_some());
}
#[test]
fn test_unversioned_fallback() {
let mut reg = Registry::new();
reg.add_named_orchestrator("orch", dummy_orchestrator);
assert!(reg.get_orchestrator_version("orch", Some("any")).is_some());
}
}