use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use async_trait::async_trait;
use super::agent::Action;
pub struct StreamLearner {
model: OnlineModel,
learning_rate: f64,
experience_buffer: VecDeque<Experience>,
buffer_size: usize,
iterations: u64,
strategy: AdaptationStrategy,
}
impl StreamLearner {
pub fn new(learning_rate: f64) -> Self {
Self {
model: OnlineModel::new(),
learning_rate,
experience_buffer: VecDeque::new(),
buffer_size: 1000,
iterations: 0,
strategy: AdaptationStrategy::default(),
}
}
pub async fn update(
&mut self,
action: &Action,
reward: f64,
context: &str,
) -> Result<(), String> {
self.iterations += 1;
let experience = Experience {
action: action.clone(),
reward,
context: context.to_string(),
timestamp: chrono::Utc::now().timestamp(),
};
self.experience_buffer.push_back(experience.clone());
if self.experience_buffer.len() > self.buffer_size {
self.experience_buffer.pop_front();
}
match &self.strategy {
AdaptationStrategy::Immediate => {
self.model.update_immediate(&experience, self.learning_rate).await?;
}
AdaptationStrategy::Batched { batch_size } => {
if self.iterations % batch_size == 0 {
self.model.update_batch(&self.experience_buffer, self.learning_rate).await?;
}
}
AdaptationStrategy::ExperienceReplay { replay_size } => {
self.model.update_immediate(&experience, self.learning_rate).await?;
let replay_samples = self.sample_experiences(*replay_size);
for sample in replay_samples {
self.model.update_immediate(&sample, self.learning_rate * 0.5).await?;
}
}
}
Ok(())
}
fn sample_experiences(&self, n: usize) -> Vec<Experience> {
let experiences: Vec<_> = self.experience_buffer.iter().cloned().collect();
let total = experiences.len();
if total == 0 || n == 0 {
return Vec::new();
}
let step = (total as f64 / n as f64).max(1.0) as usize;
experiences.iter()
.step_by(step)
.take(n)
.cloned()
.collect()
}
pub async fn predict_reward(&self, action: &Action, context: &str) -> f64 {
self.model.predict(action, context).await
}
pub fn get_stats(&self) -> LearningStats {
LearningStats {
iterations: self.iterations,
buffer_size: self.experience_buffer.len(),
average_reward: self.compute_average_reward(),
model_parameters: self.model.parameter_count(),
}
}
fn compute_average_reward(&self) -> f64 {
if self.experience_buffer.is_empty() {
return 0.0;
}
let sum: f64 = self.experience_buffer.iter()
.map(|e| e.reward)
.sum();
sum / self.experience_buffer.len() as f64
}
pub fn iteration_count(&self) -> u64 {
self.iterations
}
}
pub struct OnlineModel {
weights: HashMap<String, f64>,
bias: f64,
feature_stats: HashMap<String, FeatureStats>,
}
impl OnlineModel {
pub fn new() -> Self {
Self {
weights: HashMap::new(),
bias: 0.0,
feature_stats: HashMap::new(),
}
}
fn extract_features(&self, action: &Action, context: &str) -> HashMap<String, f64> {
let mut features = HashMap::new();
features.insert(
format!("action_{}", action.action_type),
1.0,
);
features.insert(
"param_count".to_string(),
action.parameters.len() as f64,
);
features.insert(
"tool_count".to_string(),
action.tool_calls.len() as f64,
);
features.insert(
"context_length".to_string(),
context.len() as f64 / 100.0, );
features.insert(
"expected_reward".to_string(),
action.expected_reward,
);
features
}
pub async fn predict(&self, action: &Action, context: &str) -> f64 {
let features = self.extract_features(action, context);
let mut prediction = self.bias;
for (feature, value) in features {
if let Some(weight) = self.weights.get(&feature) {
prediction += weight * value;
}
}
prediction
}
pub async fn update_immediate(
&mut self,
experience: &Experience,
learning_rate: f64,
) -> Result<(), String> {
let features = self.extract_features(&experience.action, &experience.context);
let prediction = self.predict(&experience.action, &experience.context).await;
let error = experience.reward - prediction;
self.bias += learning_rate * error;
for (feature, value) in features {
let weight = self.weights.entry(feature.clone()).or_insert(0.0);
*weight += learning_rate * error * value;
let stats = self.feature_stats.entry(feature).or_insert(FeatureStats::default());
stats.update(value);
}
Ok(())
}
pub async fn update_batch(
&mut self,
experiences: &VecDeque<Experience>,
learning_rate: f64,
) -> Result<(), String> {
for experience in experiences {
self.update_immediate(experience, learning_rate).await?;
}
Ok(())
}
pub fn parameter_count(&self) -> usize {
self.weights.len() + 1 }
}
#[derive(Debug, Clone)]
pub struct Experience {
pub action: Action,
pub reward: f64,
pub context: String,
pub timestamp: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AdaptationStrategy {
Immediate,
Batched { batch_size: u64 },
ExperienceReplay { replay_size: usize },
}
impl Default for AdaptationStrategy {
fn default() -> Self {
AdaptationStrategy::Immediate
}
}
#[derive(Debug, Clone, Default)]
struct FeatureStats {
count: u64,
sum: f64,
sum_squared: f64,
}
impl FeatureStats {
fn update(&mut self, value: f64) {
self.count += 1;
self.sum += value;
self.sum_squared += value * value;
}
fn mean(&self) -> f64 {
if self.count == 0 {
0.0
} else {
self.sum / self.count as f64
}
}
fn variance(&self) -> f64 {
if self.count == 0 {
0.0
} else {
let mean = self.mean();
(self.sum_squared / self.count as f64) - (mean * mean)
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LearningStats {
pub iterations: u64,
pub buffer_size: usize,
pub average_reward: f64,
pub model_parameters: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[tokio::test]
async fn test_stream_learner() {
let mut learner = StreamLearner::new(0.01);
let action = Action {
action_type: "test".to_string(),
description: "Test action".to_string(),
parameters: HashMap::new(),
tool_calls: vec![],
expected_outcome: None,
expected_reward: 0.5,
};
let result = learner.update(&action, 1.0, "test context").await;
assert!(result.is_ok());
let stats = learner.get_stats();
assert_eq!(stats.iterations, 1);
}
}