use crate::{Result, job::Job};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
#[cfg(any(feature = "postgres", feature = "mysql"))]
use sqlx::{Decode, Encode, Type};
#[cfg(feature = "postgres")]
use sqlx::Postgres;
#[cfg(feature = "mysql")]
use sqlx::MySql;
pub type BatchId = Uuid;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum PartialFailureMode {
ContinueOnError,
FailFast,
CollectErrors,
}
impl Default for PartialFailureMode {
fn default() -> Self {
Self::ContinueOnError
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum BatchStatus {
Pending,
Processing,
Completed,
PartiallyFailed,
Failed,
}
#[cfg(feature = "postgres")]
impl Type<Postgres> for BatchStatus {
fn type_info() -> sqlx::postgres::PgTypeInfo {
<String as Type<Postgres>>::type_info()
}
}
#[cfg(feature = "postgres")]
impl Encode<'_, Postgres> for BatchStatus {
fn encode_by_ref(
&self,
buf: &mut sqlx::postgres::PgArgumentBuffer,
) -> std::result::Result<sqlx::encode::IsNull, Box<dyn std::error::Error + Send + Sync + 'static>>
{
let status_str = match self {
BatchStatus::Pending => "Pending",
BatchStatus::Processing => "Processing",
BatchStatus::Completed => "Completed",
BatchStatus::PartiallyFailed => "PartiallyFailed",
BatchStatus::Failed => "Failed",
};
<&str as Encode<'_, Postgres>>::encode_by_ref(&status_str, buf)
}
}
#[cfg(feature = "postgres")]
impl Decode<'_, Postgres> for BatchStatus {
fn decode(
value: sqlx::postgres::PgValueRef<'_>,
) -> std::result::Result<Self, sqlx::error::BoxDynError> {
let status_str = <String as Decode<Postgres>>::decode(value)?;
let cleaned_str = status_str.trim_matches('"');
match cleaned_str {
"Pending" => Ok(BatchStatus::Pending),
"Processing" => Ok(BatchStatus::Processing),
"Completed" => Ok(BatchStatus::Completed),
"PartiallyFailed" => Ok(BatchStatus::PartiallyFailed),
"Failed" => Ok(BatchStatus::Failed),
_ => Err(format!("Unknown batch status: {}", status_str).into()),
}
}
}
#[cfg(feature = "mysql")]
impl Type<MySql> for BatchStatus {
fn type_info() -> sqlx::mysql::MySqlTypeInfo {
<String as Type<MySql>>::type_info()
}
}
#[cfg(feature = "mysql")]
impl Encode<'_, MySql> for BatchStatus {
fn encode_by_ref(
&self,
buf: &mut Vec<u8>,
) -> std::result::Result<sqlx::encode::IsNull, Box<dyn std::error::Error + Send + Sync + 'static>>
{
let status_str = match self {
BatchStatus::Pending => "Pending",
BatchStatus::Processing => "Processing",
BatchStatus::Completed => "Completed",
BatchStatus::PartiallyFailed => "PartiallyFailed",
BatchStatus::Failed => "Failed",
};
<&str as Encode<'_, MySql>>::encode_by_ref(&status_str, buf)
}
}
#[cfg(feature = "mysql")]
impl Decode<'_, MySql> for BatchStatus {
fn decode(
value: sqlx::mysql::MySqlValueRef<'_>,
) -> std::result::Result<Self, sqlx::error::BoxDynError> {
let status_str = <String as Decode<MySql>>::decode(value)?;
let cleaned_str = status_str.trim_matches('"');
match cleaned_str {
"Pending" => Ok(BatchStatus::Pending),
"Processing" => Ok(BatchStatus::Processing),
"Completed" => Ok(BatchStatus::Completed),
"PartiallyFailed" => Ok(BatchStatus::PartiallyFailed),
"Failed" => Ok(BatchStatus::Failed),
_ => Err(format!("Unknown batch status: {}", status_str).into()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchResult {
pub batch_id: BatchId,
pub total_jobs: u32,
pub completed_jobs: u32,
pub failed_jobs: u32,
pub pending_jobs: u32,
pub status: BatchStatus,
pub created_at: DateTime<Utc>,
pub completed_at: Option<DateTime<Utc>>,
pub error_summary: Option<String>,
pub job_errors: HashMap<Uuid, String>,
}
#[derive(Debug, Clone)]
pub struct JobBatch {
pub id: BatchId,
pub name: String,
pub jobs: Vec<Job>,
pub batch_size: Option<u32>,
pub failure_mode: PartialFailureMode,
pub created_at: DateTime<Utc>,
pub metadata: HashMap<String, String>,
}
impl JobBatch {
pub fn new<S: Into<String>>(name: S) -> Self {
Self {
id: Uuid::new_v4(),
name: name.into(),
jobs: Vec::new(),
batch_size: None,
failure_mode: PartialFailureMode::default(),
created_at: Utc::now(),
metadata: HashMap::new(),
}
}
pub fn with_jobs(mut self, jobs: Vec<Job>) -> Self {
self.jobs = jobs;
self
}
pub fn add_job(mut self, job: Job) -> Self {
self.jobs.push(job);
self
}
pub fn with_batch_size(mut self, size: u32) -> Self {
self.batch_size = Some(size);
self
}
pub fn with_partial_failure_handling(mut self, mode: PartialFailureMode) -> Self {
self.failure_mode = mode;
self
}
pub fn with_metadata<K, V>(mut self, key: K, value: V) -> Self
where
K: Into<String>,
V: Into<String>,
{
self.metadata.insert(key.into(), value.into());
self
}
pub fn job_count(&self) -> usize {
self.jobs.len()
}
pub fn is_empty(&self) -> bool {
self.jobs.is_empty()
}
pub fn validate(&self) -> Result<()> {
if self.jobs.is_empty() {
return Err(crate::HammerworkError::Queue {
message: "Batch cannot be empty".to_string(),
});
}
if self.jobs.len() > 10_000 {
return Err(crate::HammerworkError::Queue {
message: format!(
"Batch size {} exceeds maximum allowed size of 10,000 jobs",
self.jobs.len()
),
});
}
if let Some(first_queue) = self.jobs.first().map(|j| &j.queue_name) {
for (i, job) in self.jobs.iter().enumerate() {
if &job.queue_name != first_queue {
return Err(crate::HammerworkError::Queue {
message: format!(
"All jobs in a batch must have the same queue name. \
Job at index {} has queue '{}' but expected '{}'",
i, job.queue_name, first_queue
),
});
}
}
}
Ok(())
}
pub fn into_chunks(self) -> Vec<JobBatch> {
let chunk_size = self.batch_size.unwrap_or(1000) as usize;
if self.jobs.len() <= chunk_size {
return vec![self];
}
self.jobs
.chunks(chunk_size)
.enumerate()
.map(|(i, chunk)| JobBatch {
id: Uuid::new_v4(),
name: format!("{}_chunk_{}", self.name, i + 1),
jobs: chunk.to_vec(),
batch_size: self.batch_size,
failure_mode: self.failure_mode.clone(),
created_at: self.created_at,
metadata: self.metadata.clone(),
})
.collect()
}
}
impl BatchResult {
pub fn is_successful(&self) -> bool {
matches!(self.status, BatchStatus::Completed)
}
pub fn is_failed(&self) -> bool {
matches!(self.status, BatchStatus::Failed)
}
pub fn is_partially_failed(&self) -> bool {
matches!(self.status, BatchStatus::PartiallyFailed)
}
pub fn success_rate(&self) -> f64 {
if self.total_jobs == 0 {
return 100.0;
}
(self.completed_jobs as f64 / self.total_jobs as f64) * 100.0
}
pub fn failure_rate(&self) -> f64 {
if self.total_jobs == 0 {
return 0.0;
}
(self.failed_jobs as f64 / self.total_jobs as f64) * 100.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Job;
use serde_json::json;
#[test]
fn test_job_batch_creation() {
let batch = JobBatch::new("test_batch");
assert_eq!(batch.name, "test_batch");
assert!(batch.jobs.is_empty());
assert_eq!(batch.failure_mode, PartialFailureMode::ContinueOnError);
assert!(batch.batch_size.is_none());
}
#[test]
fn test_job_batch_with_jobs() {
let job1 = Job::new("queue1".to_string(), json!({"id": 1}));
let job2 = Job::new("queue1".to_string(), json!({"id": 2}));
let batch = JobBatch::new("test_batch").with_jobs(vec![job1, job2]);
assert_eq!(batch.job_count(), 2);
assert!(!batch.is_empty());
}
#[test]
fn test_job_batch_add_job() {
let job = Job::new("queue1".to_string(), json!({"id": 1}));
let batch = JobBatch::new("test_batch").add_job(job);
assert_eq!(batch.job_count(), 1);
}
#[test]
fn test_job_batch_configuration() {
let batch = JobBatch::new("configured_batch")
.with_batch_size(50)
.with_partial_failure_handling(PartialFailureMode::FailFast)
.with_metadata("user_id", "123")
.with_metadata("campaign", "test");
assert_eq!(batch.batch_size, Some(50));
assert_eq!(batch.failure_mode, PartialFailureMode::FailFast);
assert_eq!(batch.metadata.get("user_id"), Some(&"123".to_string()));
assert_eq!(batch.metadata.get("campaign"), Some(&"test".to_string()));
}
#[test]
fn test_batch_validation_empty() {
let batch = JobBatch::new("empty_batch");
let result = batch.validate();
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Batch cannot be empty")
);
}
#[test]
fn test_batch_validation_too_large() {
let mut jobs = Vec::new();
for i in 0..10_001 {
jobs.push(Job::new("queue1".to_string(), json!({"id": i})));
}
let batch = JobBatch::new("large_batch").with_jobs(jobs);
let result = batch.validate();
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("exceeds maximum allowed size")
);
}
#[test]
fn test_batch_validation_mixed_queues() {
let job1 = Job::new("queue1".to_string(), json!({"id": 1}));
let job2 = Job::new("queue2".to_string(), json!({"id": 2}));
let batch = JobBatch::new("mixed_batch").with_jobs(vec![job1, job2]);
let result = batch.validate();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("same queue name"));
}
#[test]
fn test_batch_validation_success() {
let job1 = Job::new("queue1".to_string(), json!({"id": 1}));
let job2 = Job::new("queue1".to_string(), json!({"id": 2}));
let batch = JobBatch::new("valid_batch").with_jobs(vec![job1, job2]);
assert!(batch.validate().is_ok());
}
#[test]
fn test_batch_into_chunks() {
let mut jobs = Vec::new();
for i in 0..250 {
jobs.push(Job::new("queue1".to_string(), json!({"id": i})));
}
let batch = JobBatch::new("large_batch")
.with_jobs(jobs)
.with_batch_size(100);
let chunks = batch.into_chunks();
assert_eq!(chunks.len(), 3);
assert_eq!(chunks[0].job_count(), 100);
assert_eq!(chunks[1].job_count(), 100);
assert_eq!(chunks[2].job_count(), 50);
}
#[test]
fn test_batch_into_chunks_small_batch() {
let job = Job::new("queue1".to_string(), json!({"id": 1}));
let batch = JobBatch::new("small_batch")
.with_jobs(vec![job])
.with_batch_size(100);
let chunks = batch.into_chunks();
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].job_count(), 1);
}
#[test]
fn test_batch_result_success_rate() {
let result = BatchResult {
batch_id: Uuid::new_v4(),
total_jobs: 100,
completed_jobs: 80,
failed_jobs: 20,
pending_jobs: 0,
status: BatchStatus::PartiallyFailed,
created_at: Utc::now(),
completed_at: Some(Utc::now()),
error_summary: None,
job_errors: HashMap::new(),
};
assert_eq!(result.success_rate(), 80.0);
assert_eq!(result.failure_rate(), 20.0);
assert!(result.is_partially_failed());
assert!(!result.is_successful());
assert!(!result.is_failed());
}
#[test]
fn test_batch_result_zero_jobs() {
let result = BatchResult {
batch_id: Uuid::new_v4(),
total_jobs: 0,
completed_jobs: 0,
failed_jobs: 0,
pending_jobs: 0,
status: BatchStatus::Completed,
created_at: Utc::now(),
completed_at: Some(Utc::now()),
error_summary: None,
job_errors: HashMap::new(),
};
assert_eq!(result.success_rate(), 100.0);
assert_eq!(result.failure_rate(), 0.0);
}
#[test]
fn test_partial_failure_mode_default() {
assert_eq!(
PartialFailureMode::default(),
PartialFailureMode::ContinueOnError
);
}
#[test]
fn test_batch_status_variants() {
let statuses = [
BatchStatus::Pending,
BatchStatus::Processing,
BatchStatus::Completed,
BatchStatus::PartiallyFailed,
BatchStatus::Failed,
];
for (i, status1) in statuses.iter().enumerate() {
for (j, status2) in statuses.iter().enumerate() {
if i != j {
assert_ne!(status1, status2);
}
}
}
}
}