oxirs_embed/models/transformer/
training.rs1use super::types::{ModelWeights, TransformerConfig, TransformerTrainingStats};
4use crate::Triple;
5use anyhow::Result;
6use scirs2_core::ndarray_ext::{Array1, Zip};
7use scirs2_core::random::{Random, Rng};
8use std::collections::HashMap;
9
10#[derive(Debug)]
12pub struct TransformerTrainer {
13 config: TransformerConfig,
14 entity_embeddings: HashMap<String, Array1<f32>>,
15 relation_embeddings: HashMap<String, Array1<f32>>,
16 entity_to_idx: HashMap<String, usize>,
17 relation_to_idx: HashMap<String, usize>,
18 model_weights: Option<ModelWeights>,
19 training_stats: TransformerTrainingStats,
20}
21
22impl TransformerTrainer {
23 pub fn new(config: TransformerConfig) -> Self {
24 Self {
25 config,
26 entity_embeddings: HashMap::new(),
27 relation_embeddings: HashMap::new(),
28 entity_to_idx: HashMap::new(),
29 relation_to_idx: HashMap::new(),
30 model_weights: None,
31 training_stats: TransformerTrainingStats::default(),
32 }
33 }
34
35 pub fn initialize_weights(&mut self, vocab_size: usize, hidden_size: usize) -> Result<()> {
37 self.model_weights = Some(ModelWeights::new(vocab_size, hidden_size));
38 Ok(())
39 }
40
41 pub async fn train(&mut self, triples: &[Triple], epochs: usize) -> Result<()> {
43 self.initialize_embeddings(triples)?;
45 let mut random = Random::default();
46
47 for epoch in 0..epochs {
48 self.training_stats.epoch = epoch;
49
50 let mut shuffled_triples = triples.to_vec();
52 for i in (1..shuffled_triples.len()).rev() {
54 let j = random.random_range(0..i + 1);
55 shuffled_triples.swap(i, j);
56 }
57
58 let batch_size = 32;
60 let batches = crate::models::common::create_batch_refs(&shuffled_triples, batch_size);
61
62 for (batch_idx, batch) in batches.enumerate() {
63 self.training_stats.batch_processed = batch_idx;
64
65 for triple in batch {
67 self.process_triple(triple).await?;
68 }
69
70 self.contrastive_learning(5).await?;
72
73 self.update_training_stats()?;
75 }
76
77 self.apply_regularization()?;
79 }
80
81 Ok(())
82 }
83
84 fn initialize_embeddings(&mut self, triples: &[Triple]) -> Result<()> {
86 let dimensions = self.config.base_config.dimensions;
87 let mut random = Random::default();
88
89 let mut entities = std::collections::HashSet::new();
91 let mut relations = std::collections::HashSet::new();
92
93 for triple in triples {
94 entities.insert(triple.subject.iri.clone());
95 entities.insert(triple.object.iri.clone());
96 relations.insert(triple.predicate.iri.clone());
97 }
98
99 let entities_vec: Vec<&String> = entities.iter().collect();
101 self.entity_embeddings.reserve(entities_vec.len());
102 self.entity_to_idx.reserve(entities_vec.len());
103
104 for (idx, entity) in entities_vec.iter().enumerate() {
105 let mut values = Vec::with_capacity(dimensions);
106 for _ in 0..dimensions {
107 values.push((random.random::<f64>() * 0.2 - 0.1) as f32);
108 }
109 let embedding = Array1::from_vec(values);
110 self.entity_embeddings.insert((*entity).clone(), embedding);
111 self.entity_to_idx.insert((*entity).clone(), idx);
112 }
113
114 let relations_vec: Vec<&String> = relations.iter().collect();
116 self.relation_embeddings.reserve(relations_vec.len());
117 self.relation_to_idx.reserve(relations_vec.len());
118
119 for (idx, relation) in relations_vec.iter().enumerate() {
120 let mut values = Vec::with_capacity(dimensions);
121 for _ in 0..dimensions {
122 values.push((random.random::<f64>() * 0.2 - 0.1) as f32);
123 }
124 let embedding = Array1::from_vec(values);
125 self.relation_embeddings
126 .insert((*relation).clone(), embedding);
127 self.relation_to_idx.insert((*relation).clone(), idx);
128 }
129
130 Ok(())
131 }
132
133 async fn process_triple(&mut self, triple: &Triple) -> Result<()> {
135 let subject_key = &triple.subject.iri;
136 let predicate_key = &triple.predicate.iri;
137 let object_key = &triple.object.iri;
138
139 let subject_emb = self.entity_embeddings.get(subject_key).cloned();
141 let predicate_emb = self.relation_embeddings.get(predicate_key).cloned();
142 let object_emb = self.entity_embeddings.get(object_key).cloned();
143
144 if let (Some(s_emb), Some(p_emb), Some(o_emb)) = (subject_emb, predicate_emb, object_emb) {
145 let predicted = &s_emb + &p_emb;
147 let diff = &predicted - &o_emb;
148 let loss = diff.mapv(|x| x * x).sum();
149
150 let learning_rate = self.config.base_config.learning_rate as f32;
152 self.apply_gradient_updates(&s_emb, &p_emb, &o_emb, &diff, learning_rate)?;
153
154 self.training_stats.reconstruction_loss = loss;
156 }
157
158 Ok(())
159 }
160
161 fn apply_gradient_updates(
163 &mut self,
164 _subject_emb: &Array1<f32>,
165 _predicate_emb: &Array1<f32>,
166 _object_emb: &Array1<f32>,
167 diff: &Array1<f32>,
168 learning_rate: f32,
169 ) -> Result<()> {
170 let subject_gradient = diff * 2.0;
172
173 let predicate_gradient = diff * 2.0;
175
176 let object_gradient = diff * -2.0;
178
179 let gradient_norm = subject_gradient.mapv(|x| x * x).sum().sqrt()
185 + predicate_gradient.mapv(|x| x * x).sum().sqrt()
186 + object_gradient.mapv(|x| x * x).sum().sqrt();
187
188 self.training_stats.gradient_norm = gradient_norm;
189 self.training_stats.learning_rate = learning_rate;
190
191 Ok(())
192 }
193
194 pub async fn contrastive_learning(&mut self, negative_samples: usize) -> Result<()> {
196 let temperature = 0.07;
197 let learning_rate = self.config.base_config.learning_rate as f32 * 0.5;
198 let mut random = Random::default();
199
200 let entity_keys: Vec<String> = self.entity_embeddings.keys().cloned().collect();
202
203 if entity_keys.len() < 2 {
204 return Ok(()); }
206
207 for (i, entity1) in entity_keys.iter().enumerate() {
209 for entity2 in entity_keys.iter().skip(i + 1) {
210 if let (Some(emb1), Some(emb2)) = (
211 self.entity_embeddings.get(entity1).cloned(),
212 self.entity_embeddings.get(entity2).cloned(),
213 ) {
214 let norm1 = emb1.mapv(|x| x * x).sum().sqrt();
216 let norm2 = emb2.mapv(|x| x * x).sum().sqrt();
217
218 if norm1 > 0.0 && norm2 > 0.0 {
219 let norm_factor = norm1 * norm2;
220
221 let positive_score = (&emb1 * &emb2).sum() / (norm_factor * temperature);
223
224 let mut negative_scores = Vec::new();
226 for _ in 0..negative_samples {
227 let neg_idx = random.random_range(0..entity_keys.len());
228 let neg_entity = &entity_keys[neg_idx];
229 {
230 if neg_entity != entity1 && neg_entity != entity2 {
231 if let Some(neg_emb) = self.entity_embeddings.get(neg_entity) {
232 let neg_norm = neg_emb.mapv(|x| x * x).sum().sqrt();
233 if neg_norm > 0.0 {
234 let neg_norm_factor = norm1 * neg_norm;
235 let neg_score = (&emb1 * neg_emb).sum()
236 / (neg_norm_factor * temperature);
237 negative_scores.push(neg_score);
238 }
239 }
240 }
241 }
242 }
243
244 if !negative_scores.is_empty() {
246 let max_neg_score = negative_scores
247 .iter()
248 .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
249 let loss_gradient = positive_score - max_neg_score;
250
251 let gradient_factor = if loss_gradient.abs() < 0.001 {
253 0.01 } else {
255 (loss_gradient / (1.0 + loss_gradient.abs())).clamp(-0.1, 0.1)
256 };
257
258 let update_factor = learning_rate * gradient_factor;
260
261 if let Some(embedding1) = self.entity_embeddings.get_mut(entity1) {
263 Zip::from(embedding1).and(&emb2).for_each(|e1, &e2| {
264 *e1 += e2 * update_factor;
265 });
266 }
267
268 if let Some(embedding2) = self.entity_embeddings.get_mut(entity2) {
269 Zip::from(embedding2).and(&emb1).for_each(|e2, &e1| {
270 *e2 += e1 * update_factor;
271 });
272 }
273
274 self.training_stats.contrastive_loss = loss_gradient.abs();
276 }
277 }
278 }
279 }
280 }
281
282 Ok(())
283 }
284
285 fn apply_regularization(&mut self) -> Result<()> {
287 let reg_strength = 0.01;
288 let mut total_reg_loss = 0.0;
289
290 for (_, embedding) in self.entity_embeddings.iter_mut() {
292 let reg_loss = embedding.mapv(|x| x * x).sum() * reg_strength;
293 total_reg_loss += reg_loss;
294
295 *embedding = embedding.mapv(|x| x * (1.0 - reg_strength));
297 }
298
299 for (_, embedding) in self.relation_embeddings.iter_mut() {
301 let reg_loss = embedding.mapv(|x| x * x).sum() * reg_strength;
302 total_reg_loss += reg_loss;
303
304 *embedding = embedding.mapv(|x| x * (1.0 - reg_strength));
306 }
307
308 self.training_stats.regularization_loss = total_reg_loss;
309 Ok(())
310 }
311
312 fn update_training_stats(&mut self) -> Result<()> {
314 let mut entity_norm_sum = 0.0;
316 let mut entity_count = 0;
317
318 for embedding in self.entity_embeddings.values() {
319 entity_norm_sum += embedding.mapv(|x| x * x).sum().sqrt();
320 entity_count += 1;
321 }
322
323 if entity_count > 0 {
324 let _avg_entity_norm = entity_norm_sum / entity_count as f32;
325 }
327
328 Ok(())
329 }
330
331 pub fn get_training_stats(&self) -> &TransformerTrainingStats {
333 &self.training_stats
334 }
335
336 pub fn get_entity_embeddings(&self) -> &HashMap<String, Array1<f32>> {
338 &self.entity_embeddings
339 }
340
341 pub fn get_relation_embeddings(&self) -> &HashMap<String, Array1<f32>> {
343 &self.relation_embeddings
344 }
345
346 pub fn set_entity_embedding(&mut self, entity: String, embedding: Array1<f32>) {
348 self.entity_embeddings.insert(entity, embedding);
349 }
350
351 pub fn setrelation_embedding(&mut self, relation: String, embedding: Array1<f32>) {
353 self.relation_embeddings.insert(relation, embedding);
354 }
355
356 pub fn is_trained(&self) -> bool {
358 !self.entity_embeddings.is_empty() && !self.relation_embeddings.is_empty()
359 }
360
361 pub fn reset(&mut self) {
363 self.entity_embeddings.clear();
364 self.relation_embeddings.clear();
365 self.entity_to_idx.clear();
366 self.relation_to_idx.clear();
367 self.model_weights = None;
368 self.training_stats = TransformerTrainingStats::default();
369 }
370
371 pub fn get_config(&self) -> &TransformerConfig {
373 &self.config
374 }
375
376 pub fn update_config(&mut self, config: TransformerConfig) {
378 self.config = config;
379 }
380}
381
382#[derive(Debug, Clone)]
384pub struct LearningRateScheduler {
385 initial_lr: f32,
386 schedule_type: String,
387 warmup_steps: usize,
388 current_step: usize,
389}
390
391impl LearningRateScheduler {
392 pub fn new(initial_lr: f32, schedule_type: String, warmup_steps: usize) -> Self {
393 Self {
394 initial_lr,
395 schedule_type,
396 warmup_steps,
397 current_step: 0,
398 }
399 }
400
401 pub fn get_learning_rate(&self) -> f32 {
402 match self.schedule_type.as_str() {
403 "linear" => self.linear_schedule(),
404 "cosine" => self.cosine_schedule(),
405 "polynomial" => self.polynomial_schedule(),
406 _ => self.initial_lr,
407 }
408 }
409
410 fn linear_schedule(&self) -> f32 {
411 if self.current_step < self.warmup_steps {
412 self.initial_lr * (self.current_step as f32 / self.warmup_steps as f32)
413 } else {
414 self.initial_lr
415 * (1.0 - (self.current_step - self.warmup_steps) as f32 / 10000.0).max(0.1)
416 }
417 }
418
419 fn cosine_schedule(&self) -> f32 {
420 if self.current_step < self.warmup_steps {
421 self.initial_lr * (self.current_step as f32 / self.warmup_steps as f32)
422 } else {
423 let progress = (self.current_step - self.warmup_steps) as f32 / 10000.0;
424 self.initial_lr * 0.5 * (1.0 + (std::f32::consts::PI * progress).cos())
425 }
426 }
427
428 fn polynomial_schedule(&self) -> f32 {
429 if self.current_step < self.warmup_steps {
430 self.initial_lr * (self.current_step as f32 / self.warmup_steps as f32)
431 } else {
432 let progress = (self.current_step - self.warmup_steps) as f32 / 10000.0;
433 self.initial_lr * (1.0 - progress).powf(2.0).max(0.1)
434 }
435 }
436
437 pub fn step(&mut self) {
438 self.current_step += 1;
439 }
440
441 pub fn reset(&mut self) {
442 self.current_step = 0;
443 }
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449
450 #[tokio::test]
451 async fn test_trainer_initialization() {
452 let config = TransformerConfig::default();
453 let mut trainer = TransformerTrainer::new(config);
454
455 assert!(trainer.initialize_weights(1000, 768).is_ok());
456 assert!(!trainer.is_trained());
457 }
458
459 #[tokio::test]
460 async fn test_contrastive_learning() {
461 let config = TransformerConfig::default();
462 let mut trainer = TransformerTrainer::new(config);
463
464 let emb1 = Array1::from_vec(vec![1.0, 0.0, 0.0]);
466 let emb2 = Array1::from_vec(vec![0.0, 1.0, 0.0]);
467
468 trainer.set_entity_embedding("entity1".to_string(), emb1);
469 trainer.set_entity_embedding("entity2".to_string(), emb2);
470
471 assert!(trainer.contrastive_learning(3).await.is_ok());
472 }
473
474 #[test]
475 fn test_learning_rate_scheduler() {
476 let mut scheduler = LearningRateScheduler::new(0.001, "linear".to_string(), 100);
477
478 let lr_start = scheduler.get_learning_rate();
480 assert_eq!(lr_start, 0.0);
481
482 scheduler.step();
483 let lr_warmup = scheduler.get_learning_rate();
484 assert!(lr_warmup > 0.0 && lr_warmup < 0.001);
485
486 scheduler.current_step = 100;
488 let lr_end_warmup = scheduler.get_learning_rate();
489 assert_eq!(lr_end_warmup, 0.001);
490 }
491}