use super::{BatchConfig, BatchRequest, BatchResult};
use crate::pipeline::VoirsPipeline;
use crate::{Result as VoirsResult, VoirsError};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizationConfig {
pub enable_deduplication: bool,
pub enable_cache_awareness: bool,
pub enable_priority_scheduling: bool,
pub enable_adaptive_batching: bool,
pub normalization_strategy: NormalizationStrategy,
pub max_cache_entries: usize,
pub cache_ttl_seconds: u64,
pub similarity_threshold: f32,
pub priority_weight: f32,
pub target_batch_size: usize,
pub min_batch_size: usize,
pub max_batch_size: usize,
}
impl Default for OptimizationConfig {
fn default() -> Self {
Self {
enable_deduplication: true,
enable_cache_awareness: true,
enable_priority_scheduling: true,
enable_adaptive_batching: true,
normalization_strategy: NormalizationStrategy::Aggressive,
max_cache_entries: 10000,
cache_ttl_seconds: 3600, similarity_threshold: 0.95,
priority_weight: 2.0,
target_batch_size: 32,
min_batch_size: 4,
max_batch_size: 128,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum NormalizationStrategy {
None,
Basic,
Aggressive,
Phonetic,
}
#[derive(Debug, Clone)]
struct CachedResult {
audio: Vec<f32>,
sample_rate: u32,
cached_at: Instant,
hit_count: usize,
}
#[derive(Debug, Clone)]
struct OptimizedRequest {
request: BatchRequest,
normalized_text: String,
text_hash: u64,
priority: i32,
original_index: usize,
is_duplicate: bool,
canonical_index: Option<usize>,
}
pub struct BatchOptimizer {
pipeline: Arc<VoirsPipeline>,
config: OptimizationConfig,
cache: Arc<RwLock<HashMap<u64, CachedResult>>>,
stats: Arc<RwLock<OptimizationStats>>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct OptimizationStats {
pub total_requests: usize,
pub deduplicated_requests: usize,
pub cache_hits: usize,
pub cache_misses: usize,
pub time_saved_ms: u64,
pub average_batch_size: f32,
pub priority_inversions_prevented: usize,
}
impl BatchOptimizer {
pub fn new(pipeline: Arc<VoirsPipeline>, config: OptimizationConfig) -> Self {
Self {
pipeline,
config,
cache: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(RwLock::new(OptimizationStats::default())),
}
}
pub async fn process_optimized(
&self,
requests: Vec<BatchRequest>,
) -> VoirsResult<Vec<BatchResult>> {
let start_time = Instant::now();
let total_requests = requests.len();
info!("Processing optimized batch: {} requests", total_requests);
let optimized = self.analyze_requests(requests).await?;
let (cache_hits, cache_misses) = self.process_cache_lookups(&optimized).await?;
let cache_hits_count = cache_hits.len();
let cache_misses_count = cache_misses.len();
let deduplicated = self.deduplicate_requests(optimized).await?;
let scheduled = if self.config.enable_priority_scheduling {
self.schedule_by_priority(deduplicated).await?
} else {
deduplicated
};
let batches = if self.config.enable_adaptive_batching {
self.adaptive_batch_split(scheduled).await?
} else {
vec![scheduled]
};
let mut all_results = Vec::with_capacity(total_requests);
for batch in batches {
let batch_results = self.execute_batch(batch).await?;
all_results.extend(batch_results);
}
let final_results = self
.restore_order_and_duplicates(all_results, cache_hits)
.await?;
let processing_time = start_time.elapsed();
self.update_statistics(
total_requests,
cache_hits_count,
cache_misses_count,
processing_time,
)
.await;
info!(
"Optimized batch complete: {} requests in {:?} ({} cache hits, {} deduplicated)",
total_requests,
processing_time,
cache_hits_count,
total_requests - cache_misses_count
);
Ok(final_results)
}
pub async fn get_statistics(&self) -> OptimizationStats {
self.stats.read().await.clone()
}
pub async fn clear_cache(&self) {
self.cache.write().await.clear();
info!("Optimization cache cleared");
}
pub async fn prune_cache(&self) {
let ttl = Duration::from_secs(self.config.cache_ttl_seconds);
let now = Instant::now();
let mut cache = self.cache.write().await;
let before = cache.len();
cache.retain(|_, entry| now.duration_since(entry.cached_at) < ttl);
let removed = before - cache.len();
if removed > 0 {
info!("Pruned {} expired cache entries", removed);
}
}
async fn analyze_requests(
&self,
requests: Vec<BatchRequest>,
) -> VoirsResult<Vec<OptimizedRequest>> {
let mut optimized = Vec::with_capacity(requests.len());
for (idx, request) in requests.into_iter().enumerate() {
let normalized_text = self.normalize_text(&request.text);
let text_hash = self.compute_text_hash(&normalized_text);
let priority = request.priority;
optimized.push(OptimizedRequest {
request,
normalized_text,
text_hash,
priority,
original_index: idx,
is_duplicate: false,
canonical_index: None,
});
}
Ok(optimized)
}
async fn process_cache_lookups(
&self,
requests: &[OptimizedRequest],
) -> VoirsResult<(Vec<(usize, Vec<f32>, u32)>, Vec<usize>)> {
if !self.config.enable_cache_awareness {
return Ok((Vec::new(), (0..requests.len()).collect()));
}
let mut cache_hits = Vec::new();
let mut cache_misses = Vec::new();
let cache = self.cache.read().await;
for (idx, req) in requests.iter().enumerate() {
if let Some(cached) = cache.get(&req.text_hash) {
debug!("Cache hit for: {}", req.request.text);
cache_hits.push((idx, cached.audio.clone(), cached.sample_rate));
} else {
cache_misses.push(idx);
}
}
drop(cache);
if !cache_hits.is_empty() {
let mut cache = self.cache.write().await;
for (_, _, _) in &cache_hits {
}
}
Ok((cache_hits, cache_misses))
}
async fn deduplicate_requests(
&self,
mut requests: Vec<OptimizedRequest>,
) -> VoirsResult<Vec<OptimizedRequest>> {
if !self.config.enable_deduplication {
return Ok(requests);
}
let mut seen_hashes: HashMap<u64, usize> = HashMap::new();
let mut deduplicated_count = 0;
for (i, req) in requests.iter_mut().enumerate() {
let hash = req.text_hash;
if let Some(&canonical_idx) = seen_hashes.get(&hash) {
req.is_duplicate = true;
req.canonical_index = Some(canonical_idx);
deduplicated_count += 1;
debug!(
"Deduplicated request {} (duplicate of {}): {}",
i, canonical_idx, req.request.text
);
} else {
seen_hashes.insert(hash, i);
}
}
if deduplicated_count > 0 {
info!("Deduplicated {} requests", deduplicated_count);
}
Ok(requests)
}
async fn schedule_by_priority(
&self,
mut requests: Vec<OptimizedRequest>,
) -> VoirsResult<Vec<OptimizedRequest>> {
requests.sort_by(|a, b| {
b.priority
.cmp(&a.priority)
.then_with(|| a.original_index.cmp(&b.original_index))
});
let priority_inversions: usize = requests
.windows(2)
.filter(|w| w[0].priority < w[1].priority)
.count();
if priority_inversions > 0 {
debug!("Prevented {} priority inversions", priority_inversions);
}
Ok(requests)
}
async fn adaptive_batch_split(
&self,
requests: Vec<OptimizedRequest>,
) -> VoirsResult<Vec<Vec<OptimizedRequest>>> {
let target_size = self.config.target_batch_size;
let min_size = self.config.min_batch_size;
let max_size = self.config.max_batch_size;
let mut batches = Vec::new();
let mut current_batch = Vec::new();
for request in requests {
current_batch.push(request);
if current_batch.len() >= max_size
|| (current_batch.len() >= target_size && should_split_batch(¤t_batch))
{
batches.push(std::mem::take(&mut current_batch));
}
}
if !current_batch.is_empty() {
if current_batch.len() < min_size && !batches.is_empty() {
batches
.last_mut()
.expect("value should be present")
.extend(current_batch);
} else {
batches.push(current_batch);
}
}
debug!("Split into {} adaptive batches", batches.len());
Ok(batches)
}
async fn execute_batch(
&self,
batch: Vec<OptimizedRequest>,
) -> VoirsResult<Vec<(usize, BatchResult)>> {
let mut results = Vec::new();
let to_synthesize: Vec<_> = batch.iter().filter(|r| !r.is_duplicate).cloned().collect();
if to_synthesize.is_empty() {
return Ok(results);
}
let batch_requests: Vec<_> = to_synthesize.iter().map(|r| r.request.clone()).collect();
for (req, opt_req) in batch_requests.into_iter().zip(to_synthesize.iter()) {
match self.pipeline.synthesize(&req.text).await {
Ok(audio_buffer) => {
let samples = audio_buffer.samples().to_vec();
let sample_rate = audio_buffer.sample_rate();
if self.config.enable_cache_awareness {
let mut cache = self.cache.write().await;
cache.insert(
opt_req.text_hash,
CachedResult {
audio: samples.clone(),
sample_rate,
cached_at: Instant::now(),
hit_count: 0,
},
);
if cache.len() > self.config.max_cache_entries {
if let Some(&oldest_key) = cache.keys().next() {
cache.remove(&oldest_key);
}
}
}
results.push((
opt_req.original_index,
BatchResult {
request_id: req.id.clone(),
result: Ok(audio_buffer),
processing_time: Duration::from_millis(0), retry_count: 0,
worker_id: None,
},
));
}
Err(e) => {
results.push((
opt_req.original_index,
BatchResult {
request_id: req.id.clone(),
result: Err(e),
processing_time: Duration::from_millis(0),
retry_count: 0,
worker_id: None,
},
));
}
}
}
Ok(results)
}
async fn restore_order_and_duplicates(
&self,
mut results: Vec<(usize, BatchResult)>,
cache_hits: Vec<(usize, Vec<f32>, u32)>,
) -> VoirsResult<Vec<BatchResult>> {
for (idx, audio, sample_rate) in cache_hits {
use crate::audio::AudioBuffer;
let audio_buffer = AudioBuffer::mono(audio, sample_rate);
results.push((
idx,
BatchResult {
request_id: String::new(), result: Ok(audio_buffer),
processing_time: Duration::from_millis(0),
retry_count: 0,
worker_id: None,
},
));
}
results.sort_by_key(|(idx, _)| *idx);
Ok(results.into_iter().map(|(_, result)| result).collect())
}
async fn update_statistics(
&self,
total: usize,
cache_hits: usize,
cache_misses: usize,
duration: Duration,
) {
let mut stats = self.stats.write().await;
stats.total_requests += total;
stats.cache_hits += cache_hits;
stats.cache_misses += cache_misses;
stats.deduplicated_requests += total - cache_misses - cache_hits;
let requests_saved = cache_hits + (total - cache_misses - cache_hits);
stats.time_saved_ms += (requests_saved as u64) * 100;
}
fn normalize_text(&self, text: &str) -> String {
match self.config.normalization_strategy {
NormalizationStrategy::None => text.to_string(),
NormalizationStrategy::Basic => text.trim().to_lowercase(),
NormalizationStrategy::Aggressive => {
text.chars()
.filter(|c| c.is_alphanumeric() || c.is_whitespace())
.collect::<String>()
.split_whitespace()
.collect::<Vec<_>>()
.join(" ")
.to_lowercase()
}
NormalizationStrategy::Phonetic => {
text.to_lowercase()
.chars()
.filter(|c| c.is_alphabetic())
.collect()
}
}
}
fn compute_text_hash(&self, text: &str) -> u64 {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
text.hash(&mut hasher);
hasher.finish()
}
}
fn should_split_batch(batch: &[OptimizedRequest]) -> bool {
if batch.len() < 2 {
return false;
}
let last_priority = batch
.last()
.expect("collection should not be empty")
.priority;
let has_priority_change = batch.iter().any(|r| (r.priority - last_priority).abs() > 5);
has_priority_change
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_should_split_batch() {
let mut batch = vec![
OptimizedRequest {
request: BatchRequest::new("test", None),
normalized_text: "test".to_string(),
text_hash: 0,
priority: 10,
original_index: 0,
is_duplicate: false,
canonical_index: None,
},
OptimizedRequest {
request: BatchRequest::new("test", None),
normalized_text: "test".to_string(),
text_hash: 0,
priority: 10,
original_index: 1,
is_duplicate: false,
canonical_index: None,
},
];
assert!(!should_split_batch(&batch));
batch.push(OptimizedRequest {
request: BatchRequest::new("test", None),
normalized_text: "test".to_string(),
text_hash: 0,
priority: 20,
original_index: 2,
is_duplicate: false,
canonical_index: None,
});
assert!(should_split_batch(&batch));
}
}