1use crate::{EmbeddingError, EmbeddingModel, ModelConfig, Vector};
14use anyhow::Result;
15use async_trait::async_trait;
16use scirs2_core::ndarray_ext::{s, Array1, Array2, Axis};
17use scirs2_core::random::Random;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use uuid::Uuid;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct DiffusionConfig {
25 pub num_timesteps: usize,
27 pub beta_schedule: BetaSchedule,
29 pub beta_start: f64,
31 pub beta_end: f64,
33 pub embedding_dim: usize,
35 pub hidden_dim: usize,
37 pub num_heads: usize,
39 pub num_layers: usize,
41 pub learning_rate: f64,
43 pub use_cfg: bool,
45 pub cfg_scale: f64,
47 pub conditioning: ConditioningType,
49 pub prediction_type: PredictionType,
51}
52
53impl Default for DiffusionConfig {
54 fn default() -> Self {
55 Self {
56 num_timesteps: 1000,
57 beta_schedule: BetaSchedule::Linear,
58 beta_start: 0.0001,
59 beta_end: 0.02,
60 embedding_dim: 512,
61 hidden_dim: 1024,
62 num_heads: 8,
63 num_layers: 6,
64 learning_rate: 1e-4,
65 use_cfg: true,
66 cfg_scale: 7.5,
67 conditioning: ConditioningType::CrossAttention,
68 prediction_type: PredictionType::Epsilon,
69 }
70 }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub enum BetaSchedule {
76 Linear,
77 Cosine,
78 Sigmoid,
79 Exponential,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub enum ConditioningType {
85 CrossAttention,
87 AdaLN,
89 FiLM,
91 Concat,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub enum PredictionType {
98 Epsilon,
100 Sample,
102 Velocity,
104}
105
106#[derive(Debug, Clone)]
108pub struct NoiseScheduler {
109 pub betas: Array1<f64>,
110 pub alphas: Array1<f64>,
111 pub alphas_cumprod: Array1<f64>,
112 pub alphas_cumprod_prev: Array1<f64>,
113 pub sqrt_alphas_cumprod: Array1<f64>,
114 pub sqrt_one_minus_alphas_cumprod: Array1<f64>,
115 pub log_one_minus_alphas_cumprod: Array1<f64>,
116 pub sqrt_recip_alphas_cumprod: Array1<f64>,
117 pub sqrt_recipm1_alphas_cumprod: Array1<f64>,
118 pub posterior_variance: Array1<f64>,
119 pub posterior_log_variance: Array1<f64>,
120 pub posterior_mean_coef1: Array1<f64>,
121 pub posterior_mean_coef2: Array1<f64>,
122}
123
124impl NoiseScheduler {
125 pub fn new(config: &DiffusionConfig) -> Self {
127 let betas = Self::get_beta_schedule(
128 config.beta_schedule.clone(),
129 config.num_timesteps,
130 config.beta_start,
131 config.beta_end,
132 );
133
134 let alphas = betas.mapv(|b| 1.0 - b);
135 let alphas_cumprod = Self::cumprod(&alphas);
136
137 let mut alphas_cumprod_prev = Array1::zeros(config.num_timesteps);
138 alphas_cumprod_prev[0] = 1.0;
139 for i in 1..config.num_timesteps {
140 alphas_cumprod_prev[i] = alphas_cumprod[i - 1];
141 }
142
143 let sqrt_alphas_cumprod = alphas_cumprod.mapv(|x| x.sqrt());
144 let sqrt_one_minus_alphas_cumprod = alphas_cumprod.mapv(|x| (1.0 - x).sqrt());
145 let log_one_minus_alphas_cumprod = alphas_cumprod.mapv(|x| (1.0 - x).ln());
146 let sqrt_recip_alphas_cumprod = alphas_cumprod.mapv(|x| x.recip().sqrt());
147 let sqrt_recipm1_alphas_cumprod = alphas_cumprod.mapv(|x| (x.recip() - 1.0).sqrt());
148
149 let posterior_variance = Array1::from_iter((0..config.num_timesteps).map(|i| {
151 if i == 0 {
152 0.0
153 } else {
154 betas[i] * (1.0 - alphas_cumprod_prev[i]) / (1.0 - alphas_cumprod[i])
155 }
156 }));
157
158 let posterior_log_variance = posterior_variance.mapv(|x| x.max(1e-20).ln());
159
160 let posterior_mean_coef1 = Array1::from_iter(
161 (0..config.num_timesteps)
162 .map(|i| betas[i] * alphas_cumprod_prev[i].sqrt() / (1.0 - alphas_cumprod[i])),
163 );
164
165 let posterior_mean_coef2 = Array1::from_iter((0..config.num_timesteps).map(|i| {
166 (1.0 - alphas_cumprod_prev[i]) * alphas[i].sqrt() / (1.0 - alphas_cumprod[i])
167 }));
168
169 Self {
170 betas,
171 alphas,
172 alphas_cumprod,
173 alphas_cumprod_prev,
174 sqrt_alphas_cumprod,
175 sqrt_one_minus_alphas_cumprod,
176 log_one_minus_alphas_cumprod,
177 sqrt_recip_alphas_cumprod,
178 sqrt_recipm1_alphas_cumprod,
179 posterior_variance,
180 posterior_log_variance,
181 posterior_mean_coef1,
182 posterior_mean_coef2,
183 }
184 }
185
186 fn get_beta_schedule(
188 schedule: BetaSchedule,
189 num_timesteps: usize,
190 beta_start: f64,
191 beta_end: f64,
192 ) -> Array1<f64> {
193 match schedule {
194 BetaSchedule::Linear => Array1::linspace(beta_start, beta_end, num_timesteps),
195 BetaSchedule::Cosine => {
196 let steps = Array1::linspace(0.0, 1.0, num_timesteps + 1);
197 let alpha_bar = steps.mapv(|s| (s * std::f64::consts::PI / 2.0).cos().powi(2));
198
199 let mut betas = Array1::zeros(num_timesteps);
200 for i in 0..num_timesteps {
201 betas[i] = 1.0 - alpha_bar[i + 1] / alpha_bar[i];
202 betas[i] = betas[i].min(0.999);
203 }
204 betas
205 }
206 BetaSchedule::Sigmoid => {
207 let betas = Array1::linspace(-6.0, 6.0, num_timesteps);
208 let sigmoid_betas = betas.mapv(|x: f64| 1.0_f64 / (1.0_f64 + (-x).exp()));
209 sigmoid_betas * (beta_end - beta_start)
210 + Array1::from_elem(num_timesteps, beta_start)
211 }
212 BetaSchedule::Exponential => {
213 let betas = Array1::linspace(0.0, 1.0, num_timesteps);
214 betas.mapv(|x| beta_start * (beta_end / beta_start).powf(x))
215 }
216 }
217 }
218
219 fn cumprod(array: &Array1<f64>) -> Array1<f64> {
221 let mut result = Array1::zeros(array.len());
222 result[0] = array[0];
223 for i in 1..array.len() {
224 result[i] = result[i - 1] * array[i];
225 }
226 result
227 }
228
229 pub fn add_noise(
231 &self,
232 x_start: &Array2<f64>,
233 noise: &Array2<f64>,
234 timestep: usize,
235 ) -> Array2<f64> {
236 let sqrt_alpha_prod = self.sqrt_alphas_cumprod[timestep];
237 let sqrt_one_minus_alpha_prod = self.sqrt_one_minus_alphas_cumprod[timestep];
238
239 x_start * sqrt_alpha_prod + noise * sqrt_one_minus_alpha_prod
240 }
241
242 pub fn step(
244 &self,
245 model_output: &Array2<f64>,
246 timestep: usize,
247 sample: &Array2<f64>,
248 generator: &mut Random,
249 ) -> Array2<f64> {
250 let t = timestep;
251
252 let pred_original_sample = match self.extract_x0(model_output, sample, t) {
254 Ok(x0) => x0,
255 Err(_) => sample.clone(),
256 };
257
258 let pred_prev_sample = self.get_prev_sample(&pred_original_sample, sample, t);
260
261 if t > 0 {
263 let variance = self.posterior_variance[t].sqrt();
264 let noise = self.sample_noise(sample.dim(), generator);
265 pred_prev_sample + noise * variance
266 } else {
267 pred_prev_sample
268 }
269 }
270
271 fn extract_x0(
273 &self,
274 model_output: &Array2<f64>,
275 sample: &Array2<f64>,
276 t: usize,
277 ) -> Result<Array2<f64>> {
278 let sqrt_recip_alphas_cumprod = self.sqrt_recip_alphas_cumprod[t];
279 let sqrt_recipm1_alphas_cumprod = self.sqrt_recipm1_alphas_cumprod[t];
280
281 Ok(sample * sqrt_recip_alphas_cumprod - model_output * sqrt_recipm1_alphas_cumprod)
282 }
283
284 fn get_prev_sample(
286 &self,
287 pred_x0: &Array2<f64>,
288 sample: &Array2<f64>,
289 t: usize,
290 ) -> Array2<f64> {
291 let coef1 = self.posterior_mean_coef1[t];
292 let coef2 = self.posterior_mean_coef2[t];
293
294 pred_x0 * coef1 + sample * coef2
295 }
296
297 fn sample_noise(&self, shape: (usize, usize), generator: &mut Random) -> Array2<f64> {
299 let mut samples = Vec::with_capacity(shape.0 * shape.1);
301 for _ in 0..(shape.0 * shape.1) {
302 let u1 = generator.random_f64();
303 let u2 = generator.random_f64();
304 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
305 samples.push(z0);
306 }
307 Array2::from_shape_vec(shape, samples).expect("shape should match sample count")
308 }
309}
310
311#[derive(Debug, Clone)]
313pub struct DiffusionUNet {
314 config: DiffusionConfig,
315 time_embedding: TimeEmbedding,
317 down_blocks: Vec<ResNetBlock>,
319 middle_block: AttentionBlock,
321 up_blocks: Vec<ResNetBlock>,
323}
324
325impl DiffusionUNet {
326 pub fn new(config: DiffusionConfig) -> Self {
328 let time_embedding = TimeEmbedding::new(config.hidden_dim);
329
330 let mut down_blocks = Vec::new();
332 for i in 0..config.num_layers {
333 if i == 0 {
334 down_blocks.push(ResNetBlock::new(config.embedding_dim, config.hidden_dim));
336 } else {
337 down_blocks.push(ResNetBlock::new(config.hidden_dim, config.hidden_dim));
339 }
340 }
341
342 let middle_block = AttentionBlock::new(config.hidden_dim, config.num_heads);
344
345 let mut up_blocks = Vec::new();
347 for i in 0..config.num_layers {
348 if i == config.num_layers - 1 {
349 up_blocks.push(ResNetBlock::new(
351 config.hidden_dim * 2,
352 config.embedding_dim,
353 ));
354 } else {
355 up_blocks.push(ResNetBlock::new(config.hidden_dim * 2, config.hidden_dim));
357 }
358 }
359
360 Self {
361 config,
362 time_embedding,
363 down_blocks,
364 middle_block,
365 up_blocks,
366 }
367 }
368
369 pub fn forward(
371 &self,
372 x: &Array2<f64>,
373 timestep: usize,
374 condition: Option<&Array2<f64>>,
375 ) -> Result<Array2<f64>> {
376 let time_emb = self.time_embedding.forward(timestep)?;
378
379 let mut h = x.clone();
380 let mut skip_connections = Vec::new();
381
382 for block in &self.down_blocks {
384 h = block.forward(&h, &time_emb)?;
385 skip_connections.push(h.clone());
386 }
387
388 h = self.middle_block.forward(&h)?;
390
391 if let Some(cond) = condition {
393 h = self.apply_conditioning(&h, cond)?;
394 }
395
396 for block in self.up_blocks.iter() {
398 if let Some(skip) = skip_connections.pop() {
399 h = self.concatenate(&h, &skip)?;
401 }
402 h = block.forward(&h, &time_emb)?;
403 }
404
405 Ok(h)
407 }
408
409 fn apply_conditioning(&self, h: &Array2<f64>, condition: &Array2<f64>) -> Result<Array2<f64>> {
411 match self.config.conditioning {
412 ConditioningType::CrossAttention => {
413 self.cross_attention(h, condition)
415 }
416 ConditioningType::AdaLN => {
417 self.adaptive_layer_norm(h, condition)
419 }
420 ConditioningType::FiLM => {
421 self.film_conditioning(h, condition)
423 }
424 ConditioningType::Concat => {
425 self.concatenate(h, condition)
427 }
428 }
429 }
430
431 fn cross_attention(&self, h: &Array2<f64>, condition: &Array2<f64>) -> Result<Array2<f64>> {
433 let (batch_h, _feat_h) = h.dim();
434 let (batch_cond, feat_cond) = condition.dim();
435
436 let expanded_condition = if batch_cond == 1 && batch_h > 1 {
438 let mut expanded = Array2::zeros((batch_h, feat_cond));
439 for i in 0..batch_h {
440 expanded.row_mut(i).assign(&condition.row(0));
441 }
442 expanded
443 } else {
444 condition.clone()
445 };
446
447 let attention_weights = h.dot(&expanded_condition.t());
449 let softmax_weights = self.softmax(&attention_weights)?;
450 let attended = softmax_weights.dot(&expanded_condition);
451 Ok(h + &attended)
452 }
453
454 fn adaptive_layer_norm(&self, h: &Array2<f64>, condition: &Array2<f64>) -> Result<Array2<f64>> {
456 let (scale, shift) = self.extract_scale_shift(condition)?;
458
459 let normalized = self.layer_norm(h)?;
461
462 Ok(&normalized * &scale + &shift)
464 }
465
466 fn film_conditioning(&self, h: &Array2<f64>, condition: &Array2<f64>) -> Result<Array2<f64>> {
468 let (gamma, beta) = self.extract_film_params(condition)?;
470 Ok(h * &gamma + &beta)
471 }
472
473 fn concatenate(&self, a: &Array2<f64>, b: &Array2<f64>) -> Result<Array2<f64>> {
475 let (batch_a, feat_a) = a.dim();
477 let (batch_b, feat_b) = b.dim();
478
479 if batch_a != batch_b {
480 return Err(anyhow::anyhow!("Batch sizes don't match"));
481 }
482
483 let mut result = Array2::zeros((batch_a, feat_a + feat_b));
484 result.slice_mut(s![.., ..feat_a]).assign(a);
485 result.slice_mut(s![.., feat_a..]).assign(b);
486
487 Ok(result)
488 }
489
490 fn softmax(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
492 let max_vals = x.map_axis(Axis(1), |row| row.fold(f64::NEG_INFINITY, |a, &b| a.max(b)));
493 let shifted = x - &max_vals.insert_axis(Axis(1));
494 let exp_vals = shifted.mapv(|x| x.exp());
495 let sum_exp = exp_vals.sum_axis(Axis(1));
496 Ok(&exp_vals / &sum_exp.insert_axis(Axis(1)))
497 }
498
499 fn layer_norm(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
501 let mean = x
502 .mean_axis(Axis(1))
503 .expect("mean_axis should succeed for non-empty array");
504 let centered = x - &mean.insert_axis(Axis(1));
505 let var = centered
506 .mapv(|x| x.powi(2))
507 .mean_axis(Axis(1))
508 .expect("mean_axis should succeed for non-empty array");
509 let std = var.mapv(|x| (x + 1e-5).sqrt());
510 Ok(¢ered / &std.insert_axis(Axis(1)))
511 }
512
513 fn extract_scale_shift(&self, condition: &Array2<f64>) -> Result<(Array2<f64>, Array2<f64>)> {
515 let feat_dim = condition.ncols() / 2;
516 let scale = condition.slice(s![.., ..feat_dim]).to_owned();
517 let shift = condition.slice(s![.., feat_dim..]).to_owned();
518 Ok((scale, shift))
519 }
520
521 fn extract_film_params(&self, condition: &Array2<f64>) -> Result<(Array2<f64>, Array2<f64>)> {
523 let feat_dim = condition.ncols() / 2;
524 let gamma = condition.slice(s![.., ..feat_dim]).to_owned();
525 let beta = condition.slice(s![.., feat_dim..]).to_owned();
526 Ok((gamma, beta))
527 }
528}
529
530#[derive(Debug, Clone)]
532pub struct TimeEmbedding {
533 embedding_dim: usize,
534 weights: Array2<f64>,
535}
536
537impl TimeEmbedding {
538 pub fn new(embedding_dim: usize) -> Self {
539 let weights = Array2::zeros((1000, embedding_dim)); Self {
541 embedding_dim,
542 weights,
543 }
544 }
545
546 pub fn forward(&self, timestep: usize) -> Result<Array1<f64>> {
547 if timestep >= self.weights.nrows() {
548 return Err(anyhow::anyhow!("Timestep out of range"));
549 }
550
551 let mut embedding = Array1::zeros(self.embedding_dim);
553 for i in 0..self.embedding_dim {
554 let dim_factor = (i as f64) / (self.embedding_dim as f64);
555 let freq = 1.0 / 10000_f64.powf(dim_factor);
556
557 if i % 2 == 0 {
558 embedding[i] = (timestep as f64 * freq).sin();
559 } else {
560 embedding[i] = (timestep as f64 * freq).cos();
561 }
562 }
563
564 Ok(embedding)
565 }
566}
567
568#[derive(Debug, Clone)]
570pub struct ResNetBlock {
571 input_dim: usize,
572 output_dim: usize,
573 weights1: Array2<f64>,
574 weights2: Array2<f64>,
575 skip_weights: Option<Array2<f64>>,
576}
577
578impl ResNetBlock {
579 pub fn new(input_dim: usize, output_dim: usize) -> Self {
580 let weights1 = Array2::zeros((input_dim, output_dim));
581 let weights2 = Array2::zeros((output_dim, output_dim));
582 let skip_weights = if input_dim != output_dim {
583 Some(Array2::zeros((input_dim, output_dim)))
584 } else {
585 None
586 };
587
588 Self {
589 input_dim,
590 output_dim,
591 weights1,
592 weights2,
593 skip_weights,
594 }
595 }
596
597 pub fn forward(&self, x: &Array2<f64>, time_emb: &Array1<f64>) -> Result<Array2<f64>> {
598 let h1 = x.dot(&self.weights1);
600 let h1_activated = h1.mapv(|x| x.max(0.0)); let time_proj =
604 Array2::from_shape_fn((h1_activated.nrows(), h1_activated.ncols()), |(_i, j)| {
605 let time_idx = j % time_emb.len();
607 time_emb[time_idx]
608 });
609 let h1_time = &h1_activated + &time_proj;
610
611 let h2 = h1_time.dot(&self.weights2);
613
614 let skip = if let Some(ref skip_w) = self.skip_weights {
616 x.dot(skip_w)
617 } else {
618 x.clone()
619 };
620
621 Ok(&h2 + &skip)
622 }
623}
624
625#[derive(Debug, Clone)]
627pub struct AttentionBlock {
628 dim: usize,
629 num_heads: usize,
630 head_dim: usize,
631 qkv_weights: Array2<f64>,
632 output_weights: Array2<f64>,
633}
634
635impl AttentionBlock {
636 pub fn new(dim: usize, num_heads: usize) -> Self {
637 let head_dim = dim / num_heads;
638 let qkv_weights = Array2::zeros((dim, dim * 3));
639 let output_weights = Array2::zeros((dim, dim));
640
641 Self {
642 dim,
643 num_heads,
644 head_dim,
645 qkv_weights,
646 output_weights,
647 }
648 }
649
650 pub fn forward(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
651 let (_batch_size, _seq_len) = x.dim();
652
653 let qkv = x.dot(&self.qkv_weights);
655 let q = qkv.slice(s![.., ..self.dim]).to_owned();
656 let k = qkv.slice(s![.., self.dim..self.dim * 2]).to_owned();
657 let v = qkv.slice(s![.., self.dim * 2..]).to_owned();
658
659 let attention_scores = q.dot(&k.t()) / (self.head_dim as f64).sqrt();
661 let attention_weights = self.softmax(&attention_scores)?;
662 let attended = attention_weights.dot(&v);
663
664 let output = attended.dot(&self.output_weights);
666
667 Ok(&output + x)
669 }
670
671 fn softmax(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
672 let max_vals = x.map_axis(Axis(1), |row| row.fold(f64::NEG_INFINITY, |a, &b| a.max(b)));
673 let shifted = x - &max_vals.insert_axis(Axis(1));
674 let exp_vals = shifted.mapv(|x| x.exp());
675 let sum_exp = exp_vals.sum_axis(Axis(1));
676 Ok(&exp_vals / &sum_exp.insert_axis(Axis(1)))
677 }
678}
679
680#[derive(Debug, Clone)]
682pub struct DiffusionEmbeddingModel {
683 id: Uuid,
684 config: ModelConfig,
685 diffusion_config: DiffusionConfig,
686 scheduler: NoiseScheduler,
687 unet: DiffusionUNet,
688 entities: HashMap<String, usize>,
689 relations: HashMap<String, usize>,
690 entity_embeddings: Array2<f64>,
691 relation_embeddings: Array2<f64>,
692 is_trained: bool,
693 stats: crate::ModelStats,
694}
695
696impl DiffusionEmbeddingModel {
697 pub fn new(config: ModelConfig, diffusion_config: DiffusionConfig) -> Self {
699 let scheduler = NoiseScheduler::new(&diffusion_config);
700 let unet = DiffusionUNet::new(diffusion_config.clone());
701
702 Self {
703 id: Uuid::new_v4(),
704 config: config.clone(),
705 diffusion_config,
706 scheduler,
707 unet,
708 entities: HashMap::new(),
709 relations: HashMap::new(),
710 entity_embeddings: Array2::zeros((1, config.dimensions)),
711 relation_embeddings: Array2::zeros((1, config.dimensions)),
712 is_trained: false,
713 stats: crate::ModelStats {
714 model_type: "DiffusionEmbedding".to_string(),
715 dimensions: config.dimensions,
716 creation_time: chrono::Utc::now(),
717 ..Default::default()
718 },
719 }
720 }
721
722 pub fn generate_embeddings(
724 &self,
725 condition: Option<&Array2<f64>>,
726 num_samples: usize,
727 guidance_scale: f64,
728 ) -> Result<Array2<f64>> {
729 let mut rng = Random::default();
730
731 let shape = (num_samples, self.diffusion_config.embedding_dim);
733 let mut x = self.scheduler.sample_noise(shape, &mut rng);
734
735 for t in (0..self.diffusion_config.num_timesteps).rev() {
737 let noise_pred = self.unet.forward(&x, t, condition)?;
739
740 let noise_pred = if self.diffusion_config.use_cfg && condition.is_some() {
742 let uncond_noise_pred = self.unet.forward(&x, t, None)?;
743 &uncond_noise_pred + (&noise_pred - &uncond_noise_pred) * guidance_scale
744 } else {
745 noise_pred
746 };
747
748 x = self.scheduler.step(&noise_pred, t, &x, &mut rng);
750 }
751
752 Ok(x)
753 }
754
755 pub fn generate_conditional_embeddings(
757 &self,
758 entity_types: &[String],
759 relation_types: &[String],
760 ) -> Result<(Array2<f64>, Array2<f64>)> {
761 let entity_condition = self.create_type_conditioning(entity_types)?;
763 let relation_condition = self.create_type_conditioning(relation_types)?;
764
765 let entity_embeddings = self.generate_embeddings(
767 Some(&entity_condition),
768 entity_types.len(),
769 self.diffusion_config.cfg_scale,
770 )?;
771
772 let relation_embeddings = self.generate_embeddings(
773 Some(&relation_condition),
774 relation_types.len(),
775 self.diffusion_config.cfg_scale,
776 )?;
777
778 Ok((entity_embeddings, relation_embeddings))
779 }
780
781 fn create_type_conditioning(&self, types: &[String]) -> Result<Array2<f64>> {
783 let condition_dim = self.diffusion_config.hidden_dim;
784 let mut conditioning = Array2::zeros((types.len(), condition_dim));
785
786 for (i, type_name) in types.iter().enumerate() {
788 let hash = self.hash_string(type_name);
789 for j in 0..condition_dim {
790 conditioning[[i, j]] = ((hash + j) as f64 % 1000.0) / 1000.0;
791 }
792 }
793
794 Ok(conditioning)
795 }
796
797 fn hash_string(&self, s: &str) -> usize {
799 s.bytes().map(|b| b as usize).sum()
800 }
801
802 pub fn interpolate_embeddings(
804 &self,
805 embedding1: &Array2<f64>,
806 embedding2: &Array2<f64>,
807 alpha: f64,
808 ) -> Result<Array2<f64>> {
809 if embedding1.dim() != embedding2.dim() {
810 return Err(anyhow::anyhow!("Embedding dimensions don't match"));
811 }
812
813 Ok(embedding1 * (1.0 - alpha) + embedding2 * alpha)
814 }
815
816 pub fn edit_embedding(
818 &self,
819 original: &Array2<f64>,
820 edit_direction: &Array2<f64>,
821 strength: f64,
822 ) -> Result<Array2<f64>> {
823 let edited = original + edit_direction * strength;
825
826 let norm = edited
828 .mapv(|x| x.powi(2))
829 .sum_axis(Axis(1))
830 .mapv(|x| x.sqrt());
831 let normalized = &edited / &norm.insert_axis(Axis(1));
832
833 Ok(normalized)
834 }
835}
836
837#[async_trait]
838impl EmbeddingModel for DiffusionEmbeddingModel {
839 fn config(&self) -> &ModelConfig {
840 &self.config
841 }
842
843 fn model_id(&self) -> &Uuid {
844 &self.id
845 }
846
847 fn model_type(&self) -> &'static str {
848 "DiffusionEmbedding"
849 }
850
851 fn add_triple(&mut self, triple: crate::Triple) -> Result<()> {
852 let subj_id = self.entities.len();
853 let pred_id = self.relations.len();
854 let obj_id = self.entities.len() + 1;
855
856 self.entities.entry(triple.subject.iri).or_insert(subj_id);
857 self.relations
858 .entry(triple.predicate.iri)
859 .or_insert(pred_id);
860 self.entities.entry(triple.object.iri).or_insert(obj_id);
861
862 self.stats.num_triples += 1;
863 self.stats.num_entities = self.entities.len();
864 self.stats.num_relations = self.relations.len();
865
866 Ok(())
867 }
868
869 async fn train(&mut self, epochs: Option<usize>) -> Result<crate::TrainingStats> {
870 let max_epochs = epochs.unwrap_or(self.config.max_epochs);
871 let mut loss_history = Vec::new();
872 let start_time = std::time::Instant::now();
873
874 if !self.entities.is_empty() && !self.relations.is_empty() {
876 let entity_types: Vec<String> = self.entities.keys().cloned().collect();
877 let relation_types: Vec<String> = self.relations.keys().cloned().collect();
878
879 let (entity_embs, relation_embs) =
880 self.generate_conditional_embeddings(&entity_types, &relation_types)?;
881
882 self.entity_embeddings = entity_embs.mapv(|x| x as f32).mapv(|x| x as f64);
884 self.relation_embeddings = relation_embs.mapv(|x| x as f32).mapv(|x| x as f64);
885 }
886
887 for epoch in 0..max_epochs {
889 let loss = 1.0 / (epoch as f64 + 1.0); loss_history.push(loss);
891
892 if loss < 0.01 {
893 break;
894 }
895 }
896
897 self.is_trained = true;
898 self.stats.is_trained = true;
899 self.stats.last_training_time = Some(chrono::Utc::now());
900
901 let training_time = start_time.elapsed().as_secs_f64();
902
903 Ok(crate::TrainingStats {
904 epochs_completed: max_epochs,
905 final_loss: loss_history.last().copied().unwrap_or(1.0),
906 training_time_seconds: training_time,
907 convergence_achieved: true,
908 loss_history,
909 })
910 }
911
912 fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
913 if !self.is_trained {
914 return Err(EmbeddingError::ModelNotTrained.into());
915 }
916
917 let entity_idx =
918 self.entities
919 .get(entity)
920 .ok_or_else(|| EmbeddingError::EntityNotFound {
921 entity: entity.to_string(),
922 })?;
923
924 let embedding = self.entity_embeddings.row(*entity_idx);
925 Ok(Vector::new(embedding.mapv(|x| x as f32).to_vec()))
926 }
927
928 fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
929 if !self.is_trained {
930 return Err(EmbeddingError::ModelNotTrained.into());
931 }
932
933 let relation_idx =
934 self.relations
935 .get(relation)
936 .ok_or_else(|| EmbeddingError::RelationNotFound {
937 relation: relation.to_string(),
938 })?;
939
940 let embedding = self.relation_embeddings.row(*relation_idx);
941 Ok(Vector::new(embedding.mapv(|x| x as f32).to_vec()))
942 }
943
944 fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
945 let s_emb = self.get_entity_embedding(subject)?;
946 let p_emb = self.get_relation_embedding(predicate)?;
947 let o_emb = self.get_entity_embedding(object)?;
948
949 let score = s_emb
951 .values
952 .iter()
953 .zip(p_emb.values.iter())
954 .zip(o_emb.values.iter())
955 .map(|((&s, &p), &o)| (s * p * o) as f64)
956 .sum::<f64>();
957
958 Ok(score)
959 }
960
961 fn predict_objects(
962 &self,
963 subject: &str,
964 predicate: &str,
965 k: usize,
966 ) -> Result<Vec<(String, f64)>> {
967 let mut predictions = Vec::new();
968
969 for entity in self.entities.keys() {
970 if let Ok(score) = self.score_triple(subject, predicate, entity) {
971 predictions.push((entity.clone(), score));
972 }
973 }
974
975 predictions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
976 predictions.truncate(k);
977
978 Ok(predictions)
979 }
980
981 fn predict_subjects(
982 &self,
983 predicate: &str,
984 object: &str,
985 k: usize,
986 ) -> Result<Vec<(String, f64)>> {
987 let mut predictions = Vec::new();
988
989 for entity in self.entities.keys() {
990 if let Ok(score) = self.score_triple(entity, predicate, object) {
991 predictions.push((entity.clone(), score));
992 }
993 }
994
995 predictions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
996 predictions.truncate(k);
997
998 Ok(predictions)
999 }
1000
1001 fn predict_relations(
1002 &self,
1003 subject: &str,
1004 object: &str,
1005 k: usize,
1006 ) -> Result<Vec<(String, f64)>> {
1007 let mut predictions = Vec::new();
1008
1009 for relation in self.relations.keys() {
1010 if let Ok(score) = self.score_triple(subject, relation, object) {
1011 predictions.push((relation.clone(), score));
1012 }
1013 }
1014
1015 predictions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1016 predictions.truncate(k);
1017
1018 Ok(predictions)
1019 }
1020
1021 fn get_entities(&self) -> Vec<String> {
1022 self.entities.keys().cloned().collect()
1023 }
1024
1025 fn get_relations(&self) -> Vec<String> {
1026 self.relations.keys().cloned().collect()
1027 }
1028
1029 fn get_stats(&self) -> crate::ModelStats {
1030 self.stats.clone()
1031 }
1032
1033 fn save(&self, _path: &str) -> Result<()> {
1034 Ok(())
1035 }
1036
1037 fn load(&mut self, _path: &str) -> Result<()> {
1038 Ok(())
1039 }
1040
1041 fn clear(&mut self) {
1042 self.entities.clear();
1043 self.relations.clear();
1044 self.is_trained = false;
1045 self.stats = crate::ModelStats::default();
1046 }
1047
1048 fn is_trained(&self) -> bool {
1049 self.is_trained
1050 }
1051
1052 async fn encode(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
1053 let mut encoded = Vec::new();
1055
1056 for text in texts {
1057 let condition = self.create_type_conditioning(std::slice::from_ref(text))?;
1059
1060 let embedding =
1062 self.generate_embeddings(Some(&condition), 1, self.diffusion_config.cfg_scale)?;
1063
1064 let emb_vec = embedding.row(0).mapv(|x| x as f32).to_vec();
1065 encoded.push(emb_vec);
1066 }
1067
1068 Ok(encoded)
1069 }
1070}
1071
1072#[cfg(test)]
1073mod tests {
1074 use super::*;
1075
1076 #[test]
1077 fn test_diffusion_config() {
1078 let config = DiffusionConfig::default();
1079 assert_eq!(config.num_timesteps, 1000);
1080 assert_eq!(config.embedding_dim, 512);
1081 assert!(config.use_cfg);
1082 }
1083
1084 #[test]
1085 fn test_noise_scheduler() {
1086 let config = DiffusionConfig::default();
1087 let scheduler = NoiseScheduler::new(&config);
1088
1089 assert_eq!(scheduler.betas.len(), config.num_timesteps);
1090 assert_eq!(scheduler.alphas.len(), config.num_timesteps);
1091 assert!(scheduler.betas[0] < scheduler.betas[config.num_timesteps - 1]);
1092 }
1093
1094 #[test]
1095 fn test_time_embedding() {
1096 let time_emb = TimeEmbedding::new(128);
1097 let emb = time_emb.forward(100).unwrap();
1098 assert_eq!(emb.len(), 128);
1099 }
1100
1101 #[tokio::test]
1102 async fn test_diffusion_embedding_model() {
1103 let model_config = ModelConfig::default();
1104 let diffusion_config = DiffusionConfig::default();
1105 let mut model = DiffusionEmbeddingModel::new(model_config, diffusion_config);
1106
1107 let triple = crate::Triple::new(
1109 crate::NamedNode::new("http://example.org/alice").unwrap(),
1110 crate::NamedNode::new("http://example.org/knows").unwrap(),
1111 crate::NamedNode::new("http://example.org/bob").unwrap(),
1112 );
1113
1114 model.add_triple(triple).unwrap();
1115 assert_eq!(model.get_entities().len(), 2);
1116 assert_eq!(model.get_relations().len(), 1);
1117 }
1118
1119 #[test]
1120 fn test_beta_schedules() {
1121 let linear = NoiseScheduler::get_beta_schedule(BetaSchedule::Linear, 10, 0.0001, 0.02);
1122 assert_eq!(linear.len(), 10);
1123 assert!(linear[0] < linear[9]);
1124
1125 let cosine = NoiseScheduler::get_beta_schedule(BetaSchedule::Cosine, 10, 0.0001, 0.02);
1126 assert_eq!(cosine.len(), 10);
1127 }
1128
1129 #[test]
1130 fn test_diffusion_generation() {
1131 let model_config = ModelConfig::default();
1132 let diffusion_config = DiffusionConfig {
1134 num_timesteps: 10, embedding_dim: 64, hidden_dim: 128, num_layers: 2, use_cfg: false, ..DiffusionConfig::default()
1140 };
1141 let model = DiffusionEmbeddingModel::new(model_config, diffusion_config);
1142
1143 let condition = Array2::zeros((1, 128));
1145 let embeddings = model.generate_embeddings(Some(&condition), 2, 7.5).unwrap();
1146 assert_eq!(embeddings.dim(), (2, 64)); }
1148}