use crate::{AudioBuffer, VoirsError, VoirsPipeline, VoirsResult};
use async_trait::async_trait;
use futures::stream::{self, StreamExt};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{RwLock, Semaphore};
use tracing::{debug, info, warn};
mod optimization;
mod processor;
mod scheduler;
mod statistics;
pub use optimization::{
BatchOptimizer, NormalizationStrategy, OptimizationConfig, OptimizationStats,
};
pub use processor::BatchProcessor;
pub use scheduler::{BatchScheduler, SchedulingStrategy};
pub use statistics::{BatchStatistics, ProcessingMetrics};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchConfig {
pub max_concurrency: usize,
pub max_batch_size: usize,
pub memory_limit_mb: usize,
pub synthesis_timeout: Duration,
pub adaptive_resources: bool,
pub track_progress: bool,
pub retry_failed: bool,
pub max_retries: usize,
pub scheduling_strategy: SchedulingStrategy,
}
impl Default for BatchConfig {
fn default() -> Self {
Self {
max_concurrency: num_cpus::get(),
max_batch_size: 100,
memory_limit_mb: 2048,
synthesis_timeout: Duration::from_secs(30),
adaptive_resources: true,
track_progress: true,
retry_failed: true,
max_retries: 3,
scheduling_strategy: SchedulingStrategy::LoadBalanced,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchRequest {
pub id: String,
pub text: String,
pub voice: Option<String>,
pub speed: Option<f32>,
pub pitch: Option<f32>,
pub priority: i32,
pub metadata: Option<serde_json::Value>,
}
impl BatchRequest {
pub fn new(text: impl Into<String>, voice: Option<&str>) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
text: text.into(),
voice: voice.map(|v| v.to_string()),
speed: None,
pitch: None,
priority: 0,
metadata: None,
}
}
pub fn with_priority(mut self, priority: i32) -> Self {
self.priority = priority;
self
}
pub fn with_speed(mut self, speed: f32) -> Self {
self.speed = Some(speed);
self
}
pub fn with_pitch(mut self, pitch: f32) -> Self {
self.pitch = Some(pitch);
self
}
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
self.metadata = Some(metadata);
self
}
}
#[derive(Debug)]
pub struct BatchResult {
pub request_id: String,
pub result: VoirsResult<AudioBuffer>,
pub processing_time: Duration,
pub retry_count: usize,
pub worker_id: Option<usize>,
}
impl BatchResult {
pub fn is_success(&self) -> bool {
self.result.is_ok()
}
pub fn audio(&self) -> Option<&AudioBuffer> {
self.result.as_ref().ok()
}
pub fn error(&self) -> Option<&VoirsError> {
self.result.as_ref().err()
}
}
pub type ProgressCallback = Arc<dyn Fn(usize, usize) + Send + Sync>;
#[derive(Clone)]
pub struct BatchContext {
pub config: BatchConfig,
pub pipeline: Arc<VoirsPipeline>,
pub semaphore: Arc<Semaphore>,
pub statistics: Arc<RwLock<BatchStatistics>>,
pub progress_callback: Option<ProgressCallback>,
}
impl BatchContext {
pub fn new(pipeline: Arc<VoirsPipeline>, config: BatchConfig) -> Self {
let semaphore = Arc::new(Semaphore::new(config.max_concurrency));
let statistics = Arc::new(RwLock::new(BatchStatistics::new()));
Self {
config,
pipeline,
semaphore,
statistics,
progress_callback: None,
}
}
pub fn with_progress_callback(mut self, callback: ProgressCallback) -> Self {
self.progress_callback = Some(callback);
self
}
}
#[async_trait]
pub trait BatchStrategy: Send + Sync {
async fn process_batch(
&self,
context: &BatchContext,
requests: Vec<BatchRequest>,
) -> VoirsResult<Vec<BatchResult>>;
fn estimate_time(&self, requests: &[BatchRequest]) -> Duration;
fn estimate_memory(&self, requests: &[BatchRequest]) -> usize;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_batch_config_default() {
let config = BatchConfig::default();
assert!(config.max_concurrency > 0);
assert!(config.max_batch_size > 0);
assert!(config.memory_limit_mb > 0);
}
#[test]
fn test_batch_request_creation() {
let request = BatchRequest::new("Hello, world!", Some("voice-1"));
assert_eq!(request.text, "Hello, world!");
assert_eq!(request.voice.as_deref(), Some("voice-1"));
assert_eq!(request.priority, 0);
}
#[test]
fn test_batch_request_builder() {
let request = BatchRequest::new("Test", None)
.with_priority(10)
.with_speed(1.2)
.with_pitch(0.8);
assert_eq!(request.priority, 10);
assert_eq!(request.speed, Some(1.2));
assert_eq!(request.pitch, Some(0.8));
}
#[test]
fn test_batch_result() {
let result = BatchResult {
request_id: "test-id".to_string(),
result: Err(VoirsError::audio_error("test error")),
processing_time: Duration::from_millis(100),
retry_count: 2,
worker_id: Some(1),
};
assert!(!result.is_success());
assert!(result.audio().is_none());
assert!(result.error().is_some());
}
}