use crate::{
Result,
error::HammerworkError,
job::{Job, JobId},
queue::DatabaseQueue,
};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpawnResult {
pub parent_job_id: JobId,
pub spawned_jobs: Vec<JobId>,
pub spawned_at: DateTime<Utc>,
pub spawn_operation_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpawnConfig {
pub max_spawn_count: Option<usize>,
pub inherit_priority: bool,
pub inherit_retry_strategy: bool,
pub inherit_timeout: bool,
pub inherit_trace_context: bool,
pub operation_id: Option<String>,
}
impl Default for SpawnConfig {
fn default() -> Self {
Self {
max_spawn_count: Some(100), inherit_priority: true,
inherit_retry_strategy: true,
inherit_timeout: false, inherit_trace_context: true,
operation_id: None,
}
}
}
pub struct SpawnContext<DB: sqlx::Database> {
pub parent_job: Job,
pub config: SpawnConfig,
pub queue: Arc<dyn DatabaseQueue<Database = DB> + Send + Sync>,
}
#[derive(Debug, thiserror::Error)]
pub enum SpawnError {
#[error("Spawn limit exceeded: attempted to spawn {attempted} jobs, limit is {limit}")]
SpawnLimitExceeded { attempted: usize, limit: usize },
#[error("Invalid spawn configuration: {message}")]
InvalidConfig { message: String },
#[error("Parent job {parent_id} is not eligible for spawning")]
ParentNotEligible { parent_id: JobId },
#[error("Spawn operation failed: {message}")]
SpawnOperationFailed { message: String },
}
#[async_trait]
pub trait SpawnHandler<DB: sqlx::Database>: Send + Sync {
async fn spawn_jobs(&self, context: SpawnContext<DB>) -> Result<Vec<Job>>;
async fn validate_spawn(&self, _parent_job: &Job, _config: &SpawnConfig) -> Result<()> {
Ok(())
}
async fn on_spawn_complete(&self, _result: &SpawnResult) -> Result<()> {
Ok(())
}
}
pub struct ClosureSpawnHandler<F, DB: sqlx::Database> {
handler: F,
_phantom: std::marker::PhantomData<DB>,
}
impl<F, DB> ClosureSpawnHandler<F, DB>
where
F: Fn(SpawnContext<DB>) -> Result<Vec<Job>> + Send + Sync,
DB: sqlx::Database + Send + Sync,
{
pub fn new(handler: F) -> Self {
Self {
handler,
_phantom: std::marker::PhantomData,
}
}
}
#[async_trait]
impl<F, DB> SpawnHandler<DB> for ClosureSpawnHandler<F, DB>
where
F: Fn(SpawnContext<DB>) -> Result<Vec<Job>> + Send + Sync,
DB: sqlx::Database + Send + Sync,
{
async fn spawn_jobs(&self, context: SpawnContext<DB>) -> Result<Vec<Job>> {
(self.handler)(context)
}
}
pub struct SpawnManager<DB: sqlx::Database> {
handlers: std::collections::HashMap<String, Arc<dyn SpawnHandler<DB>>>,
_phantom: std::marker::PhantomData<DB>,
}
impl<DB: sqlx::Database> SpawnManager<DB> {
pub fn new() -> Self {
Self {
handlers: std::collections::HashMap::new(),
_phantom: std::marker::PhantomData,
}
}
pub fn register_handler<H>(&mut self, job_type: impl Into<String>, handler: H)
where
H: SpawnHandler<DB> + 'static,
{
self.handlers.insert(job_type.into(), Arc::new(handler));
}
pub async fn execute_spawn(
&self,
job: Job,
config: SpawnConfig,
queue: Arc<dyn DatabaseQueue<Database = DB> + Send + Sync>,
) -> Result<Option<SpawnResult>> {
if let Some(handler) = self.handlers.get(&job.queue_name) {
handler.validate_spawn(&job, &config).await?;
let context = SpawnContext {
parent_job: job.clone(),
config: config.clone(),
queue: queue.clone(),
};
let mut child_jobs = handler.spawn_jobs(context).await?;
if let Some(max_count) = config.max_spawn_count {
if child_jobs.len() > max_count {
return Err(HammerworkError::SpawnError(
SpawnError::SpawnLimitExceeded {
attempted: child_jobs.len(),
limit: max_count,
},
));
}
}
for child_job in &mut child_jobs {
if config.inherit_priority {
child_job.priority = job.priority;
}
if config.inherit_retry_strategy {
child_job.retry_strategy = job.retry_strategy.clone();
}
if config.inherit_timeout {
child_job.timeout = job.timeout;
}
if config.inherit_trace_context {
child_job.trace_id = job.trace_id.clone();
child_job.correlation_id = job.correlation_id.clone();
child_job.parent_span_id = job.parent_span_id.clone();
child_job.span_context = job.span_context.clone();
}
child_job.depends_on = vec![job.id];
child_job.workflow_id = job.workflow_id;
child_job.workflow_name = job.workflow_name.clone();
}
let mut spawned_job_ids = Vec::new();
for child_job in child_jobs {
let job_id = queue.enqueue(child_job).await?;
spawned_job_ids.push(job_id);
}
let spawn_result = SpawnResult {
parent_job_id: job.id,
spawned_jobs: spawned_job_ids,
spawned_at: Utc::now(),
spawn_operation_id: config.operation_id.clone(),
};
handler.on_spawn_complete(&spawn_result).await?;
Ok(Some(spawn_result))
} else {
Ok(None)
}
}
pub fn has_handler(&self, job_type: &str) -> bool {
self.handlers.contains_key(job_type)
}
pub fn registered_types(&self) -> Vec<String> {
self.handlers.keys().cloned().collect()
}
}
impl<DB: sqlx::Database> Default for SpawnManager<DB> {
fn default() -> Self {
Self::new()
}
}
pub trait JobSpawnExt {
fn with_spawn_config(self, config: SpawnConfig) -> Self;
fn with_spawning(self) -> Self;
}
impl JobSpawnExt for Job {
fn with_spawn_config(mut self, config: SpawnConfig) -> Self {
if let Some(payload_obj) = self.payload.as_object_mut() {
payload_obj.insert(
"_spawn_config".to_string(),
serde_json::to_value(config).unwrap(),
);
}
self
}
fn with_spawning(self) -> Self {
self.with_spawn_config(SpawnConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
struct TestSpawnHandler;
#[async_trait]
impl<DB: sqlx::Database> SpawnHandler<DB> for TestSpawnHandler {
async fn spawn_jobs(&self, context: SpawnContext<DB>) -> Result<Vec<Job>> {
let count = context.parent_job.payload["spawn_count"]
.as_u64()
.unwrap_or(1) as usize;
let mut jobs = Vec::new();
for i in 0..count {
let job = Job::new(
"child_task".to_string(),
json!({
"index": i,
"parent_id": context.parent_job.id
}),
);
jobs.push(job);
}
Ok(jobs)
}
}
#[tokio::test]
async fn test_spawn_handler_basic() {
let _handler = TestSpawnHandler;
let _parent_job = Job::new("parent_task".to_string(), json!({"spawn_count": 3}));
}
#[test]
fn test_spawn_config_defaults() {
let config = SpawnConfig::default();
assert_eq!(config.max_spawn_count, Some(100));
assert!(config.inherit_priority);
assert!(config.inherit_retry_strategy);
assert!(!config.inherit_timeout);
assert!(config.inherit_trace_context);
}
#[cfg(feature = "postgres")]
#[test]
fn test_spawn_manager_registration() {
let mut manager: SpawnManager<sqlx::Postgres> = SpawnManager::new();
assert!(!manager.has_handler("test_job"));
manager.register_handler("test_job", TestSpawnHandler);
assert!(manager.has_handler("test_job"));
let types = manager.registered_types();
assert!(types.contains(&"test_job".to_string()));
}
}