1use anyhow::{anyhow, Result};
7use rayon::prelude::*;
8use scirs2_core::ndarray_ext::Array1;
9use scirs2_core::random::Random;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use tracing::{debug, info, warn};
13
14use crate::{EmbeddingModel, Triple};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
18pub enum FineTuningStrategy {
19 FullFineTuning,
21 FreezeEntities,
23 FreezeRelations,
25 PartialDimensions,
27 AdapterBased,
29 Discriminative,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct FineTuningConfig {
36 pub strategy: FineTuningStrategy,
38 pub learning_rate: f64,
40 pub max_epochs: usize,
42 pub regularization: f64,
44 pub partial_dimensions_pct: f32,
46 pub adapter_dim: usize,
48 pub early_stopping_patience: usize,
50 pub min_improvement: f64,
52 pub validation_split: f32,
54 pub use_distillation: bool,
56 pub distillation_temperature: f32,
58 pub distillation_weight: f32,
60}
61
62impl Default for FineTuningConfig {
63 fn default() -> Self {
64 Self {
65 strategy: FineTuningStrategy::FullFineTuning,
66 learning_rate: 0.001, max_epochs: 50,
68 regularization: 0.01,
69 partial_dimensions_pct: 0.2, adapter_dim: 32,
71 early_stopping_patience: 5,
72 min_improvement: 0.001,
73 validation_split: 0.1,
74 use_distillation: false,
75 distillation_temperature: 2.0,
76 distillation_weight: 0.5,
77 }
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct FineTuningResult {
84 pub epochs_completed: usize,
86 pub final_training_loss: f64,
88 pub final_validation_loss: f64,
90 pub training_time_seconds: f64,
92 pub early_stopped: bool,
94 pub best_validation_loss: f64,
96 pub training_loss_history: Vec<f64>,
98 pub validation_loss_history: Vec<f64>,
100 pub num_parameters_updated: usize,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct AdapterLayer {
107 pub down_projection: Vec<Vec<f32>>,
109 pub up_projection: Vec<Vec<f32>>,
111 pub down_bias: Vec<f32>,
113 pub up_bias: Vec<f32>,
115}
116
117impl AdapterLayer {
118 pub fn new(embed_dim: usize, adapter_dim: usize) -> Self {
120 let mut rng = Random::default();
121 let scale = (2.0 / embed_dim as f32).sqrt();
122
123 let down_projection = (0..adapter_dim)
124 .map(|_| {
125 (0..embed_dim)
126 .map(|_| rng.gen_range(-scale..scale))
127 .collect()
128 })
129 .collect();
130
131 let up_projection = (0..embed_dim)
132 .map(|_| {
133 (0..adapter_dim)
134 .map(|_| rng.gen_range(-scale..scale))
135 .collect()
136 })
137 .collect();
138
139 let down_bias = vec![0.0; adapter_dim];
140 let up_bias = vec![0.0; embed_dim];
141
142 Self {
143 down_projection,
144 up_projection,
145 down_bias,
146 up_bias,
147 }
148 }
149
150 pub fn forward(&self, input: &Array1<f32>) -> Array1<f32> {
152 let embed_dim = input.len();
153
154 let mut hidden: Vec<f32> = vec![0.0; self.down_bias.len()];
156 for (i, h) in hidden.iter_mut().enumerate() {
157 let mut sum = self.down_bias[i];
158 for j in 0..embed_dim {
159 sum += self.down_projection[i][j] * input[j];
160 }
161 *h = sum.max(0.0);
163 }
164
165 let mut output = vec![0.0; embed_dim];
167 for i in 0..embed_dim {
168 let mut sum = self.up_bias[i];
169 for (j, &h_val) in hidden.iter().enumerate() {
170 sum += self.up_projection[i][j] * h_val;
171 }
172 output[i] = sum + input[i];
174 }
175
176 Array1::from_vec(output)
177 }
178}
179
180pub struct FineTuningManager {
182 config: FineTuningConfig,
183 pretrained_entities: HashMap<String, Array1<f32>>,
185 pretrained_relations: HashMap<String, Array1<f32>>,
186 entity_adapters: HashMap<String, AdapterLayer>,
188 relation_adapters: HashMap<String, AdapterLayer>,
189}
190
191impl FineTuningManager {
192 pub fn new(config: FineTuningConfig) -> Self {
194 info!(
195 "Initialized fine-tuning manager with strategy: {:?}",
196 config.strategy
197 );
198
199 Self {
200 config,
201 pretrained_entities: HashMap::new(),
202 pretrained_relations: HashMap::new(),
203 entity_adapters: HashMap::new(),
204 relation_adapters: HashMap::new(),
205 }
206 }
207
208 pub fn save_pretrained_embeddings<M: EmbeddingModel>(&mut self, model: &M) -> Result<()> {
210 if !self.config.use_distillation {
211 return Ok(());
212 }
213
214 info!("Saving pre-trained embeddings for knowledge distillation");
215
216 for entity in model.get_entities() {
218 if let Ok(emb) = model.get_entity_embedding(&entity) {
219 self.pretrained_entities
220 .insert(entity, Array1::from_vec(emb.values));
221 }
222 }
223
224 for relation in model.get_relations() {
226 if let Ok(emb) = model.get_relation_embedding(&relation) {
227 self.pretrained_relations
228 .insert(relation, Array1::from_vec(emb.values));
229 }
230 }
231
232 info!(
233 "Saved {} entity and {} relation embeddings",
234 self.pretrained_entities.len(),
235 self.pretrained_relations.len()
236 );
237
238 Ok(())
239 }
240
241 pub fn initialize_adapters<M: EmbeddingModel>(
243 &mut self,
244 model: &M,
245 embed_dim: usize,
246 ) -> Result<()> {
247 if self.config.strategy != FineTuningStrategy::AdapterBased {
248 return Ok(());
249 }
250
251 info!(
252 "Initializing adapters with dimension: embed_dim={}, adapter_dim={}",
253 embed_dim, self.config.adapter_dim
254 );
255
256 for entity in model.get_entities() {
258 let adapter = AdapterLayer::new(embed_dim, self.config.adapter_dim);
259 self.entity_adapters.insert(entity, adapter);
260 }
261
262 for relation in model.get_relations() {
264 let adapter = AdapterLayer::new(embed_dim, self.config.adapter_dim);
265 self.relation_adapters.insert(relation, adapter);
266 }
267
268 info!(
269 "Initialized {} entity and {} relation adapters",
270 self.entity_adapters.len(),
271 self.relation_adapters.len()
272 );
273
274 Ok(())
275 }
276
277 pub async fn fine_tune<M: EmbeddingModel>(
279 &mut self,
280 model: &mut M,
281 training_triples: Vec<Triple>,
282 ) -> Result<FineTuningResult> {
283 if training_triples.is_empty() {
284 return Err(anyhow!("No training data provided for fine-tuning"));
285 }
286
287 info!(
288 "Starting fine-tuning with {} triples using {:?} strategy",
289 training_triples.len(),
290 self.config.strategy
291 );
292
293 let (train_data, val_data) = self.split_data(&training_triples)?;
295
296 info!(
297 "Split data: {} training, {} validation",
298 train_data.len(),
299 val_data.len()
300 );
301
302 if self.config.use_distillation {
304 self.save_pretrained_embeddings(model)?;
305 }
306
307 if self.config.strategy == FineTuningStrategy::AdapterBased {
309 let config = model.config();
310 self.initialize_adapters(model, config.dimensions)?;
311 }
312
313 for triple in &train_data {
315 model.add_triple(triple.clone())?;
316 }
317
318 let start_time = std::time::Instant::now();
319 let mut training_loss_history = Vec::new();
320 let mut validation_loss_history = Vec::new();
321 let mut best_val_loss = f64::INFINITY;
322 let mut patience_counter = 0;
323 let mut early_stopped = false;
324
325 for epoch in 0..self.config.max_epochs {
327 let stats = model.train(Some(1)).await?;
329 let train_loss = stats.final_loss;
330 training_loss_history.push(train_loss);
331
332 let val_loss = self.validate(model, &val_data)?;
334 validation_loss_history.push(val_loss);
335
336 debug!(
337 "Epoch {}/{}: train_loss={:.6}, val_loss={:.6}",
338 epoch + 1,
339 self.config.max_epochs,
340 train_loss,
341 val_loss
342 );
343
344 if val_loss < best_val_loss - self.config.min_improvement {
346 best_val_loss = val_loss;
347 patience_counter = 0;
348 info!("New best validation loss: {:.6}", best_val_loss);
349 } else {
350 patience_counter += 1;
351 if patience_counter >= self.config.early_stopping_patience {
352 warn!(
353 "Early stopping triggered at epoch {} (patience={})",
354 epoch + 1,
355 self.config.early_stopping_patience
356 );
357 early_stopped = true;
358 break;
359 }
360 }
361 }
362
363 let training_time = start_time.elapsed().as_secs_f64();
364
365 let num_parameters_updated = self.count_updated_parameters(model)?;
367
368 info!(
369 "Fine-tuning complete: {} epochs, {:.2}s, {} parameters updated",
370 training_loss_history.len(),
371 training_time,
372 num_parameters_updated
373 );
374
375 Ok(FineTuningResult {
376 epochs_completed: training_loss_history.len(),
377 final_training_loss: *training_loss_history.last().unwrap_or(&0.0),
378 final_validation_loss: *validation_loss_history.last().unwrap_or(&0.0),
379 training_time_seconds: training_time,
380 early_stopped,
381 best_validation_loss: best_val_loss,
382 training_loss_history,
383 validation_loss_history,
384 num_parameters_updated,
385 })
386 }
387
388 fn split_data(&self, data: &[Triple]) -> Result<(Vec<Triple>, Vec<Triple>)> {
390 let val_size = (data.len() as f32 * self.config.validation_split) as usize;
391 let train_size = data.len() - val_size;
392
393 if val_size == 0 {
394 warn!("Validation set is empty, using full data for training");
395 return Ok((data.to_vec(), Vec::new()));
396 }
397
398 let mut indices: Vec<usize> = (0..data.len()).collect();
399 let mut rng = Random::default();
400
401 for i in (1..indices.len()).rev() {
403 let j = rng.random_range(0..i + 1);
404 indices.swap(i, j);
405 }
406
407 let train_data: Vec<Triple> = indices[..train_size]
408 .iter()
409 .map(|&i| data[i].clone())
410 .collect();
411
412 let val_data: Vec<Triple> = indices[train_size..]
413 .iter()
414 .map(|&i| data[i].clone())
415 .collect();
416
417 Ok((train_data, val_data))
418 }
419
420 fn validate<M: EmbeddingModel>(&self, model: &M, val_data: &[Triple]) -> Result<f64> {
422 if val_data.is_empty() {
423 return Ok(0.0);
424 }
425
426 let total_loss: f64 = val_data
427 .par_iter()
428 .filter_map(|triple| {
429 model
430 .score_triple(
431 &triple.subject.iri,
432 &triple.predicate.iri,
433 &triple.object.iri,
434 )
435 .ok()
436 })
437 .map(|score| {
438 -score
440 })
441 .sum();
442
443 Ok(total_loss / val_data.len() as f64)
444 }
445
446 fn count_updated_parameters<M: EmbeddingModel>(&self, model: &M) -> Result<usize> {
448 let stats = model.get_stats();
449 let embed_dim = stats.dimensions;
450
451 match self.config.strategy {
452 FineTuningStrategy::FullFineTuning => {
453 Ok((stats.num_entities + stats.num_relations) * embed_dim)
454 }
455 FineTuningStrategy::FreezeEntities => Ok(stats.num_relations * embed_dim),
456 FineTuningStrategy::FreezeRelations => Ok(stats.num_entities * embed_dim),
457 FineTuningStrategy::PartialDimensions => {
458 let partial_dim = (embed_dim as f32 * self.config.partial_dimensions_pct) as usize;
459 Ok((stats.num_entities + stats.num_relations) * partial_dim)
460 }
461 FineTuningStrategy::AdapterBased => {
462 let adapter_params =
463 2 * embed_dim * self.config.adapter_dim + embed_dim + self.config.adapter_dim;
464 Ok((stats.num_entities + stats.num_relations) * adapter_params)
465 }
466 FineTuningStrategy::Discriminative => {
467 Ok((stats.num_entities + stats.num_relations) * embed_dim)
469 }
470 }
471 }
472
473 pub fn get_stats(&self) -> FineTuningStats {
475 FineTuningStats {
476 num_pretrained_entities: self.pretrained_entities.len(),
477 num_pretrained_relations: self.pretrained_relations.len(),
478 num_entity_adapters: self.entity_adapters.len(),
479 num_relation_adapters: self.relation_adapters.len(),
480 strategy: self.config.strategy,
481 }
482 }
483}
484
485#[derive(Debug, Clone, Serialize, Deserialize)]
487pub struct FineTuningStats {
488 pub num_pretrained_entities: usize,
489 pub num_pretrained_relations: usize,
490 pub num_entity_adapters: usize,
491 pub num_relation_adapters: usize,
492 pub strategy: FineTuningStrategy,
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498 use crate::NamedNode;
499
500 #[test]
501 fn test_fine_tuning_config_default() {
502 let config = FineTuningConfig::default();
503 assert_eq!(config.strategy, FineTuningStrategy::FullFineTuning);
504 assert!(config.learning_rate < 0.01); assert_eq!(config.max_epochs, 50);
506 }
507
508 #[test]
509 fn test_adapter_layer_creation() {
510 let adapter = AdapterLayer::new(128, 32);
511 assert_eq!(adapter.down_projection.len(), 32);
512 assert_eq!(adapter.up_projection.len(), 128);
513 assert_eq!(adapter.down_bias.len(), 32);
514 assert_eq!(adapter.up_bias.len(), 128);
515 }
516
517 #[test]
518 fn test_adapter_forward_pass() {
519 let adapter = AdapterLayer::new(128, 32);
520 let input = Array1::from_vec(vec![1.0; 128]);
521 let output = adapter.forward(&input);
522 assert_eq!(output.len(), 128);
523 }
525
526 #[test]
527 fn test_fine_tuning_manager_creation() {
528 let config = FineTuningConfig::default();
529 let manager = FineTuningManager::new(config);
530 let stats = manager.get_stats();
531 assert_eq!(stats.num_pretrained_entities, 0);
532 assert_eq!(stats.strategy, FineTuningStrategy::FullFineTuning);
533 }
534
535 #[test]
536 fn test_split_data() {
537 let config = FineTuningConfig {
538 validation_split: 0.2,
539 ..Default::default()
540 };
541 let manager = FineTuningManager::new(config);
542
543 let triples: Vec<Triple> = (0..100)
544 .map(|i| Triple {
545 subject: NamedNode {
546 iri: format!("s{}", i),
547 },
548 predicate: NamedNode {
549 iri: format!("p{}", i),
550 },
551 object: NamedNode {
552 iri: format!("o{}", i),
553 },
554 })
555 .collect();
556
557 let (train, val) = manager.split_data(&triples).unwrap();
558 assert_eq!(train.len(), 80);
559 assert_eq!(val.len(), 20);
560 }
561}