pub mod embeddings;
pub mod entity_resolution;
pub mod gnn;
pub mod gpu_monitor;
pub mod neural;
pub mod relation_extraction;
pub mod temporal_reasoning;
pub mod training;
pub mod vector_store;
use crate::model::Triple;
use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
pub use embeddings::{
create_embedding_model, ComplEx, DistMult, EmbeddingConfig, EmbeddingModelType,
KnowledgeGraphEmbedding, TransE,
};
pub use gnn::{
Aggregation, GnnArchitecture, GnnConfig, GraphNeuralNetwork, LayerType, MessagePassingType,
};
pub use training::{
DefaultTrainer, LossFunction, Optimizer, Trainer, TrainingConfig, TrainingMetrics,
};
pub use vector_store::{SimilarityMetric, VectorIndex, VectorQuery, VectorStore};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AiConfig {
pub enable_gnn: bool,
pub embedding_config: EmbeddingConfig,
pub vector_store_config: VectorStoreConfig,
pub training_config: TrainingConfig,
pub gpu_config: GpuConfig,
pub cache_config: CacheConfig,
}
impl Default for AiConfig {
fn default() -> Self {
Self {
enable_gnn: true,
embedding_config: EmbeddingConfig::default(),
vector_store_config: VectorStoreConfig::default(),
training_config: TrainingConfig::default(),
gpu_config: GpuConfig::default(),
cache_config: CacheConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorStoreConfig {
pub dimension: usize,
pub metric: SimilarityMetric,
pub index_type: IndexType,
pub max_vectors: usize,
pub enable_ann: bool,
pub ann_neighbors: usize,
}
impl Default for VectorStoreConfig {
fn default() -> Self {
Self {
dimension: 128,
metric: SimilarityMetric::Cosine,
index_type: IndexType::HierarchicalNavigableSmallWorld,
max_vectors: 10_000_000,
enable_ann: true,
ann_neighbors: 16,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum IndexType {
Flat,
InvertedFile { clusters: usize },
LocalitySensitiveHashing {
hash_tables: usize,
hash_length: usize,
},
HierarchicalNavigableSmallWorld,
ProductQuantization { subquantizers: usize, bits: usize },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpuConfig {
pub enabled: bool,
pub device_id: u32,
pub memory_pool_mb: usize,
pub batch_size: usize,
pub mixed_precision: bool,
}
impl Default for GpuConfig {
fn default() -> Self {
Self {
enabled: true,
device_id: 0,
memory_pool_mb: 4096,
batch_size: 1024,
mixed_precision: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheConfig {
pub enabled: bool,
pub cache_dir: String,
pub max_size_mb: usize,
pub ttl_seconds: u64,
pub compression: bool,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
enabled: true,
cache_dir: "/tmp/oxirs/ai_cache".to_string(),
max_size_mb: 10240, ttl_seconds: 86400, compression: true,
}
}
}
pub struct AiEngine {
#[allow(dead_code)]
config: AiConfig,
gnn: Option<Arc<dyn GraphNeuralNetwork>>,
embeddings: HashMap<String, Arc<dyn KnowledgeGraphEmbedding>>,
vector_store: Arc<dyn VectorStore>,
trainer: Arc<Mutex<Box<dyn Trainer>>>,
entity_resolver: Arc<entity_resolution::EntityResolver>,
relation_extractor: Arc<relation_extraction::RelationExtractor>,
temporal_reasoner: Arc<temporal_reasoning::TemporalReasoner>,
}
impl AiEngine {
pub fn new(config: AiConfig) -> Result<Self> {
let vs_config = vector_store::VectorStoreConfig {
dimension: config.vector_store_config.dimension,
default_metric: config.vector_store_config.metric,
index_type: match config.vector_store_config.index_type {
IndexType::Flat => vector_store::IndexType::Flat,
IndexType::HierarchicalNavigableSmallWorld => vector_store::IndexType::HNSW {
max_connections: 16,
ef_construction: 200,
ef_search: 50,
},
IndexType::InvertedFile { clusters } => vector_store::IndexType::IVF {
num_clusters: clusters,
num_probes: 8,
},
IndexType::LocalitySensitiveHashing {
hash_tables,
hash_length,
} => vector_store::IndexType::LSH {
num_tables: hash_tables,
hash_length,
},
IndexType::ProductQuantization {
subquantizers,
bits,
} => vector_store::IndexType::PQ {
num_subquantizers: subquantizers,
bits_per_subquantizer: bits,
},
},
enable_cache: config.vector_store_config.enable_ann,
cache_size: if config.vector_store_config.max_vectors > 10000 {
10000
} else {
config.vector_store_config.max_vectors
},
cache_ttl: 3600,
batch_size: 1000,
};
let vector_store = vector_store::create_vector_store(&vs_config)?;
let trainer = Arc::new(Mutex::new(Box::new(training::DefaultTrainer::new(
config.training_config.clone(),
)) as Box<dyn Trainer>));
let entity_resolver = Arc::new(entity_resolution::EntityResolver::new(&config)?);
let relation_extractor = Arc::new(relation_extraction::RelationExtractor::new(&config)?);
let temporal_reasoner = Arc::new(temporal_reasoning::TemporalReasoner::new(&config)?);
Ok(Self {
config,
gnn: None,
embeddings: HashMap::new(),
vector_store,
trainer,
entity_resolver,
relation_extractor,
temporal_reasoner,
})
}
pub async fn initialize_gnn(&mut self, gnn_config: GnnConfig) -> Result<()> {
let gnn = gnn::create_gnn(gnn_config)?;
self.gnn = Some(gnn);
Ok(())
}
pub async fn add_embedding_model(
&mut self,
name: String,
model: Arc<dyn KnowledgeGraphEmbedding>,
) -> Result<()> {
self.embeddings.insert(name, model);
Ok(())
}
pub async fn generate_embeddings(
&self,
model_name: &str,
triples: &[Triple],
) -> Result<Vec<Vec<f32>>> {
let model = self
.embeddings
.get(model_name)
.ok_or_else(|| anyhow!("Embedding model not found: {}", model_name))?;
model.generate_embeddings(triples).await
}
pub async fn find_similar_entities(
&self,
entity_vector: &[f32],
top_k: usize,
) -> Result<Vec<(String, f32)>> {
let query = VectorQuery {
vector: entity_vector.to_vec(),
k: top_k,
include_metadata: true,
metric: None,
filters: None,
min_similarity: None,
};
self.vector_store.search(&query).await
}
pub async fn predict_links(
&self,
model_name: &str,
entities: &[String],
relations: &[String],
) -> Result<Vec<(String, String, String, f32)>> {
let model = self
.embeddings
.get(model_name)
.ok_or_else(|| anyhow!("Embedding model not found: {}", model_name))?;
model.predict_links(entities, relations).await
}
pub async fn resolve_entities(
&self,
entities: &[Triple],
) -> Result<Vec<entity_resolution::EntityCluster>> {
self.entity_resolver.resolve_entities(entities).await
}
pub async fn extract_relations_from_text(
&self,
text: &str,
) -> Result<Vec<relation_extraction::ExtractedRelation>> {
self.relation_extractor.extract_relations(text).await
}
pub async fn temporal_reasoning(
&self,
query: &temporal_reasoning::TemporalQuery,
) -> Result<temporal_reasoning::TemporalResult> {
self.temporal_reasoner.reason(query).await
}
pub async fn train_embedding_model(
&self,
model_name: &str,
training_data: &[Triple],
validation_data: &[Triple],
) -> Result<TrainingMetrics> {
let model = self
.embeddings
.get(model_name)
.ok_or_else(|| anyhow!("Embedding model not found: {}", model_name))?;
let trainer = self.trainer.clone();
let model = model.clone();
let training_data = training_data.to_vec();
let validation_data = validation_data.to_vec();
let mut trainer_guard = trainer.lock().await;
trainer_guard
.train_embedding_model(model, &training_data, &validation_data)
.await
}
pub async fn evaluate_model(
&self,
model_name: &str,
test_data: &[Triple],
) -> Result<EvaluationMetrics> {
let model = self
.embeddings
.get(model_name)
.ok_or_else(|| anyhow!("Embedding model not found: {}", model_name))?;
EvaluationMetrics::evaluate(model.as_ref(), test_data).await
}
pub async fn get_statistics(&self) -> Result<AiStatistics> {
let vs_stats = self.vector_store.get_statistics().await?;
let gpu_monitor = gpu_monitor::GpuMonitor::global();
let gpu_utilization = gpu_monitor
.lock()
.map(|monitor| monitor.get_utilization())
.unwrap_or(0.0);
Ok(AiStatistics {
gnn_enabled: self.gnn.is_some(),
embedding_models: self.embeddings.len(),
vector_store_size: self.vector_store.size(),
cache_hit_rate: vs_stats.cache_hit_rate,
gpu_utilization,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EvaluationMetrics {
pub mrr: f32,
pub hits_at_1: f32,
pub hits_at_3: f32,
pub hits_at_10: f32,
pub link_prediction_accuracy: f32,
pub entity_resolution_f1: f32,
pub relation_extraction_precision: f32,
pub relation_extraction_recall: f32,
}
impl EvaluationMetrics {
pub async fn evaluate(
model: &dyn KnowledgeGraphEmbedding,
test_data: &[Triple],
) -> Result<Self> {
let test_triples: Vec<(String, String, String)> = test_data
.iter()
.map(|t| {
(
t.subject().to_string(),
t.predicate().to_string(),
t.object().to_string(),
)
})
.collect();
let all_triples = test_triples.clone();
let k_values = vec![1, 3, 10];
let kg_metrics = embeddings::evaluation::compute_kg_metrics(
model,
&test_triples,
&all_triples,
&k_values,
)
.await?;
let link_prediction_accuracy =
Self::compute_link_prediction_accuracy(model, &test_triples).await?;
let mrr = kg_metrics.mrr_filtered;
let hits_at_1 = *kg_metrics.hits_at_k_filtered.get(&1).unwrap_or(&0.0);
let hits_at_3 = *kg_metrics.hits_at_k_filtered.get(&3).unwrap_or(&0.0);
let hits_at_10 = *kg_metrics.hits_at_k_filtered.get(&10).unwrap_or(&0.0);
let entity_resolution_f1 = 0.0;
let relation_extraction_precision = 0.0;
let relation_extraction_recall = 0.0;
Ok(Self {
mrr,
hits_at_1,
hits_at_3,
hits_at_10,
link_prediction_accuracy,
entity_resolution_f1,
relation_extraction_precision,
relation_extraction_recall,
})
}
async fn compute_link_prediction_accuracy(
model: &dyn KnowledgeGraphEmbedding,
test_triples: &[(String, String, String)],
) -> Result<f32> {
if test_triples.is_empty() {
return Ok(0.0);
}
let sample_size = test_triples.len().min(100);
let mut correct = 0;
let entities: std::collections::HashSet<String> = test_triples
.iter()
.flat_map(|(h, _, t)| vec![h.clone(), t.clone()])
.collect();
let entity_vec: Vec<String> = entities.into_iter().collect();
if entity_vec.len() < 2 {
return Ok(0.0);
}
for triple in test_triples.iter().take(sample_size) {
let positive_score = model.score_triple(&triple.0, &triple.1, &triple.2).await?;
let corrupt_idx = {
use scirs2_core::random::Random;
let mut rng = Random::default();
rng.random_range(0..entity_vec.len())
};
let corrupt_entity = &entity_vec[corrupt_idx];
let negative_score = {
use scirs2_core::random::Random;
let mut rng = Random::default();
if rng.random_bool_with_chance(0.5) {
model
.score_triple(corrupt_entity, &triple.1, &triple.2)
.await?
} else {
model
.score_triple(&triple.0, &triple.1, corrupt_entity)
.await?
}
};
if (positive_score - negative_score).abs() > 0.01 {
correct += 1;
}
}
Ok(correct as f32 / sample_size as f32)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AiStatistics {
pub gnn_enabled: bool,
pub embedding_models: usize,
pub vector_store_size: usize,
pub cache_hit_rate: f32,
pub gpu_utilization: f32,
}
pub trait AiQueryEnhancement {
fn enhance_query(&self, query: &str) -> Result<String>;
fn suggest_entities(&self, entity: &str) -> Result<Vec<String>>;
fn expand_query(&self, query: &str) -> Result<Vec<String>>;
}
pub trait AiDataValidation {
fn detect_anomalies(&self, triples: &[Triple]) -> Result<Vec<Anomaly>>;
fn suggest_improvements(&self, triples: &[Triple]) -> Result<Vec<Improvement>>;
fn validate_consistency(&self, triples: &[Triple]) -> Result<Vec<InconsistencyError>>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Anomaly {
pub anomaly_type: AnomalyType,
pub triple: Triple,
pub confidence: f32,
pub description: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AnomalyType {
Outlier,
MissingRelation,
InconsistentType,
DuplicateEntity,
InvalidFormat,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Improvement {
pub improvement_type: ImprovementType,
pub target: String,
pub suggestion: String,
pub impact: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ImprovementType {
AddRelation,
MergeEntities,
CorrectType,
AddConstraint,
NormalizeFormat,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InconsistencyError {
pub error_type: InconsistencyType,
pub triples: Vec<Triple>,
pub severity: Severity,
pub message: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum InconsistencyType {
LogicalContradiction,
TypeViolation,
CardinalityViolation,
DomainRangeViolation,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Severity {
Low,
Medium,
High,
Critical,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_ai_engine_creation() {
let config = AiConfig::default();
let engine = AiEngine::new(config);
assert!(engine.is_ok());
}
#[test]
fn test_config_serialization() {
let config = AiConfig::default();
let serialized = serde_json::to_string(&config).expect("construction should succeed");
let deserialized: AiConfig =
serde_json::from_str(&serialized).expect("construction should succeed");
assert_eq!(config.enable_gnn, deserialized.enable_gnn);
}
}