use super::{QualityTarget, SystemMetrics};
use crate::{Result, VoirsError};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::numeric::Float;
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictionInput {
pub cpu_usage: f32,
pub memory_usage: f32,
pub text_complexity: f32,
pub time_of_day: u8,
pub recent_rtf: f32,
}
impl PredictionInput {
pub fn to_features(&self) -> Array1<f32> {
Array1::from_vec(vec![
self.cpu_usage,
self.memory_usage,
self.text_complexity,
self.time_of_day as f32 / 24.0, self.recent_rtf,
])
}
pub fn from_metrics(metrics: &SystemMetrics, text_complexity: f32) -> Self {
use chrono::Timelike;
let now = chrono::Local::now();
Self {
cpu_usage: metrics.cpu_usage,
memory_usage: metrics.memory_usage,
text_complexity,
time_of_day: now.hour() as u8,
recent_rtf: metrics.current_rtf,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QualityPrediction {
pub quality: QualityTarget,
pub confidence: f32,
pub expected_time_ms: u64,
pub success_probability: f32,
}
#[derive(Debug, Clone)]
pub struct TrainingSample {
pub input: PredictionInput,
pub quality: QualityTarget,
pub synthesis_time_ms: u64,
pub success: bool,
pub measured_rtf: f32,
}
#[derive(Debug, Clone)]
struct LinearModel {
weights: Array1<f32>,
learning_rate: f32,
samples_seen: u64,
}
impl LinearModel {
fn new() -> Self {
Self {
weights: Array1::zeros(6), learning_rate: 0.01,
samples_seen: 0,
}
}
fn predict(&self, features: &Array1<f32>) -> f32 {
let mut score = self.weights[5]; for (i, &feature) in features.iter().enumerate() {
score += self.weights[i] * feature;
}
score.clamp(0.0, 100.0)
}
fn update(&mut self, features: &Array1<f32>, target: f32) {
let prediction = self.predict(features);
let error = target - prediction;
let effective_lr = self.learning_rate / (1.0 + (self.samples_seen as f32).sqrt());
for (i, &feature) in features.iter().enumerate() {
self.weights[i] += effective_lr * error * feature;
}
self.weights[5] += effective_lr * error;
self.samples_seen += 1;
}
}
#[derive(Debug, Clone)]
pub struct QualityPredictor {
quality_model: LinearModel,
time_model: LinearModel,
success_model: LinearModel,
recent_samples: VecDeque<TrainingSample>,
max_history: usize,
min_samples: usize,
}
impl QualityPredictor {
pub fn new() -> Self {
Self {
quality_model: LinearModel::new(),
time_model: LinearModel::new(),
success_model: LinearModel::new(),
recent_samples: VecDeque::new(),
max_history: 1000,
min_samples: 10,
}
}
pub fn with_history_size(mut self, max_history: usize, min_samples: usize) -> Self {
self.max_history = max_history;
self.min_samples = min_samples;
self
}
pub async fn predict(&self, input: &PredictionInput) -> Result<QualityPrediction> {
let features = input.to_features();
if self.recent_samples.len() < self.min_samples {
return Ok(QualityPrediction {
quality: QualityTarget::Medium, confidence: 0.0, expected_time_ms: 100,
success_probability: 0.95,
});
}
let quality_score = self.quality_model.predict(&features);
let quality = QualityTarget::Custom(quality_score as u8);
let expected_time = self.time_model.predict(&features).max(10.0);
let success_prob = self.success_model.predict(&features).clamp(0.0, 1.0);
let confidence = self.calculate_confidence(&features);
Ok(QualityPrediction {
quality,
confidence,
expected_time_ms: expected_time as u64,
success_probability: success_prob,
})
}
pub async fn train(&mut self, sample: TrainingSample) -> Result<()> {
let features = sample.input.to_features();
let quality_target = sample.quality.score() as f32;
self.quality_model.update(&features, quality_target);
let time_target = sample.synthesis_time_ms as f32;
self.time_model.update(&features, time_target);
let success_target = if sample.success { 1.0 } else { 0.0 };
self.success_model.update(&features, success_target);
self.recent_samples.push_back(sample);
if self.recent_samples.len() > self.max_history {
self.recent_samples.pop_front();
}
Ok(())
}
pub async fn batch_train(&mut self, samples: Vec<TrainingSample>) -> Result<()> {
for sample in samples {
self.train(sample).await?;
}
Ok(())
}
pub fn get_stats(&self) -> PredictorStats {
let total_samples = self.recent_samples.len();
let successful = self.recent_samples.iter().filter(|s| s.success).count();
let avg_time = if !self.recent_samples.is_empty() {
self.recent_samples
.iter()
.map(|s| s.synthesis_time_ms as f64)
.sum::<f64>()
/ total_samples as f64
} else {
0.0
};
PredictorStats {
total_samples,
successful_samples: successful,
avg_synthesis_time_ms: avg_time,
model_samples_seen: self.quality_model.samples_seen,
confidence: if total_samples >= self.min_samples {
(total_samples as f32 / self.max_history as f32).min(1.0)
} else {
0.0
},
}
}
pub fn reset(&mut self) {
self.quality_model = LinearModel::new();
self.time_model = LinearModel::new();
self.success_model = LinearModel::new();
self.recent_samples.clear();
}
fn calculate_confidence(&self, _features: &Array1<f32>) -> f32 {
let sample_confidence =
(self.recent_samples.len() as f32 / self.max_history as f32).min(1.0);
let recent_window = 20.min(self.recent_samples.len());
let recent_success_rate = if recent_window > 0 {
self.recent_samples
.iter()
.rev()
.take(recent_window)
.filter(|s| s.success)
.count() as f32
/ recent_window as f32
} else {
1.0
};
(sample_confidence * 0.6 + recent_success_rate * 0.4).clamp(0.0, 1.0)
}
}
impl Default for QualityPredictor {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictorStats {
pub total_samples: usize,
pub successful_samples: usize,
pub avg_synthesis_time_ms: f64,
pub model_samples_seen: u64,
pub confidence: f32,
}
pub struct TextComplexityAnalyzer;
impl TextComplexityAnalyzer {
pub fn analyze(text: &str) -> f32 {
if text.is_empty() {
return 0.0;
}
let length_score = Self::length_complexity(text);
let vocab_score = Self::vocabulary_complexity(text);
let structure_score = Self::structure_complexity(text);
(length_score * 0.3 + vocab_score * 0.4 + structure_score * 0.3).clamp(0.0, 1.0)
}
fn length_complexity(text: &str) -> f32 {
let char_count = text.chars().count();
(char_count as f32 / 100.0).min(1.0)
}
fn vocabulary_complexity(text: &str) -> f32 {
let words: Vec<&str> = text.split_whitespace().collect();
if words.is_empty() {
return 0.0;
}
let unique_words: std::collections::HashSet<_> = words.iter().collect();
unique_words.len() as f32 / words.len() as f32
}
fn structure_complexity(text: &str) -> f32 {
let sentence_count = text.matches(&['.', '!', '?'][..]).count();
let word_count = text.split_whitespace().count();
if sentence_count == 0 {
return 0.5; }
let avg_words_per_sentence = word_count as f32 / sentence_count as f32;
((avg_words_per_sentence - 5.0) / 15.0).clamp(0.0, 1.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prediction_input_features() {
let input = PredictionInput {
cpu_usage: 0.5,
memory_usage: 0.6,
text_complexity: 0.7,
time_of_day: 12,
recent_rtf: 0.4,
};
let features = input.to_features();
assert_eq!(features.len(), 5);
assert_eq!(features[0], 0.5);
assert_eq!(features[3], 0.5); }
#[test]
fn test_linear_model_prediction() {
let model = LinearModel::new();
let features = Array1::from_vec(vec![0.5, 0.5, 0.5, 0.5, 0.5]);
let prediction = model.predict(&features);
assert!(prediction >= 0.0 && prediction <= 100.0);
}
#[tokio::test]
async fn test_quality_predictor_without_training() {
let predictor = QualityPredictor::new();
let input = PredictionInput {
cpu_usage: 0.5,
memory_usage: 0.5,
text_complexity: 0.5,
time_of_day: 12,
recent_rtf: 0.5,
};
let prediction = predictor.predict(&input).await.unwrap();
assert_eq!(prediction.quality, QualityTarget::Medium);
assert_eq!(prediction.confidence, 0.0); }
#[tokio::test]
async fn test_quality_predictor_with_training() {
let mut predictor = QualityPredictor::new().with_history_size(100, 5);
for i in 0..10 {
let sample = TrainingSample {
input: PredictionInput {
cpu_usage: 0.5,
memory_usage: 0.5,
text_complexity: 0.5,
time_of_day: 12,
recent_rtf: 0.5,
},
quality: QualityTarget::High,
synthesis_time_ms: 100 + i * 10,
success: true,
measured_rtf: 0.5,
};
predictor.train(sample).await.unwrap();
}
let stats = predictor.get_stats();
assert_eq!(stats.total_samples, 10);
assert_eq!(stats.successful_samples, 10);
}
#[test]
fn test_text_complexity_simple() {
let simple_text = "Hello world.";
let complexity = TextComplexityAnalyzer::analyze(simple_text);
assert!(complexity < 0.5); }
#[test]
fn test_text_complexity_complex() {
let complex_text = "The sophisticated implementation of machine learning \
algorithms requires comprehensive understanding of mathematical \
foundations, statistical methodologies, and computational efficiency.";
let complexity = TextComplexityAnalyzer::analyze(complex_text);
assert!(complexity > 0.5); }
#[test]
fn test_text_complexity_empty() {
let empty_text = "";
let complexity = TextComplexityAnalyzer::analyze(empty_text);
assert_eq!(complexity, 0.0);
}
#[tokio::test]
async fn test_predictor_reset() {
let mut predictor = QualityPredictor::new();
let sample = TrainingSample {
input: PredictionInput {
cpu_usage: 0.5,
memory_usage: 0.5,
text_complexity: 0.5,
time_of_day: 12,
recent_rtf: 0.5,
},
quality: QualityTarget::High,
synthesis_time_ms: 100,
success: true,
measured_rtf: 0.5,
};
predictor.train(sample).await.unwrap();
assert_eq!(predictor.get_stats().total_samples, 1);
predictor.reset();
assert_eq!(predictor.get_stats().total_samples, 0);
}
}