1use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
8use anyhow::{anyhow, Result};
9use async_trait::async_trait;
10use chrono::Utc;
11use scirs2_core::ndarray_ext::Array1;
12use serde::{Deserialize, Serialize};
13use std::collections::{HashMap, HashSet};
14use uuid::Uuid;
15
16#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
18pub enum OntologyRelation {
19 SubClassOf,
21 EquivalentClass,
23 DisjointWith,
25 Domain,
27 Range,
29 InverseOf,
31 FunctionalProperty,
33 SymmetricProperty,
35 TransitiveProperty,
37}
38
39#[derive(Debug, Clone, Default, Serialize, Deserialize)]
41pub struct PropertyCharacteristics {
42 pub is_functional: bool,
43 pub is_inverse_functional: bool,
44 pub is_symmetric: bool,
45 pub is_asymmetric: bool,
46 pub is_transitive: bool,
47 pub is_reflexive: bool,
48 pub is_irreflexive: bool,
49 pub has_inverse: Option<String>,
50 pub domain_classes: HashSet<String>,
51 pub range_classes: HashSet<String>,
52}
53
54impl PropertyCharacteristics {
55 pub fn has_domain_constraints(&self) -> bool {
57 !self.domain_classes.is_empty()
58 }
59
60 pub fn has_range_constraints(&self) -> bool {
62 !self.range_classes.is_empty()
63 }
64
65 pub fn satisfies_domain(&self, entity_type: &str) -> bool {
67 if self.domain_classes.is_empty() {
68 true } else {
70 self.domain_classes.contains(entity_type)
71 }
72 }
73
74 pub fn satisfies_range(&self, entity_type: &str) -> bool {
76 if self.range_classes.is_empty() {
77 true } else {
79 self.range_classes.contains(entity_type)
80 }
81 }
82}
83
84impl OntologyRelation {
85 pub fn from_iri(iri: &str) -> Option<Self> {
87 match iri {
88 "http://www.w3.org/2000/01/rdf-schema#subClassOf" => Some(Self::SubClassOf),
89 "http://www.w3.org/2002/07/owl#equivalentClass" => Some(Self::EquivalentClass),
90 "http://www.w3.org/2002/07/owl#disjointWith" => Some(Self::DisjointWith),
91 "http://www.w3.org/2000/01/rdf-schema#domain" => Some(Self::Domain),
92 "http://www.w3.org/2000/01/rdf-schema#range" => Some(Self::Range),
93 "http://www.w3.org/2002/07/owl#inverseOf" => Some(Self::InverseOf),
94 "http://www.w3.org/2002/07/owl#FunctionalProperty" => Some(Self::FunctionalProperty),
95 "http://www.w3.org/2002/07/owl#SymmetricProperty" => Some(Self::SymmetricProperty),
96 "http://www.w3.org/2002/07/owl#TransitiveProperty" => Some(Self::TransitiveProperty),
97 _ => None,
98 }
99 }
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct OntologyAwareConfig {
105 pub base_config: ModelConfig,
106 pub hierarchy_weight: f32,
108 pub equivalence_weight: f32,
110 pub disjoint_weight: f32,
112 pub property_constraint_weight: f32,
114 pub cross_modal_weight: f32,
116 pub use_transitive_closure: bool,
118 pub max_transitive_depth: usize,
120 pub normalize_for_hierarchy: bool,
122 pub hierarchy_margin: f32,
124 pub enable_contrastive_learning: bool,
126 pub contrastive_temperature: f32,
128 pub enable_mutual_info_max: bool,
130 pub enable_property_chains: bool,
132 pub max_property_chain_length: usize,
134}
135
136impl Default for OntologyAwareConfig {
137 fn default() -> Self {
138 Self {
139 base_config: ModelConfig::default(),
140 hierarchy_weight: 1.0,
141 equivalence_weight: 2.0,
142 disjoint_weight: 1.5,
143 property_constraint_weight: 1.2,
144 cross_modal_weight: 0.8,
145 use_transitive_closure: true,
146 max_transitive_depth: 10,
147 normalize_for_hierarchy: true,
148 hierarchy_margin: 1.0,
149 enable_contrastive_learning: true,
150 contrastive_temperature: 0.1,
151 enable_mutual_info_max: false,
152 enable_property_chains: true,
153 max_property_chain_length: 3,
154 }
155 }
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct OntologyAwareEmbedding {
161 pub config: OntologyAwareConfig,
162 pub model_id: Uuid,
164 pub entity_embeddings: HashMap<String, Array1<f32>>,
166 pub relation_embeddings: HashMap<String, Array1<f32>>,
168 pub entity_to_idx: HashMap<String, usize>,
170 pub relation_to_idx: HashMap<String, usize>,
172 pub triples: Vec<Triple>,
174 pub ontology_constraints: OntologyConstraints,
176 pub training_stats: TrainingStats,
178 pub model_stats: ModelStats,
180 pub is_trained: bool,
182}
183
184#[derive(Debug, Clone, Default, Serialize, Deserialize)]
186pub struct OntologyConstraints {
187 pub class_hierarchy: HashMap<String, HashSet<String>>,
189 pub equivalent_classes: HashMap<String, HashSet<String>>,
191 pub disjoint_classes: HashMap<String, HashSet<String>>,
193 pub property_domains: HashMap<String, HashSet<String>>,
195 pub property_ranges: HashMap<String, HashSet<String>>,
197 pub inverse_properties: HashMap<String, String>,
199 pub functional_properties: HashSet<String>,
201 pub symmetric_properties: HashSet<String>,
203 pub transitive_properties: HashSet<String>,
205 pub transitive_hierarchy: HashMap<String, HashSet<String>>,
207 pub property_chains: HashMap<String, Vec<Vec<String>>>,
209 pub cross_modal_alignments: HashMap<String, HashSet<String>>,
211 pub property_characteristics: HashMap<String, PropertyCharacteristics>,
213}
214
215impl OntologyConstraints {
216 pub fn compute_transitive_closure(&mut self, max_depth: usize) {
218 self.transitive_hierarchy = self.class_hierarchy.clone();
219
220 for _ in 0..max_depth {
221 let mut changed = false;
222 let current_hierarchy = self.transitive_hierarchy.clone();
223
224 for (subclass, superclasses) in ¤t_hierarchy {
225 let mut new_superclasses = superclasses.clone();
226
227 for superclass in superclasses {
229 if let Some(super_superclasses) = current_hierarchy.get(superclass) {
230 for super_superclass in super_superclasses {
231 if !new_superclasses.contains(super_superclass)
232 && super_superclass != subclass
233 {
234 new_superclasses.insert(super_superclass.clone());
235 changed = true;
236 }
237 }
238 }
239 }
240
241 self.transitive_hierarchy
242 .insert(subclass.clone(), new_superclasses);
243 }
244
245 if !changed {
246 break;
247 }
248 }
249 }
250
251 pub fn is_subclass_of(&self, subclass: &str, superclass: &str) -> bool {
253 if let Some(superclasses) = self.transitive_hierarchy.get(subclass) {
254 superclasses.contains(superclass)
255 } else {
256 false
257 }
258 }
259
260 pub fn are_equivalent(&self, class1: &str, class2: &str) -> bool {
262 if let Some(equivalent) = self.equivalent_classes.get(class1) {
263 equivalent.contains(class2)
264 } else {
265 false
266 }
267 }
268
269 pub fn are_disjoint(&self, class1: &str, class2: &str) -> bool {
271 if let Some(disjoint) = self.disjoint_classes.get(class1) {
272 disjoint.contains(class2)
273 } else {
274 false
275 }
276 }
277
278 pub fn add_property_chain(&mut self, target_property: &str, chain: Vec<String>) {
280 self.property_chains
281 .entry(target_property.to_string())
282 .or_default()
283 .push(chain);
284 }
285
286 pub fn get_property_chains(&self, property: &str) -> Option<&Vec<Vec<String>>> {
288 self.property_chains.get(property)
289 }
290
291 pub fn add_cross_modal_alignment(&mut self, entity1: &str, entity2: &str) {
293 self.cross_modal_alignments
294 .entry(entity1.to_string())
295 .or_default()
296 .insert(entity2.to_string());
297
298 self.cross_modal_alignments
300 .entry(entity2.to_string())
301 .or_default()
302 .insert(entity1.to_string());
303 }
304
305 pub fn get_cross_modal_alignments(&self, entity: &str) -> Option<&HashSet<String>> {
307 self.cross_modal_alignments.get(entity)
308 }
309
310 pub fn build_property_characteristics_cache(&mut self) {
312 let all_properties: HashSet<String> = self
314 .property_domains
315 .keys()
316 .chain(self.property_ranges.keys())
317 .chain(self.functional_properties.iter())
318 .chain(self.symmetric_properties.iter())
319 .chain(self.transitive_properties.iter())
320 .chain(self.inverse_properties.keys())
321 .cloned()
322 .collect();
323
324 for property in all_properties {
325 let mut characteristics = PropertyCharacteristics {
326 is_functional: self.functional_properties.contains(&property),
327 is_symmetric: self.symmetric_properties.contains(&property),
328 is_transitive: self.transitive_properties.contains(&property),
329 has_inverse: self.inverse_properties.get(&property).cloned(),
330 ..Default::default()
331 };
332
333 if let Some(domains) = self.property_domains.get(&property) {
334 characteristics.domain_classes = domains.clone();
335 }
336
337 if let Some(ranges) = self.property_ranges.get(&property) {
338 characteristics.range_classes = ranges.clone();
339 }
340
341 self.property_characteristics
342 .insert(property, characteristics);
343 }
344 }
345
346 pub fn validate_property_usage(
348 &self,
349 subject: &str,
350 property: &str,
351 object: &str,
352 entity_types: &HashMap<String, String>,
353 ) -> bool {
354 if let Some(characteristics) = self.property_characteristics.get(property) {
355 if characteristics.has_domain_constraints() {
357 if let Some(subject_type) = entity_types.get(subject) {
358 if !characteristics.satisfies_domain(subject_type) {
359 return false;
360 }
361 }
362 }
363
364 if characteristics.has_range_constraints() {
366 if let Some(object_type) = entity_types.get(object) {
367 if !characteristics.satisfies_range(object_type) {
368 return false;
369 }
370 }
371 }
372 }
373
374 true
375 }
376
377 pub fn infer_from_property_chains(&self, existing_triples: &[Triple]) -> Vec<Triple> {
379 let mut inferred_triples = Vec::new();
380
381 for (target_property, chains) in &self.property_chains {
382 for chain in chains {
383 if chain.len() >= 2 {
384 inferred_triples.extend(self.find_chain_matches(
386 existing_triples,
387 target_property,
388 chain,
389 ));
390 }
391 }
392 }
393
394 inferred_triples
395 }
396
397 fn find_chain_matches(
399 &self,
400 triples: &[Triple],
401 target_property: &str,
402 chain: &[String],
403 ) -> Vec<Triple> {
404 let mut matches = Vec::new();
405
406 let mut triples_by_predicate: HashMap<String, Vec<&Triple>> = HashMap::new();
408 for triple in triples {
409 triples_by_predicate
410 .entry(triple.predicate.iri.clone())
411 .or_default()
412 .push(triple);
413 }
414
415 if chain.len() == 2 {
417 let prop1 = &chain[0];
418 let prop2 = &chain[1];
419
420 if let (Some(triples1), Some(triples2)) = (
421 triples_by_predicate.get(prop1),
422 triples_by_predicate.get(prop2),
423 ) {
424 for t1 in triples1 {
425 for t2 in triples2 {
426 if t1.object.iri == t2.subject.iri {
428 if let (Ok(subject), Ok(predicate), Ok(object)) = (
430 crate::NamedNode::new(&t1.subject.iri),
431 crate::NamedNode::new(target_property),
432 crate::NamedNode::new(&t2.object.iri),
433 ) {
434 matches.push(Triple::new(subject, predicate, object));
435 }
436 }
437 }
438 }
439 }
440 }
441
442 matches
443 }
444}
445
446impl Default for TrainingStats {
447 fn default() -> Self {
448 Self {
449 epochs_completed: 0,
450 final_loss: 0.0,
451 training_time_seconds: 0.0,
452 convergence_achieved: false,
453 loss_history: Vec::new(),
454 }
455 }
456}
457
458impl OntologyAwareEmbedding {
459 pub fn new(config: OntologyAwareConfig) -> Self {
461 let model_id = Uuid::new_v4();
462 let now = Utc::now();
463
464 Self {
465 model_id,
466 entity_embeddings: HashMap::new(),
467 relation_embeddings: HashMap::new(),
468 entity_to_idx: HashMap::new(),
469 relation_to_idx: HashMap::new(),
470 triples: Vec::new(),
471 ontology_constraints: OntologyConstraints::default(),
472 training_stats: TrainingStats::default(),
473 model_stats: ModelStats {
474 num_entities: 0,
475 num_relations: 0,
476 num_triples: 0,
477 dimensions: config.base_config.dimensions,
478 is_trained: false,
479 model_type: "OntologyAware".to_string(),
480 creation_time: now,
481 last_training_time: None,
482 },
483 is_trained: false,
484 config,
485 }
486 }
487
488 pub fn hierarchy_optimized_config(dimensions: usize) -> OntologyAwareConfig {
490 OntologyAwareConfig {
491 base_config: ModelConfig::default().with_dimensions(dimensions),
492 hierarchy_weight: 2.0,
493 equivalence_weight: 1.0,
494 disjoint_weight: 1.0,
495 property_constraint_weight: 1.0,
496 cross_modal_weight: 0.5,
497 use_transitive_closure: true,
498 max_transitive_depth: 15,
499 normalize_for_hierarchy: true,
500 hierarchy_margin: 0.5,
501 enable_contrastive_learning: false,
502 contrastive_temperature: 0.1,
503 enable_mutual_info_max: false,
504 enable_property_chains: true,
505 max_property_chain_length: 2,
506 }
507 }
508
509 pub fn property_optimized_config(dimensions: usize) -> OntologyAwareConfig {
511 OntologyAwareConfig {
512 base_config: ModelConfig::default().with_dimensions(dimensions),
513 hierarchy_weight: 1.0,
514 equivalence_weight: 1.5,
515 disjoint_weight: 2.0,
516 property_constraint_weight: 2.5,
517 cross_modal_weight: 1.0,
518 use_transitive_closure: true,
519 max_transitive_depth: 8,
520 normalize_for_hierarchy: false,
521 hierarchy_margin: 1.0,
522 enable_contrastive_learning: true,
523 contrastive_temperature: 0.05,
524 enable_mutual_info_max: true,
525 enable_property_chains: true,
526 max_property_chain_length: 3,
527 }
528 }
529
530 fn extract_ontology_constraints(&mut self) {
532 for triple in &self.triples {
533 if let Some(relation_type) = OntologyRelation::from_iri(&triple.predicate.iri) {
534 match relation_type {
535 OntologyRelation::SubClassOf => {
536 self.ontology_constraints
537 .class_hierarchy
538 .entry(triple.subject.iri.clone())
539 .or_default()
540 .insert(triple.object.iri.clone());
541 }
542 OntologyRelation::EquivalentClass => {
543 self.ontology_constraints
544 .equivalent_classes
545 .entry(triple.subject.iri.clone())
546 .or_default()
547 .insert(triple.object.iri.clone());
548 self.ontology_constraints
550 .equivalent_classes
551 .entry(triple.object.iri.clone())
552 .or_default()
553 .insert(triple.subject.iri.clone());
554 }
555 OntologyRelation::DisjointWith => {
556 self.ontology_constraints
557 .disjoint_classes
558 .entry(triple.subject.iri.clone())
559 .or_default()
560 .insert(triple.object.iri.clone());
561 self.ontology_constraints
563 .disjoint_classes
564 .entry(triple.object.iri.clone())
565 .or_default()
566 .insert(triple.subject.iri.clone());
567 }
568 OntologyRelation::Domain => {
569 self.ontology_constraints
570 .property_domains
571 .entry(triple.subject.iri.clone())
572 .or_default()
573 .insert(triple.object.iri.clone());
574 }
575 OntologyRelation::Range => {
576 self.ontology_constraints
577 .property_ranges
578 .entry(triple.subject.iri.clone())
579 .or_default()
580 .insert(triple.object.iri.clone());
581 }
582 OntologyRelation::InverseOf => {
583 self.ontology_constraints
584 .inverse_properties
585 .insert(triple.subject.iri.clone(), triple.object.iri.clone());
586 self.ontology_constraints
587 .inverse_properties
588 .insert(triple.object.iri.clone(), triple.subject.iri.clone());
589 }
590 OntologyRelation::FunctionalProperty => {
591 self.ontology_constraints
592 .functional_properties
593 .insert(triple.subject.iri.clone());
594 }
595 OntologyRelation::SymmetricProperty => {
596 self.ontology_constraints
597 .symmetric_properties
598 .insert(triple.subject.iri.clone());
599 }
600 OntologyRelation::TransitiveProperty => {
601 self.ontology_constraints
602 .transitive_properties
603 .insert(triple.subject.iri.clone());
604 }
605 }
606 }
607 }
608
609 if self.config.use_transitive_closure {
611 self.ontology_constraints
612 .compute_transitive_closure(self.config.max_transitive_depth);
613 }
614 }
615
616 fn compute_hierarchy_loss(&self) -> f32 {
618 let mut total_loss = 0.0;
619 let mut count = 0;
620
621 for (subclass, superclasses) in &self.ontology_constraints.transitive_hierarchy {
622 if let Some(sub_emb) = self.entity_embeddings.get(subclass) {
623 for superclass in superclasses {
624 if let Some(super_emb) = self.entity_embeddings.get(superclass) {
625 let sub_norm = sub_emb.dot(sub_emb).sqrt();
628 let super_norm = super_emb.dot(super_emb).sqrt();
629 let similarity = sub_emb.dot(super_emb) / (sub_norm * super_norm + 1e-8);
630
631 let hierarchy_score = similarity + (super_norm - sub_norm) * 0.1;
633 let loss = (self.config.hierarchy_margin - hierarchy_score).max(0.0);
634 total_loss += loss;
635 count += 1;
636 }
637 }
638 }
639 }
640
641 if count > 0 {
642 total_loss / count as f32
643 } else {
644 0.0
645 }
646 }
647
648 fn compute_equivalence_loss(&self) -> f32 {
650 let mut total_loss = 0.0;
651 let mut count = 0;
652
653 for (class1, equivalent_classes) in &self.ontology_constraints.equivalent_classes {
654 if let Some(emb1) = self.entity_embeddings.get(class1) {
655 for class2 in equivalent_classes {
656 if let Some(emb2) = self.entity_embeddings.get(class2) {
657 let distance = (emb1 - emb2).mapv(|x| x * x).sum().sqrt();
659 total_loss += distance;
660 count += 1;
661 }
662 }
663 }
664 }
665
666 if count > 0 {
667 total_loss / count as f32
668 } else {
669 0.0
670 }
671 }
672
673 fn compute_disjoint_loss(&self) -> f32 {
675 let mut total_loss = 0.0;
676 let mut count = 0;
677
678 for (class1, disjoint_classes) in &self.ontology_constraints.disjoint_classes {
679 if let Some(emb1) = self.entity_embeddings.get(class1) {
680 for class2 in disjoint_classes {
681 if let Some(emb2) = self.entity_embeddings.get(class2) {
682 let norm1 = emb1.dot(emb1).sqrt();
684 let norm2 = emb2.dot(emb2).sqrt();
685 let similarity = emb1.dot(emb2) / (norm1 * norm2 + 1e-8);
686 let loss = (similarity + self.config.hierarchy_margin).max(0.0);
687 total_loss += loss;
688 count += 1;
689 }
690 }
691 }
692 }
693
694 if count > 0 {
695 total_loss / count as f32
696 } else {
697 0.0
698 }
699 }
700
701 fn compute_property_constraint_loss(&self) -> f32 {
703 let mut total_loss = 0.0;
704 let mut count = 0;
705
706 for (property, domains) in &self.ontology_constraints.property_domains {
708 if let Some(relation_emb) = self.relation_embeddings.get(property) {
709 for domain_class in domains {
710 if let Some(domain_emb) = self.entity_embeddings.get(domain_class) {
711 let compatibility = relation_emb.dot(domain_emb);
713 let loss = (1.0 - compatibility).max(0.0); total_loss += loss;
715 count += 1;
716 }
717 }
718 }
719 }
720
721 for (property, ranges) in &self.ontology_constraints.property_ranges {
723 if let Some(relation_emb) = self.relation_embeddings.get(property) {
724 for range_class in ranges {
725 if let Some(range_emb) = self.entity_embeddings.get(range_class) {
726 let compatibility = relation_emb.dot(range_emb);
728 let loss = (1.0 - compatibility).max(0.0);
729 total_loss += loss;
730 count += 1;
731 }
732 }
733 }
734 }
735
736 if count > 0 {
737 total_loss / count as f32
738 } else {
739 0.0
740 }
741 }
742
743 fn compute_contrastive_loss(&self) -> f32 {
745 if !self.config.enable_contrastive_learning {
746 return 0.0;
747 }
748
749 let mut total_loss = 0.0;
750 let mut count = 0;
751 let temperature = self.config.contrastive_temperature;
752
753 for (entity1, aligned_entities) in &self.ontology_constraints.cross_modal_alignments {
754 if let Some(emb1) = self.entity_embeddings.get(entity1) {
755 for entity2 in aligned_entities {
756 if let Some(emb2) = self.entity_embeddings.get(entity2) {
757 let pos_sim = emb1.dot(emb2) / temperature;
759
760 let mut neg_sims = Vec::new();
762 for (neg_entity, neg_emb) in self.entity_embeddings.iter().take(10) {
763 if neg_entity != entity1 && neg_entity != entity2 {
764 let neg_sim = emb1.dot(neg_emb) / temperature;
765 neg_sims.push(neg_sim);
766 }
767 }
768
769 if !neg_sims.is_empty() {
770 let exp_pos = pos_sim.exp();
772 let sum_exp_neg: f32 = neg_sims.iter().copied().map(|x| x.exp()).sum();
773 let loss = -(exp_pos / (exp_pos + sum_exp_neg)).ln();
774 total_loss += loss;
775 count += 1;
776 }
777 }
778 }
779 }
780 }
781
782 if count > 0 {
783 total_loss / count as f32
784 } else {
785 0.0
786 }
787 }
788
789 fn compute_mutual_info_loss(&self) -> f32 {
791 if !self.config.enable_mutual_info_max {
792 return 0.0;
793 }
794
795 let mut total_loss = 0.0;
796 let mut count = 0;
797
798 for (entity, entity_emb) in &self.entity_embeddings {
800 for relation_emb in self.relation_embeddings.values() {
801 let pair_exists = self
803 .triples
804 .iter()
805 .any(|t| t.subject.iri == *entity || t.object.iri == *entity);
806
807 if pair_exists {
808 let mi = entity_emb.dot(relation_emb);
810 let loss = (1.0 - mi).max(0.0);
811 total_loss += loss;
812 count += 1;
813 }
814 }
815 }
816
817 if count > 0 {
818 total_loss / count as f32
819 } else {
820 0.0
821 }
822 }
823
824 fn compute_property_chain_loss(&self) -> f32 {
826 if !self.config.enable_property_chains {
827 return 0.0;
828 }
829
830 let mut total_loss = 0.0;
831 let mut count = 0;
832
833 for (target_property, chains) in &self.ontology_constraints.property_chains {
834 if let Some(target_emb) = self.relation_embeddings.get(target_property) {
835 for chain in chains {
836 if chain.len() == 2 {
837 if let (Some(prop1_emb), Some(prop2_emb)) = (
839 self.relation_embeddings.get(&chain[0]),
840 self.relation_embeddings.get(&chain[1]),
841 ) {
842 let chain_emb = prop1_emb + prop2_emb;
843 let distance = (target_emb - &chain_emb).mapv(|x| x * x).sum().sqrt();
844 total_loss += distance;
845 count += 1;
846 }
847 }
848 }
849 }
850 }
851
852 if count > 0 {
853 total_loss / count as f32
854 } else {
855 0.0
856 }
857 }
858}
859
860#[async_trait]
861impl EmbeddingModel for OntologyAwareEmbedding {
862 fn config(&self) -> &ModelConfig {
863 &self.config.base_config
864 }
865
866 fn model_id(&self) -> &Uuid {
867 &self.model_id
868 }
869
870 fn model_type(&self) -> &'static str {
871 "OntologyAware"
872 }
873
874 fn add_triple(&mut self, triple: Triple) -> Result<()> {
875 self.triples.push(triple);
876 Ok(())
877 }
878
879 async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
880 let start_time = std::time::Instant::now();
881
882 self.extract_ontology_constraints();
884
885 self.ontology_constraints
887 .build_property_characteristics_cache();
888
889 let mut entity_set = HashSet::new();
891 let mut relation_set = HashSet::new();
892
893 for triple in &self.triples {
894 entity_set.insert(triple.subject.iri.clone());
895 entity_set.insert(triple.object.iri.clone());
896 relation_set.insert(triple.predicate.iri.clone());
897 }
898
899 for (idx, entity) in entity_set.iter().enumerate() {
901 self.entity_to_idx.insert(entity.clone(), idx);
902 }
903
904 for (idx, relation) in relation_set.iter().enumerate() {
905 self.relation_to_idx.insert(relation.clone(), idx);
906 }
907
908 let dimensions = self.config.base_config.dimensions;
910 for entity in &entity_set {
911 let embedding = Array1::from_vec(
912 (0..dimensions)
913 .map(|_| {
914 use scirs2_core::random::{Random, Rng};
915 let mut random = Random::default();
916 (random.random::<f32>() - 0.5) * 0.1
917 })
918 .collect(),
919 );
920 self.entity_embeddings.insert(entity.clone(), embedding);
921 }
922
923 for relation in &relation_set {
924 let embedding = Array1::from_vec(
925 (0..dimensions)
926 .map(|_| {
927 use scirs2_core::random::{Random, Rng};
928 let mut random = Random::default();
929 (random.random::<f32>() - 0.5) * 0.1
930 })
931 .collect(),
932 );
933 self.relation_embeddings.insert(relation.clone(), embedding);
934 }
935
936 let max_epochs = epochs.unwrap_or(self.config.base_config.max_epochs);
938 let learning_rate = self.config.base_config.learning_rate as f32;
939 let mut loss_history = Vec::new();
940
941 for epoch in 0..max_epochs {
942 let mut total_loss = 0.0;
943
944 for triple in &self.triples {
946 if let (Some(h), Some(r), Some(t)) = (
947 self.entity_embeddings.get(&triple.subject.iri).cloned(),
948 self.relation_embeddings.get(&triple.predicate.iri).cloned(),
949 self.entity_embeddings.get(&triple.object.iri).cloned(),
950 ) {
951 let predicted = &h + &r;
953 let error = &t - &predicted;
954 let loss = error.dot(&error).sqrt();
955 total_loss += loss;
956
957 let gradient_scale = learning_rate * 0.01;
959 let h_grad = &error * gradient_scale;
960 let r_grad = &error * gradient_scale;
961 let t_grad = &error * (-gradient_scale);
962
963 if let Some(h_emb) = self.entity_embeddings.get_mut(&triple.subject.iri) {
965 *h_emb += &h_grad;
966 }
967 if let Some(r_emb) = self.relation_embeddings.get_mut(&triple.predicate.iri) {
968 *r_emb += &r_grad;
969 }
970 if let Some(t_emb) = self.entity_embeddings.get_mut(&triple.object.iri) {
971 *t_emb += &t_grad;
972 }
973 }
974 }
975
976 let hierarchy_loss = self.compute_hierarchy_loss();
978 let equivalence_loss = self.compute_equivalence_loss();
979 let disjoint_loss = self.compute_disjoint_loss();
980 let property_loss = self.compute_property_constraint_loss();
981 let contrastive_loss = self.compute_contrastive_loss();
982 let mutual_info_loss = self.compute_mutual_info_loss();
983 let property_chain_loss = self.compute_property_chain_loss();
984
985 total_loss += hierarchy_loss * self.config.hierarchy_weight;
986 total_loss += equivalence_loss * self.config.equivalence_weight;
987 total_loss += disjoint_loss * self.config.disjoint_weight;
988 total_loss += property_loss * self.config.property_constraint_weight;
989 total_loss += contrastive_loss * self.config.cross_modal_weight;
990 total_loss += mutual_info_loss * self.config.cross_modal_weight * 0.5;
991 total_loss += property_chain_loss * self.config.property_constraint_weight * 0.8;
992
993 loss_history.push(total_loss as f64);
994
995 if self.config.normalize_for_hierarchy {
997 for embedding in self.entity_embeddings.values_mut() {
998 let norm = embedding.dot(embedding).sqrt();
999 if norm > 0.0 {
1000 *embedding /= norm;
1001 }
1002 }
1003 }
1004
1005 if epoch % 10 == 0 {
1006 tracing::info!(
1007 "Epoch {}: total_loss={:.6}, hierarchy={:.6}, equiv={:.6}, disjoint={:.6}",
1008 epoch,
1009 total_loss,
1010 hierarchy_loss,
1011 equivalence_loss,
1012 disjoint_loss
1013 );
1014 }
1015 }
1016
1017 let training_time = start_time.elapsed().as_secs_f64();
1018 self.is_trained = true;
1019
1020 self.model_stats.num_entities = entity_set.len();
1022 self.model_stats.num_relations = relation_set.len();
1023 self.model_stats.num_triples = self.triples.len();
1024 self.model_stats.is_trained = true;
1025 self.model_stats.last_training_time = Some(Utc::now());
1026
1027 self.training_stats = TrainingStats {
1029 epochs_completed: max_epochs,
1030 final_loss: loss_history.last().copied().unwrap_or(0.0),
1031 training_time_seconds: training_time,
1032 convergence_achieved: loss_history.last().copied().unwrap_or(0.0) < 0.01,
1033 loss_history,
1034 };
1035
1036 Ok(self.training_stats.clone())
1037 }
1038
1039 fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
1040 self.entity_embeddings
1041 .get(entity)
1042 .map(|arr| Vector::new(arr.to_vec()))
1043 .ok_or_else(|| anyhow!("Entity not found: {}", entity))
1044 }
1045
1046 fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
1047 self.relation_embeddings
1048 .get(relation)
1049 .map(|arr| Vector::new(arr.to_vec()))
1050 .ok_or_else(|| anyhow!("Relation not found: {}", relation))
1051 }
1052
1053 fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
1054 let h = self
1055 .entity_embeddings
1056 .get(subject)
1057 .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
1058 let r = self
1059 .relation_embeddings
1060 .get(predicate)
1061 .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
1062 let t = self
1063 .entity_embeddings
1064 .get(object)
1065 .ok_or_else(|| anyhow!("Object not found: {}", object))?;
1066
1067 let predicted = h + r;
1069 let distance = (&predicted - t).mapv(|x| x * x).sum().sqrt();
1070 Ok(-(distance as f64)) }
1072
1073 fn predict_objects(
1074 &self,
1075 subject: &str,
1076 predicate: &str,
1077 k: usize,
1078 ) -> Result<Vec<(String, f64)>> {
1079 let h = self
1080 .entity_embeddings
1081 .get(subject)
1082 .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
1083 let r = self
1084 .relation_embeddings
1085 .get(predicate)
1086 .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
1087 let predicted = h + r;
1088
1089 let mut scores = Vec::new();
1090 for (entity, embedding) in &self.entity_embeddings {
1091 let distance = (&predicted - embedding).mapv(|x| x * x).sum().sqrt();
1092 scores.push((entity.clone(), -(distance as f64)));
1093 }
1094
1095 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
1096 scores.truncate(k);
1097 Ok(scores)
1098 }
1099
1100 fn predict_subjects(
1101 &self,
1102 predicate: &str,
1103 object: &str,
1104 k: usize,
1105 ) -> Result<Vec<(String, f64)>> {
1106 let r = self
1107 .relation_embeddings
1108 .get(predicate)
1109 .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
1110 let t = self
1111 .entity_embeddings
1112 .get(object)
1113 .ok_or_else(|| anyhow!("Object not found: {}", object))?;
1114 let target = t - r; let mut scores = Vec::new();
1117 for (entity, embedding) in &self.entity_embeddings {
1118 let distance = (embedding - &target).mapv(|x| x * x).sum().sqrt();
1119 scores.push((entity.clone(), -(distance as f64)));
1120 }
1121
1122 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
1123 scores.truncate(k);
1124 Ok(scores)
1125 }
1126
1127 fn predict_relations(
1128 &self,
1129 subject: &str,
1130 object: &str,
1131 k: usize,
1132 ) -> Result<Vec<(String, f64)>> {
1133 let h = self
1134 .entity_embeddings
1135 .get(subject)
1136 .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
1137 let t = self
1138 .entity_embeddings
1139 .get(object)
1140 .ok_or_else(|| anyhow!("Object not found: {}", object))?;
1141 let target = t - h; let mut scores = Vec::new();
1144 for (relation, embedding) in &self.relation_embeddings {
1145 let distance = (embedding - &target).mapv(|x| x * x).sum().sqrt();
1146 scores.push((relation.clone(), -(distance as f64)));
1147 }
1148
1149 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
1150 scores.truncate(k);
1151 Ok(scores)
1152 }
1153
1154 fn get_entities(&self) -> Vec<String> {
1155 self.entity_embeddings.keys().cloned().collect()
1156 }
1157
1158 fn get_relations(&self) -> Vec<String> {
1159 self.relation_embeddings.keys().cloned().collect()
1160 }
1161
1162 fn get_stats(&self) -> ModelStats {
1163 self.model_stats.clone()
1164 }
1165
1166 fn save(&self, path: &str) -> Result<()> {
1167 let serialized = serde_json::to_string_pretty(self)?;
1168 std::fs::write(path, serialized)?;
1169 Ok(())
1170 }
1171
1172 fn load(&mut self, path: &str) -> Result<()> {
1173 let content = std::fs::read_to_string(path)?;
1174 let loaded: OntologyAwareEmbedding = serde_json::from_str(&content)?;
1175 *self = loaded;
1176 Ok(())
1177 }
1178
1179 fn clear(&mut self) {
1180 self.entity_embeddings.clear();
1181 self.relation_embeddings.clear();
1182 self.entity_to_idx.clear();
1183 self.relation_to_idx.clear();
1184 self.triples.clear();
1185 self.ontology_constraints = OntologyConstraints::default();
1186 self.training_stats = TrainingStats::default();
1187 self.is_trained = false;
1188 self.model_stats.is_trained = false;
1189 self.model_stats.num_entities = 0;
1190 self.model_stats.num_relations = 0;
1191 self.model_stats.num_triples = 0;
1192 }
1193
1194 fn is_trained(&self) -> bool {
1195 self.is_trained
1196 }
1197
1198 async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
1199 Err(anyhow!(
1200 "Knowledge graph embedding model does not support text encoding"
1201 ))
1202 }
1203}
1204
1205#[cfg(test)]
1206mod tests {
1207 use super::*;
1208 use crate::NamedNode;
1209
1210 #[test]
1211 fn test_ontology_relation_from_iri() {
1212 assert_eq!(
1213 OntologyRelation::from_iri("http://www.w3.org/2000/01/rdf-schema#subClassOf"),
1214 Some(OntologyRelation::SubClassOf)
1215 );
1216 assert_eq!(
1217 OntologyRelation::from_iri("http://www.w3.org/2002/07/owl#equivalentClass"),
1218 Some(OntologyRelation::EquivalentClass)
1219 );
1220 assert_eq!(
1221 OntologyRelation::from_iri("http://example.org/custom"),
1222 None
1223 );
1224 }
1225
1226 #[test]
1227 fn test_ontology_constraint_extraction() {
1228 let config = OntologyAwareEmbedding::hierarchy_optimized_config(50);
1229 let mut model = OntologyAwareEmbedding::new(config);
1230
1231 let triples = vec![
1233 Triple::new(
1234 NamedNode::new("http://example.org/Dog").unwrap(),
1235 NamedNode::new("http://www.w3.org/2000/01/rdf-schema#subClassOf").unwrap(),
1236 NamedNode::new("http://example.org/Animal").unwrap(),
1237 ),
1238 Triple::new(
1239 NamedNode::new("http://example.org/Cat").unwrap(),
1240 NamedNode::new("http://www.w3.org/2000/01/rdf-schema#subClassOf").unwrap(),
1241 NamedNode::new("http://example.org/Animal").unwrap(),
1242 ),
1243 Triple::new(
1244 NamedNode::new("http://example.org/Canine").unwrap(),
1245 NamedNode::new("http://www.w3.org/2002/07/owl#equivalentClass").unwrap(),
1246 NamedNode::new("http://example.org/Dog").unwrap(),
1247 ),
1248 ];
1249
1250 model.triples = triples;
1251 model.extract_ontology_constraints();
1252
1253 assert!(model
1255 .ontology_constraints
1256 .class_hierarchy
1257 .contains_key("http://example.org/Dog"));
1258 assert!(model
1259 .ontology_constraints
1260 .class_hierarchy
1261 .contains_key("http://example.org/Cat"));
1262
1263 assert!(model
1265 .ontology_constraints
1266 .equivalent_classes
1267 .contains_key("http://example.org/Canine"));
1268 assert!(model
1269 .ontology_constraints
1270 .equivalent_classes
1271 .contains_key("http://example.org/Dog"));
1272 }
1273
1274 #[test]
1275 fn test_transitive_closure_computation() {
1276 let mut constraints = OntologyConstraints::default();
1277
1278 constraints.class_hierarchy.insert("A".to_string(), {
1280 let mut set = HashSet::new();
1281 set.insert("B".to_string());
1282 set
1283 });
1284 constraints.class_hierarchy.insert("B".to_string(), {
1285 let mut set = HashSet::new();
1286 set.insert("C".to_string());
1287 set
1288 });
1289
1290 constraints.compute_transitive_closure(5);
1291
1292 assert!(constraints.is_subclass_of("A", "B"));
1294 assert!(constraints.is_subclass_of("A", "C"));
1295 assert!(constraints.is_subclass_of("B", "C"));
1296 assert!(!constraints.is_subclass_of("C", "A"));
1297 }
1298
1299 #[test]
1300 fn test_ontology_aware_config_factory_methods() {
1301 let hierarchy_config = OntologyAwareEmbedding::hierarchy_optimized_config(100);
1302 assert_eq!(hierarchy_config.base_config.dimensions, 100);
1303 assert_eq!(hierarchy_config.hierarchy_weight, 2.0);
1304 assert!(hierarchy_config.use_transitive_closure);
1305
1306 let property_config = OntologyAwareEmbedding::property_optimized_config(100);
1307 assert_eq!(property_config.disjoint_weight, 2.0);
1308 assert_eq!(property_config.max_transitive_depth, 8);
1309 }
1310
1311 #[tokio::test]
1312 async fn test_ontology_aware_embedding_training() {
1313 let config = OntologyAwareEmbedding::hierarchy_optimized_config(32);
1314 let mut model = OntologyAwareEmbedding::new(config);
1315
1316 model
1318 .add_triple(Triple::new(
1319 NamedNode::new("http://example.org/Dog").unwrap(),
1320 NamedNode::new("http://www.w3.org/2000/01/rdf-schema#subClassOf").unwrap(),
1321 NamedNode::new("http://example.org/Animal").unwrap(),
1322 ))
1323 .unwrap();
1324
1325 model
1326 .add_triple(Triple::new(
1327 NamedNode::new("http://example.org/Fido").unwrap(),
1328 NamedNode::new("http://www.w3.org/1999/02/22-rdf-syntax-ns#type").unwrap(),
1329 NamedNode::new("http://example.org/Dog").unwrap(),
1330 ))
1331 .unwrap();
1332
1333 let result = model.train(Some(10)).await;
1334 assert!(result.is_ok());
1335
1336 assert!(model
1338 .entity_embeddings
1339 .contains_key("http://example.org/Dog"));
1340 assert!(model
1341 .entity_embeddings
1342 .contains_key("http://example.org/Animal"));
1343 assert!(model
1344 .entity_embeddings
1345 .contains_key("http://example.org/Fido"));
1346
1347 let dog_embedding = model.get_entity_embedding("http://example.org/Dog");
1349 assert!(dog_embedding.is_ok());
1350 assert_eq!(dog_embedding.unwrap().dimensions, 32);
1351
1352 assert!(model.is_trained());
1354 }
1355}