oxirs_embed/models/
base.rs1use crate::{ModelConfig, ModelStats, Triple};
4use anyhow::Result;
5use chrono::{DateTime, Utc};
6#[allow(unused_imports)]
7use scirs2_core::random::{Random, Rng};
8use std::collections::{HashMap, HashSet};
9use uuid::Uuid;
10
11#[derive(Debug, Clone)]
13pub struct BaseModel {
14 pub config: ModelConfig,
16 pub model_id: Uuid,
18 pub entity_to_id: HashMap<String, usize>,
20 pub id_to_entity: HashMap<usize, String>,
22 pub relation_to_id: HashMap<String, usize>,
24 pub id_to_relation: HashMap<usize, String>,
26 pub triples: Vec<(usize, usize, usize)>,
28 pub positive_triples: HashSet<(usize, usize, usize)>,
30 pub is_trained: bool,
32 pub creation_time: DateTime<Utc>,
34 pub last_training_time: Option<DateTime<Utc>>,
36}
37
38impl BaseModel {
39 pub fn new(config: ModelConfig) -> Self {
41 Self {
42 model_id: Uuid::new_v4(),
43 config,
44 entity_to_id: HashMap::new(),
45 id_to_entity: HashMap::new(),
46 relation_to_id: HashMap::new(),
47 id_to_relation: HashMap::new(),
48 triples: Vec::new(),
49 positive_triples: HashSet::new(),
50 is_trained: false,
51 creation_time: Utc::now(),
52 last_training_time: None,
53 }
54 }
55
56 pub fn add_triple(&mut self, triple: Triple) -> Result<()> {
58 let subject_str = triple.subject.to_string();
59 let predicate_str = triple.predicate.to_string();
60 let object_str = triple.object.to_string();
61
62 let subject_id = self.get_or_create_entity_id(subject_str);
64 let object_id = self.get_or_create_entity_id(object_str);
65
66 let predicate_id = self.get_or_create_relation_id(predicate_str);
68
69 let triple_ids = (subject_id, predicate_id, object_id);
71 if !self.positive_triples.contains(&triple_ids) {
72 self.triples.push(triple_ids);
73 self.positive_triples.insert(triple_ids);
74 }
75
76 Ok(())
77 }
78
79 fn get_or_create_entity_id(&mut self, entity: String) -> usize {
81 if let Some(&id) = self.entity_to_id.get(&entity) {
82 id
83 } else {
84 let id = self.entity_to_id.len();
85 self.entity_to_id.insert(entity.clone(), id);
86 self.id_to_entity.insert(id, entity);
87 id
88 }
89 }
90
91 fn get_or_create_relation_id(&mut self, relation: String) -> usize {
93 if let Some(&id) = self.relation_to_id.get(&relation) {
94 id
95 } else {
96 let id = self.relation_to_id.len();
97 self.relation_to_id.insert(relation.clone(), id);
98 self.id_to_relation.insert(id, relation);
99 id
100 }
101 }
102
103 pub fn get_entity_id(&self, entity: &str) -> Option<usize> {
105 self.entity_to_id.get(entity).copied()
106 }
107
108 pub fn get_relation_id(&self, relation: &str) -> Option<usize> {
110 self.relation_to_id.get(relation).copied()
111 }
112
113 pub fn get_entity(&self, id: usize) -> Option<&String> {
115 self.id_to_entity.get(&id)
116 }
117
118 pub fn get_relation(&self, id: usize) -> Option<&String> {
120 self.id_to_relation.get(&id)
121 }
122
123 pub fn num_entities(&self) -> usize {
125 self.entity_to_id.len()
126 }
127
128 pub fn num_relations(&self) -> usize {
130 self.relation_to_id.len()
131 }
132
133 pub fn num_triples(&self) -> usize {
135 self.triples.len()
136 }
137
138 pub fn get_entities(&self) -> Vec<String> {
140 self.entity_to_id.keys().cloned().collect()
141 }
142
143 pub fn get_relations(&self) -> Vec<String> {
145 self.relation_to_id.keys().cloned().collect()
146 }
147
148 pub fn has_triple(&self, subject_id: usize, predicate_id: usize, object_id: usize) -> bool {
150 self.positive_triples
151 .contains(&(subject_id, predicate_id, object_id))
152 }
153
154 pub fn generate_negative_samples<R>(
156 &self,
157 num_samples: usize,
158 rng: &mut Random<R>,
159 ) -> Vec<(usize, usize, usize)>
160 where
161 R: scirs2_core::random::RngCore,
162 {
163 let mut negative_samples = Vec::new();
164 let num_entities = self.num_entities();
165
166 while negative_samples.len() < num_samples {
167 if !self.triples.is_empty() {
169 let idx = rng.random_range(0..self.triples.len());
170 let &(s, p, o) = &self.triples[idx];
171
172 let corrupt_subject = rng.random_bool_with_chance(0.5);
174
175 let negative_triple = if corrupt_subject {
176 let new_subject = rng.random_range(0..num_entities);
177 (new_subject, p, o)
178 } else {
179 let new_object = rng.random_range(0..num_entities);
180 (s, p, new_object)
181 };
182
183 if !self.has_triple(negative_triple.0, negative_triple.1, negative_triple.2) {
185 negative_samples.push(negative_triple);
186 }
187 }
188 }
189
190 negative_samples
191 }
192
193 pub fn get_stats(&self, model_type: &str) -> ModelStats {
195 ModelStats {
196 num_entities: self.num_entities(),
197 num_relations: self.num_relations(),
198 num_triples: self.num_triples(),
199 dimensions: self.config.dimensions,
200 is_trained: self.is_trained,
201 model_type: model_type.to_string(),
202 creation_time: self.creation_time,
203 last_training_time: self.last_training_time,
204 }
205 }
206
207 pub fn clear(&mut self) {
209 self.entity_to_id.clear();
210 self.id_to_entity.clear();
211 self.relation_to_id.clear();
212 self.id_to_relation.clear();
213 self.triples.clear();
214 self.positive_triples.clear();
215 self.is_trained = false;
216 self.last_training_time = None;
217 }
218
219 pub fn mark_trained(&mut self) {
221 self.is_trained = true;
222 self.last_training_time = Some(Utc::now());
223 }
224}