oxirs_embed/models/
base.rs

1//! Base functionality shared across embedding models
2
3use crate::{ModelConfig, ModelStats, Triple};
4use anyhow::Result;
5use chrono::{DateTime, Utc};
6#[allow(unused_imports)]
7use scirs2_core::random::{Random, Rng};
8use std::collections::{HashMap, HashSet};
9use uuid::Uuid;
10
11/// Core data structures and functionality shared by all embedding models
12#[derive(Debug, Clone)]
13pub struct BaseModel {
14    /// Model configuration
15    pub config: ModelConfig,
16    /// Unique model identifier
17    pub model_id: Uuid,
18    /// Entity to index mapping
19    pub entity_to_id: HashMap<String, usize>,
20    /// Index to entity mapping
21    pub id_to_entity: HashMap<usize, String>,
22    /// Relation to index mapping
23    pub relation_to_id: HashMap<String, usize>,
24    /// Index to relation mapping
25    pub id_to_relation: HashMap<usize, String>,
26    /// Training triples (subject_id, predicate_id, object_id)
27    pub triples: Vec<(usize, usize, usize)>,
28    /// Set of all positive triples for fast lookup
29    pub positive_triples: HashSet<(usize, usize, usize)>,
30    /// Whether the model has been trained
31    pub is_trained: bool,
32    /// Model creation time
33    pub creation_time: DateTime<Utc>,
34    /// Last training time
35    pub last_training_time: Option<DateTime<Utc>>,
36}
37
38impl BaseModel {
39    /// Create a new base model
40    pub fn new(config: ModelConfig) -> Self {
41        Self {
42            model_id: Uuid::new_v4(),
43            config,
44            entity_to_id: HashMap::new(),
45            id_to_entity: HashMap::new(),
46            relation_to_id: HashMap::new(),
47            id_to_relation: HashMap::new(),
48            triples: Vec::new(),
49            positive_triples: HashSet::new(),
50            is_trained: false,
51            creation_time: Utc::now(),
52            last_training_time: None,
53        }
54    }
55
56    /// Add a triple to the model
57    pub fn add_triple(&mut self, triple: Triple) -> Result<()> {
58        let subject_str = triple.subject.to_string();
59        let predicate_str = triple.predicate.to_string();
60        let object_str = triple.object.to_string();
61
62        // Get or create entity IDs
63        let subject_id = self.get_or_create_entity_id(subject_str);
64        let object_id = self.get_or_create_entity_id(object_str);
65
66        // Get or create relation ID
67        let predicate_id = self.get_or_create_relation_id(predicate_str);
68
69        // Add triple
70        let triple_ids = (subject_id, predicate_id, object_id);
71        if !self.positive_triples.contains(&triple_ids) {
72            self.triples.push(triple_ids);
73            self.positive_triples.insert(triple_ids);
74        }
75
76        Ok(())
77    }
78
79    /// Get or create entity ID
80    fn get_or_create_entity_id(&mut self, entity: String) -> usize {
81        if let Some(&id) = self.entity_to_id.get(&entity) {
82            id
83        } else {
84            let id = self.entity_to_id.len();
85            self.entity_to_id.insert(entity.clone(), id);
86            self.id_to_entity.insert(id, entity);
87            id
88        }
89    }
90
91    /// Get or create relation ID
92    fn get_or_create_relation_id(&mut self, relation: String) -> usize {
93        if let Some(&id) = self.relation_to_id.get(&relation) {
94            id
95        } else {
96            let id = self.relation_to_id.len();
97            self.relation_to_id.insert(relation.clone(), id);
98            self.id_to_relation.insert(id, relation);
99            id
100        }
101    }
102
103    /// Get entity ID
104    pub fn get_entity_id(&self, entity: &str) -> Option<usize> {
105        self.entity_to_id.get(entity).copied()
106    }
107
108    /// Get relation ID
109    pub fn get_relation_id(&self, relation: &str) -> Option<usize> {
110        self.relation_to_id.get(relation).copied()
111    }
112
113    /// Get entity string from ID
114    pub fn get_entity(&self, id: usize) -> Option<&String> {
115        self.id_to_entity.get(&id)
116    }
117
118    /// Get relation string from ID
119    pub fn get_relation(&self, id: usize) -> Option<&String> {
120        self.id_to_relation.get(&id)
121    }
122
123    /// Get number of entities
124    pub fn num_entities(&self) -> usize {
125        self.entity_to_id.len()
126    }
127
128    /// Get number of relations
129    pub fn num_relations(&self) -> usize {
130        self.relation_to_id.len()
131    }
132
133    /// Get number of triples
134    pub fn num_triples(&self) -> usize {
135        self.triples.len()
136    }
137
138    /// Get all entity strings
139    pub fn get_entities(&self) -> Vec<String> {
140        self.entity_to_id.keys().cloned().collect()
141    }
142
143    /// Get all relation strings
144    pub fn get_relations(&self) -> Vec<String> {
145        self.relation_to_id.keys().cloned().collect()
146    }
147
148    /// Check if a triple exists in the knowledge base
149    pub fn has_triple(&self, subject_id: usize, predicate_id: usize, object_id: usize) -> bool {
150        self.positive_triples
151            .contains(&(subject_id, predicate_id, object_id))
152    }
153
154    /// Generate negative samples for training
155    pub fn generate_negative_samples<R>(
156        &self,
157        num_samples: usize,
158        rng: &mut Random<R>,
159    ) -> Vec<(usize, usize, usize)>
160    where
161        R: scirs2_core::random::RngCore,
162    {
163        let mut negative_samples = Vec::new();
164        let num_entities = self.num_entities();
165
166        while negative_samples.len() < num_samples {
167            // Choose a random positive triple
168            if !self.triples.is_empty() {
169                let idx = rng.random_range(0..self.triples.len());
170                let &(s, p, o) = &self.triples[idx];
171
172                // Corrupt either subject or object
173                let corrupt_subject = rng.random_bool_with_chance(0.5);
174
175                let negative_triple = if corrupt_subject {
176                    let new_subject = rng.random_range(0..num_entities);
177                    (new_subject, p, o)
178                } else {
179                    let new_object = rng.random_range(0..num_entities);
180                    (s, p, new_object)
181                };
182
183                // Make sure it's actually negative
184                if !self.has_triple(negative_triple.0, negative_triple.1, negative_triple.2) {
185                    negative_samples.push(negative_triple);
186                }
187            }
188        }
189
190        negative_samples
191    }
192
193    /// Get model statistics
194    pub fn get_stats(&self, model_type: &str) -> ModelStats {
195        ModelStats {
196            num_entities: self.num_entities(),
197            num_relations: self.num_relations(),
198            num_triples: self.num_triples(),
199            dimensions: self.config.dimensions,
200            is_trained: self.is_trained,
201            model_type: model_type.to_string(),
202            creation_time: self.creation_time,
203            last_training_time: self.last_training_time,
204        }
205    }
206
207    /// Clear all data
208    pub fn clear(&mut self) {
209        self.entity_to_id.clear();
210        self.id_to_entity.clear();
211        self.relation_to_id.clear();
212        self.id_to_relation.clear();
213        self.triples.clear();
214        self.positive_triples.clear();
215        self.is_trained = false;
216        self.last_training_time = None;
217    }
218
219    /// Mark model as trained
220    pub fn mark_trained(&mut self) {
221        self.is_trained = true;
222        self.last_training_time = Some(Utc::now());
223    }
224}