use crate::{ModelConfig, ModelStats, Triple};
use anyhow::Result;
use chrono::{DateTime, Utc};
#[allow(unused_imports)]
use scirs2_core::random::{Random, RngExt};
use std::collections::{HashMap, HashSet};
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct BaseModel {
pub config: ModelConfig,
pub model_id: Uuid,
pub entity_to_id: HashMap<String, usize>,
pub id_to_entity: HashMap<usize, String>,
pub relation_to_id: HashMap<String, usize>,
pub id_to_relation: HashMap<usize, String>,
pub triples: Vec<(usize, usize, usize)>,
pub positive_triples: HashSet<(usize, usize, usize)>,
pub is_trained: bool,
pub creation_time: DateTime<Utc>,
pub last_training_time: Option<DateTime<Utc>>,
}
impl BaseModel {
pub fn new(config: ModelConfig) -> Self {
Self {
model_id: Uuid::new_v4(),
config,
entity_to_id: HashMap::new(),
id_to_entity: HashMap::new(),
relation_to_id: HashMap::new(),
id_to_relation: HashMap::new(),
triples: Vec::new(),
positive_triples: HashSet::new(),
is_trained: false,
creation_time: Utc::now(),
last_training_time: None,
}
}
pub fn add_triple(&mut self, triple: Triple) -> Result<()> {
let subject_str = triple.subject.to_string();
let predicate_str = triple.predicate.to_string();
let object_str = triple.object.to_string();
let subject_id = self.get_or_create_entity_id(subject_str);
let object_id = self.get_or_create_entity_id(object_str);
let predicate_id = self.get_or_create_relation_id(predicate_str);
let triple_ids = (subject_id, predicate_id, object_id);
if !self.positive_triples.contains(&triple_ids) {
self.triples.push(triple_ids);
self.positive_triples.insert(triple_ids);
}
Ok(())
}
fn get_or_create_entity_id(&mut self, entity: String) -> usize {
if let Some(&id) = self.entity_to_id.get(&entity) {
id
} else {
let id = self.entity_to_id.len();
self.entity_to_id.insert(entity.clone(), id);
self.id_to_entity.insert(id, entity);
id
}
}
fn get_or_create_relation_id(&mut self, relation: String) -> usize {
if let Some(&id) = self.relation_to_id.get(&relation) {
id
} else {
let id = self.relation_to_id.len();
self.relation_to_id.insert(relation.clone(), id);
self.id_to_relation.insert(id, relation);
id
}
}
pub fn get_entity_id(&self, entity: &str) -> Option<usize> {
self.entity_to_id.get(entity).copied()
}
pub fn get_relation_id(&self, relation: &str) -> Option<usize> {
self.relation_to_id.get(relation).copied()
}
pub fn get_entity(&self, id: usize) -> Option<&String> {
self.id_to_entity.get(&id)
}
pub fn get_relation(&self, id: usize) -> Option<&String> {
self.id_to_relation.get(&id)
}
pub fn num_entities(&self) -> usize {
self.entity_to_id.len()
}
pub fn num_relations(&self) -> usize {
self.relation_to_id.len()
}
pub fn num_triples(&self) -> usize {
self.triples.len()
}
pub fn get_entities(&self) -> Vec<String> {
self.entity_to_id.keys().cloned().collect()
}
pub fn get_relations(&self) -> Vec<String> {
self.relation_to_id.keys().cloned().collect()
}
pub fn has_triple(&self, subject_id: usize, predicate_id: usize, object_id: usize) -> bool {
self.positive_triples
.contains(&(subject_id, predicate_id, object_id))
}
pub fn generate_negative_samples<R>(
&self,
num_samples: usize,
rng: &mut Random<R>,
) -> Vec<(usize, usize, usize)>
where
R: scirs2_core::random::Rng,
{
let mut negative_samples = Vec::new();
let num_entities = self.num_entities();
while negative_samples.len() < num_samples {
if !self.triples.is_empty() {
let idx = rng.random_range(0..self.triples.len());
let &(s, p, o) = &self.triples[idx];
let corrupt_subject = rng.random_bool_with_chance(0.5);
let negative_triple = if corrupt_subject {
let new_subject = rng.random_range(0..num_entities);
(new_subject, p, o)
} else {
let new_object = rng.random_range(0..num_entities);
(s, p, new_object)
};
if !self.has_triple(negative_triple.0, negative_triple.1, negative_triple.2) {
negative_samples.push(negative_triple);
}
}
}
negative_samples
}
pub fn get_stats(&self, model_type: &str) -> ModelStats {
ModelStats {
num_entities: self.num_entities(),
num_relations: self.num_relations(),
num_triples: self.num_triples(),
dimensions: self.config.dimensions,
is_trained: self.is_trained,
model_type: model_type.to_string(),
creation_time: self.creation_time,
last_training_time: self.last_training_time,
}
}
pub fn clear(&mut self) {
self.entity_to_id.clear();
self.id_to_entity.clear();
self.relation_to_id.clear();
self.id_to_relation.clear();
self.triples.clear();
self.positive_triples.clear();
self.is_trained = false;
self.last_training_time = None;
}
pub fn mark_trained(&mut self) {
self.is_trained = true;
self.last_training_time = Some(Utc::now());
}
}