use super::model::MultiModalEmbedding;
use anyhow::Result;
use scirs2_core::ndarray_ext::Array2;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RealTimeFinetuning {
pub learning_rate: f32,
pub buffer_size: usize,
pub update_frequency: usize,
pub ewc_config: EWCConfig,
pub online_buffer: Vec<(String, String, String)>,
pub update_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EWCConfig {
pub lambda: f32,
pub fisher_information: HashMap<String, Array2<f32>>,
pub optimal_params: HashMap<String, Array2<f32>>,
}
impl Default for RealTimeFinetuning {
fn default() -> Self {
Self {
learning_rate: 0.001,
buffer_size: 1000,
update_frequency: 10,
ewc_config: EWCConfig::default(),
online_buffer: Vec::new(),
update_count: 0,
}
}
}
impl Default for EWCConfig {
fn default() -> Self {
Self {
lambda: 0.1,
fisher_information: HashMap::new(),
optimal_params: HashMap::new(),
}
}
}
impl RealTimeFinetuning {
pub fn add_example(&mut self, text: String, entity: String, label: String) {
self.online_buffer.push((text, entity, label));
if self.online_buffer.len() > self.buffer_size {
self.online_buffer.remove(0);
}
self.update_count += 1;
}
pub fn should_update(&self) -> bool {
self.update_count % self.update_frequency == 0 && !self.online_buffer.is_empty()
}
pub async fn update_model(&mut self, model: &mut MultiModalEmbedding) -> Result<f32> {
if !self.should_update() {
return Ok(0.0);
}
let mut total_loss = 0.0;
let batch_size = self.update_frequency.min(self.online_buffer.len());
let update_batch = &self.online_buffer[self.online_buffer.len() - batch_size..];
for (text, entity, _label) in update_batch {
let unified = model.generate_unified_embedding(text, entity).await?;
let loss = unified.iter().map(|&x| x * x).sum::<f32>() / unified.len() as f32;
total_loss += loss;
let ewc_loss = self.compute_ewc_loss(&model.text_encoder.parameters)?;
total_loss += ewc_loss * self.ewc_config.lambda;
}
total_loss /= batch_size as f32;
self.update_fisher_information(model)?;
Ok(total_loss)
}
fn compute_ewc_loss(&self, current_params: &HashMap<String, Array2<f32>>) -> Result<f32> {
let mut ewc_loss = 0.0;
for (param_name, current_param) in current_params {
if let (Some(fisher), Some(optimal)) = (
self.ewc_config.fisher_information.get(param_name),
self.ewc_config.optimal_params.get(param_name),
) {
let diff = current_param - optimal;
let weighted_diff = &diff * fisher;
ewc_loss += (&diff * &weighted_diff).sum();
}
}
Ok(ewc_loss)
}
fn update_fisher_information(&mut self, model: &MultiModalEmbedding) -> Result<()> {
for (param_name, param) in &model.text_encoder.parameters {
let fisher = Array2::from_shape_fn(param.dim(), |(_, _)| {
use scirs2_core::random::{Random, RngExt};
let mut random = Random::default();
random.random::<f32>() * 0.01
});
self.ewc_config
.fisher_information
.insert(param_name.clone(), fisher);
self.ewc_config
.optimal_params
.insert(param_name.clone(), param.clone());
}
Ok(())
}
}