oxirs_embed/multimodal/impl/
adaptation.rs1use super::model::MultiModalEmbedding;
4use anyhow::Result;
5use scirs2_core::ndarray_ext::Array2;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct RealTimeFinetuning {
12 pub learning_rate: f32,
14 pub buffer_size: usize,
16 pub update_frequency: usize,
18 pub ewc_config: EWCConfig,
20 pub online_buffer: Vec<(String, String, String)>,
22 pub update_count: usize,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct EWCConfig {
29 pub lambda: f32,
31 pub fisher_information: HashMap<String, Array2<f32>>,
33 pub optimal_params: HashMap<String, Array2<f32>>,
35}
36
37impl Default for RealTimeFinetuning {
38 fn default() -> Self {
39 Self {
40 learning_rate: 0.001,
41 buffer_size: 1000,
42 update_frequency: 10,
43 ewc_config: EWCConfig::default(),
44 online_buffer: Vec::new(),
45 update_count: 0,
46 }
47 }
48}
49
50impl Default for EWCConfig {
51 fn default() -> Self {
52 Self {
53 lambda: 0.1,
54 fisher_information: HashMap::new(),
55 optimal_params: HashMap::new(),
56 }
57 }
58}
59
60impl RealTimeFinetuning {
61 pub fn add_example(&mut self, text: String, entity: String, label: String) {
63 self.online_buffer.push((text, entity, label));
64
65 if self.online_buffer.len() > self.buffer_size {
67 self.online_buffer.remove(0);
68 }
69
70 self.update_count += 1;
71 }
72
73 pub fn should_update(&self) -> bool {
75 self.update_count % self.update_frequency == 0 && !self.online_buffer.is_empty()
76 }
77
78 pub async fn update_model(&mut self, model: &mut MultiModalEmbedding) -> Result<f32> {
80 if !self.should_update() {
81 return Ok(0.0);
82 }
83
84 let mut total_loss = 0.0;
85 let batch_size = self.update_frequency.min(self.online_buffer.len());
86
87 let update_batch = &self.online_buffer[self.online_buffer.len() - batch_size..];
89
90 for (text, entity, _label) in update_batch {
91 let unified = model.generate_unified_embedding(text, entity).await?;
93
94 let loss = unified.iter().map(|&x| x * x).sum::<f32>() / unified.len() as f32;
96 total_loss += loss;
97
98 let ewc_loss = self.compute_ewc_loss(&model.text_encoder.parameters)?;
100 total_loss += ewc_loss * self.ewc_config.lambda;
101 }
102
103 total_loss /= batch_size as f32;
104
105 self.update_fisher_information(model)?;
107
108 Ok(total_loss)
109 }
110
111 fn compute_ewc_loss(&self, current_params: &HashMap<String, Array2<f32>>) -> Result<f32> {
113 let mut ewc_loss = 0.0;
114
115 for (param_name, current_param) in current_params {
116 if let (Some(fisher), Some(optimal)) = (
117 self.ewc_config.fisher_information.get(param_name),
118 self.ewc_config.optimal_params.get(param_name),
119 ) {
120 let diff = current_param - optimal;
121 let weighted_diff = &diff * fisher;
122 ewc_loss += (&diff * &weighted_diff).sum();
123 }
124 }
125
126 Ok(ewc_loss)
127 }
128
129 fn update_fisher_information(&mut self, model: &MultiModalEmbedding) -> Result<()> {
131 for (param_name, param) in &model.text_encoder.parameters {
132 let fisher = Array2::from_shape_fn(param.dim(), |(_, _)| {
134 use scirs2_core::random::{Random, Rng};
135 let mut random = Random::default();
136 random.random::<f32>() * 0.01
137 });
138 self.ewc_config
139 .fisher_information
140 .insert(param_name.clone(), fisher);
141 self.ewc_config
142 .optimal_params
143 .insert(param_name.clone(), param.clone());
144 }
145
146 Ok(())
147 }
148}