Skip to main content

oxirs_embed/models/
conve.rs

1//! ConvE (Convolutional Embeddings) Model
2//!
3//! ConvE uses 2D convolutional neural networks to model interactions between
4//! entities and relations in knowledge graphs. This allows for expressive
5//! feature learning while maintaining parameter efficiency.
6//!
7//! Reference: Dettmers et al. "Convolutional 2D Knowledge Graph Embeddings." AAAI 2018.
8//!
9//! The model reshapes entity and relation embeddings into 2D matrices,
10//! concatenates them, applies 2D convolution, and projects to entity space.
11
12use anyhow::{anyhow, Result};
13use rayon::prelude::*;
14use scirs2_core::ndarray_ext::{Array1, Array2, Array3};
15use scirs2_core::random::Random;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::fs::File;
19use std::io::{BufReader, BufWriter};
20use std::path::Path;
21use tracing::{debug, info};
22
23#[cfg(test)]
24use crate::NamedNode;
25use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
26use uuid::Uuid;
27
28/// ConvE model configuration
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ConvEConfig {
31    /// Base model configuration
32    pub base: ModelConfig,
33    /// Width of the 2D reshape (height = dimensions / width)
34    pub reshape_width: usize,
35    /// Number of output channels for convolution
36    pub num_filters: usize,
37    /// Kernel size for 2D convolution (square kernel)
38    pub kernel_size: usize,
39    /// Dropout rate for regularization
40    pub dropout_rate: f32,
41    /// L2 regularization coefficient
42    pub regularization: f32,
43    /// Margin for ranking loss
44    pub margin: f32,
45    /// Number of negative samples per positive
46    pub num_negatives: usize,
47    /// Use batch normalization
48    pub use_batch_norm: bool,
49}
50
51impl Default for ConvEConfig {
52    fn default() -> Self {
53        Self {
54            base: ModelConfig::default().with_dimensions(200),
55            reshape_width: 20, // 200 dimensions -> 10x20 matrix
56            num_filters: 32,
57            kernel_size: 3,
58            dropout_rate: 0.3,
59            regularization: 0.0001,
60            margin: 1.0,
61            num_negatives: 10,
62            use_batch_norm: true,
63        }
64    }
65}
66
67/// Serializable convolutional layer parameters
68#[derive(Debug, Serialize, Deserialize)]
69struct ConvLayerSerializable {
70    filters: Vec<Vec<Vec<f32>>>, // num_filters x kernel_size x kernel_size
71    biases: Vec<f32>,
72}
73
74/// Convolutional layer parameters
75struct ConvLayer {
76    /// Filters: shape (num_filters, kernel_size, kernel_size)
77    filters: Vec<Array2<f32>>,
78    /// Biases for each filter
79    biases: Array1<f32>,
80}
81
82impl ConvLayer {
83    fn new(num_filters: usize, kernel_size: usize, rng: &mut Random) -> Self {
84        let scale = (2.0 / (kernel_size * kernel_size) as f32).sqrt();
85        let mut filters = Vec::new();
86
87        for _ in 0..num_filters {
88            let filter = Array2::from_shape_fn((kernel_size, kernel_size), |_| {
89                rng.random_range(-scale..scale)
90            });
91            filters.push(filter);
92        }
93
94        let biases = Array1::zeros(num_filters);
95
96        Self { filters, biases }
97    }
98
99    /// Apply 2D convolution with valid padding
100    fn forward(&self, input: &Array2<f32>) -> Array3<f32> {
101        let kernel_size = self.filters[0].nrows();
102        let input_height = input.nrows();
103        let input_width = input.ncols();
104
105        let out_height = input_height.saturating_sub(kernel_size - 1);
106        let out_width = input_width.saturating_sub(kernel_size - 1);
107
108        if out_height == 0 || out_width == 0 {
109            // Return empty array if convolution cannot be performed
110            return Array3::zeros((self.filters.len(), 1, 1));
111        }
112
113        let mut output = Array3::zeros((self.filters.len(), out_height, out_width));
114
115        for (f_idx, filter) in self.filters.iter().enumerate() {
116            for i in 0..out_height {
117                for j in 0..out_width {
118                    let mut sum = 0.0;
119
120                    for ki in 0..kernel_size {
121                        for kj in 0..kernel_size {
122                            sum += input[[i + ki, j + kj]] * filter[[ki, kj]];
123                        }
124                    }
125
126                    output[[f_idx, i, j]] = sum + self.biases[f_idx];
127                }
128            }
129        }
130
131        output
132    }
133}
134
135/// Serializable fully connected layer
136#[derive(Debug, Serialize, Deserialize)]
137struct FCLayerSerializable {
138    weights: Vec<Vec<f32>>, // input_size x output_size
139    bias: Vec<f32>,
140}
141
142/// Fully connected layer
143struct FCLayer {
144    weights: Array2<f32>,
145    bias: Array1<f32>,
146}
147
148impl FCLayer {
149    fn new(input_size: usize, output_size: usize, rng: &mut Random) -> Self {
150        let scale = (2.0 / input_size as f32).sqrt();
151        let weights = Array2::from_shape_fn((input_size, output_size), |_| {
152            rng.random_range(-scale..scale)
153        });
154        let bias = Array1::zeros(output_size);
155
156        Self { weights, bias }
157    }
158
159    fn forward(&self, input: &Array1<f32>) -> Array1<f32> {
160        let mut output = self.bias.clone();
161        for i in 0..output.len() {
162            for j in 0..input.len() {
163                output[i] += input[j] * self.weights[[j, i]];
164            }
165        }
166        output
167    }
168}
169
170/// Serializable representation of ConvE model for persistence
171#[derive(Debug, Serialize, Deserialize)]
172struct ConvESerializable {
173    model_id: Uuid,
174    config: ConvEConfig,
175    entity_embeddings: HashMap<String, Vec<f32>>,
176    relation_embeddings: HashMap<String, Vec<f32>>,
177    conv_layer: ConvLayerSerializable,
178    fc_layer: FCLayerSerializable,
179    triples: Vec<Triple>,
180    entity_to_id: HashMap<String, usize>,
181    relation_to_id: HashMap<String, usize>,
182    id_to_entity: HashMap<usize, String>,
183    id_to_relation: HashMap<usize, String>,
184    is_trained: bool,
185}
186
187/// ConvE (Convolutional Embeddings) model
188pub struct ConvE {
189    model_id: Uuid,
190    config: ConvEConfig,
191    entity_embeddings: HashMap<String, Array1<f32>>,
192    relation_embeddings: HashMap<String, Array1<f32>>,
193    conv_layer: ConvLayer,
194    fc_layer: FCLayer,
195    triples: Vec<Triple>,
196    entity_to_id: HashMap<String, usize>,
197    relation_to_id: HashMap<String, usize>,
198    id_to_entity: HashMap<usize, String>,
199    id_to_relation: HashMap<usize, String>,
200    is_trained: bool,
201}
202
203impl ConvE {
204    /// Create new ConvE model with configuration
205    pub fn new(config: ConvEConfig) -> Self {
206        let mut rng = Random::default();
207
208        // Calculate feature map size after convolution
209        let reshape_height = config.base.dimensions / config.reshape_width;
210        let conv_out_height = reshape_height.saturating_sub(config.kernel_size - 1);
211        let conv_out_width = (config.reshape_width * 2).saturating_sub(config.kernel_size - 1);
212        let fc_input_size = config.num_filters * conv_out_height * conv_out_width;
213
214        let conv_layer = ConvLayer::new(config.num_filters, config.kernel_size, &mut rng);
215        let fc_layer = FCLayer::new(fc_input_size, config.base.dimensions, &mut rng);
216
217        info!(
218            "Initialized ConvE model: dim={}, filters={}, kernel={}, fc_input={}",
219            config.base.dimensions, config.num_filters, config.kernel_size, fc_input_size
220        );
221
222        Self {
223            model_id: Uuid::new_v4(),
224            config,
225            entity_embeddings: HashMap::new(),
226            relation_embeddings: HashMap::new(),
227            conv_layer,
228            fc_layer,
229            triples: Vec::new(),
230            entity_to_id: HashMap::new(),
231            relation_to_id: HashMap::new(),
232            id_to_entity: HashMap::new(),
233            id_to_relation: HashMap::new(),
234            is_trained: false,
235        }
236    }
237
238    /// Reshape 1D embedding to 2D matrix
239    fn reshape_embedding(&self, embedding: &Array1<f32>) -> Array2<f32> {
240        let height = self.config.base.dimensions / self.config.reshape_width;
241        let width = self.config.reshape_width;
242
243        Array2::from_shape_fn((height, width), |(i, j)| embedding[i * width + j])
244    }
245
246    /// Apply ReLU activation
247    fn relu(&self, x: f32) -> f32 {
248        x.max(0.0)
249    }
250
251    /// Apply dropout (during training)
252    fn dropout(&mut self, values: &mut Array1<f32>, training: bool) {
253        if !training || self.config.dropout_rate == 0.0 {
254            return;
255        }
256
257        let mut local_rng = Random::default();
258        let keep_prob = 1.0 - self.config.dropout_rate;
259        for val in values.iter_mut() {
260            if local_rng.random_range(0.0..1.0) > keep_prob {
261                *val = 0.0;
262            } else {
263                *val /= keep_prob; // Inverted dropout
264            }
265        }
266    }
267
268    /// Forward pass to compute score
269    fn forward(
270        &mut self,
271        head: &Array1<f32>,
272        relation: &Array1<f32>,
273        training: bool,
274    ) -> Array1<f32> {
275        // Reshape head and relation to 2D
276        let head_2d = self.reshape_embedding(head);
277        let rel_2d = self.reshape_embedding(relation);
278
279        // Concatenate horizontally: [head | relation]
280        let height = head_2d.nrows();
281        let width = head_2d.ncols() * 2;
282        let mut concat = Array2::zeros((height, width));
283
284        for i in 0..height {
285            for j in 0..head_2d.ncols() {
286                concat[[i, j]] = head_2d[[i, j]];
287            }
288            for j in 0..rel_2d.ncols() {
289                concat[[i, head_2d.ncols() + j]] = rel_2d[[i, j]];
290            }
291        }
292
293        // Apply 2D convolution
294        let conv_out = self.conv_layer.forward(&concat);
295
296        // Apply ReLU activation
297        let conv_out_relu = conv_out.mapv(|x| self.relu(x));
298
299        // Flatten the feature maps
300        let flattened_size = conv_out_relu.len();
301        let mut flattened = Array1::zeros(flattened_size);
302        for (idx, &val) in conv_out_relu.iter().enumerate() {
303            flattened[idx] = val;
304        }
305
306        // Apply dropout
307        self.dropout(&mut flattened, training);
308
309        // Fully connected layer
310        let mut output = self.fc_layer.forward(&flattened);
311
312        // Apply dropout again
313        self.dropout(&mut output, training);
314
315        output
316    }
317
318    /// Compute score for a triple
319    fn score_triple_internal(
320        &mut self,
321        head: &Array1<f32>,
322        relation: &Array1<f32>,
323        tail: &Array1<f32>,
324    ) -> f32 {
325        let projected = self.forward(head, relation, false);
326        // Score is dot product with tail entity
327        projected.dot(tail)
328    }
329
330    /// Initialize embeddings for an entity
331    fn init_entity(&mut self, entity: &str) {
332        if !self.entity_embeddings.contains_key(entity) {
333            let id = self.entity_embeddings.len();
334            self.entity_to_id.insert(entity.to_string(), id);
335            self.id_to_entity.insert(id, entity.to_string());
336
337            let mut local_rng = Random::default();
338            let scale = (6.0 / self.config.base.dimensions as f32).sqrt();
339            let embedding = Array1::from_vec(
340                (0..self.config.base.dimensions)
341                    .map(|_| local_rng.random_range(-scale..scale))
342                    .collect(),
343            );
344            self.entity_embeddings.insert(entity.to_string(), embedding);
345        }
346    }
347
348    /// Initialize embeddings for a relation
349    fn init_relation(&mut self, relation: &str) {
350        if !self.relation_embeddings.contains_key(relation) {
351            let id = self.relation_embeddings.len();
352            self.relation_to_id.insert(relation.to_string(), id);
353            self.id_to_relation.insert(id, relation.to_string());
354
355            let mut local_rng = Random::default();
356            let scale = (6.0 / self.config.base.dimensions as f32).sqrt();
357            let embedding = Array1::from_vec(
358                (0..self.config.base.dimensions)
359                    .map(|_| local_rng.random_range(-scale..scale))
360                    .collect(),
361            );
362            self.relation_embeddings
363                .insert(relation.to_string(), embedding);
364        }
365    }
366
367    /// Training step with simplified gradient updates
368    fn train_step(&mut self) -> f32 {
369        let mut total_loss = 0.0;
370        let mut local_rng = Random::default();
371
372        // Shuffle triples
373        let mut indices: Vec<usize> = (0..self.triples.len()).collect();
374        for i in (1..indices.len()).rev() {
375            let j = local_rng.random_range(0..i + 1);
376            indices.swap(i, j);
377        }
378
379        for &idx in &indices {
380            let triple = &self.triples[idx].clone();
381
382            let subject_str = &triple.subject.iri;
383            let predicate_str = &triple.predicate.iri;
384            let object_str = &triple.object.iri;
385
386            let head_emb = self.entity_embeddings[subject_str].clone();
387            let rel_emb = self.relation_embeddings[predicate_str].clone();
388            let tail_emb = self.entity_embeddings[object_str].clone();
389
390            // Positive score
391            let pos_score = self.score_triple_internal(&head_emb, &rel_emb, &tail_emb);
392
393            // Generate negative samples
394            let entity_list: Vec<String> = self.entity_embeddings.keys().cloned().collect();
395            for _ in 0..self.config.num_negatives {
396                let neg_tail_id = entity_list[local_rng.random_range(0..entity_list.len())].clone();
397                let neg_tail_emb = self.entity_embeddings[&neg_tail_id].clone();
398
399                let neg_score = self.score_triple_internal(&head_emb, &rel_emb, &neg_tail_emb);
400
401                // Margin ranking loss
402                let loss = (self.config.margin + neg_score - pos_score).max(0.0);
403                total_loss += loss;
404
405                // Simplified parameter update (in practice, use proper backpropagation)
406                if loss > 0.0 {
407                    let lr = self.config.base.learning_rate as f32;
408                    // Apply L2 regularization
409                    for emb in self.entity_embeddings.values_mut() {
410                        *emb = &*emb * (1.0 - self.config.regularization * lr);
411                    }
412                    for emb in self.relation_embeddings.values_mut() {
413                        *emb = &*emb * (1.0 - self.config.regularization * lr);
414                    }
415                }
416            }
417        }
418
419        total_loss / (self.triples.len() as f32 * self.config.num_negatives as f32)
420    }
421}
422
423#[async_trait::async_trait]
424impl EmbeddingModel for ConvE {
425    fn config(&self) -> &ModelConfig {
426        &self.config.base
427    }
428
429    fn model_id(&self) -> &Uuid {
430        &self.model_id
431    }
432
433    fn model_type(&self) -> &'static str {
434        "ConvE"
435    }
436
437    fn add_triple(&mut self, triple: Triple) -> Result<()> {
438        self.init_entity(&triple.subject.iri);
439        self.init_entity(&triple.object.iri);
440        self.init_relation(&triple.predicate.iri);
441        self.triples.push(triple);
442        Ok(())
443    }
444
445    async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
446        let num_epochs = epochs.unwrap_or(self.config.base.max_epochs);
447
448        if self.triples.is_empty() {
449            return Err(anyhow!("No training data available"));
450        }
451
452        info!(
453            "Training ConvE model for {} epochs on {} triples",
454            num_epochs,
455            self.triples.len()
456        );
457
458        let start_time = std::time::Instant::now();
459        let mut loss_history = Vec::new();
460
461        for epoch in 0..num_epochs {
462            let loss = self.train_step();
463            loss_history.push(loss as f64);
464
465            if epoch % 10 == 0 {
466                debug!("Epoch {}/{}: loss = {:.6}", epoch + 1, num_epochs, loss);
467            }
468
469            if loss < 0.001 {
470                info!("Converged at epoch {}", epoch);
471                break;
472            }
473        }
474
475        let training_time = start_time.elapsed().as_secs_f64();
476        self.is_trained = true;
477
478        Ok(TrainingStats {
479            epochs_completed: num_epochs,
480            final_loss: *loss_history.last().unwrap_or(&0.0),
481            training_time_seconds: training_time,
482            convergence_achieved: loss_history.last().unwrap_or(&1.0) < &0.001,
483            loss_history,
484        })
485    }
486
487    fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
488        self.entity_embeddings
489            .get(entity)
490            .map(Vector::from_array1)
491            .ok_or_else(|| anyhow!("Unknown entity: {}", entity))
492    }
493
494    fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
495        self.relation_embeddings
496            .get(relation)
497            .map(Vector::from_array1)
498            .ok_or_else(|| anyhow!("Unknown relation: {}", relation))
499    }
500
501    fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
502        let head_emb = self
503            .entity_embeddings
504            .get(subject)
505            .ok_or_else(|| anyhow!("Unknown subject: {}", subject))?;
506        let rel_emb = self
507            .relation_embeddings
508            .get(predicate)
509            .ok_or_else(|| anyhow!("Unknown predicate: {}", predicate))?;
510        let tail_emb = self
511            .entity_embeddings
512            .get(object)
513            .ok_or_else(|| anyhow!("Unknown object: {}", object))?;
514
515        // Simplified scoring: (head + relation) ยท tail
516        // Note: Full ConvE scoring requires mutable access for CNN forward pass
517        let score = (head_emb + rel_emb).dot(tail_emb);
518        Ok(score as f64)
519    }
520
521    fn predict_objects(
522        &self,
523        subject: &str,
524        predicate: &str,
525        k: usize,
526    ) -> Result<Vec<(String, f64)>> {
527        let head_emb = self
528            .entity_embeddings
529            .get(subject)
530            .ok_or_else(|| anyhow!("Unknown subject: {}", subject))?;
531        let rel_emb = self
532            .relation_embeddings
533            .get(predicate)
534            .ok_or_else(|| anyhow!("Unknown predicate: {}", predicate))?;
535
536        let combined = head_emb + rel_emb;
537        let mut scored_objects: Vec<(String, f64)> = self
538            .entity_embeddings
539            .par_iter()
540            .map(|(entity, tail_emb)| {
541                let score = combined.dot(tail_emb);
542                (entity.clone(), score as f64)
543            })
544            .collect();
545
546        scored_objects.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
547        scored_objects.truncate(k);
548        Ok(scored_objects)
549    }
550
551    fn predict_subjects(
552        &self,
553        predicate: &str,
554        object: &str,
555        k: usize,
556    ) -> Result<Vec<(String, f64)>> {
557        let rel_emb = self
558            .relation_embeddings
559            .get(predicate)
560            .ok_or_else(|| anyhow!("Unknown predicate: {}", predicate))?;
561        let tail_emb = self
562            .entity_embeddings
563            .get(object)
564            .ok_or_else(|| anyhow!("Unknown object: {}", object))?;
565
566        let mut scored_subjects: Vec<(String, f64)> = self
567            .entity_embeddings
568            .par_iter()
569            .map(|(entity, head_emb)| {
570                let score = (head_emb + rel_emb).dot(tail_emb);
571                (entity.clone(), score as f64)
572            })
573            .collect();
574
575        scored_subjects.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
576        scored_subjects.truncate(k);
577        Ok(scored_subjects)
578    }
579
580    fn predict_relations(
581        &self,
582        subject: &str,
583        object: &str,
584        k: usize,
585    ) -> Result<Vec<(String, f64)>> {
586        let head_emb = self
587            .entity_embeddings
588            .get(subject)
589            .ok_or_else(|| anyhow!("Unknown subject: {}", subject))?;
590        let tail_emb = self
591            .entity_embeddings
592            .get(object)
593            .ok_or_else(|| anyhow!("Unknown object: {}", object))?;
594
595        let mut scored_relations: Vec<(String, f64)> = self
596            .relation_embeddings
597            .par_iter()
598            .map(|(relation, rel_emb)| {
599                let score = (head_emb + rel_emb).dot(tail_emb);
600                (relation.clone(), score as f64)
601            })
602            .collect();
603
604        scored_relations.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
605        scored_relations.truncate(k);
606        Ok(scored_relations)
607    }
608
609    fn get_entities(&self) -> Vec<String> {
610        self.entity_embeddings.keys().cloned().collect()
611    }
612
613    fn get_relations(&self) -> Vec<String> {
614        self.relation_embeddings.keys().cloned().collect()
615    }
616
617    fn get_stats(&self) -> ModelStats {
618        ModelStats {
619            num_entities: self.entity_embeddings.len(),
620            num_relations: self.relation_embeddings.len(),
621            num_triples: self.triples.len(),
622            dimensions: self.config.base.dimensions,
623            is_trained: self.is_trained,
624            model_type: "ConvE".to_string(),
625            creation_time: chrono::Utc::now(),
626            last_training_time: if self.is_trained {
627                Some(chrono::Utc::now())
628            } else {
629                None
630            },
631        }
632    }
633
634    fn save(&self, path: &str) -> Result<()> {
635        info!("Saving ConvE model to {}", path);
636
637        // Convert Array1 to Vec for serialization
638        let entity_embeddings_vec: HashMap<String, Vec<f32>> = self
639            .entity_embeddings
640            .iter()
641            .map(|(k, v)| (k.clone(), v.to_vec()))
642            .collect();
643
644        let relation_embeddings_vec: HashMap<String, Vec<f32>> = self
645            .relation_embeddings
646            .iter()
647            .map(|(k, v)| (k.clone(), v.to_vec()))
648            .collect();
649
650        // Serialize convolutional layer
651        let conv_filters: Vec<Vec<Vec<f32>>> = self
652            .conv_layer
653            .filters
654            .iter()
655            .map(|filter| {
656                let mut rows = Vec::new();
657                for i in 0..filter.nrows() {
658                    let mut row = Vec::new();
659                    for j in 0..filter.ncols() {
660                        row.push(filter[[i, j]]);
661                    }
662                    rows.push(row);
663                }
664                rows
665            })
666            .collect();
667
668        let conv_layer_ser = ConvLayerSerializable {
669            filters: conv_filters,
670            biases: self.conv_layer.biases.to_vec(),
671        };
672
673        // Serialize fully connected layer
674        let mut fc_weights = Vec::new();
675        for i in 0..self.fc_layer.weights.nrows() {
676            let mut row = Vec::new();
677            for j in 0..self.fc_layer.weights.ncols() {
678                row.push(self.fc_layer.weights[[i, j]]);
679            }
680            fc_weights.push(row);
681        }
682
683        let fc_layer_ser = FCLayerSerializable {
684            weights: fc_weights,
685            bias: self.fc_layer.bias.to_vec(),
686        };
687
688        let serializable = ConvESerializable {
689            model_id: self.model_id,
690            config: self.config.clone(),
691            entity_embeddings: entity_embeddings_vec,
692            relation_embeddings: relation_embeddings_vec,
693            conv_layer: conv_layer_ser,
694            fc_layer: fc_layer_ser,
695            triples: self.triples.clone(),
696            entity_to_id: self.entity_to_id.clone(),
697            relation_to_id: self.relation_to_id.clone(),
698            id_to_entity: self.id_to_entity.clone(),
699            id_to_relation: self.id_to_relation.clone(),
700            is_trained: self.is_trained,
701        };
702
703        let file = File::create(path)?;
704        let writer = BufWriter::new(file);
705        oxicode::serde::encode_into_std_write(&serializable, writer, oxicode::config::standard())
706            .map_err(|e| anyhow!("Failed to serialize model: {}", e))?;
707
708        info!("Model saved successfully");
709        Ok(())
710    }
711
712    fn load(&mut self, path: &str) -> Result<()> {
713        info!("Loading ConvE model from {}", path);
714
715        if !Path::new(path).exists() {
716            return Err(anyhow!("Model file not found: {}", path));
717        }
718
719        let file = File::open(path)?;
720        let reader = BufReader::new(file);
721        let (serializable, _): (ConvESerializable, _) =
722            oxicode::serde::decode_from_std_read(reader, oxicode::config::standard())
723                .map_err(|e| anyhow!("Failed to deserialize model: {}", e))?;
724
725        // Convert Vec back to Array1
726        let entity_embeddings: HashMap<String, Array1<f32>> = serializable
727            .entity_embeddings
728            .into_iter()
729            .map(|(k, v)| (k, Array1::from_vec(v)))
730            .collect();
731
732        let relation_embeddings: HashMap<String, Array1<f32>> = serializable
733            .relation_embeddings
734            .into_iter()
735            .map(|(k, v)| (k, Array1::from_vec(v)))
736            .collect();
737
738        // Deserialize convolutional layer
739        let conv_filters: Vec<Array2<f32>> = serializable
740            .conv_layer
741            .filters
742            .into_iter()
743            .map(|filter_vec| {
744                let kernel_size = filter_vec.len();
745                Array2::from_shape_fn((kernel_size, kernel_size), |(i, j)| filter_vec[i][j])
746            })
747            .collect();
748
749        let conv_layer = ConvLayer {
750            filters: conv_filters,
751            biases: Array1::from_vec(serializable.conv_layer.biases),
752        };
753
754        // Deserialize fully connected layer
755        let fc_weights_vec = serializable.fc_layer.weights;
756        let input_size = fc_weights_vec.len();
757        let output_size = if input_size > 0 {
758            fc_weights_vec[0].len()
759        } else {
760            0
761        };
762
763        let fc_weights =
764            Array2::from_shape_fn((input_size, output_size), |(i, j)| fc_weights_vec[i][j]);
765
766        let fc_layer = FCLayer {
767            weights: fc_weights,
768            bias: Array1::from_vec(serializable.fc_layer.bias),
769        };
770
771        // Update model state
772        self.model_id = serializable.model_id;
773        self.config = serializable.config;
774        self.entity_embeddings = entity_embeddings;
775        self.relation_embeddings = relation_embeddings;
776        self.conv_layer = conv_layer;
777        self.fc_layer = fc_layer;
778        self.triples = serializable.triples;
779        self.entity_to_id = serializable.entity_to_id;
780        self.relation_to_id = serializable.relation_to_id;
781        self.id_to_entity = serializable.id_to_entity;
782        self.id_to_relation = serializable.id_to_relation;
783        self.is_trained = serializable.is_trained;
784
785        info!("Model loaded successfully");
786        Ok(())
787    }
788
789    fn clear(&mut self) {
790        self.entity_embeddings.clear();
791        self.relation_embeddings.clear();
792        self.triples.clear();
793        self.entity_to_id.clear();
794        self.relation_to_id.clear();
795        self.id_to_entity.clear();
796        self.id_to_relation.clear();
797        self.is_trained = false;
798    }
799
800    fn is_trained(&self) -> bool {
801        self.is_trained
802    }
803
804    async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
805        // TODO: Implement text encoding
806        Err(anyhow!("Text encoding not implemented for ConvE"))
807    }
808}
809
810#[cfg(test)]
811mod tests {
812    use super::*;
813
814    #[test]
815    fn test_conve_creation() {
816        let config = ConvEConfig::default();
817        let model = ConvE::new(config);
818
819        assert_eq!(model.entity_embeddings.len(), 0);
820        assert_eq!(model.relation_embeddings.len(), 0);
821    }
822
823    #[tokio::test]
824    async fn test_conve_training() {
825        let config = ConvEConfig {
826            base: ModelConfig {
827                dimensions: 50, // Reduced from 100 for faster tests
828                learning_rate: 0.001,
829                max_epochs: 5, // Reduced from 20 for faster tests
830                ..Default::default()
831            },
832            reshape_width: 10,
833            num_filters: 8, // Reduced from 16 for faster tests
834            ..Default::default()
835        };
836
837        let mut model = ConvE::new(config);
838
839        model
840            .add_triple(Triple::new(
841                NamedNode::new("alice").expect("should succeed"),
842                NamedNode::new("knows").expect("should succeed"),
843                NamedNode::new("bob").expect("should succeed"),
844            ))
845            .expect("should succeed");
846
847        model
848            .add_triple(Triple::new(
849                NamedNode::new("bob").expect("should succeed"),
850                NamedNode::new("likes").expect("should succeed"),
851                NamedNode::new("charlie").expect("should succeed"),
852            ))
853            .expect("should succeed");
854
855        let stats = model.train(Some(5)).await.expect("should succeed"); // Reduced from 20 for faster tests
856
857        assert_eq!(stats.epochs_completed, 5);
858        assert!(stats.final_loss >= 0.0);
859        assert_eq!(model.entity_embeddings.len(), 3);
860        assert_eq!(model.relation_embeddings.len(), 2);
861    }
862
863    #[tokio::test]
864    async fn test_conve_save_load() {
865        use std::env::temp_dir;
866
867        let config = ConvEConfig {
868            base: ModelConfig {
869                dimensions: 50,
870                learning_rate: 0.001,
871                max_epochs: 15,
872                ..Default::default()
873            },
874            reshape_width: 10,
875            num_filters: 8,
876            kernel_size: 2,
877            ..Default::default()
878        };
879
880        let mut model = ConvE::new(config);
881
882        // Add and train
883        model
884            .add_triple(Triple::new(
885                NamedNode::new("alice").expect("should succeed"),
886                NamedNode::new("knows").expect("should succeed"),
887                NamedNode::new("bob").expect("should succeed"),
888            ))
889            .expect("should succeed");
890
891        model
892            .add_triple(Triple::new(
893                NamedNode::new("bob").expect("should succeed"),
894                NamedNode::new("likes").expect("should succeed"),
895                NamedNode::new("charlie").expect("should succeed"),
896            ))
897            .expect("should succeed");
898
899        model.train(Some(15)).await.expect("should succeed");
900
901        // Get embedding before save
902        let emb_before = model.get_entity_embedding("alice").expect("should succeed");
903        let score_before = model
904            .score_triple("alice", "knows", "bob")
905            .expect("should succeed");
906
907        // Save model
908        let model_path = temp_dir().join("test_conve_model.bin");
909        let path_str = model_path.to_str().expect("should succeed");
910        model.save(path_str).expect("should succeed");
911
912        // Create new model and load
913        let mut loaded_model = ConvE::new(ConvEConfig::default());
914        loaded_model.load(path_str).expect("should succeed");
915
916        // Verify loaded model
917        assert!(loaded_model.is_trained());
918        assert_eq!(loaded_model.get_entities().len(), 3);
919        assert_eq!(loaded_model.get_relations().len(), 2);
920
921        // Verify embeddings are preserved
922        let emb_after = loaded_model
923            .get_entity_embedding("alice")
924            .expect("should succeed");
925        assert_eq!(emb_before.dimensions, emb_after.dimensions);
926        for i in 0..emb_before.values.len() {
927            assert!((emb_before.values[i] - emb_after.values[i]).abs() < 1e-6);
928        }
929
930        // Verify scoring is consistent
931        let score_after = loaded_model
932            .score_triple("alice", "knows", "bob")
933            .expect("should succeed");
934        assert!((score_before - score_after).abs() < 1e-5);
935
936        // Cleanup
937        std::fs::remove_file(model_path).ok();
938    }
939
940    #[test]
941    fn test_conve_load_nonexistent() {
942        let mut model = ConvE::new(ConvEConfig::default());
943        let result = model.load("/nonexistent/path/model.bin");
944        assert!(result.is_err());
945    }
946}