mod delegation;
mod submission;
use std::collections::HashMap;
use std::sync::Arc;
use rustvello_core::broker::Broker;
use rustvello_core::client_data_store::ClientDataStoreManager;
use rustvello_core::error::RustvelloResult;
use rustvello_core::orchestrator::Orchestrator;
use rustvello_core::state_backend::StateBackend;
use rustvello_core::task::{DynTask, Task, TaskDefinition, TaskFn, TaskRegistry};
use rustvello_core::trigger::TriggerManager;
use rustvello_proto::config::{AppConfig, TaskConfig};
use rustvello_proto::identifiers::TaskId;
use crate::orchestration::OrchestratorCoordinator;
use crate::task_config::{apply_task_env_overrides, TaskConfigOverride};
pub struct TaskEntry {
pub register_fn: fn(&mut TaskRegistry) -> RustvelloResult<()>,
}
inventory::collect!(TaskEntry);
pub struct RustvelloApp {
pub config: AppConfig,
pub task_registry: TaskRegistry,
pub(crate) coordinator: OrchestratorCoordinator,
task_config_overrides: HashMap<String, TaskConfigOverride>,
task_defaults_override: TaskConfigOverride,
env_override_cache: std::sync::Mutex<HashMap<String, Arc<TaskConfigOverride>>>,
stored_runner_cache: tokio::sync::Mutex<std::collections::HashSet<String>>,
}
impl std::fmt::Debug for RustvelloApp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RustvelloApp")
.field("config", &self.config)
.field("tasks", &self.task_registry.task_ids().len())
.finish_non_exhaustive()
}
}
impl RustvelloApp {
#[cfg(feature = "mem")]
pub fn new(config: AppConfig) -> Self {
use rustvello_core::client_data_store::ClientDataStore;
use rustvello_proto::config::ClientDataStoreConfig;
let broker: Arc<dyn Broker> = Arc::new(rustvello_mem::broker::MemBroker::new());
let orchestrator: Arc<dyn Orchestrator> =
Arc::new(rustvello_mem::orchestrator::MemOrchestrator::new());
let state_backend: Arc<dyn StateBackend> =
Arc::new(rustvello_mem::state_backend::MemStateBackend::new());
let mem_cds: Arc<dyn ClientDataStore> =
Arc::new(rustvello_mem::client_data_store::MemClientDataStore::new());
let client_data_store = Arc::new(ClientDataStoreManager::new(
mem_cds,
ClientDataStoreConfig::default(),
));
let coordinator = OrchestratorCoordinator::new(
orchestrator,
state_backend,
broker,
client_data_store,
None,
config.auto_final_invocation_purge_hours,
);
Self {
config,
task_registry: TaskRegistry::new(),
coordinator,
task_config_overrides: HashMap::new(),
task_defaults_override: TaskConfigOverride::default(),
env_override_cache: std::sync::Mutex::new(HashMap::new()),
stored_runner_cache: tokio::sync::Mutex::new(std::collections::HashSet::new()),
}
}
pub fn with_backends(
config: AppConfig,
broker: Arc<dyn Broker>,
orchestrator: Arc<dyn Orchestrator>,
state_backend: Arc<dyn StateBackend>,
client_data_store: Arc<ClientDataStoreManager>,
) -> Self {
Self::with_backends_and_triggers(
config,
broker,
orchestrator,
state_backend,
client_data_store,
None,
)
}
pub fn with_backends_and_triggers(
config: AppConfig,
broker: Arc<dyn Broker>,
orchestrator: Arc<dyn Orchestrator>,
state_backend: Arc<dyn StateBackend>,
client_data_store: Arc<ClientDataStoreManager>,
trigger_manager: Option<TriggerManager>,
) -> Self {
let coordinator = OrchestratorCoordinator::new(
orchestrator,
state_backend,
broker,
client_data_store,
trigger_manager,
config.auto_final_invocation_purge_hours,
);
Self {
config,
task_registry: TaskRegistry::new(),
coordinator,
task_config_overrides: HashMap::new(),
task_defaults_override: TaskConfigOverride::default(),
env_override_cache: std::sync::Mutex::new(HashMap::new()),
stored_runner_cache: tokio::sync::Mutex::new(std::collections::HashSet::new()),
}
}
pub fn register_task(
&mut self,
task_id: TaskId,
config: TaskConfig,
func: TaskFn,
) -> RustvelloResult<()> {
let definition = TaskDefinition::new(task_id, config, func);
self.task_registry.register(definition)
}
pub fn register<T: Task>(&mut self, task: T) -> RustvelloResult<()> {
self.task_registry.register_typed(task)
}
pub fn set_task_config_overrides(
&mut self,
overrides: HashMap<String, TaskConfigOverride>,
defaults: TaskConfigOverride,
) {
self.task_config_overrides = overrides;
self.task_defaults_override = defaults;
}
pub fn resolve_task_config(&self, task_id: &TaskId, base: &TaskConfig) -> TaskConfig {
let mut config = base.clone();
self.task_defaults_override.apply_to(&mut config);
if let Some(per_task) = self.task_config_overrides.get(task_id.name()) {
per_task.apply_to(&mut config);
}
let env_override = self.get_or_load_env_override(task_id.name());
env_override.apply_to(&mut config);
config
}
fn resolve_force_new_workflow(&self, task_id: &TaskId, base: &TaskConfig) -> bool {
let env_override = self.get_or_load_env_override(task_id.name());
if let Some(v) = env_override.force_new_workflow {
return v;
}
if let Some(per_task) = self.task_config_overrides.get(task_id.name()) {
if let Some(v) = per_task.force_new_workflow {
return v;
}
}
if let Some(v) = self.task_defaults_override.force_new_workflow {
return v;
}
base.force_new_workflow
}
fn get_or_load_env_override(&self, task_name: &str) -> Arc<TaskConfigOverride> {
let mut cache = self
.env_override_cache
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if let Some(cached) = cache.get(task_name) {
return Arc::clone(cached);
}
let env_prefix = format!("RUSTVELLO__TASK__{}__", task_name.to_uppercase());
let mut config = TaskConfig::default();
let base = TaskConfig::default();
apply_task_env_overrides(&env_prefix, &mut config);
let env_override = Arc::new(TaskConfigOverride {
max_retries: (config.max_retries != base.max_retries).then_some(config.max_retries),
concurrency_control: (config.concurrency_control != base.concurrency_control)
.then_some(config.concurrency_control),
running_concurrency: (config.running_concurrency != base.running_concurrency)
.then_some(config.running_concurrency),
registration_concurrency: None,
cache_results: (config.cache_results != base.cache_results)
.then_some(config.cache_results),
key_arguments: None,
retry_for_errors: None,
disable_cache_args: None,
on_diff_non_key_args_raise: None,
parallel_batch_size: None,
force_new_workflow: (config.force_new_workflow != base.force_new_workflow)
.then_some(config.force_new_workflow),
reroute_on_cc: (config.reroute_on_cc != base.reroute_on_cc)
.then_some(config.reroute_on_cc),
blocking: (config.blocking != base.blocking).then_some(config.blocking),
});
cache.insert(task_name.to_owned(), Arc::clone(&env_override));
env_override
}
pub fn get_task(&self, task_id: &TaskId) -> Option<Arc<dyn DynTask>> {
self.task_registry.get_dyn(task_id)
}
pub fn broker(&self) -> Arc<dyn Broker> {
Arc::clone(&self.coordinator.broker)
}
pub fn orchestrator(&self) -> Arc<dyn Orchestrator> {
Arc::clone(&self.coordinator.orchestrator)
}
pub fn state_backend(&self) -> Arc<dyn StateBackend> {
Arc::clone(&self.coordinator.state_backend)
}
pub fn client_data_store(&self) -> Arc<ClientDataStoreManager> {
Arc::clone(&self.coordinator.client_data_store)
}
pub fn trigger_manager(&self) -> Option<&TriggerManager> {
self.coordinator.trigger_manager.as_ref()
}
pub fn set_trigger_manager(&mut self, manager: TriggerManager) {
self.coordinator.trigger_manager = Some(manager);
}
pub async fn purge(&self) -> RustvelloResult<()> {
self.coordinator.orchestrator.purge().await?;
self.coordinator.broker.purge(None).await?;
self.coordinator.state_backend.purge().await?;
Ok(())
}
pub fn into_runner(self) -> crate::runner::TaskRunner {
crate::runner::TaskRunner::new(
self.config.app_id.clone(),
self.config,
self.coordinator.broker,
self.coordinator.orchestrator,
self.coordinator.state_backend,
Arc::new(self.task_registry),
self.coordinator.trigger_manager,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rustvello_core::error::RustvelloError;
use rustvello_proto::call::SerializedArguments;
use rustvello_proto::status::InvocationStatus;
fn make_app() -> RustvelloApp {
let mut app = RustvelloApp::new(AppConfig::new("test-app"));
app.register_task(
TaskId::new("test", "double"),
TaskConfig::default(),
Arc::new(|args_json: String| {
let args: std::collections::BTreeMap<String, String> =
serde_json::from_str(&args_json).map_err(|e| {
RustvelloError::Serialization {
message: e.to_string(),
}
})?;
let x: i64 = args.get("x").and_then(|v| v.parse().ok()).unwrap_or(0);
serde_json::to_string(&(x * 2)).map_err(|e| RustvelloError::Serialization {
message: e.to_string(),
})
}),
)
.unwrap();
app
}
#[tokio::test]
async fn test_submit_and_status() {
let app = make_app();
let mut args = SerializedArguments::new();
args.insert("x", "21");
let inv_id = app
.submit(&TaskId::new("test", "double"), args)
.await
.unwrap();
let status = app.get_status(&inv_id).await.unwrap();
assert_eq!(status, InvocationStatus::Registered);
}
#[tokio::test]
async fn test_submit_unregistered_task() {
let app = make_app();
let args = SerializedArguments::new();
let result = app.submit(&TaskId::new("nonexistent", "task"), args).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_submit_sync() {
let app = make_app();
let mut args = SerializedArguments::new();
args.insert("x", "21");
let result = app
.submit_sync(&TaskId::new("test", "double"), args)
.await
.unwrap();
assert_eq!(result, "42");
}
#[tokio::test]
async fn test_submit_sync_unregistered() {
let app = make_app();
let args = SerializedArguments::new();
let result = app.submit_sync(&TaskId::new("no", "such"), args).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_get_result() {
let app = make_app();
let mut args = SerializedArguments::new();
args.insert("x", "21");
let inv_id = app
.submit(&TaskId::new("test", "double"), args)
.await
.unwrap();
let result = app.get_result(&inv_id).await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_backend_accessors() {
let app = make_app();
let _broker = app.broker();
let _orch = app.orchestrator();
let _sb = app.state_backend();
}
}