oxirs_embed/models/
conve.rs

1//! ConvE (Convolutional Embeddings) Model
2//!
3//! ConvE uses 2D convolutional neural networks to model interactions between
4//! entities and relations in knowledge graphs. This allows for expressive
5//! feature learning while maintaining parameter efficiency.
6//!
7//! Reference: Dettmers et al. "Convolutional 2D Knowledge Graph Embeddings." AAAI 2018.
8//!
9//! The model reshapes entity and relation embeddings into 2D matrices,
10//! concatenates them, applies 2D convolution, and projects to entity space.
11
12use anyhow::{anyhow, Result};
13use rayon::prelude::*;
14use scirs2_core::ndarray_ext::{Array1, Array2, Array3};
15use scirs2_core::random::Random;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::fs::File;
19use std::io::{BufReader, BufWriter};
20use std::path::Path;
21use tracing::{debug, info};
22
23#[cfg(test)]
24use crate::NamedNode;
25use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
26use uuid::Uuid;
27
28/// ConvE model configuration
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ConvEConfig {
31    /// Base model configuration
32    pub base: ModelConfig,
33    /// Width of the 2D reshape (height = dimensions / width)
34    pub reshape_width: usize,
35    /// Number of output channels for convolution
36    pub num_filters: usize,
37    /// Kernel size for 2D convolution (square kernel)
38    pub kernel_size: usize,
39    /// Dropout rate for regularization
40    pub dropout_rate: f32,
41    /// L2 regularization coefficient
42    pub regularization: f32,
43    /// Margin for ranking loss
44    pub margin: f32,
45    /// Number of negative samples per positive
46    pub num_negatives: usize,
47    /// Use batch normalization
48    pub use_batch_norm: bool,
49}
50
51impl Default for ConvEConfig {
52    fn default() -> Self {
53        Self {
54            base: ModelConfig::default().with_dimensions(200),
55            reshape_width: 20, // 200 dimensions -> 10x20 matrix
56            num_filters: 32,
57            kernel_size: 3,
58            dropout_rate: 0.3,
59            regularization: 0.0001,
60            margin: 1.0,
61            num_negatives: 10,
62            use_batch_norm: true,
63        }
64    }
65}
66
67/// Serializable convolutional layer parameters
68#[derive(Debug, Serialize, Deserialize)]
69struct ConvLayerSerializable {
70    filters: Vec<Vec<Vec<f32>>>, // num_filters x kernel_size x kernel_size
71    biases: Vec<f32>,
72}
73
74/// Convolutional layer parameters
75struct ConvLayer {
76    /// Filters: shape (num_filters, kernel_size, kernel_size)
77    filters: Vec<Array2<f32>>,
78    /// Biases for each filter
79    biases: Array1<f32>,
80}
81
82impl ConvLayer {
83    fn new(num_filters: usize, kernel_size: usize, rng: &mut Random) -> Self {
84        let scale = (2.0 / (kernel_size * kernel_size) as f32).sqrt();
85        let mut filters = Vec::new();
86
87        for _ in 0..num_filters {
88            let filter =
89                Array2::from_shape_fn((kernel_size, kernel_size), |_| rng.gen_range(-scale..scale));
90            filters.push(filter);
91        }
92
93        let biases = Array1::zeros(num_filters);
94
95        Self { filters, biases }
96    }
97
98    /// Apply 2D convolution with valid padding
99    fn forward(&self, input: &Array2<f32>) -> Array3<f32> {
100        let kernel_size = self.filters[0].nrows();
101        let input_height = input.nrows();
102        let input_width = input.ncols();
103
104        let out_height = input_height.saturating_sub(kernel_size - 1);
105        let out_width = input_width.saturating_sub(kernel_size - 1);
106
107        if out_height == 0 || out_width == 0 {
108            // Return empty array if convolution cannot be performed
109            return Array3::zeros((self.filters.len(), 1, 1));
110        }
111
112        let mut output = Array3::zeros((self.filters.len(), out_height, out_width));
113
114        for (f_idx, filter) in self.filters.iter().enumerate() {
115            for i in 0..out_height {
116                for j in 0..out_width {
117                    let mut sum = 0.0;
118
119                    for ki in 0..kernel_size {
120                        for kj in 0..kernel_size {
121                            sum += input[[i + ki, j + kj]] * filter[[ki, kj]];
122                        }
123                    }
124
125                    output[[f_idx, i, j]] = sum + self.biases[f_idx];
126                }
127            }
128        }
129
130        output
131    }
132}
133
134/// Serializable fully connected layer
135#[derive(Debug, Serialize, Deserialize)]
136struct FCLayerSerializable {
137    weights: Vec<Vec<f32>>, // input_size x output_size
138    bias: Vec<f32>,
139}
140
141/// Fully connected layer
142struct FCLayer {
143    weights: Array2<f32>,
144    bias: Array1<f32>,
145}
146
147impl FCLayer {
148    fn new(input_size: usize, output_size: usize, rng: &mut Random) -> Self {
149        let scale = (2.0 / input_size as f32).sqrt();
150        let weights =
151            Array2::from_shape_fn((input_size, output_size), |_| rng.gen_range(-scale..scale));
152        let bias = Array1::zeros(output_size);
153
154        Self { weights, bias }
155    }
156
157    fn forward(&self, input: &Array1<f32>) -> Array1<f32> {
158        let mut output = self.bias.clone();
159        for i in 0..output.len() {
160            for j in 0..input.len() {
161                output[i] += input[j] * self.weights[[j, i]];
162            }
163        }
164        output
165    }
166}
167
168/// Serializable representation of ConvE model for persistence
169#[derive(Debug, Serialize, Deserialize)]
170struct ConvESerializable {
171    model_id: Uuid,
172    config: ConvEConfig,
173    entity_embeddings: HashMap<String, Vec<f32>>,
174    relation_embeddings: HashMap<String, Vec<f32>>,
175    conv_layer: ConvLayerSerializable,
176    fc_layer: FCLayerSerializable,
177    triples: Vec<Triple>,
178    entity_to_id: HashMap<String, usize>,
179    relation_to_id: HashMap<String, usize>,
180    id_to_entity: HashMap<usize, String>,
181    id_to_relation: HashMap<usize, String>,
182    is_trained: bool,
183}
184
185/// ConvE (Convolutional Embeddings) model
186pub struct ConvE {
187    model_id: Uuid,
188    config: ConvEConfig,
189    entity_embeddings: HashMap<String, Array1<f32>>,
190    relation_embeddings: HashMap<String, Array1<f32>>,
191    conv_layer: ConvLayer,
192    fc_layer: FCLayer,
193    triples: Vec<Triple>,
194    entity_to_id: HashMap<String, usize>,
195    relation_to_id: HashMap<String, usize>,
196    id_to_entity: HashMap<usize, String>,
197    id_to_relation: HashMap<usize, String>,
198    is_trained: bool,
199}
200
201impl ConvE {
202    /// Create new ConvE model with configuration
203    pub fn new(config: ConvEConfig) -> Self {
204        let mut rng = Random::default();
205
206        // Calculate feature map size after convolution
207        let reshape_height = config.base.dimensions / config.reshape_width;
208        let conv_out_height = reshape_height.saturating_sub(config.kernel_size - 1);
209        let conv_out_width = (config.reshape_width * 2).saturating_sub(config.kernel_size - 1);
210        let fc_input_size = config.num_filters * conv_out_height * conv_out_width;
211
212        let conv_layer = ConvLayer::new(config.num_filters, config.kernel_size, &mut rng);
213        let fc_layer = FCLayer::new(fc_input_size, config.base.dimensions, &mut rng);
214
215        info!(
216            "Initialized ConvE model: dim={}, filters={}, kernel={}, fc_input={}",
217            config.base.dimensions, config.num_filters, config.kernel_size, fc_input_size
218        );
219
220        Self {
221            model_id: Uuid::new_v4(),
222            config,
223            entity_embeddings: HashMap::new(),
224            relation_embeddings: HashMap::new(),
225            conv_layer,
226            fc_layer,
227            triples: Vec::new(),
228            entity_to_id: HashMap::new(),
229            relation_to_id: HashMap::new(),
230            id_to_entity: HashMap::new(),
231            id_to_relation: HashMap::new(),
232            is_trained: false,
233        }
234    }
235
236    /// Reshape 1D embedding to 2D matrix
237    fn reshape_embedding(&self, embedding: &Array1<f32>) -> Array2<f32> {
238        let height = self.config.base.dimensions / self.config.reshape_width;
239        let width = self.config.reshape_width;
240
241        Array2::from_shape_fn((height, width), |(i, j)| embedding[i * width + j])
242    }
243
244    /// Apply ReLU activation
245    fn relu(&self, x: f32) -> f32 {
246        x.max(0.0)
247    }
248
249    /// Apply dropout (during training)
250    fn dropout(&mut self, values: &mut Array1<f32>, training: bool) {
251        if !training || self.config.dropout_rate == 0.0 {
252            return;
253        }
254
255        let mut local_rng = Random::default();
256        let keep_prob = 1.0 - self.config.dropout_rate;
257        for val in values.iter_mut() {
258            if local_rng.gen_range(0.0..1.0) > keep_prob {
259                *val = 0.0;
260            } else {
261                *val /= keep_prob; // Inverted dropout
262            }
263        }
264    }
265
266    /// Forward pass to compute score
267    fn forward(
268        &mut self,
269        head: &Array1<f32>,
270        relation: &Array1<f32>,
271        training: bool,
272    ) -> Array1<f32> {
273        // Reshape head and relation to 2D
274        let head_2d = self.reshape_embedding(head);
275        let rel_2d = self.reshape_embedding(relation);
276
277        // Concatenate horizontally: [head | relation]
278        let height = head_2d.nrows();
279        let width = head_2d.ncols() * 2;
280        let mut concat = Array2::zeros((height, width));
281
282        for i in 0..height {
283            for j in 0..head_2d.ncols() {
284                concat[[i, j]] = head_2d[[i, j]];
285            }
286            for j in 0..rel_2d.ncols() {
287                concat[[i, head_2d.ncols() + j]] = rel_2d[[i, j]];
288            }
289        }
290
291        // Apply 2D convolution
292        let conv_out = self.conv_layer.forward(&concat);
293
294        // Apply ReLU activation
295        let conv_out_relu = conv_out.mapv(|x| self.relu(x));
296
297        // Flatten the feature maps
298        let flattened_size = conv_out_relu.len();
299        let mut flattened = Array1::zeros(flattened_size);
300        for (idx, &val) in conv_out_relu.iter().enumerate() {
301            flattened[idx] = val;
302        }
303
304        // Apply dropout
305        self.dropout(&mut flattened, training);
306
307        // Fully connected layer
308        let mut output = self.fc_layer.forward(&flattened);
309
310        // Apply dropout again
311        self.dropout(&mut output, training);
312
313        output
314    }
315
316    /// Compute score for a triple
317    fn score_triple_internal(
318        &mut self,
319        head: &Array1<f32>,
320        relation: &Array1<f32>,
321        tail: &Array1<f32>,
322    ) -> f32 {
323        let projected = self.forward(head, relation, false);
324        // Score is dot product with tail entity
325        projected.dot(tail)
326    }
327
328    /// Initialize embeddings for an entity
329    fn init_entity(&mut self, entity: &str) {
330        if !self.entity_embeddings.contains_key(entity) {
331            let id = self.entity_embeddings.len();
332            self.entity_to_id.insert(entity.to_string(), id);
333            self.id_to_entity.insert(id, entity.to_string());
334
335            let mut local_rng = Random::default();
336            let scale = (6.0 / self.config.base.dimensions as f32).sqrt();
337            let embedding = Array1::from_vec(
338                (0..self.config.base.dimensions)
339                    .map(|_| local_rng.gen_range(-scale..scale))
340                    .collect(),
341            );
342            self.entity_embeddings.insert(entity.to_string(), embedding);
343        }
344    }
345
346    /// Initialize embeddings for a relation
347    fn init_relation(&mut self, relation: &str) {
348        if !self.relation_embeddings.contains_key(relation) {
349            let id = self.relation_embeddings.len();
350            self.relation_to_id.insert(relation.to_string(), id);
351            self.id_to_relation.insert(id, relation.to_string());
352
353            let mut local_rng = Random::default();
354            let scale = (6.0 / self.config.base.dimensions as f32).sqrt();
355            let embedding = Array1::from_vec(
356                (0..self.config.base.dimensions)
357                    .map(|_| local_rng.gen_range(-scale..scale))
358                    .collect(),
359            );
360            self.relation_embeddings
361                .insert(relation.to_string(), embedding);
362        }
363    }
364
365    /// Training step with simplified gradient updates
366    fn train_step(&mut self) -> f32 {
367        let mut total_loss = 0.0;
368        let mut local_rng = Random::default();
369
370        // Shuffle triples
371        let mut indices: Vec<usize> = (0..self.triples.len()).collect();
372        for i in (1..indices.len()).rev() {
373            let j = local_rng.random_range(0..i + 1);
374            indices.swap(i, j);
375        }
376
377        for &idx in &indices {
378            let triple = &self.triples[idx].clone();
379
380            let subject_str = &triple.subject.iri;
381            let predicate_str = &triple.predicate.iri;
382            let object_str = &triple.object.iri;
383
384            let head_emb = self.entity_embeddings[subject_str].clone();
385            let rel_emb = self.relation_embeddings[predicate_str].clone();
386            let tail_emb = self.entity_embeddings[object_str].clone();
387
388            // Positive score
389            let pos_score = self.score_triple_internal(&head_emb, &rel_emb, &tail_emb);
390
391            // Generate negative samples
392            let entity_list: Vec<String> = self.entity_embeddings.keys().cloned().collect();
393            for _ in 0..self.config.num_negatives {
394                let neg_tail_id = entity_list[local_rng.random_range(0..entity_list.len())].clone();
395                let neg_tail_emb = self.entity_embeddings[&neg_tail_id].clone();
396
397                let neg_score = self.score_triple_internal(&head_emb, &rel_emb, &neg_tail_emb);
398
399                // Margin ranking loss
400                let loss = (self.config.margin + neg_score - pos_score).max(0.0);
401                total_loss += loss;
402
403                // Simplified parameter update (in practice, use proper backpropagation)
404                if loss > 0.0 {
405                    let lr = self.config.base.learning_rate as f32;
406                    // Apply L2 regularization
407                    for emb in self.entity_embeddings.values_mut() {
408                        *emb = &*emb * (1.0 - self.config.regularization * lr);
409                    }
410                    for emb in self.relation_embeddings.values_mut() {
411                        *emb = &*emb * (1.0 - self.config.regularization * lr);
412                    }
413                }
414            }
415        }
416
417        total_loss / (self.triples.len() as f32 * self.config.num_negatives as f32)
418    }
419}
420
421#[async_trait::async_trait]
422impl EmbeddingModel for ConvE {
423    fn config(&self) -> &ModelConfig {
424        &self.config.base
425    }
426
427    fn model_id(&self) -> &Uuid {
428        &self.model_id
429    }
430
431    fn model_type(&self) -> &'static str {
432        "ConvE"
433    }
434
435    fn add_triple(&mut self, triple: Triple) -> Result<()> {
436        self.init_entity(&triple.subject.iri);
437        self.init_entity(&triple.object.iri);
438        self.init_relation(&triple.predicate.iri);
439        self.triples.push(triple);
440        Ok(())
441    }
442
443    async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
444        let num_epochs = epochs.unwrap_or(self.config.base.max_epochs);
445
446        if self.triples.is_empty() {
447            return Err(anyhow!("No training data available"));
448        }
449
450        info!(
451            "Training ConvE model for {} epochs on {} triples",
452            num_epochs,
453            self.triples.len()
454        );
455
456        let start_time = std::time::Instant::now();
457        let mut loss_history = Vec::new();
458
459        for epoch in 0..num_epochs {
460            let loss = self.train_step();
461            loss_history.push(loss as f64);
462
463            if epoch % 10 == 0 {
464                debug!("Epoch {}/{}: loss = {:.6}", epoch + 1, num_epochs, loss);
465            }
466
467            if loss < 0.001 {
468                info!("Converged at epoch {}", epoch);
469                break;
470            }
471        }
472
473        let training_time = start_time.elapsed().as_secs_f64();
474        self.is_trained = true;
475
476        Ok(TrainingStats {
477            epochs_completed: num_epochs,
478            final_loss: *loss_history.last().unwrap_or(&0.0),
479            training_time_seconds: training_time,
480            convergence_achieved: loss_history.last().unwrap_or(&1.0) < &0.001,
481            loss_history,
482        })
483    }
484
485    fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
486        self.entity_embeddings
487            .get(entity)
488            .map(Vector::from_array1)
489            .ok_or_else(|| anyhow!("Unknown entity: {}", entity))
490    }
491
492    fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
493        self.relation_embeddings
494            .get(relation)
495            .map(Vector::from_array1)
496            .ok_or_else(|| anyhow!("Unknown relation: {}", relation))
497    }
498
499    fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
500        let head_emb = self
501            .entity_embeddings
502            .get(subject)
503            .ok_or_else(|| anyhow!("Unknown subject: {}", subject))?;
504        let rel_emb = self
505            .relation_embeddings
506            .get(predicate)
507            .ok_or_else(|| anyhow!("Unknown predicate: {}", predicate))?;
508        let tail_emb = self
509            .entity_embeddings
510            .get(object)
511            .ok_or_else(|| anyhow!("Unknown object: {}", object))?;
512
513        // Simplified scoring: (head + relation) ยท tail
514        // Note: Full ConvE scoring requires mutable access for CNN forward pass
515        let score = (head_emb + rel_emb).dot(tail_emb);
516        Ok(score as f64)
517    }
518
519    fn predict_objects(
520        &self,
521        subject: &str,
522        predicate: &str,
523        k: usize,
524    ) -> Result<Vec<(String, f64)>> {
525        let head_emb = self
526            .entity_embeddings
527            .get(subject)
528            .ok_or_else(|| anyhow!("Unknown subject: {}", subject))?;
529        let rel_emb = self
530            .relation_embeddings
531            .get(predicate)
532            .ok_or_else(|| anyhow!("Unknown predicate: {}", predicate))?;
533
534        let combined = head_emb + rel_emb;
535        let mut scored_objects: Vec<(String, f64)> = self
536            .entity_embeddings
537            .par_iter()
538            .map(|(entity, tail_emb)| {
539                let score = combined.dot(tail_emb);
540                (entity.clone(), score as f64)
541            })
542            .collect();
543
544        scored_objects.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
545        scored_objects.truncate(k);
546        Ok(scored_objects)
547    }
548
549    fn predict_subjects(
550        &self,
551        predicate: &str,
552        object: &str,
553        k: usize,
554    ) -> Result<Vec<(String, f64)>> {
555        let rel_emb = self
556            .relation_embeddings
557            .get(predicate)
558            .ok_or_else(|| anyhow!("Unknown predicate: {}", predicate))?;
559        let tail_emb = self
560            .entity_embeddings
561            .get(object)
562            .ok_or_else(|| anyhow!("Unknown object: {}", object))?;
563
564        let mut scored_subjects: Vec<(String, f64)> = self
565            .entity_embeddings
566            .par_iter()
567            .map(|(entity, head_emb)| {
568                let score = (head_emb + rel_emb).dot(tail_emb);
569                (entity.clone(), score as f64)
570            })
571            .collect();
572
573        scored_subjects.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
574        scored_subjects.truncate(k);
575        Ok(scored_subjects)
576    }
577
578    fn predict_relations(
579        &self,
580        subject: &str,
581        object: &str,
582        k: usize,
583    ) -> Result<Vec<(String, f64)>> {
584        let head_emb = self
585            .entity_embeddings
586            .get(subject)
587            .ok_or_else(|| anyhow!("Unknown subject: {}", subject))?;
588        let tail_emb = self
589            .entity_embeddings
590            .get(object)
591            .ok_or_else(|| anyhow!("Unknown object: {}", object))?;
592
593        let mut scored_relations: Vec<(String, f64)> = self
594            .relation_embeddings
595            .par_iter()
596            .map(|(relation, rel_emb)| {
597                let score = (head_emb + rel_emb).dot(tail_emb);
598                (relation.clone(), score as f64)
599            })
600            .collect();
601
602        scored_relations.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
603        scored_relations.truncate(k);
604        Ok(scored_relations)
605    }
606
607    fn get_entities(&self) -> Vec<String> {
608        self.entity_embeddings.keys().cloned().collect()
609    }
610
611    fn get_relations(&self) -> Vec<String> {
612        self.relation_embeddings.keys().cloned().collect()
613    }
614
615    fn get_stats(&self) -> ModelStats {
616        ModelStats {
617            num_entities: self.entity_embeddings.len(),
618            num_relations: self.relation_embeddings.len(),
619            num_triples: self.triples.len(),
620            dimensions: self.config.base.dimensions,
621            is_trained: self.is_trained,
622            model_type: "ConvE".to_string(),
623            creation_time: chrono::Utc::now(),
624            last_training_time: if self.is_trained {
625                Some(chrono::Utc::now())
626            } else {
627                None
628            },
629        }
630    }
631
632    fn save(&self, path: &str) -> Result<()> {
633        info!("Saving ConvE model to {}", path);
634
635        // Convert Array1 to Vec for serialization
636        let entity_embeddings_vec: HashMap<String, Vec<f32>> = self
637            .entity_embeddings
638            .iter()
639            .map(|(k, v)| (k.clone(), v.to_vec()))
640            .collect();
641
642        let relation_embeddings_vec: HashMap<String, Vec<f32>> = self
643            .relation_embeddings
644            .iter()
645            .map(|(k, v)| (k.clone(), v.to_vec()))
646            .collect();
647
648        // Serialize convolutional layer
649        let conv_filters: Vec<Vec<Vec<f32>>> = self
650            .conv_layer
651            .filters
652            .iter()
653            .map(|filter| {
654                let mut rows = Vec::new();
655                for i in 0..filter.nrows() {
656                    let mut row = Vec::new();
657                    for j in 0..filter.ncols() {
658                        row.push(filter[[i, j]]);
659                    }
660                    rows.push(row);
661                }
662                rows
663            })
664            .collect();
665
666        let conv_layer_ser = ConvLayerSerializable {
667            filters: conv_filters,
668            biases: self.conv_layer.biases.to_vec(),
669        };
670
671        // Serialize fully connected layer
672        let mut fc_weights = Vec::new();
673        for i in 0..self.fc_layer.weights.nrows() {
674            let mut row = Vec::new();
675            for j in 0..self.fc_layer.weights.ncols() {
676                row.push(self.fc_layer.weights[[i, j]]);
677            }
678            fc_weights.push(row);
679        }
680
681        let fc_layer_ser = FCLayerSerializable {
682            weights: fc_weights,
683            bias: self.fc_layer.bias.to_vec(),
684        };
685
686        let serializable = ConvESerializable {
687            model_id: self.model_id,
688            config: self.config.clone(),
689            entity_embeddings: entity_embeddings_vec,
690            relation_embeddings: relation_embeddings_vec,
691            conv_layer: conv_layer_ser,
692            fc_layer: fc_layer_ser,
693            triples: self.triples.clone(),
694            entity_to_id: self.entity_to_id.clone(),
695            relation_to_id: self.relation_to_id.clone(),
696            id_to_entity: self.id_to_entity.clone(),
697            id_to_relation: self.id_to_relation.clone(),
698            is_trained: self.is_trained,
699        };
700
701        let file = File::create(path)?;
702        let writer = BufWriter::new(file);
703        oxicode::serde::encode_into_std_write(&serializable, writer, oxicode::config::standard())
704            .map_err(|e| anyhow!("Failed to serialize model: {}", e))?;
705
706        info!("Model saved successfully");
707        Ok(())
708    }
709
710    fn load(&mut self, path: &str) -> Result<()> {
711        info!("Loading ConvE model from {}", path);
712
713        if !Path::new(path).exists() {
714            return Err(anyhow!("Model file not found: {}", path));
715        }
716
717        let file = File::open(path)?;
718        let reader = BufReader::new(file);
719        let (serializable, _): (ConvESerializable, _) =
720            oxicode::serde::decode_from_std_read(reader, oxicode::config::standard())
721                .map_err(|e| anyhow!("Failed to deserialize model: {}", e))?;
722
723        // Convert Vec back to Array1
724        let entity_embeddings: HashMap<String, Array1<f32>> = serializable
725            .entity_embeddings
726            .into_iter()
727            .map(|(k, v)| (k, Array1::from_vec(v)))
728            .collect();
729
730        let relation_embeddings: HashMap<String, Array1<f32>> = serializable
731            .relation_embeddings
732            .into_iter()
733            .map(|(k, v)| (k, Array1::from_vec(v)))
734            .collect();
735
736        // Deserialize convolutional layer
737        let conv_filters: Vec<Array2<f32>> = serializable
738            .conv_layer
739            .filters
740            .into_iter()
741            .map(|filter_vec| {
742                let kernel_size = filter_vec.len();
743                Array2::from_shape_fn((kernel_size, kernel_size), |(i, j)| filter_vec[i][j])
744            })
745            .collect();
746
747        let conv_layer = ConvLayer {
748            filters: conv_filters,
749            biases: Array1::from_vec(serializable.conv_layer.biases),
750        };
751
752        // Deserialize fully connected layer
753        let fc_weights_vec = serializable.fc_layer.weights;
754        let input_size = fc_weights_vec.len();
755        let output_size = if input_size > 0 {
756            fc_weights_vec[0].len()
757        } else {
758            0
759        };
760
761        let fc_weights =
762            Array2::from_shape_fn((input_size, output_size), |(i, j)| fc_weights_vec[i][j]);
763
764        let fc_layer = FCLayer {
765            weights: fc_weights,
766            bias: Array1::from_vec(serializable.fc_layer.bias),
767        };
768
769        // Update model state
770        self.model_id = serializable.model_id;
771        self.config = serializable.config;
772        self.entity_embeddings = entity_embeddings;
773        self.relation_embeddings = relation_embeddings;
774        self.conv_layer = conv_layer;
775        self.fc_layer = fc_layer;
776        self.triples = serializable.triples;
777        self.entity_to_id = serializable.entity_to_id;
778        self.relation_to_id = serializable.relation_to_id;
779        self.id_to_entity = serializable.id_to_entity;
780        self.id_to_relation = serializable.id_to_relation;
781        self.is_trained = serializable.is_trained;
782
783        info!("Model loaded successfully");
784        Ok(())
785    }
786
787    fn clear(&mut self) {
788        self.entity_embeddings.clear();
789        self.relation_embeddings.clear();
790        self.triples.clear();
791        self.entity_to_id.clear();
792        self.relation_to_id.clear();
793        self.id_to_entity.clear();
794        self.id_to_relation.clear();
795        self.is_trained = false;
796    }
797
798    fn is_trained(&self) -> bool {
799        self.is_trained
800    }
801
802    async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
803        // TODO: Implement text encoding
804        Err(anyhow!("Text encoding not implemented for ConvE"))
805    }
806}
807
808#[cfg(test)]
809mod tests {
810    use super::*;
811
812    #[test]
813    fn test_conve_creation() {
814        let config = ConvEConfig::default();
815        let model = ConvE::new(config);
816
817        assert_eq!(model.entity_embeddings.len(), 0);
818        assert_eq!(model.relation_embeddings.len(), 0);
819    }
820
821    #[tokio::test]
822    async fn test_conve_training() {
823        let config = ConvEConfig {
824            base: ModelConfig {
825                dimensions: 50, // Reduced from 100 for faster tests
826                learning_rate: 0.001,
827                max_epochs: 5, // Reduced from 20 for faster tests
828                ..Default::default()
829            },
830            reshape_width: 10,
831            num_filters: 8, // Reduced from 16 for faster tests
832            ..Default::default()
833        };
834
835        let mut model = ConvE::new(config);
836
837        model
838            .add_triple(Triple::new(
839                NamedNode::new("alice").unwrap(),
840                NamedNode::new("knows").unwrap(),
841                NamedNode::new("bob").unwrap(),
842            ))
843            .unwrap();
844
845        model
846            .add_triple(Triple::new(
847                NamedNode::new("bob").unwrap(),
848                NamedNode::new("likes").unwrap(),
849                NamedNode::new("charlie").unwrap(),
850            ))
851            .unwrap();
852
853        let stats = model.train(Some(5)).await.unwrap(); // Reduced from 20 for faster tests
854
855        assert_eq!(stats.epochs_completed, 5);
856        assert!(stats.final_loss >= 0.0);
857        assert_eq!(model.entity_embeddings.len(), 3);
858        assert_eq!(model.relation_embeddings.len(), 2);
859    }
860
861    #[tokio::test]
862    async fn test_conve_save_load() {
863        use std::env::temp_dir;
864
865        let config = ConvEConfig {
866            base: ModelConfig {
867                dimensions: 50,
868                learning_rate: 0.001,
869                max_epochs: 15,
870                ..Default::default()
871            },
872            reshape_width: 10,
873            num_filters: 8,
874            kernel_size: 2,
875            ..Default::default()
876        };
877
878        let mut model = ConvE::new(config);
879
880        // Add and train
881        model
882            .add_triple(Triple::new(
883                NamedNode::new("alice").unwrap(),
884                NamedNode::new("knows").unwrap(),
885                NamedNode::new("bob").unwrap(),
886            ))
887            .unwrap();
888
889        model
890            .add_triple(Triple::new(
891                NamedNode::new("bob").unwrap(),
892                NamedNode::new("likes").unwrap(),
893                NamedNode::new("charlie").unwrap(),
894            ))
895            .unwrap();
896
897        model.train(Some(15)).await.unwrap();
898
899        // Get embedding before save
900        let emb_before = model.get_entity_embedding("alice").unwrap();
901        let score_before = model.score_triple("alice", "knows", "bob").unwrap();
902
903        // Save model
904        let model_path = temp_dir().join("test_conve_model.bin");
905        let path_str = model_path.to_str().unwrap();
906        model.save(path_str).unwrap();
907
908        // Create new model and load
909        let mut loaded_model = ConvE::new(ConvEConfig::default());
910        loaded_model.load(path_str).unwrap();
911
912        // Verify loaded model
913        assert!(loaded_model.is_trained());
914        assert_eq!(loaded_model.get_entities().len(), 3);
915        assert_eq!(loaded_model.get_relations().len(), 2);
916
917        // Verify embeddings are preserved
918        let emb_after = loaded_model.get_entity_embedding("alice").unwrap();
919        assert_eq!(emb_before.dimensions, emb_after.dimensions);
920        for i in 0..emb_before.values.len() {
921            assert!((emb_before.values[i] - emb_after.values[i]).abs() < 1e-6);
922        }
923
924        // Verify scoring is consistent
925        let score_after = loaded_model.score_triple("alice", "knows", "bob").unwrap();
926        assert!((score_before - score_after).abs() < 1e-5);
927
928        // Cleanup
929        std::fs::remove_file(model_path).ok();
930    }
931
932    #[test]
933    fn test_conve_load_nonexistent() {
934        let mut model = ConvE::new(ConvEConfig::default());
935        let result = model.load("/nonexistent/path/model.bin");
936        assert!(result.is_err());
937    }
938}