1pub mod embeddings;
8pub mod entity_resolution;
9pub mod gnn;
10pub mod gpu_monitor;
11pub mod neural;
12pub mod relation_extraction;
13pub mod temporal_reasoning;
14pub mod training;
15pub mod vector_store;
16
17use crate::model::Triple;
18use anyhow::{anyhow, Result};
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use std::sync::Arc;
22use tokio::sync::Mutex;
23
24pub use embeddings::{
25 create_embedding_model, ComplEx, DistMult, EmbeddingConfig, EmbeddingModelType,
26 KnowledgeGraphEmbedding, TransE,
27};
28pub use gnn::{
29 Aggregation, GnnArchitecture, GnnConfig, GraphNeuralNetwork, LayerType, MessagePassingType,
30};
31pub use training::{
32 DefaultTrainer, LossFunction, Optimizer, Trainer, TrainingConfig, TrainingMetrics,
33};
34pub use vector_store::{SimilarityMetric, VectorIndex, VectorQuery, VectorStore};
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct AiConfig {
39 pub enable_gnn: bool,
41
42 pub embedding_config: EmbeddingConfig,
44
45 pub vector_store_config: VectorStoreConfig,
47
48 pub training_config: TrainingConfig,
50
51 pub gpu_config: GpuConfig,
53
54 pub cache_config: CacheConfig,
56}
57
58impl Default for AiConfig {
59 fn default() -> Self {
60 Self {
61 enable_gnn: true,
62 embedding_config: EmbeddingConfig::default(),
63 vector_store_config: VectorStoreConfig::default(),
64 training_config: TrainingConfig::default(),
65 gpu_config: GpuConfig::default(),
66 cache_config: CacheConfig::default(),
67 }
68 }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct VectorStoreConfig {
74 pub dimension: usize,
76
77 pub metric: SimilarityMetric,
79
80 pub index_type: IndexType,
82
83 pub max_vectors: usize,
85
86 pub enable_ann: bool,
88
89 pub ann_neighbors: usize,
91}
92
93impl Default for VectorStoreConfig {
94 fn default() -> Self {
95 Self {
96 dimension: 128,
97 metric: SimilarityMetric::Cosine,
98 index_type: IndexType::HierarchicalNavigableSmallWorld,
99 max_vectors: 10_000_000,
100 enable_ann: true,
101 ann_neighbors: 16,
102 }
103 }
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub enum IndexType {
109 Flat,
111 InvertedFile { clusters: usize },
113 LocalitySensitiveHashing {
115 hash_tables: usize,
116 hash_length: usize,
117 },
118 HierarchicalNavigableSmallWorld,
120 ProductQuantization { subquantizers: usize, bits: usize },
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct GpuConfig {
127 pub enabled: bool,
129
130 pub device_id: u32,
132
133 pub memory_pool_mb: usize,
135
136 pub batch_size: usize,
138
139 pub mixed_precision: bool,
141}
142
143impl Default for GpuConfig {
144 fn default() -> Self {
145 Self {
146 enabled: true,
147 device_id: 0,
148 memory_pool_mb: 4096,
149 batch_size: 1024,
150 mixed_precision: true,
151 }
152 }
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct CacheConfig {
158 pub enabled: bool,
160
161 pub cache_dir: String,
163
164 pub max_size_mb: usize,
166
167 pub ttl_seconds: u64,
169
170 pub compression: bool,
172}
173
174impl Default for CacheConfig {
175 fn default() -> Self {
176 Self {
177 enabled: true,
178 cache_dir: "/tmp/oxirs/ai_cache".to_string(),
179 max_size_mb: 10240, ttl_seconds: 86400, compression: true,
182 }
183 }
184}
185
186pub struct AiEngine {
188 #[allow(dead_code)]
190 config: AiConfig,
191
192 gnn: Option<Arc<dyn GraphNeuralNetwork>>,
194
195 embeddings: HashMap<String, Arc<dyn KnowledgeGraphEmbedding>>,
197
198 vector_store: Arc<dyn VectorStore>,
200
201 trainer: Arc<Mutex<Box<dyn Trainer>>>,
203
204 entity_resolver: Arc<entity_resolution::EntityResolver>,
206
207 relation_extractor: Arc<relation_extraction::RelationExtractor>,
209
210 temporal_reasoner: Arc<temporal_reasoning::TemporalReasoner>,
212}
213
214impl AiEngine {
215 pub fn new(config: AiConfig) -> Result<Self> {
217 let vs_config = vector_store::VectorStoreConfig {
218 dimension: config.vector_store_config.dimension,
219 default_metric: config.vector_store_config.metric,
220 index_type: match config.vector_store_config.index_type {
221 IndexType::Flat => vector_store::IndexType::Flat,
222 IndexType::HierarchicalNavigableSmallWorld => vector_store::IndexType::HNSW {
223 max_connections: 16,
224 ef_construction: 200,
225 ef_search: 50,
226 },
227 IndexType::InvertedFile { clusters } => vector_store::IndexType::IVF {
228 num_clusters: clusters,
229 num_probes: 8,
230 },
231 IndexType::LocalitySensitiveHashing {
232 hash_tables,
233 hash_length,
234 } => vector_store::IndexType::LSH {
235 num_tables: hash_tables,
236 hash_length,
237 },
238 IndexType::ProductQuantization {
239 subquantizers,
240 bits,
241 } => vector_store::IndexType::PQ {
242 num_subquantizers: subquantizers,
243 bits_per_subquantizer: bits,
244 },
245 },
246 enable_cache: config.vector_store_config.enable_ann,
247 cache_size: if config.vector_store_config.max_vectors > 10000 {
248 10000
249 } else {
250 config.vector_store_config.max_vectors
251 },
252 cache_ttl: 3600,
253 batch_size: 1000,
254 };
255 let vector_store = vector_store::create_vector_store(&vs_config)?;
256 let trainer = Arc::new(Mutex::new(Box::new(training::DefaultTrainer::new(
258 config.training_config.clone(),
259 )) as Box<dyn Trainer>));
260 let entity_resolver = Arc::new(entity_resolution::EntityResolver::new(&config)?);
261 let relation_extractor = Arc::new(relation_extraction::RelationExtractor::new(&config)?);
262 let temporal_reasoner = Arc::new(temporal_reasoning::TemporalReasoner::new(&config)?);
263
264 Ok(Self {
265 config,
266 gnn: None,
267 embeddings: HashMap::new(),
268 vector_store,
269 trainer,
270 entity_resolver,
271 relation_extractor,
272 temporal_reasoner,
273 })
274 }
275
276 pub async fn initialize_gnn(&mut self, gnn_config: GnnConfig) -> Result<()> {
278 let gnn = gnn::create_gnn(gnn_config)?;
279 self.gnn = Some(gnn);
280 Ok(())
281 }
282
283 pub async fn add_embedding_model(
285 &mut self,
286 name: String,
287 model: Arc<dyn KnowledgeGraphEmbedding>,
288 ) -> Result<()> {
289 self.embeddings.insert(name, model);
290 Ok(())
291 }
292
293 pub async fn generate_embeddings(
295 &self,
296 model_name: &str,
297 triples: &[Triple],
298 ) -> Result<Vec<Vec<f32>>> {
299 let model = self
300 .embeddings
301 .get(model_name)
302 .ok_or_else(|| anyhow!("Embedding model not found: {}", model_name))?;
303
304 model.generate_embeddings(triples).await
305 }
306
307 pub async fn find_similar_entities(
309 &self,
310 entity_vector: &[f32],
311 top_k: usize,
312 ) -> Result<Vec<(String, f32)>> {
313 let query = VectorQuery {
314 vector: entity_vector.to_vec(),
315 k: top_k,
316 include_metadata: true,
317 metric: None,
318 filters: None,
319 min_similarity: None,
320 };
321
322 self.vector_store.search(&query).await
323 }
324
325 pub async fn predict_links(
327 &self,
328 model_name: &str,
329 entities: &[String],
330 relations: &[String],
331 ) -> Result<Vec<(String, String, String, f32)>> {
332 let model = self
333 .embeddings
334 .get(model_name)
335 .ok_or_else(|| anyhow!("Embedding model not found: {}", model_name))?;
336
337 model.predict_links(entities, relations).await
338 }
339
340 pub async fn resolve_entities(
342 &self,
343 entities: &[Triple],
344 ) -> Result<Vec<entity_resolution::EntityCluster>> {
345 self.entity_resolver.resolve_entities(entities).await
346 }
347
348 pub async fn extract_relations_from_text(
350 &self,
351 text: &str,
352 ) -> Result<Vec<relation_extraction::ExtractedRelation>> {
353 self.relation_extractor.extract_relations(text).await
354 }
355
356 pub async fn temporal_reasoning(
358 &self,
359 query: &temporal_reasoning::TemporalQuery,
360 ) -> Result<temporal_reasoning::TemporalResult> {
361 self.temporal_reasoner.reason(query).await
362 }
363
364 pub async fn train_embedding_model(
366 &self,
367 model_name: &str,
368 training_data: &[Triple],
369 validation_data: &[Triple],
370 ) -> Result<TrainingMetrics> {
371 let model = self
372 .embeddings
373 .get(model_name)
374 .ok_or_else(|| anyhow!("Embedding model not found: {}", model_name))?;
375
376 let trainer = self.trainer.clone();
378 let model = model.clone();
379 let training_data = training_data.to_vec();
380 let validation_data = validation_data.to_vec();
381
382 let mut trainer_guard = trainer.lock().await;
384 trainer_guard
385 .train_embedding_model(model, &training_data, &validation_data)
386 .await
387 }
388
389 pub async fn evaluate_model(
391 &self,
392 model_name: &str,
393 test_data: &[Triple],
394 ) -> Result<EvaluationMetrics> {
395 let model = self
396 .embeddings
397 .get(model_name)
398 .ok_or_else(|| anyhow!("Embedding model not found: {}", model_name))?;
399
400 EvaluationMetrics::evaluate(model.as_ref(), test_data).await
401 }
402
403 pub async fn get_statistics(&self) -> Result<AiStatistics> {
405 let vs_stats = self.vector_store.get_statistics().await?;
407
408 let gpu_monitor = gpu_monitor::GpuMonitor::global();
410 let gpu_utilization = gpu_monitor
411 .lock()
412 .map(|monitor| monitor.get_utilization())
413 .unwrap_or(0.0);
414
415 Ok(AiStatistics {
416 gnn_enabled: self.gnn.is_some(),
417 embedding_models: self.embeddings.len(),
418 vector_store_size: self.vector_store.size(),
419 cache_hit_rate: vs_stats.cache_hit_rate,
420 gpu_utilization,
421 })
422 }
423}
424
425#[derive(Debug, Clone, Serialize, Deserialize)]
427pub struct EvaluationMetrics {
428 pub mrr: f32,
430
431 pub hits_at_1: f32,
433 pub hits_at_3: f32,
434 pub hits_at_10: f32,
435
436 pub link_prediction_accuracy: f32,
438
439 pub entity_resolution_f1: f32,
441
442 pub relation_extraction_precision: f32,
444 pub relation_extraction_recall: f32,
445}
446
447impl EvaluationMetrics {
448 pub async fn evaluate(
450 model: &dyn KnowledgeGraphEmbedding,
451 test_data: &[Triple],
452 ) -> Result<Self> {
453 let test_triples: Vec<(String, String, String)> = test_data
455 .iter()
456 .map(|t| {
457 (
458 t.subject().to_string(),
459 t.predicate().to_string(),
460 t.object().to_string(),
461 )
462 })
463 .collect();
464
465 let all_triples = test_triples.clone();
468
469 let k_values = vec![1, 3, 10];
471
472 let kg_metrics = embeddings::evaluation::compute_kg_metrics(
474 model,
475 &test_triples,
476 &all_triples,
477 &k_values,
478 )
479 .await?;
480
481 let link_prediction_accuracy =
483 Self::compute_link_prediction_accuracy(model, &test_triples).await?;
484
485 let mrr = kg_metrics.mrr_filtered;
487 let hits_at_1 = *kg_metrics.hits_at_k_filtered.get(&1).unwrap_or(&0.0);
488 let hits_at_3 = *kg_metrics.hits_at_k_filtered.get(&3).unwrap_or(&0.0);
489 let hits_at_10 = *kg_metrics.hits_at_k_filtered.get(&10).unwrap_or(&0.0);
490
491 let entity_resolution_f1 = 0.0;
494 let relation_extraction_precision = 0.0;
495 let relation_extraction_recall = 0.0;
496
497 Ok(Self {
498 mrr,
499 hits_at_1,
500 hits_at_3,
501 hits_at_10,
502 link_prediction_accuracy,
503 entity_resolution_f1,
504 relation_extraction_precision,
505 relation_extraction_recall,
506 })
507 }
508
509 async fn compute_link_prediction_accuracy(
511 model: &dyn KnowledgeGraphEmbedding,
512 test_triples: &[(String, String, String)],
513 ) -> Result<f32> {
514 if test_triples.is_empty() {
515 return Ok(0.0);
516 }
517
518 let sample_size = test_triples.len().min(100);
520 let mut correct = 0;
521
522 let entities: std::collections::HashSet<String> = test_triples
524 .iter()
525 .flat_map(|(h, _, t)| vec![h.clone(), t.clone()])
526 .collect();
527 let entity_vec: Vec<String> = entities.into_iter().collect();
528
529 if entity_vec.len() < 2 {
530 return Ok(0.0);
531 }
532
533 for triple in test_triples.iter().take(sample_size) {
534 let positive_score = model.score_triple(&triple.0, &triple.1, &triple.2).await?;
535
536 let corrupt_idx = {
538 use scirs2_core::random::Random;
539 let mut rng = Random::default();
540 rng.random_range(0..entity_vec.len())
541 };
542 let corrupt_entity = &entity_vec[corrupt_idx];
543
544 let negative_score = {
545 use scirs2_core::random::Random;
546 let mut rng = Random::default();
547 if rng.random_bool_with_chance(0.5) {
548 model
550 .score_triple(corrupt_entity, &triple.1, &triple.2)
551 .await?
552 } else {
553 model
555 .score_triple(&triple.0, &triple.1, corrupt_entity)
556 .await?
557 }
558 };
559
560 if (positive_score - negative_score).abs() > 0.01 {
564 correct += 1;
565 }
566 }
567
568 Ok(correct as f32 / sample_size as f32)
569 }
570}
571
572#[derive(Debug, Clone, Serialize, Deserialize)]
574pub struct AiStatistics {
575 pub gnn_enabled: bool,
577
578 pub embedding_models: usize,
580
581 pub vector_store_size: usize,
583
584 pub cache_hit_rate: f32,
586
587 pub gpu_utilization: f32,
589}
590
591pub trait AiQueryEnhancement {
593 fn enhance_query(&self, query: &str) -> Result<String>;
595
596 fn suggest_entities(&self, entity: &str) -> Result<Vec<String>>;
598
599 fn expand_query(&self, query: &str) -> Result<Vec<String>>;
601}
602
603pub trait AiDataValidation {
605 fn detect_anomalies(&self, triples: &[Triple]) -> Result<Vec<Anomaly>>;
607
608 fn suggest_improvements(&self, triples: &[Triple]) -> Result<Vec<Improvement>>;
610
611 fn validate_consistency(&self, triples: &[Triple]) -> Result<Vec<InconsistencyError>>;
613}
614
615#[derive(Debug, Clone, Serialize, Deserialize)]
617pub struct Anomaly {
618 pub anomaly_type: AnomalyType,
620
621 pub triple: Triple,
623
624 pub confidence: f32,
626
627 pub description: String,
629}
630
631#[derive(Debug, Clone, Serialize, Deserialize)]
633pub enum AnomalyType {
634 Outlier,
636
637 MissingRelation,
639
640 InconsistentType,
642
643 DuplicateEntity,
645
646 InvalidFormat,
648}
649
650#[derive(Debug, Clone, Serialize, Deserialize)]
652pub struct Improvement {
653 pub improvement_type: ImprovementType,
655
656 pub target: String,
658
659 pub suggestion: String,
661
662 pub impact: f32,
664}
665
666#[derive(Debug, Clone, Serialize, Deserialize)]
668pub enum ImprovementType {
669 AddRelation,
671
672 MergeEntities,
674
675 CorrectType,
677
678 AddConstraint,
680
681 NormalizeFormat,
683}
684
685#[derive(Debug, Clone, Serialize, Deserialize)]
687pub struct InconsistencyError {
688 pub error_type: InconsistencyType,
690
691 pub triples: Vec<Triple>,
693
694 pub severity: Severity,
696
697 pub message: String,
699}
700
701#[derive(Debug, Clone, Serialize, Deserialize)]
703pub enum InconsistencyType {
704 LogicalContradiction,
706
707 TypeViolation,
709
710 CardinalityViolation,
712
713 DomainRangeViolation,
715}
716
717#[derive(Debug, Clone, Serialize, Deserialize)]
719pub enum Severity {
720 Low,
721 Medium,
722 High,
723 Critical,
724}
725
726#[cfg(test)]
727mod tests {
728 use super::*;
729
730 #[tokio::test]
731 async fn test_ai_engine_creation() {
732 let config = AiConfig::default();
733 let engine = AiEngine::new(config);
734 assert!(engine.is_ok());
735 }
736
737 #[test]
738 fn test_config_serialization() {
739 let config = AiConfig::default();
740 let serialized = serde_json::to_string(&config).expect("construction should succeed");
741 let deserialized: AiConfig =
742 serde_json::from_str(&serialized).expect("construction should succeed");
743 assert_eq!(config.enable_gnn, deserialized.enable_gnn);
744 }
745}