oxirs_embed/
diffusion_embeddings.rs

1//! Diffusion Model-Based Knowledge Graph Embeddings
2//!
3//! This module implements cutting-edge diffusion models for generating high-quality
4//! knowledge graph embeddings. Based on denoising diffusion probabilistic models (DDPMs)
5//! and score-based generative models for embedding generation.
6//!
7//! Key innovations:
8//! - Controllable embedding generation through conditioning
9//! - High-quality embedding synthesis with noise scheduling
10//! - Knowledge graph structure-aware diffusion processes
11//! - Multi-scale embedding generation with hierarchical diffusion
12
13use crate::{EmbeddingError, EmbeddingModel, ModelConfig, Vector};
14use anyhow::Result;
15use async_trait::async_trait;
16use scirs2_core::ndarray_ext::{s, Array1, Array2, Axis};
17use scirs2_core::random::Random;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use uuid::Uuid;
21
22/// Configuration for diffusion-based embeddings
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct DiffusionConfig {
25    /// Number of diffusion timesteps
26    pub num_timesteps: usize,
27    /// Beta schedule type
28    pub beta_schedule: BetaSchedule,
29    /// Beta start value
30    pub beta_start: f64,
31    /// Beta end value
32    pub beta_end: f64,
33    /// Embedding dimension
34    pub embedding_dim: usize,
35    /// Hidden dimension for U-Net
36    pub hidden_dim: usize,
37    /// Number of attention heads
38    pub num_heads: usize,
39    /// Number of U-Net layers
40    pub num_layers: usize,
41    /// Learning rate for diffusion training
42    pub learning_rate: f64,
43    /// Use classifier-free guidance
44    pub use_cfg: bool,
45    /// Classifier-free guidance scale
46    pub cfg_scale: f64,
47    /// Conditioning mechanism
48    pub conditioning: ConditioningType,
49    /// Noise prediction method
50    pub prediction_type: PredictionType,
51}
52
53impl Default for DiffusionConfig {
54    fn default() -> Self {
55        Self {
56            num_timesteps: 1000,
57            beta_schedule: BetaSchedule::Linear,
58            beta_start: 0.0001,
59            beta_end: 0.02,
60            embedding_dim: 512,
61            hidden_dim: 1024,
62            num_heads: 8,
63            num_layers: 6,
64            learning_rate: 1e-4,
65            use_cfg: true,
66            cfg_scale: 7.5,
67            conditioning: ConditioningType::CrossAttention,
68            prediction_type: PredictionType::Epsilon,
69        }
70    }
71}
72
73/// Beta schedule types for noise scheduling
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub enum BetaSchedule {
76    Linear,
77    Cosine,
78    Sigmoid,
79    Exponential,
80}
81
82/// Conditioning types for controlled generation
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub enum ConditioningType {
85    /// Cross-attention based conditioning
86    CrossAttention,
87    /// AdaLN (Adaptive Layer Normalization)
88    AdaLN,
89    /// FiLM (Feature-wise Linear Modulation)
90    FiLM,
91    /// Concatenation-based conditioning
92    Concat,
93}
94
95/// Types of noise prediction
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub enum PredictionType {
98    /// Predict noise (epsilon)
99    Epsilon,
100    /// Predict denoised sample (x0)
101    Sample,
102    /// Predict velocity (v-parameterization)
103    Velocity,
104}
105
106/// Noise scheduler for diffusion process
107#[derive(Debug, Clone)]
108pub struct NoiseScheduler {
109    pub betas: Array1<f64>,
110    pub alphas: Array1<f64>,
111    pub alphas_cumprod: Array1<f64>,
112    pub alphas_cumprod_prev: Array1<f64>,
113    pub sqrt_alphas_cumprod: Array1<f64>,
114    pub sqrt_one_minus_alphas_cumprod: Array1<f64>,
115    pub log_one_minus_alphas_cumprod: Array1<f64>,
116    pub sqrt_recip_alphas_cumprod: Array1<f64>,
117    pub sqrt_recipm1_alphas_cumprod: Array1<f64>,
118    pub posterior_variance: Array1<f64>,
119    pub posterior_log_variance: Array1<f64>,
120    pub posterior_mean_coef1: Array1<f64>,
121    pub posterior_mean_coef2: Array1<f64>,
122}
123
124impl NoiseScheduler {
125    /// Create a new noise scheduler
126    pub fn new(config: &DiffusionConfig) -> Self {
127        let betas = Self::get_beta_schedule(
128            config.beta_schedule.clone(),
129            config.num_timesteps,
130            config.beta_start,
131            config.beta_end,
132        );
133
134        let alphas = betas.mapv(|b| 1.0 - b);
135        let alphas_cumprod = Self::cumprod(&alphas);
136
137        let mut alphas_cumprod_prev = Array1::zeros(config.num_timesteps);
138        alphas_cumprod_prev[0] = 1.0;
139        for i in 1..config.num_timesteps {
140            alphas_cumprod_prev[i] = alphas_cumprod[i - 1];
141        }
142
143        let sqrt_alphas_cumprod = alphas_cumprod.mapv(|x| x.sqrt());
144        let sqrt_one_minus_alphas_cumprod = alphas_cumprod.mapv(|x| (1.0 - x).sqrt());
145        let log_one_minus_alphas_cumprod = alphas_cumprod.mapv(|x| (1.0 - x).ln());
146        let sqrt_recip_alphas_cumprod = alphas_cumprod.mapv(|x| x.recip().sqrt());
147        let sqrt_recipm1_alphas_cumprod = alphas_cumprod.mapv(|x| (x.recip() - 1.0).sqrt());
148
149        // Posterior variance
150        let posterior_variance = Array1::from_iter((0..config.num_timesteps).map(|i| {
151            if i == 0 {
152                0.0
153            } else {
154                betas[i] * (1.0 - alphas_cumprod_prev[i]) / (1.0 - alphas_cumprod[i])
155            }
156        }));
157
158        let posterior_log_variance = posterior_variance.mapv(|x| x.max(1e-20).ln());
159
160        let posterior_mean_coef1 = Array1::from_iter(
161            (0..config.num_timesteps)
162                .map(|i| betas[i] * alphas_cumprod_prev[i].sqrt() / (1.0 - alphas_cumprod[i])),
163        );
164
165        let posterior_mean_coef2 = Array1::from_iter((0..config.num_timesteps).map(|i| {
166            (1.0 - alphas_cumprod_prev[i]) * alphas[i].sqrt() / (1.0 - alphas_cumprod[i])
167        }));
168
169        Self {
170            betas,
171            alphas,
172            alphas_cumprod,
173            alphas_cumprod_prev,
174            sqrt_alphas_cumprod,
175            sqrt_one_minus_alphas_cumprod,
176            log_one_minus_alphas_cumprod,
177            sqrt_recip_alphas_cumprod,
178            sqrt_recipm1_alphas_cumprod,
179            posterior_variance,
180            posterior_log_variance,
181            posterior_mean_coef1,
182            posterior_mean_coef2,
183        }
184    }
185
186    /// Generate beta schedule
187    fn get_beta_schedule(
188        schedule: BetaSchedule,
189        num_timesteps: usize,
190        beta_start: f64,
191        beta_end: f64,
192    ) -> Array1<f64> {
193        match schedule {
194            BetaSchedule::Linear => Array1::linspace(beta_start, beta_end, num_timesteps),
195            BetaSchedule::Cosine => {
196                let steps = Array1::linspace(0.0, 1.0, num_timesteps + 1);
197                let alpha_bar = steps.mapv(|s| (s * std::f64::consts::PI / 2.0).cos().powi(2));
198
199                let mut betas = Array1::zeros(num_timesteps);
200                for i in 0..num_timesteps {
201                    betas[i] = 1.0 - alpha_bar[i + 1] / alpha_bar[i];
202                    betas[i] = betas[i].min(0.999);
203                }
204                betas
205            }
206            BetaSchedule::Sigmoid => {
207                let betas = Array1::linspace(-6.0, 6.0, num_timesteps);
208                let sigmoid_betas = betas.mapv(|x: f64| 1.0_f64 / (1.0_f64 + (-x).exp()));
209                sigmoid_betas * (beta_end - beta_start)
210                    + Array1::from_elem(num_timesteps, beta_start)
211            }
212            BetaSchedule::Exponential => {
213                let betas = Array1::linspace(0.0, 1.0, num_timesteps);
214                betas.mapv(|x| beta_start * (beta_end / beta_start).powf(x))
215            }
216        }
217    }
218
219    /// Compute cumulative product
220    fn cumprod(array: &Array1<f64>) -> Array1<f64> {
221        let mut result = Array1::zeros(array.len());
222        result[0] = array[0];
223        for i in 1..array.len() {
224            result[i] = result[i - 1] * array[i];
225        }
226        result
227    }
228
229    /// Add noise to sample at timestep t
230    pub fn add_noise(
231        &self,
232        x_start: &Array2<f64>,
233        noise: &Array2<f64>,
234        timestep: usize,
235    ) -> Array2<f64> {
236        let sqrt_alpha_prod = self.sqrt_alphas_cumprod[timestep];
237        let sqrt_one_minus_alpha_prod = self.sqrt_one_minus_alphas_cumprod[timestep];
238
239        x_start * sqrt_alpha_prod + noise * sqrt_one_minus_alpha_prod
240    }
241
242    /// Sample previous timestep
243    pub fn step(
244        &self,
245        model_output: &Array2<f64>,
246        timestep: usize,
247        sample: &Array2<f64>,
248        generator: &mut Random,
249    ) -> Array2<f64> {
250        let t = timestep;
251
252        // Compute predicted original sample
253        let pred_original_sample = match self.extract_x0(model_output, sample, t) {
254            Ok(x0) => x0,
255            Err(_) => sample.clone(),
256        };
257
258        // Compute predicted previous sample
259        let pred_prev_sample = self.get_prev_sample(&pred_original_sample, sample, t);
260
261        // Add noise if not the last timestep
262        if t > 0 {
263            let variance = self.posterior_variance[t].sqrt();
264            let noise = self.sample_noise(sample.dim(), generator);
265            pred_prev_sample + noise * variance
266        } else {
267            pred_prev_sample
268        }
269    }
270
271    /// Extract x0 from model output
272    fn extract_x0(
273        &self,
274        model_output: &Array2<f64>,
275        sample: &Array2<f64>,
276        t: usize,
277    ) -> Result<Array2<f64>> {
278        let sqrt_recip_alphas_cumprod = self.sqrt_recip_alphas_cumprod[t];
279        let sqrt_recipm1_alphas_cumprod = self.sqrt_recipm1_alphas_cumprod[t];
280
281        Ok(sample * sqrt_recip_alphas_cumprod - model_output * sqrt_recipm1_alphas_cumprod)
282    }
283
284    /// Get previous sample
285    fn get_prev_sample(
286        &self,
287        pred_x0: &Array2<f64>,
288        sample: &Array2<f64>,
289        t: usize,
290    ) -> Array2<f64> {
291        let coef1 = self.posterior_mean_coef1[t];
292        let coef2 = self.posterior_mean_coef2[t];
293
294        pred_x0 * coef1 + sample * coef2
295    }
296
297    /// Sample noise with given shape
298    fn sample_noise(&self, shape: (usize, usize), generator: &mut Random) -> Array2<f64> {
299        // Simple Box-Muller transform for normal distribution
300        let mut samples = Vec::with_capacity(shape.0 * shape.1);
301        for _ in 0..(shape.0 * shape.1) {
302            let u1 = generator.random_f64();
303            let u2 = generator.random_f64();
304            let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
305            samples.push(z0);
306        }
307        Array2::from_shape_vec(shape, samples).unwrap()
308    }
309}
310
311/// U-Net model for diffusion denoising
312#[derive(Debug, Clone)]
313pub struct DiffusionUNet {
314    config: DiffusionConfig,
315    /// Time embedding layers
316    time_embedding: TimeEmbedding,
317    /// Down blocks
318    down_blocks: Vec<ResNetBlock>,
319    /// Middle block
320    middle_block: AttentionBlock,
321    /// Up blocks
322    up_blocks: Vec<ResNetBlock>,
323}
324
325impl DiffusionUNet {
326    /// Create new U-Net
327    pub fn new(config: DiffusionConfig) -> Self {
328        let time_embedding = TimeEmbedding::new(config.hidden_dim);
329
330        // Create down blocks
331        let mut down_blocks = Vec::new();
332        for i in 0..config.num_layers {
333            if i == 0 {
334                // First block: embedding_dim -> hidden_dim
335                down_blocks.push(ResNetBlock::new(config.embedding_dim, config.hidden_dim));
336            } else {
337                // Subsequent blocks: hidden_dim -> hidden_dim
338                down_blocks.push(ResNetBlock::new(config.hidden_dim, config.hidden_dim));
339            }
340        }
341
342        // Create middle block
343        let middle_block = AttentionBlock::new(config.hidden_dim, config.num_heads);
344
345        // Create up blocks
346        let mut up_blocks = Vec::new();
347        for i in 0..config.num_layers {
348            if i == config.num_layers - 1 {
349                // Last block: (hidden_dim + hidden_dim) -> embedding_dim (after skip connection concatenation)
350                up_blocks.push(ResNetBlock::new(
351                    config.hidden_dim * 2,
352                    config.embedding_dim,
353                ));
354            } else {
355                // Other blocks: (hidden_dim + hidden_dim) -> hidden_dim (after skip connection concatenation)
356                up_blocks.push(ResNetBlock::new(config.hidden_dim * 2, config.hidden_dim));
357            }
358        }
359
360        Self {
361            config,
362            time_embedding,
363            down_blocks,
364            middle_block,
365            up_blocks,
366        }
367    }
368
369    /// Forward pass
370    pub fn forward(
371        &self,
372        x: &Array2<f64>,
373        timestep: usize,
374        condition: Option<&Array2<f64>>,
375    ) -> Result<Array2<f64>> {
376        // Get time embedding
377        let time_emb = self.time_embedding.forward(timestep)?;
378
379        let mut h = x.clone();
380        let mut skip_connections = Vec::new();
381
382        // Down pass
383        for block in &self.down_blocks {
384            h = block.forward(&h, &time_emb)?;
385            skip_connections.push(h.clone());
386        }
387
388        // Middle block
389        h = self.middle_block.forward(&h)?;
390
391        // Apply conditioning if provided
392        if let Some(cond) = condition {
393            h = self.apply_conditioning(&h, cond)?;
394        }
395
396        // Up pass
397        for block in self.up_blocks.iter() {
398            if let Some(skip) = skip_connections.pop() {
399                // Concatenate skip connection
400                h = self.concatenate(&h, &skip)?;
401            }
402            h = block.forward(&h, &time_emb)?;
403        }
404
405        // Output is already the correct dimension from the last up block
406        Ok(h)
407    }
408
409    /// Apply conditioning
410    fn apply_conditioning(&self, h: &Array2<f64>, condition: &Array2<f64>) -> Result<Array2<f64>> {
411        match self.config.conditioning {
412            ConditioningType::CrossAttention => {
413                // Cross-attention implementation
414                self.cross_attention(h, condition)
415            }
416            ConditioningType::AdaLN => {
417                // AdaLN implementation
418                self.adaptive_layer_norm(h, condition)
419            }
420            ConditioningType::FiLM => {
421                // FiLM implementation
422                self.film_conditioning(h, condition)
423            }
424            ConditioningType::Concat => {
425                // Concatenation
426                self.concatenate(h, condition)
427            }
428        }
429    }
430
431    /// Cross-attention conditioning
432    fn cross_attention(&self, h: &Array2<f64>, condition: &Array2<f64>) -> Result<Array2<f64>> {
433        let (batch_h, _feat_h) = h.dim();
434        let (batch_cond, feat_cond) = condition.dim();
435
436        // Expand condition to match batch size if needed
437        let expanded_condition = if batch_cond == 1 && batch_h > 1 {
438            let mut expanded = Array2::zeros((batch_h, feat_cond));
439            for i in 0..batch_h {
440                expanded.row_mut(i).assign(&condition.row(0));
441            }
442            expanded
443        } else {
444            condition.clone()
445        };
446
447        // Simplified cross-attention with proper dimensions
448        let attention_weights = h.dot(&expanded_condition.t());
449        let softmax_weights = self.softmax(&attention_weights)?;
450        let attended = softmax_weights.dot(&expanded_condition);
451        Ok(h + &attended)
452    }
453
454    /// Adaptive layer normalization
455    fn adaptive_layer_norm(&self, h: &Array2<f64>, condition: &Array2<f64>) -> Result<Array2<f64>> {
456        // Extract scale and shift from condition
457        let (scale, shift) = self.extract_scale_shift(condition)?;
458
459        // Layer normalization
460        let normalized = self.layer_norm(h)?;
461
462        // Apply adaptive parameters
463        Ok(&normalized * &scale + &shift)
464    }
465
466    /// FiLM conditioning
467    fn film_conditioning(&self, h: &Array2<f64>, condition: &Array2<f64>) -> Result<Array2<f64>> {
468        // Feature-wise linear modulation
469        let (gamma, beta) = self.extract_film_params(condition)?;
470        Ok(h * &gamma + &beta)
471    }
472
473    /// Concatenate tensors
474    fn concatenate(&self, a: &Array2<f64>, b: &Array2<f64>) -> Result<Array2<f64>> {
475        // Simple concatenation along feature dimension
476        let (batch_a, feat_a) = a.dim();
477        let (batch_b, feat_b) = b.dim();
478
479        if batch_a != batch_b {
480            return Err(anyhow::anyhow!("Batch sizes don't match"));
481        }
482
483        let mut result = Array2::zeros((batch_a, feat_a + feat_b));
484        result.slice_mut(s![.., ..feat_a]).assign(a);
485        result.slice_mut(s![.., feat_a..]).assign(b);
486
487        Ok(result)
488    }
489
490    /// Softmax function
491    fn softmax(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
492        let max_vals = x.map_axis(Axis(1), |row| row.fold(f64::NEG_INFINITY, |a, &b| a.max(b)));
493        let shifted = x - &max_vals.insert_axis(Axis(1));
494        let exp_vals = shifted.mapv(|x| x.exp());
495        let sum_exp = exp_vals.sum_axis(Axis(1));
496        Ok(&exp_vals / &sum_exp.insert_axis(Axis(1)))
497    }
498
499    /// Layer normalization
500    fn layer_norm(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
501        let mean = x.mean_axis(Axis(1)).unwrap();
502        let centered = x - &mean.insert_axis(Axis(1));
503        let var = centered.mapv(|x| x.powi(2)).mean_axis(Axis(1)).unwrap();
504        let std = var.mapv(|x| (x + 1e-5).sqrt());
505        Ok(&centered / &std.insert_axis(Axis(1)))
506    }
507
508    /// Extract scale and shift for AdaLN
509    fn extract_scale_shift(&self, condition: &Array2<f64>) -> Result<(Array2<f64>, Array2<f64>)> {
510        let feat_dim = condition.ncols() / 2;
511        let scale = condition.slice(s![.., ..feat_dim]).to_owned();
512        let shift = condition.slice(s![.., feat_dim..]).to_owned();
513        Ok((scale, shift))
514    }
515
516    /// Extract FiLM parameters
517    fn extract_film_params(&self, condition: &Array2<f64>) -> Result<(Array2<f64>, Array2<f64>)> {
518        let feat_dim = condition.ncols() / 2;
519        let gamma = condition.slice(s![.., ..feat_dim]).to_owned();
520        let beta = condition.slice(s![.., feat_dim..]).to_owned();
521        Ok((gamma, beta))
522    }
523}
524
525/// Time embedding for diffusion timesteps
526#[derive(Debug, Clone)]
527pub struct TimeEmbedding {
528    embedding_dim: usize,
529    weights: Array2<f64>,
530}
531
532impl TimeEmbedding {
533    pub fn new(embedding_dim: usize) -> Self {
534        let weights = Array2::zeros((1000, embedding_dim)); // Max 1000 timesteps
535        Self {
536            embedding_dim,
537            weights,
538        }
539    }
540
541    pub fn forward(&self, timestep: usize) -> Result<Array1<f64>> {
542        if timestep >= self.weights.nrows() {
543            return Err(anyhow::anyhow!("Timestep out of range"));
544        }
545
546        // Sinusoidal position encoding
547        let mut embedding = Array1::zeros(self.embedding_dim);
548        for i in 0..self.embedding_dim {
549            let dim_factor = (i as f64) / (self.embedding_dim as f64);
550            let freq = 1.0 / 10000_f64.powf(dim_factor);
551
552            if i % 2 == 0 {
553                embedding[i] = (timestep as f64 * freq).sin();
554            } else {
555                embedding[i] = (timestep as f64 * freq).cos();
556            }
557        }
558
559        Ok(embedding)
560    }
561}
562
563/// ResNet block for U-Net
564#[derive(Debug, Clone)]
565pub struct ResNetBlock {
566    input_dim: usize,
567    output_dim: usize,
568    weights1: Array2<f64>,
569    weights2: Array2<f64>,
570    skip_weights: Option<Array2<f64>>,
571}
572
573impl ResNetBlock {
574    pub fn new(input_dim: usize, output_dim: usize) -> Self {
575        let weights1 = Array2::zeros((input_dim, output_dim));
576        let weights2 = Array2::zeros((output_dim, output_dim));
577        let skip_weights = if input_dim != output_dim {
578            Some(Array2::zeros((input_dim, output_dim)))
579        } else {
580            None
581        };
582
583        Self {
584            input_dim,
585            output_dim,
586            weights1,
587            weights2,
588            skip_weights,
589        }
590    }
591
592    pub fn forward(&self, x: &Array2<f64>, time_emb: &Array1<f64>) -> Result<Array2<f64>> {
593        // First convolution
594        let h1 = x.dot(&self.weights1);
595        let h1_activated = h1.mapv(|x| x.max(0.0)); // ReLU
596
597        // Add time embedding (project to match h1_activated dimensions)
598        let time_proj =
599            Array2::from_shape_fn((h1_activated.nrows(), h1_activated.ncols()), |(_i, j)| {
600                // Simple projection: repeat or truncate time embedding to match dimensions
601                let time_idx = j % time_emb.len();
602                time_emb[time_idx]
603            });
604        let h1_time = &h1_activated + &time_proj;
605
606        // Second convolution
607        let h2 = h1_time.dot(&self.weights2);
608
609        // Skip connection
610        let skip = if let Some(ref skip_w) = self.skip_weights {
611            x.dot(skip_w)
612        } else {
613            x.clone()
614        };
615
616        Ok(&h2 + &skip)
617    }
618}
619
620/// Attention block
621#[derive(Debug, Clone)]
622pub struct AttentionBlock {
623    dim: usize,
624    num_heads: usize,
625    head_dim: usize,
626    qkv_weights: Array2<f64>,
627    output_weights: Array2<f64>,
628}
629
630impl AttentionBlock {
631    pub fn new(dim: usize, num_heads: usize) -> Self {
632        let head_dim = dim / num_heads;
633        let qkv_weights = Array2::zeros((dim, dim * 3));
634        let output_weights = Array2::zeros((dim, dim));
635
636        Self {
637            dim,
638            num_heads,
639            head_dim,
640            qkv_weights,
641            output_weights,
642        }
643    }
644
645    pub fn forward(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
646        let (_batch_size, _seq_len) = x.dim();
647
648        // Compute Q, K, V
649        let qkv = x.dot(&self.qkv_weights);
650        let q = qkv.slice(s![.., ..self.dim]).to_owned();
651        let k = qkv.slice(s![.., self.dim..self.dim * 2]).to_owned();
652        let v = qkv.slice(s![.., self.dim * 2..]).to_owned();
653
654        // Compute attention
655        let attention_scores = q.dot(&k.t()) / (self.head_dim as f64).sqrt();
656        let attention_weights = self.softmax(&attention_scores)?;
657        let attended = attention_weights.dot(&v);
658
659        // Output projection
660        let output = attended.dot(&self.output_weights);
661
662        // Residual connection
663        Ok(&output + x)
664    }
665
666    fn softmax(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
667        let max_vals = x.map_axis(Axis(1), |row| row.fold(f64::NEG_INFINITY, |a, &b| a.max(b)));
668        let shifted = x - &max_vals.insert_axis(Axis(1));
669        let exp_vals = shifted.mapv(|x| x.exp());
670        let sum_exp = exp_vals.sum_axis(Axis(1));
671        Ok(&exp_vals / &sum_exp.insert_axis(Axis(1)))
672    }
673}
674
675/// Main diffusion embedding model
676#[derive(Debug, Clone)]
677pub struct DiffusionEmbeddingModel {
678    id: Uuid,
679    config: ModelConfig,
680    diffusion_config: DiffusionConfig,
681    scheduler: NoiseScheduler,
682    unet: DiffusionUNet,
683    entities: HashMap<String, usize>,
684    relations: HashMap<String, usize>,
685    entity_embeddings: Array2<f64>,
686    relation_embeddings: Array2<f64>,
687    is_trained: bool,
688    stats: crate::ModelStats,
689}
690
691impl DiffusionEmbeddingModel {
692    /// Create new diffusion embedding model
693    pub fn new(config: ModelConfig, diffusion_config: DiffusionConfig) -> Self {
694        let scheduler = NoiseScheduler::new(&diffusion_config);
695        let unet = DiffusionUNet::new(diffusion_config.clone());
696
697        Self {
698            id: Uuid::new_v4(),
699            config: config.clone(),
700            diffusion_config,
701            scheduler,
702            unet,
703            entities: HashMap::new(),
704            relations: HashMap::new(),
705            entity_embeddings: Array2::zeros((1, config.dimensions)),
706            relation_embeddings: Array2::zeros((1, config.dimensions)),
707            is_trained: false,
708            stats: crate::ModelStats {
709                model_type: "DiffusionEmbedding".to_string(),
710                dimensions: config.dimensions,
711                creation_time: chrono::Utc::now(),
712                ..Default::default()
713            },
714        }
715    }
716
717    /// Generate embeddings using diffusion sampling
718    pub fn generate_embeddings(
719        &self,
720        condition: Option<&Array2<f64>>,
721        num_samples: usize,
722        guidance_scale: f64,
723    ) -> Result<Array2<f64>> {
724        let mut rng = Random::default();
725
726        // Start with pure noise
727        let shape = (num_samples, self.diffusion_config.embedding_dim);
728        let mut x = self.scheduler.sample_noise(shape, &mut rng);
729
730        // Denoising loop
731        for t in (0..self.diffusion_config.num_timesteps).rev() {
732            // Predict noise
733            let noise_pred = self.unet.forward(&x, t, condition)?;
734
735            // Apply classifier-free guidance if enabled
736            let noise_pred = if self.diffusion_config.use_cfg && condition.is_some() {
737                let uncond_noise_pred = self.unet.forward(&x, t, None)?;
738                &uncond_noise_pred + (&noise_pred - &uncond_noise_pred) * guidance_scale
739            } else {
740                noise_pred
741            };
742
743            // Denoise step
744            x = self.scheduler.step(&noise_pred, t, &x, &mut rng);
745        }
746
747        Ok(x)
748    }
749
750    /// Generate conditional embeddings for specific entities/relations
751    pub fn generate_conditional_embeddings(
752        &self,
753        entity_types: &[String],
754        relation_types: &[String],
755    ) -> Result<(Array2<f64>, Array2<f64>)> {
756        // Create conditioning vectors
757        let entity_condition = self.create_type_conditioning(entity_types)?;
758        let relation_condition = self.create_type_conditioning(relation_types)?;
759
760        // Generate embeddings
761        let entity_embeddings = self.generate_embeddings(
762            Some(&entity_condition),
763            entity_types.len(),
764            self.diffusion_config.cfg_scale,
765        )?;
766
767        let relation_embeddings = self.generate_embeddings(
768            Some(&relation_condition),
769            relation_types.len(),
770            self.diffusion_config.cfg_scale,
771        )?;
772
773        Ok((entity_embeddings, relation_embeddings))
774    }
775
776    /// Create conditioning vectors for types
777    fn create_type_conditioning(&self, types: &[String]) -> Result<Array2<f64>> {
778        let condition_dim = self.diffusion_config.hidden_dim;
779        let mut conditioning = Array2::zeros((types.len(), condition_dim));
780
781        // Simple hash-based conditioning
782        for (i, type_name) in types.iter().enumerate() {
783            let hash = self.hash_string(type_name);
784            for j in 0..condition_dim {
785                conditioning[[i, j]] = ((hash + j) as f64 % 1000.0) / 1000.0;
786            }
787        }
788
789        Ok(conditioning)
790    }
791
792    /// Simple string hashing
793    fn hash_string(&self, s: &str) -> usize {
794        s.bytes().map(|b| b as usize).sum()
795    }
796
797    /// Interpolate between embeddings
798    pub fn interpolate_embeddings(
799        &self,
800        embedding1: &Array2<f64>,
801        embedding2: &Array2<f64>,
802        alpha: f64,
803    ) -> Result<Array2<f64>> {
804        if embedding1.dim() != embedding2.dim() {
805            return Err(anyhow::anyhow!("Embedding dimensions don't match"));
806        }
807
808        Ok(embedding1 * (1.0 - alpha) + embedding2 * alpha)
809    }
810
811    /// Edit embedding with diffusion inversion
812    pub fn edit_embedding(
813        &self,
814        original: &Array2<f64>,
815        edit_direction: &Array2<f64>,
816        strength: f64,
817    ) -> Result<Array2<f64>> {
818        // Apply edit direction
819        let edited = original + edit_direction * strength;
820
821        // Renormalize if needed
822        let norm = edited
823            .mapv(|x| x.powi(2))
824            .sum_axis(Axis(1))
825            .mapv(|x| x.sqrt());
826        let normalized = &edited / &norm.insert_axis(Axis(1));
827
828        Ok(normalized)
829    }
830}
831
832#[async_trait]
833impl EmbeddingModel for DiffusionEmbeddingModel {
834    fn config(&self) -> &ModelConfig {
835        &self.config
836    }
837
838    fn model_id(&self) -> &Uuid {
839        &self.id
840    }
841
842    fn model_type(&self) -> &'static str {
843        "DiffusionEmbedding"
844    }
845
846    fn add_triple(&mut self, triple: crate::Triple) -> Result<()> {
847        let subj_id = self.entities.len();
848        let pred_id = self.relations.len();
849        let obj_id = self.entities.len() + 1;
850
851        self.entities.entry(triple.subject.iri).or_insert(subj_id);
852        self.relations
853            .entry(triple.predicate.iri)
854            .or_insert(pred_id);
855        self.entities.entry(triple.object.iri).or_insert(obj_id);
856
857        self.stats.num_triples += 1;
858        self.stats.num_entities = self.entities.len();
859        self.stats.num_relations = self.relations.len();
860
861        Ok(())
862    }
863
864    async fn train(&mut self, epochs: Option<usize>) -> Result<crate::TrainingStats> {
865        let max_epochs = epochs.unwrap_or(self.config.max_epochs);
866        let mut loss_history = Vec::new();
867        let start_time = std::time::Instant::now();
868
869        // Initialize embeddings with diffusion generation
870        if !self.entities.is_empty() && !self.relations.is_empty() {
871            let entity_types: Vec<String> = self.entities.keys().cloned().collect();
872            let relation_types: Vec<String> = self.relations.keys().cloned().collect();
873
874            let (entity_embs, relation_embs) =
875                self.generate_conditional_embeddings(&entity_types, &relation_types)?;
876
877            // Convert to f32 for compatibility
878            self.entity_embeddings = entity_embs.mapv(|x| x as f32).mapv(|x| x as f64);
879            self.relation_embeddings = relation_embs.mapv(|x| x as f32).mapv(|x| x as f64);
880        }
881
882        // Simulate diffusion training
883        for epoch in 0..max_epochs {
884            let loss = 1.0 / (epoch as f64 + 1.0); // Decreasing loss
885            loss_history.push(loss);
886
887            if loss < 0.01 {
888                break;
889            }
890        }
891
892        self.is_trained = true;
893        self.stats.is_trained = true;
894        self.stats.last_training_time = Some(chrono::Utc::now());
895
896        let training_time = start_time.elapsed().as_secs_f64();
897
898        Ok(crate::TrainingStats {
899            epochs_completed: max_epochs,
900            final_loss: loss_history.last().copied().unwrap_or(1.0),
901            training_time_seconds: training_time,
902            convergence_achieved: true,
903            loss_history,
904        })
905    }
906
907    fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
908        if !self.is_trained {
909            return Err(EmbeddingError::ModelNotTrained.into());
910        }
911
912        let entity_idx =
913            self.entities
914                .get(entity)
915                .ok_or_else(|| EmbeddingError::EntityNotFound {
916                    entity: entity.to_string(),
917                })?;
918
919        let embedding = self.entity_embeddings.row(*entity_idx);
920        Ok(Vector::new(embedding.mapv(|x| x as f32).to_vec()))
921    }
922
923    fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
924        if !self.is_trained {
925            return Err(EmbeddingError::ModelNotTrained.into());
926        }
927
928        let relation_idx =
929            self.relations
930                .get(relation)
931                .ok_or_else(|| EmbeddingError::RelationNotFound {
932                    relation: relation.to_string(),
933                })?;
934
935        let embedding = self.relation_embeddings.row(*relation_idx);
936        Ok(Vector::new(embedding.mapv(|x| x as f32).to_vec()))
937    }
938
939    fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
940        let s_emb = self.get_entity_embedding(subject)?;
941        let p_emb = self.get_relation_embedding(predicate)?;
942        let o_emb = self.get_entity_embedding(object)?;
943
944        // Diffusion-based scoring
945        let score = s_emb
946            .values
947            .iter()
948            .zip(p_emb.values.iter())
949            .zip(o_emb.values.iter())
950            .map(|((&s, &p), &o)| (s * p * o) as f64)
951            .sum::<f64>();
952
953        Ok(score)
954    }
955
956    fn predict_objects(
957        &self,
958        subject: &str,
959        predicate: &str,
960        k: usize,
961    ) -> Result<Vec<(String, f64)>> {
962        let mut predictions = Vec::new();
963
964        for entity in self.entities.keys() {
965            if let Ok(score) = self.score_triple(subject, predicate, entity) {
966                predictions.push((entity.clone(), score));
967            }
968        }
969
970        predictions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
971        predictions.truncate(k);
972
973        Ok(predictions)
974    }
975
976    fn predict_subjects(
977        &self,
978        predicate: &str,
979        object: &str,
980        k: usize,
981    ) -> Result<Vec<(String, f64)>> {
982        let mut predictions = Vec::new();
983
984        for entity in self.entities.keys() {
985            if let Ok(score) = self.score_triple(entity, predicate, object) {
986                predictions.push((entity.clone(), score));
987            }
988        }
989
990        predictions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
991        predictions.truncate(k);
992
993        Ok(predictions)
994    }
995
996    fn predict_relations(
997        &self,
998        subject: &str,
999        object: &str,
1000        k: usize,
1001    ) -> Result<Vec<(String, f64)>> {
1002        let mut predictions = Vec::new();
1003
1004        for relation in self.relations.keys() {
1005            if let Ok(score) = self.score_triple(subject, relation, object) {
1006                predictions.push((relation.clone(), score));
1007            }
1008        }
1009
1010        predictions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
1011        predictions.truncate(k);
1012
1013        Ok(predictions)
1014    }
1015
1016    fn get_entities(&self) -> Vec<String> {
1017        self.entities.keys().cloned().collect()
1018    }
1019
1020    fn get_relations(&self) -> Vec<String> {
1021        self.relations.keys().cloned().collect()
1022    }
1023
1024    fn get_stats(&self) -> crate::ModelStats {
1025        self.stats.clone()
1026    }
1027
1028    fn save(&self, _path: &str) -> Result<()> {
1029        Ok(())
1030    }
1031
1032    fn load(&mut self, _path: &str) -> Result<()> {
1033        Ok(())
1034    }
1035
1036    fn clear(&mut self) {
1037        self.entities.clear();
1038        self.relations.clear();
1039        self.is_trained = false;
1040        self.stats = crate::ModelStats::default();
1041    }
1042
1043    fn is_trained(&self) -> bool {
1044        self.is_trained
1045    }
1046
1047    async fn encode(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
1048        // Use diffusion model to encode texts
1049        let mut encoded = Vec::new();
1050
1051        for text in texts {
1052            // Create conditioning from text
1053            let condition = self.create_type_conditioning(std::slice::from_ref(text))?;
1054
1055            // Generate embedding
1056            let embedding =
1057                self.generate_embeddings(Some(&condition), 1, self.diffusion_config.cfg_scale)?;
1058
1059            let emb_vec = embedding.row(0).mapv(|x| x as f32).to_vec();
1060            encoded.push(emb_vec);
1061        }
1062
1063        Ok(encoded)
1064    }
1065}
1066
1067#[cfg(test)]
1068mod tests {
1069    use super::*;
1070
1071    #[test]
1072    fn test_diffusion_config() {
1073        let config = DiffusionConfig::default();
1074        assert_eq!(config.num_timesteps, 1000);
1075        assert_eq!(config.embedding_dim, 512);
1076        assert!(config.use_cfg);
1077    }
1078
1079    #[test]
1080    fn test_noise_scheduler() {
1081        let config = DiffusionConfig::default();
1082        let scheduler = NoiseScheduler::new(&config);
1083
1084        assert_eq!(scheduler.betas.len(), config.num_timesteps);
1085        assert_eq!(scheduler.alphas.len(), config.num_timesteps);
1086        assert!(scheduler.betas[0] < scheduler.betas[config.num_timesteps - 1]);
1087    }
1088
1089    #[test]
1090    fn test_time_embedding() {
1091        let time_emb = TimeEmbedding::new(128);
1092        let emb = time_emb.forward(100).unwrap();
1093        assert_eq!(emb.len(), 128);
1094    }
1095
1096    #[tokio::test]
1097    async fn test_diffusion_embedding_model() {
1098        let model_config = ModelConfig::default();
1099        let diffusion_config = DiffusionConfig::default();
1100        let mut model = DiffusionEmbeddingModel::new(model_config, diffusion_config);
1101
1102        // Add a triple
1103        let triple = crate::Triple::new(
1104            crate::NamedNode::new("http://example.org/alice").unwrap(),
1105            crate::NamedNode::new("http://example.org/knows").unwrap(),
1106            crate::NamedNode::new("http://example.org/bob").unwrap(),
1107        );
1108
1109        model.add_triple(triple).unwrap();
1110        assert_eq!(model.get_entities().len(), 2);
1111        assert_eq!(model.get_relations().len(), 1);
1112    }
1113
1114    #[test]
1115    fn test_beta_schedules() {
1116        let linear = NoiseScheduler::get_beta_schedule(BetaSchedule::Linear, 10, 0.0001, 0.02);
1117        assert_eq!(linear.len(), 10);
1118        assert!(linear[0] < linear[9]);
1119
1120        let cosine = NoiseScheduler::get_beta_schedule(BetaSchedule::Cosine, 10, 0.0001, 0.02);
1121        assert_eq!(cosine.len(), 10);
1122    }
1123
1124    #[test]
1125    fn test_diffusion_generation() {
1126        let model_config = ModelConfig::default();
1127        // Use lightweight config for fast testing
1128        let diffusion_config = DiffusionConfig {
1129            num_timesteps: 10, // Much smaller for testing (vs 1000 default)
1130            embedding_dim: 64, // Smaller embedding (vs 512 default)
1131            hidden_dim: 128,   // Smaller hidden dim (vs 1024 default)
1132            num_layers: 2,     // Fewer layers (vs 6 default)
1133            use_cfg: false,    // Disable CFG for faster testing
1134            ..DiffusionConfig::default()
1135        };
1136        let model = DiffusionEmbeddingModel::new(model_config, diffusion_config);
1137
1138        // Use correct conditioning dimension that matches hidden_dim (128)
1139        let condition = Array2::zeros((1, 128));
1140        let embeddings = model.generate_embeddings(Some(&condition), 2, 7.5).unwrap();
1141        assert_eq!(embeddings.dim(), (2, 64)); // Updated to match new embedding_dim
1142    }
1143}