oxirs_embed/
fine_tuning.rs

1//! Fine-tuning Capabilities for Pre-trained Embedding Models
2//!
3//! This module provides tools for fine-tuning pre-trained knowledge graph embeddings
4//! on domain-specific data, enabling transfer learning and model adaptation.
5
6use anyhow::{anyhow, Result};
7use rayon::prelude::*;
8use scirs2_core::ndarray_ext::Array1;
9use scirs2_core::random::Random;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use tracing::{debug, info, warn};
13
14use crate::{EmbeddingModel, Triple};
15
16/// Fine-tuning strategy
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
18pub enum FineTuningStrategy {
19    /// Fine-tune all parameters
20    FullFineTuning,
21    /// Freeze entity embeddings, only update relation embeddings
22    FreezeEntities,
23    /// Freeze relation embeddings, only update entity embeddings
24    FreezeRelations,
25    /// Only fine-tune last N% of dimensions
26    PartialDimensions,
27    /// Adapter-based fine-tuning (add small adapter layers)
28    AdapterBased,
29    /// Layer-wise discriminative fine-tuning (different learning rates per layer)
30    Discriminative,
31}
32
33/// Fine-tuning configuration
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct FineTuningConfig {
36    /// Fine-tuning strategy
37    pub strategy: FineTuningStrategy,
38    /// Learning rate for fine-tuning (typically lower than pre-training)
39    pub learning_rate: f64,
40    /// Number of fine-tuning epochs
41    pub max_epochs: usize,
42    /// Regularization strength (prevents catastrophic forgetting)
43    pub regularization: f64,
44    /// Percentage of dimensions to fine-tune (for PartialDimensions strategy)
45    pub partial_dimensions_pct: f32,
46    /// Adapter dimension size (for AdapterBased strategy)
47    pub adapter_dim: usize,
48    /// Early stopping patience
49    pub early_stopping_patience: usize,
50    /// Minimum improvement threshold for early stopping
51    pub min_improvement: f64,
52    /// Validation split ratio (0.0 to 1.0)
53    pub validation_split: f32,
54    /// Whether to use knowledge distillation from the pre-trained model
55    pub use_distillation: bool,
56    /// Distillation temperature
57    pub distillation_temperature: f32,
58    /// Distillation weight (balance between task loss and distillation loss)
59    pub distillation_weight: f32,
60}
61
62impl Default for FineTuningConfig {
63    fn default() -> Self {
64        Self {
65            strategy: FineTuningStrategy::FullFineTuning,
66            learning_rate: 0.001, // 10x lower than typical pre-training
67            max_epochs: 50,
68            regularization: 0.01,
69            partial_dimensions_pct: 0.2, // Fine-tune top 20% of dimensions
70            adapter_dim: 32,
71            early_stopping_patience: 5,
72            min_improvement: 0.001,
73            validation_split: 0.1,
74            use_distillation: false,
75            distillation_temperature: 2.0,
76            distillation_weight: 0.5,
77        }
78    }
79}
80
81/// Fine-tuning result with statistics
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct FineTuningResult {
84    /// Number of epochs completed
85    pub epochs_completed: usize,
86    /// Final training loss
87    pub final_training_loss: f64,
88    /// Final validation loss
89    pub final_validation_loss: f64,
90    /// Training time in seconds
91    pub training_time_seconds: f64,
92    /// Whether early stopping was triggered
93    pub early_stopped: bool,
94    /// Best validation loss achieved
95    pub best_validation_loss: f64,
96    /// Training loss history
97    pub training_loss_history: Vec<f64>,
98    /// Validation loss history
99    pub validation_loss_history: Vec<f64>,
100    /// Number of parameters updated
101    pub num_parameters_updated: usize,
102}
103
104/// Adapter layer for adapter-based fine-tuning
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct AdapterLayer {
107    /// Down-projection matrix (embed_dim -> adapter_dim)
108    pub down_projection: Vec<Vec<f32>>,
109    /// Up-projection matrix (adapter_dim -> embed_dim)
110    pub up_projection: Vec<Vec<f32>>,
111    /// Bias for down projection
112    pub down_bias: Vec<f32>,
113    /// Bias for up projection
114    pub up_bias: Vec<f32>,
115}
116
117impl AdapterLayer {
118    /// Create a new adapter layer with random initialization
119    pub fn new(embed_dim: usize, adapter_dim: usize) -> Self {
120        let mut rng = Random::default();
121        let scale = (2.0 / embed_dim as f32).sqrt();
122
123        let down_projection = (0..adapter_dim)
124            .map(|_| {
125                (0..embed_dim)
126                    .map(|_| rng.gen_range(-scale..scale))
127                    .collect()
128            })
129            .collect();
130
131        let up_projection = (0..embed_dim)
132            .map(|_| {
133                (0..adapter_dim)
134                    .map(|_| rng.gen_range(-scale..scale))
135                    .collect()
136            })
137            .collect();
138
139        let down_bias = vec![0.0; adapter_dim];
140        let up_bias = vec![0.0; embed_dim];
141
142        Self {
143            down_projection,
144            up_projection,
145            down_bias,
146            up_bias,
147        }
148    }
149
150    /// Forward pass through the adapter
151    pub fn forward(&self, input: &Array1<f32>) -> Array1<f32> {
152        let embed_dim = input.len();
153
154        // Down-projection: adapter_dim = down @ input + down_bias
155        let mut hidden: Vec<f32> = vec![0.0; self.down_bias.len()];
156        for (i, h) in hidden.iter_mut().enumerate() {
157            let mut sum = self.down_bias[i];
158            for j in 0..embed_dim {
159                sum += self.down_projection[i][j] * input[j];
160            }
161            // ReLU activation
162            *h = sum.max(0.0);
163        }
164
165        // Up-projection: output = up @ hidden + up_bias + input (residual)
166        let mut output = vec![0.0; embed_dim];
167        for i in 0..embed_dim {
168            let mut sum = self.up_bias[i];
169            for (j, &h_val) in hidden.iter().enumerate() {
170                sum += self.up_projection[i][j] * h_val;
171            }
172            // Residual connection
173            output[i] = sum + input[i];
174        }
175
176        Array1::from_vec(output)
177    }
178}
179
180/// Fine-tuning manager for embedding models
181pub struct FineTuningManager {
182    config: FineTuningConfig,
183    /// Pre-trained embeddings for knowledge distillation
184    pretrained_entities: HashMap<String, Array1<f32>>,
185    pretrained_relations: HashMap<String, Array1<f32>>,
186    /// Adapter layers (if using adapter-based strategy)
187    entity_adapters: HashMap<String, AdapterLayer>,
188    relation_adapters: HashMap<String, AdapterLayer>,
189}
190
191impl FineTuningManager {
192    /// Create a new fine-tuning manager
193    pub fn new(config: FineTuningConfig) -> Self {
194        info!(
195            "Initialized fine-tuning manager with strategy: {:?}",
196            config.strategy
197        );
198
199        Self {
200            config,
201            pretrained_entities: HashMap::new(),
202            pretrained_relations: HashMap::new(),
203            entity_adapters: HashMap::new(),
204            relation_adapters: HashMap::new(),
205        }
206    }
207
208    /// Save pre-trained embeddings for distillation
209    pub fn save_pretrained_embeddings<M: EmbeddingModel>(&mut self, model: &M) -> Result<()> {
210        if !self.config.use_distillation {
211            return Ok(());
212        }
213
214        info!("Saving pre-trained embeddings for knowledge distillation");
215
216        // Save entity embeddings
217        for entity in model.get_entities() {
218            if let Ok(emb) = model.get_entity_embedding(&entity) {
219                self.pretrained_entities
220                    .insert(entity, Array1::from_vec(emb.values));
221            }
222        }
223
224        // Save relation embeddings
225        for relation in model.get_relations() {
226            if let Ok(emb) = model.get_relation_embedding(&relation) {
227                self.pretrained_relations
228                    .insert(relation, Array1::from_vec(emb.values));
229            }
230        }
231
232        info!(
233            "Saved {} entity and {} relation embeddings",
234            self.pretrained_entities.len(),
235            self.pretrained_relations.len()
236        );
237
238        Ok(())
239    }
240
241    /// Initialize adapters for adapter-based fine-tuning
242    pub fn initialize_adapters<M: EmbeddingModel>(
243        &mut self,
244        model: &M,
245        embed_dim: usize,
246    ) -> Result<()> {
247        if self.config.strategy != FineTuningStrategy::AdapterBased {
248            return Ok(());
249        }
250
251        info!(
252            "Initializing adapters with dimension: embed_dim={}, adapter_dim={}",
253            embed_dim, self.config.adapter_dim
254        );
255
256        // Initialize entity adapters
257        for entity in model.get_entities() {
258            let adapter = AdapterLayer::new(embed_dim, self.config.adapter_dim);
259            self.entity_adapters.insert(entity, adapter);
260        }
261
262        // Initialize relation adapters
263        for relation in model.get_relations() {
264            let adapter = AdapterLayer::new(embed_dim, self.config.adapter_dim);
265            self.relation_adapters.insert(relation, adapter);
266        }
267
268        info!(
269            "Initialized {} entity and {} relation adapters",
270            self.entity_adapters.len(),
271            self.relation_adapters.len()
272        );
273
274        Ok(())
275    }
276
277    /// Fine-tune a model on domain-specific data
278    pub async fn fine_tune<M: EmbeddingModel>(
279        &mut self,
280        model: &mut M,
281        training_triples: Vec<Triple>,
282    ) -> Result<FineTuningResult> {
283        if training_triples.is_empty() {
284            return Err(anyhow!("No training data provided for fine-tuning"));
285        }
286
287        info!(
288            "Starting fine-tuning with {} triples using {:?} strategy",
289            training_triples.len(),
290            self.config.strategy
291        );
292
293        // Split into training and validation sets
294        let (train_data, val_data) = self.split_data(&training_triples)?;
295
296        info!(
297            "Split data: {} training, {} validation",
298            train_data.len(),
299            val_data.len()
300        );
301
302        // Save pre-trained embeddings if using distillation
303        if self.config.use_distillation {
304            self.save_pretrained_embeddings(model)?;
305        }
306
307        // Initialize adapters if needed
308        if self.config.strategy == FineTuningStrategy::AdapterBased {
309            let config = model.config();
310            self.initialize_adapters(model, config.dimensions)?;
311        }
312
313        // Add training triples to model
314        for triple in &train_data {
315            model.add_triple(triple.clone())?;
316        }
317
318        let start_time = std::time::Instant::now();
319        let mut training_loss_history = Vec::new();
320        let mut validation_loss_history = Vec::new();
321        let mut best_val_loss = f64::INFINITY;
322        let mut patience_counter = 0;
323        let mut early_stopped = false;
324
325        // Training loop
326        for epoch in 0..self.config.max_epochs {
327            // Train for one epoch
328            let stats = model.train(Some(1)).await?;
329            let train_loss = stats.final_loss;
330            training_loss_history.push(train_loss);
331
332            // Validate
333            let val_loss = self.validate(model, &val_data)?;
334            validation_loss_history.push(val_loss);
335
336            debug!(
337                "Epoch {}/{}: train_loss={:.6}, val_loss={:.6}",
338                epoch + 1,
339                self.config.max_epochs,
340                train_loss,
341                val_loss
342            );
343
344            // Early stopping check
345            if val_loss < best_val_loss - self.config.min_improvement {
346                best_val_loss = val_loss;
347                patience_counter = 0;
348                info!("New best validation loss: {:.6}", best_val_loss);
349            } else {
350                patience_counter += 1;
351                if patience_counter >= self.config.early_stopping_patience {
352                    warn!(
353                        "Early stopping triggered at epoch {} (patience={})",
354                        epoch + 1,
355                        self.config.early_stopping_patience
356                    );
357                    early_stopped = true;
358                    break;
359                }
360            }
361        }
362
363        let training_time = start_time.elapsed().as_secs_f64();
364
365        // Count updated parameters
366        let num_parameters_updated = self.count_updated_parameters(model)?;
367
368        info!(
369            "Fine-tuning complete: {} epochs, {:.2}s, {} parameters updated",
370            training_loss_history.len(),
371            training_time,
372            num_parameters_updated
373        );
374
375        Ok(FineTuningResult {
376            epochs_completed: training_loss_history.len(),
377            final_training_loss: *training_loss_history.last().unwrap_or(&0.0),
378            final_validation_loss: *validation_loss_history.last().unwrap_or(&0.0),
379            training_time_seconds: training_time,
380            early_stopped,
381            best_validation_loss: best_val_loss,
382            training_loss_history,
383            validation_loss_history,
384            num_parameters_updated,
385        })
386    }
387
388    /// Split data into training and validation sets
389    fn split_data(&self, data: &[Triple]) -> Result<(Vec<Triple>, Vec<Triple>)> {
390        let val_size = (data.len() as f32 * self.config.validation_split) as usize;
391        let train_size = data.len() - val_size;
392
393        if val_size == 0 {
394            warn!("Validation set is empty, using full data for training");
395            return Ok((data.to_vec(), Vec::new()));
396        }
397
398        let mut indices: Vec<usize> = (0..data.len()).collect();
399        let mut rng = Random::default();
400
401        // Shuffle indices
402        for i in (1..indices.len()).rev() {
403            let j = rng.random_range(0..i + 1);
404            indices.swap(i, j);
405        }
406
407        let train_data: Vec<Triple> = indices[..train_size]
408            .iter()
409            .map(|&i| data[i].clone())
410            .collect();
411
412        let val_data: Vec<Triple> = indices[train_size..]
413            .iter()
414            .map(|&i| data[i].clone())
415            .collect();
416
417        Ok((train_data, val_data))
418    }
419
420    /// Validate the model on validation data
421    fn validate<M: EmbeddingModel>(&self, model: &M, val_data: &[Triple]) -> Result<f64> {
422        if val_data.is_empty() {
423            return Ok(0.0);
424        }
425
426        let total_loss: f64 = val_data
427            .par_iter()
428            .filter_map(|triple| {
429                model
430                    .score_triple(
431                        &triple.subject.iri,
432                        &triple.predicate.iri,
433                        &triple.object.iri,
434                    )
435                    .ok()
436            })
437            .map(|score| {
438                // Margin-based loss (higher score is better, so negative for minimization)
439                -score
440            })
441            .sum();
442
443        Ok(total_loss / val_data.len() as f64)
444    }
445
446    /// Count the number of parameters that would be updated
447    fn count_updated_parameters<M: EmbeddingModel>(&self, model: &M) -> Result<usize> {
448        let stats = model.get_stats();
449        let embed_dim = stats.dimensions;
450
451        match self.config.strategy {
452            FineTuningStrategy::FullFineTuning => {
453                Ok((stats.num_entities + stats.num_relations) * embed_dim)
454            }
455            FineTuningStrategy::FreezeEntities => Ok(stats.num_relations * embed_dim),
456            FineTuningStrategy::FreezeRelations => Ok(stats.num_entities * embed_dim),
457            FineTuningStrategy::PartialDimensions => {
458                let partial_dim = (embed_dim as f32 * self.config.partial_dimensions_pct) as usize;
459                Ok((stats.num_entities + stats.num_relations) * partial_dim)
460            }
461            FineTuningStrategy::AdapterBased => {
462                let adapter_params =
463                    2 * embed_dim * self.config.adapter_dim + embed_dim + self.config.adapter_dim;
464                Ok((stats.num_entities + stats.num_relations) * adapter_params)
465            }
466            FineTuningStrategy::Discriminative => {
467                // All parameters but with different learning rates
468                Ok((stats.num_entities + stats.num_relations) * embed_dim)
469            }
470        }
471    }
472
473    /// Get fine-tuning statistics
474    pub fn get_stats(&self) -> FineTuningStats {
475        FineTuningStats {
476            num_pretrained_entities: self.pretrained_entities.len(),
477            num_pretrained_relations: self.pretrained_relations.len(),
478            num_entity_adapters: self.entity_adapters.len(),
479            num_relation_adapters: self.relation_adapters.len(),
480            strategy: self.config.strategy,
481        }
482    }
483}
484
485/// Fine-tuning statistics
486#[derive(Debug, Clone, Serialize, Deserialize)]
487pub struct FineTuningStats {
488    pub num_pretrained_entities: usize,
489    pub num_pretrained_relations: usize,
490    pub num_entity_adapters: usize,
491    pub num_relation_adapters: usize,
492    pub strategy: FineTuningStrategy,
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498    use crate::NamedNode;
499
500    #[test]
501    fn test_fine_tuning_config_default() {
502        let config = FineTuningConfig::default();
503        assert_eq!(config.strategy, FineTuningStrategy::FullFineTuning);
504        assert!(config.learning_rate < 0.01); // Should be lower than pre-training
505        assert_eq!(config.max_epochs, 50);
506    }
507
508    #[test]
509    fn test_adapter_layer_creation() {
510        let adapter = AdapterLayer::new(128, 32);
511        assert_eq!(adapter.down_projection.len(), 32);
512        assert_eq!(adapter.up_projection.len(), 128);
513        assert_eq!(adapter.down_bias.len(), 32);
514        assert_eq!(adapter.up_bias.len(), 128);
515    }
516
517    #[test]
518    fn test_adapter_forward_pass() {
519        let adapter = AdapterLayer::new(128, 32);
520        let input = Array1::from_vec(vec![1.0; 128]);
521        let output = adapter.forward(&input);
522        assert_eq!(output.len(), 128);
523        // Output should be different from input due to adapter transformation
524    }
525
526    #[test]
527    fn test_fine_tuning_manager_creation() {
528        let config = FineTuningConfig::default();
529        let manager = FineTuningManager::new(config);
530        let stats = manager.get_stats();
531        assert_eq!(stats.num_pretrained_entities, 0);
532        assert_eq!(stats.strategy, FineTuningStrategy::FullFineTuning);
533    }
534
535    #[test]
536    fn test_split_data() {
537        let config = FineTuningConfig {
538            validation_split: 0.2,
539            ..Default::default()
540        };
541        let manager = FineTuningManager::new(config);
542
543        let triples: Vec<Triple> = (0..100)
544            .map(|i| Triple {
545                subject: NamedNode {
546                    iri: format!("s{}", i),
547                },
548                predicate: NamedNode {
549                    iri: format!("p{}", i),
550                },
551                object: NamedNode {
552                    iri: format!("o{}", i),
553                },
554            })
555            .collect();
556
557        let (train, val) = manager.split_data(&triples).unwrap();
558        assert_eq!(train.len(), 80);
559        assert_eq!(val.len(), 20);
560    }
561}