oxirs_embed/models/
ontology.rs

1//! Ontology-aware embedding models
2//!
3//! This module provides embedding models that understand and respect RDF schema
4//! relationships such as class hierarchies, property constraints, and semantic
5//! relationships defined in OWL and RDFS ontologies.
6
7use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
8use anyhow::{anyhow, Result};
9use async_trait::async_trait;
10use chrono::Utc;
11use scirs2_core::ndarray_ext::Array1;
12use serde::{Deserialize, Serialize};
13use std::collections::{HashMap, HashSet};
14use uuid::Uuid;
15
16/// Ontology relationship types that are semantically meaningful
17#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
18pub enum OntologyRelation {
19    /// rdfs:subClassOf - class hierarchy relationships
20    SubClassOf,
21    /// owl:equivalentClass - equivalent class relationships
22    EquivalentClass,
23    /// owl:disjointWith - disjoint class relationships
24    DisjointWith,
25    /// rdfs:domain - property domain constraints
26    Domain,
27    /// rdfs:range - property range constraints
28    Range,
29    /// owl:inverseOf - inverse property relationships
30    InverseOf,
31    /// owl:functionalProperty - functional property constraints
32    FunctionalProperty,
33    /// owl:symmetricProperty - symmetric property constraints
34    SymmetricProperty,
35    /// owl:transitiveProperty - transitive property constraints
36    TransitiveProperty,
37}
38
39/// Property characteristics for enhanced constraint handling
40#[derive(Debug, Clone, Default, Serialize, Deserialize)]
41pub struct PropertyCharacteristics {
42    pub is_functional: bool,
43    pub is_inverse_functional: bool,
44    pub is_symmetric: bool,
45    pub is_asymmetric: bool,
46    pub is_transitive: bool,
47    pub is_reflexive: bool,
48    pub is_irreflexive: bool,
49    pub has_inverse: Option<String>,
50    pub domain_classes: HashSet<String>,
51    pub range_classes: HashSet<String>,
52}
53
54impl PropertyCharacteristics {
55    /// Check if property has any domain constraints
56    pub fn has_domain_constraints(&self) -> bool {
57        !self.domain_classes.is_empty()
58    }
59
60    /// Check if property has any range constraints
61    pub fn has_range_constraints(&self) -> bool {
62        !self.range_classes.is_empty()
63    }
64
65    /// Check if entity satisfies domain constraints for this property
66    pub fn satisfies_domain(&self, entity_type: &str) -> bool {
67        if self.domain_classes.is_empty() {
68            true // No constraints means always satisfied
69        } else {
70            self.domain_classes.contains(entity_type)
71        }
72    }
73
74    /// Check if entity satisfies range constraints for this property
75    pub fn satisfies_range(&self, entity_type: &str) -> bool {
76        if self.range_classes.is_empty() {
77            true // No constraints means always satisfied
78        } else {
79            self.range_classes.contains(entity_type)
80        }
81    }
82}
83
84impl OntologyRelation {
85    /// Convert from RDF predicate IRI to ontology relation type
86    pub fn from_iri(iri: &str) -> Option<Self> {
87        match iri {
88            "http://www.w3.org/2000/01/rdf-schema#subClassOf" => Some(Self::SubClassOf),
89            "http://www.w3.org/2002/07/owl#equivalentClass" => Some(Self::EquivalentClass),
90            "http://www.w3.org/2002/07/owl#disjointWith" => Some(Self::DisjointWith),
91            "http://www.w3.org/2000/01/rdf-schema#domain" => Some(Self::Domain),
92            "http://www.w3.org/2000/01/rdf-schema#range" => Some(Self::Range),
93            "http://www.w3.org/2002/07/owl#inverseOf" => Some(Self::InverseOf),
94            "http://www.w3.org/2002/07/owl#FunctionalProperty" => Some(Self::FunctionalProperty),
95            "http://www.w3.org/2002/07/owl#SymmetricProperty" => Some(Self::SymmetricProperty),
96            "http://www.w3.org/2002/07/owl#TransitiveProperty" => Some(Self::TransitiveProperty),
97            _ => None,
98        }
99    }
100}
101
102/// Configuration for ontology-aware embeddings
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct OntologyAwareConfig {
105    pub base_config: ModelConfig,
106    /// Weight for hierarchy constraint enforcement
107    pub hierarchy_weight: f32,
108    /// Weight for equivalence constraint enforcement
109    pub equivalence_weight: f32,
110    /// Weight for disjointness constraint enforcement
111    pub disjoint_weight: f32,
112    /// Weight for property domain/range constraint enforcement
113    pub property_constraint_weight: f32,
114    /// Weight for multi-modal alignment
115    pub cross_modal_weight: f32,
116    /// Whether to use transitive closure for hierarchies
117    pub use_transitive_closure: bool,
118    /// Maximum depth for transitive closure computation
119    pub max_transitive_depth: usize,
120    /// Whether to normalize embeddings for hierarchy preservation
121    pub normalize_for_hierarchy: bool,
122    /// Margin for ranking loss in hierarchy constraints
123    pub hierarchy_margin: f32,
124    /// Enable contrastive learning for cross-modal alignment
125    pub enable_contrastive_learning: bool,
126    /// Temperature parameter for contrastive learning
127    pub contrastive_temperature: f32,
128    /// Enable mutual information maximization
129    pub enable_mutual_info_max: bool,
130    /// Enable property chain inference
131    pub enable_property_chains: bool,
132    /// Maximum length for property chains
133    pub max_property_chain_length: usize,
134}
135
136impl Default for OntologyAwareConfig {
137    fn default() -> Self {
138        Self {
139            base_config: ModelConfig::default(),
140            hierarchy_weight: 1.0,
141            equivalence_weight: 2.0,
142            disjoint_weight: 1.5,
143            property_constraint_weight: 1.2,
144            cross_modal_weight: 0.8,
145            use_transitive_closure: true,
146            max_transitive_depth: 10,
147            normalize_for_hierarchy: true,
148            hierarchy_margin: 1.0,
149            enable_contrastive_learning: true,
150            contrastive_temperature: 0.1,
151            enable_mutual_info_max: false,
152            enable_property_chains: true,
153            max_property_chain_length: 3,
154        }
155    }
156}
157
158/// Ontology-aware embedding model that respects semantic relationships
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct OntologyAwareEmbedding {
161    pub config: OntologyAwareConfig,
162    /// Unique model identifier
163    pub model_id: Uuid,
164    /// Entity embeddings
165    pub entity_embeddings: HashMap<String, Array1<f32>>,
166    /// Relation embeddings
167    pub relation_embeddings: HashMap<String, Array1<f32>>,
168    /// Entity to index mapping
169    pub entity_to_idx: HashMap<String, usize>,
170    /// Relation to index mapping
171    pub relation_to_idx: HashMap<String, usize>,
172    /// Training triples
173    pub triples: Vec<Triple>,
174    /// Ontology constraints extracted from triples
175    pub ontology_constraints: OntologyConstraints,
176    /// Training statistics
177    pub training_stats: TrainingStats,
178    /// Model statistics
179    pub model_stats: ModelStats,
180    /// Whether the model has been trained
181    pub is_trained: bool,
182}
183
184/// Ontology constraints extracted from RDF data
185#[derive(Debug, Clone, Default, Serialize, Deserialize)]
186pub struct OntologyConstraints {
187    /// Class hierarchy: subclass -> set of superclasses
188    pub class_hierarchy: HashMap<String, HashSet<String>>,
189    /// Equivalent classes: class -> set of equivalent classes
190    pub equivalent_classes: HashMap<String, HashSet<String>>,
191    /// Disjoint classes: class -> set of disjoint classes
192    pub disjoint_classes: HashMap<String, HashSet<String>>,
193    /// Property domains: property -> set of domain classes
194    pub property_domains: HashMap<String, HashSet<String>>,
195    /// Property ranges: property -> set of range classes
196    pub property_ranges: HashMap<String, HashSet<String>>,
197    /// Inverse properties: property -> inverse property
198    pub inverse_properties: HashMap<String, String>,
199    /// Functional properties
200    pub functional_properties: HashSet<String>,
201    /// Symmetric properties
202    pub symmetric_properties: HashSet<String>,
203    /// Transitive properties
204    pub transitive_properties: HashSet<String>,
205    /// Transitive closure of class hierarchy
206    pub transitive_hierarchy: HashMap<String, HashSet<String>>,
207    /// Property chains: property -> list of property sequences that imply it
208    pub property_chains: HashMap<String, Vec<Vec<String>>>,
209    /// Cross-modal alignments: entity -> set of aligned entities from other modalities
210    pub cross_modal_alignments: HashMap<String, HashSet<String>>,
211    /// Property characteristics cache for fast lookup
212    pub property_characteristics: HashMap<String, PropertyCharacteristics>,
213}
214
215impl OntologyConstraints {
216    /// Compute transitive closure of class hierarchy
217    pub fn compute_transitive_closure(&mut self, max_depth: usize) {
218        self.transitive_hierarchy = self.class_hierarchy.clone();
219
220        for _ in 0..max_depth {
221            let mut changed = false;
222            let current_hierarchy = self.transitive_hierarchy.clone();
223
224            for (subclass, superclasses) in &current_hierarchy {
225                let mut new_superclasses = superclasses.clone();
226
227                // Add superclasses of superclasses
228                for superclass in superclasses {
229                    if let Some(super_superclasses) = current_hierarchy.get(superclass) {
230                        for super_superclass in super_superclasses {
231                            if !new_superclasses.contains(super_superclass)
232                                && super_superclass != subclass
233                            {
234                                new_superclasses.insert(super_superclass.clone());
235                                changed = true;
236                            }
237                        }
238                    }
239                }
240
241                self.transitive_hierarchy
242                    .insert(subclass.clone(), new_superclasses);
243            }
244
245            if !changed {
246                break;
247            }
248        }
249    }
250
251    /// Check if one class is a subclass of another (considering transitive closure)
252    pub fn is_subclass_of(&self, subclass: &str, superclass: &str) -> bool {
253        if let Some(superclasses) = self.transitive_hierarchy.get(subclass) {
254            superclasses.contains(superclass)
255        } else {
256            false
257        }
258    }
259
260    /// Check if two classes are equivalent
261    pub fn are_equivalent(&self, class1: &str, class2: &str) -> bool {
262        if let Some(equivalent) = self.equivalent_classes.get(class1) {
263            equivalent.contains(class2)
264        } else {
265            false
266        }
267    }
268
269    /// Check if two classes are disjoint
270    pub fn are_disjoint(&self, class1: &str, class2: &str) -> bool {
271        if let Some(disjoint) = self.disjoint_classes.get(class1) {
272            disjoint.contains(class2)
273        } else {
274            false
275        }
276    }
277
278    /// Add property chain constraint
279    pub fn add_property_chain(&mut self, target_property: &str, chain: Vec<String>) {
280        self.property_chains
281            .entry(target_property.to_string())
282            .or_default()
283            .push(chain);
284    }
285
286    /// Get property chains that imply the given property
287    pub fn get_property_chains(&self, property: &str) -> Option<&Vec<Vec<String>>> {
288        self.property_chains.get(property)
289    }
290
291    /// Add cross-modal alignment between entities
292    pub fn add_cross_modal_alignment(&mut self, entity1: &str, entity2: &str) {
293        self.cross_modal_alignments
294            .entry(entity1.to_string())
295            .or_default()
296            .insert(entity2.to_string());
297
298        // Symmetric relationship
299        self.cross_modal_alignments
300            .entry(entity2.to_string())
301            .or_default()
302            .insert(entity1.to_string());
303    }
304
305    /// Get cross-modal alignments for an entity
306    pub fn get_cross_modal_alignments(&self, entity: &str) -> Option<&HashSet<String>> {
307        self.cross_modal_alignments.get(entity)
308    }
309
310    /// Build property characteristics cache
311    pub fn build_property_characteristics_cache(&mut self) {
312        // Build comprehensive property characteristics
313        let all_properties: HashSet<String> = self
314            .property_domains
315            .keys()
316            .chain(self.property_ranges.keys())
317            .chain(self.functional_properties.iter())
318            .chain(self.symmetric_properties.iter())
319            .chain(self.transitive_properties.iter())
320            .chain(self.inverse_properties.keys())
321            .cloned()
322            .collect();
323
324        for property in all_properties {
325            let mut characteristics = PropertyCharacteristics {
326                is_functional: self.functional_properties.contains(&property),
327                is_symmetric: self.symmetric_properties.contains(&property),
328                is_transitive: self.transitive_properties.contains(&property),
329                has_inverse: self.inverse_properties.get(&property).cloned(),
330                ..Default::default()
331            };
332
333            if let Some(domains) = self.property_domains.get(&property) {
334                characteristics.domain_classes = domains.clone();
335            }
336
337            if let Some(ranges) = self.property_ranges.get(&property) {
338                characteristics.range_classes = ranges.clone();
339            }
340
341            self.property_characteristics
342                .insert(property, characteristics);
343        }
344    }
345
346    /// Validate property usage based on domain/range constraints
347    pub fn validate_property_usage(
348        &self,
349        subject: &str,
350        property: &str,
351        object: &str,
352        entity_types: &HashMap<String, String>,
353    ) -> bool {
354        if let Some(characteristics) = self.property_characteristics.get(property) {
355            // Check domain constraints
356            if characteristics.has_domain_constraints() {
357                if let Some(subject_type) = entity_types.get(subject) {
358                    if !characteristics.satisfies_domain(subject_type) {
359                        return false;
360                    }
361                }
362            }
363
364            // Check range constraints
365            if characteristics.has_range_constraints() {
366                if let Some(object_type) = entity_types.get(object) {
367                    if !characteristics.satisfies_range(object_type) {
368                        return false;
369                    }
370                }
371            }
372        }
373
374        true
375    }
376
377    /// Infer new triples based on property chains
378    pub fn infer_from_property_chains(&self, existing_triples: &[Triple]) -> Vec<Triple> {
379        let mut inferred_triples = Vec::new();
380
381        for (target_property, chains) in &self.property_chains {
382            for chain in chains {
383                if chain.len() >= 2 {
384                    // Find sequences of triples that match the property chain
385                    inferred_triples.extend(self.find_chain_matches(
386                        existing_triples,
387                        target_property,
388                        chain,
389                    ));
390                }
391            }
392        }
393
394        inferred_triples
395    }
396
397    /// Find triple sequences that match a property chain pattern
398    fn find_chain_matches(
399        &self,
400        triples: &[Triple],
401        target_property: &str,
402        chain: &[String],
403    ) -> Vec<Triple> {
404        let mut matches = Vec::new();
405
406        // Build index of triples by predicate for faster lookup
407        let mut triples_by_predicate: HashMap<String, Vec<&Triple>> = HashMap::new();
408        for triple in triples {
409            triples_by_predicate
410                .entry(triple.predicate.iri.clone())
411                .or_default()
412                .push(triple);
413        }
414
415        // For simplicity, handle 2-property chains
416        if chain.len() == 2 {
417            let prop1 = &chain[0];
418            let prop2 = &chain[1];
419
420            if let (Some(triples1), Some(triples2)) = (
421                triples_by_predicate.get(prop1),
422                triples_by_predicate.get(prop2),
423            ) {
424                for t1 in triples1 {
425                    for t2 in triples2 {
426                        // Check if t1.object == t2.subject (chain connection)
427                        if t1.object.iri == t2.subject.iri {
428                            // Create inferred triple: t1.subject target_property t2.object
429                            if let (Ok(subject), Ok(predicate), Ok(object)) = (
430                                crate::NamedNode::new(&t1.subject.iri),
431                                crate::NamedNode::new(target_property),
432                                crate::NamedNode::new(&t2.object.iri),
433                            ) {
434                                matches.push(Triple::new(subject, predicate, object));
435                            }
436                        }
437                    }
438                }
439            }
440        }
441
442        matches
443    }
444}
445
446impl Default for TrainingStats {
447    fn default() -> Self {
448        Self {
449            epochs_completed: 0,
450            final_loss: 0.0,
451            training_time_seconds: 0.0,
452            convergence_achieved: false,
453            loss_history: Vec::new(),
454        }
455    }
456}
457
458impl OntologyAwareEmbedding {
459    /// Create new ontology-aware embedding model
460    pub fn new(config: OntologyAwareConfig) -> Self {
461        let model_id = Uuid::new_v4();
462        let now = Utc::now();
463
464        Self {
465            model_id,
466            entity_embeddings: HashMap::new(),
467            relation_embeddings: HashMap::new(),
468            entity_to_idx: HashMap::new(),
469            relation_to_idx: HashMap::new(),
470            triples: Vec::new(),
471            ontology_constraints: OntologyConstraints::default(),
472            training_stats: TrainingStats::default(),
473            model_stats: ModelStats {
474                num_entities: 0,
475                num_relations: 0,
476                num_triples: 0,
477                dimensions: config.base_config.dimensions,
478                is_trained: false,
479                model_type: "OntologyAware".to_string(),
480                creation_time: now,
481                last_training_time: None,
482            },
483            is_trained: false,
484            config,
485        }
486    }
487
488    /// Create configuration optimized for class hierarchies
489    pub fn hierarchy_optimized_config(dimensions: usize) -> OntologyAwareConfig {
490        OntologyAwareConfig {
491            base_config: ModelConfig::default().with_dimensions(dimensions),
492            hierarchy_weight: 2.0,
493            equivalence_weight: 1.0,
494            disjoint_weight: 1.0,
495            property_constraint_weight: 1.0,
496            cross_modal_weight: 0.5,
497            use_transitive_closure: true,
498            max_transitive_depth: 15,
499            normalize_for_hierarchy: true,
500            hierarchy_margin: 0.5,
501            enable_contrastive_learning: false,
502            contrastive_temperature: 0.1,
503            enable_mutual_info_max: false,
504            enable_property_chains: true,
505            max_property_chain_length: 2,
506        }
507    }
508
509    /// Create configuration optimized for property constraints
510    pub fn property_optimized_config(dimensions: usize) -> OntologyAwareConfig {
511        OntologyAwareConfig {
512            base_config: ModelConfig::default().with_dimensions(dimensions),
513            hierarchy_weight: 1.0,
514            equivalence_weight: 1.5,
515            disjoint_weight: 2.0,
516            property_constraint_weight: 2.5,
517            cross_modal_weight: 1.0,
518            use_transitive_closure: true,
519            max_transitive_depth: 8,
520            normalize_for_hierarchy: false,
521            hierarchy_margin: 1.0,
522            enable_contrastive_learning: true,
523            contrastive_temperature: 0.05,
524            enable_mutual_info_max: true,
525            enable_property_chains: true,
526            max_property_chain_length: 3,
527        }
528    }
529
530    /// Extract ontology constraints from triples
531    fn extract_ontology_constraints(&mut self) {
532        for triple in &self.triples {
533            if let Some(relation_type) = OntologyRelation::from_iri(&triple.predicate.iri) {
534                match relation_type {
535                    OntologyRelation::SubClassOf => {
536                        self.ontology_constraints
537                            .class_hierarchy
538                            .entry(triple.subject.iri.clone())
539                            .or_default()
540                            .insert(triple.object.iri.clone());
541                    }
542                    OntologyRelation::EquivalentClass => {
543                        self.ontology_constraints
544                            .equivalent_classes
545                            .entry(triple.subject.iri.clone())
546                            .or_default()
547                            .insert(triple.object.iri.clone());
548                        // Symmetric relationship
549                        self.ontology_constraints
550                            .equivalent_classes
551                            .entry(triple.object.iri.clone())
552                            .or_default()
553                            .insert(triple.subject.iri.clone());
554                    }
555                    OntologyRelation::DisjointWith => {
556                        self.ontology_constraints
557                            .disjoint_classes
558                            .entry(triple.subject.iri.clone())
559                            .or_default()
560                            .insert(triple.object.iri.clone());
561                        // Symmetric relationship
562                        self.ontology_constraints
563                            .disjoint_classes
564                            .entry(triple.object.iri.clone())
565                            .or_default()
566                            .insert(triple.subject.iri.clone());
567                    }
568                    OntologyRelation::Domain => {
569                        self.ontology_constraints
570                            .property_domains
571                            .entry(triple.subject.iri.clone())
572                            .or_default()
573                            .insert(triple.object.iri.clone());
574                    }
575                    OntologyRelation::Range => {
576                        self.ontology_constraints
577                            .property_ranges
578                            .entry(triple.subject.iri.clone())
579                            .or_default()
580                            .insert(triple.object.iri.clone());
581                    }
582                    OntologyRelation::InverseOf => {
583                        self.ontology_constraints
584                            .inverse_properties
585                            .insert(triple.subject.iri.clone(), triple.object.iri.clone());
586                        self.ontology_constraints
587                            .inverse_properties
588                            .insert(triple.object.iri.clone(), triple.subject.iri.clone());
589                    }
590                    OntologyRelation::FunctionalProperty => {
591                        self.ontology_constraints
592                            .functional_properties
593                            .insert(triple.subject.iri.clone());
594                    }
595                    OntologyRelation::SymmetricProperty => {
596                        self.ontology_constraints
597                            .symmetric_properties
598                            .insert(triple.subject.iri.clone());
599                    }
600                    OntologyRelation::TransitiveProperty => {
601                        self.ontology_constraints
602                            .transitive_properties
603                            .insert(triple.subject.iri.clone());
604                    }
605                }
606            }
607        }
608
609        // Compute transitive closure if enabled
610        if self.config.use_transitive_closure {
611            self.ontology_constraints
612                .compute_transitive_closure(self.config.max_transitive_depth);
613        }
614    }
615
616    /// Compute hierarchy-preserving loss
617    fn compute_hierarchy_loss(&self) -> f32 {
618        let mut total_loss = 0.0;
619        let mut count = 0;
620
621        for (subclass, superclasses) in &self.ontology_constraints.transitive_hierarchy {
622            if let Some(sub_emb) = self.entity_embeddings.get(subclass) {
623                for superclass in superclasses {
624                    if let Some(super_emb) = self.entity_embeddings.get(superclass) {
625                        // Subclass embedding should be "closer" to origin than superclass
626                        // in the direction of the hierarchy
627                        let sub_norm = sub_emb.dot(sub_emb).sqrt();
628                        let super_norm = super_emb.dot(super_emb).sqrt();
629                        let similarity = sub_emb.dot(super_emb) / (sub_norm * super_norm + 1e-8);
630
631                        // Hierarchy loss: encourage high similarity and proper ordering
632                        let hierarchy_score = similarity + (super_norm - sub_norm) * 0.1;
633                        let loss = (self.config.hierarchy_margin - hierarchy_score).max(0.0);
634                        total_loss += loss;
635                        count += 1;
636                    }
637                }
638            }
639        }
640
641        if count > 0 {
642            total_loss / count as f32
643        } else {
644            0.0
645        }
646    }
647
648    /// Compute equivalence loss
649    fn compute_equivalence_loss(&self) -> f32 {
650        let mut total_loss = 0.0;
651        let mut count = 0;
652
653        for (class1, equivalent_classes) in &self.ontology_constraints.equivalent_classes {
654            if let Some(emb1) = self.entity_embeddings.get(class1) {
655                for class2 in equivalent_classes {
656                    if let Some(emb2) = self.entity_embeddings.get(class2) {
657                        // Equivalent classes should have very similar embeddings
658                        let distance = (emb1 - emb2).mapv(|x| x * x).sum().sqrt();
659                        total_loss += distance;
660                        count += 1;
661                    }
662                }
663            }
664        }
665
666        if count > 0 {
667            total_loss / count as f32
668        } else {
669            0.0
670        }
671    }
672
673    /// Compute disjointness loss
674    fn compute_disjoint_loss(&self) -> f32 {
675        let mut total_loss = 0.0;
676        let mut count = 0;
677
678        for (class1, disjoint_classes) in &self.ontology_constraints.disjoint_classes {
679            if let Some(emb1) = self.entity_embeddings.get(class1) {
680                for class2 in disjoint_classes {
681                    if let Some(emb2) = self.entity_embeddings.get(class2) {
682                        // Disjoint classes should have low similarity
683                        let norm1 = emb1.dot(emb1).sqrt();
684                        let norm2 = emb2.dot(emb2).sqrt();
685                        let similarity = emb1.dot(emb2) / (norm1 * norm2 + 1e-8);
686                        let loss = (similarity + self.config.hierarchy_margin).max(0.0);
687                        total_loss += loss;
688                        count += 1;
689                    }
690                }
691            }
692        }
693
694        if count > 0 {
695            total_loss / count as f32
696        } else {
697            0.0
698        }
699    }
700
701    /// Compute property constraint loss (domain/range constraints)
702    fn compute_property_constraint_loss(&self) -> f32 {
703        let mut total_loss = 0.0;
704        let mut count = 0;
705
706        // Check domain constraints
707        for (property, domains) in &self.ontology_constraints.property_domains {
708            if let Some(relation_emb) = self.relation_embeddings.get(property) {
709                for domain_class in domains {
710                    if let Some(domain_emb) = self.entity_embeddings.get(domain_class) {
711                        // Domain entities should be compatible with property
712                        let compatibility = relation_emb.dot(domain_emb);
713                        let loss = (1.0 - compatibility).max(0.0); // Encourage positive compatibility
714                        total_loss += loss;
715                        count += 1;
716                    }
717                }
718            }
719        }
720
721        // Check range constraints
722        for (property, ranges) in &self.ontology_constraints.property_ranges {
723            if let Some(relation_emb) = self.relation_embeddings.get(property) {
724                for range_class in ranges {
725                    if let Some(range_emb) = self.entity_embeddings.get(range_class) {
726                        // Range entities should be compatible with property
727                        let compatibility = relation_emb.dot(range_emb);
728                        let loss = (1.0 - compatibility).max(0.0);
729                        total_loss += loss;
730                        count += 1;
731                    }
732                }
733            }
734        }
735
736        if count > 0 {
737            total_loss / count as f32
738        } else {
739            0.0
740        }
741    }
742
743    /// Compute contrastive learning loss for cross-modal alignment
744    fn compute_contrastive_loss(&self) -> f32 {
745        if !self.config.enable_contrastive_learning {
746            return 0.0;
747        }
748
749        let mut total_loss = 0.0;
750        let mut count = 0;
751        let temperature = self.config.contrastive_temperature;
752
753        for (entity1, aligned_entities) in &self.ontology_constraints.cross_modal_alignments {
754            if let Some(emb1) = self.entity_embeddings.get(entity1) {
755                for entity2 in aligned_entities {
756                    if let Some(emb2) = self.entity_embeddings.get(entity2) {
757                        // Positive similarity
758                        let pos_sim = emb1.dot(emb2) / temperature;
759
760                        // Negative similarities (sample random entities)
761                        let mut neg_sims = Vec::new();
762                        for (neg_entity, neg_emb) in self.entity_embeddings.iter().take(10) {
763                            if neg_entity != entity1 && neg_entity != entity2 {
764                                let neg_sim = emb1.dot(neg_emb) / temperature;
765                                neg_sims.push(neg_sim);
766                            }
767                        }
768
769                        if !neg_sims.is_empty() {
770                            // InfoNCE loss
771                            let exp_pos = pos_sim.exp();
772                            let sum_exp_neg: f32 = neg_sims.iter().copied().map(|x| x.exp()).sum();
773                            let loss = -(exp_pos / (exp_pos + sum_exp_neg)).ln();
774                            total_loss += loss;
775                            count += 1;
776                        }
777                    }
778                }
779            }
780        }
781
782        if count > 0 {
783            total_loss / count as f32
784        } else {
785            0.0
786        }
787    }
788
789    /// Compute mutual information maximization loss
790    fn compute_mutual_info_loss(&self) -> f32 {
791        if !self.config.enable_mutual_info_max {
792            return 0.0;
793        }
794
795        let mut total_loss = 0.0;
796        let mut count = 0;
797
798        // Simplified mutual information between entity and relation embeddings
799        for (entity, entity_emb) in &self.entity_embeddings {
800            for relation_emb in self.relation_embeddings.values() {
801                // Check if this entity-relation pair appears in training data
802                let pair_exists = self
803                    .triples
804                    .iter()
805                    .any(|t| t.subject.iri == *entity || t.object.iri == *entity);
806
807                if pair_exists {
808                    // Encourage high mutual information for related pairs
809                    let mi = entity_emb.dot(relation_emb);
810                    let loss = (1.0 - mi).max(0.0);
811                    total_loss += loss;
812                    count += 1;
813                }
814            }
815        }
816
817        if count > 0 {
818            total_loss / count as f32
819        } else {
820            0.0
821        }
822    }
823
824    /// Compute property chain consistency loss
825    fn compute_property_chain_loss(&self) -> f32 {
826        if !self.config.enable_property_chains {
827            return 0.0;
828        }
829
830        let mut total_loss = 0.0;
831        let mut count = 0;
832
833        for (target_property, chains) in &self.ontology_constraints.property_chains {
834            if let Some(target_emb) = self.relation_embeddings.get(target_property) {
835                for chain in chains {
836                    if chain.len() == 2 {
837                        // For 2-property chains: target ≈ prop1 + prop2
838                        if let (Some(prop1_emb), Some(prop2_emb)) = (
839                            self.relation_embeddings.get(&chain[0]),
840                            self.relation_embeddings.get(&chain[1]),
841                        ) {
842                            let chain_emb = prop1_emb + prop2_emb;
843                            let distance = (target_emb - &chain_emb).mapv(|x| x * x).sum().sqrt();
844                            total_loss += distance;
845                            count += 1;
846                        }
847                    }
848                }
849            }
850        }
851
852        if count > 0 {
853            total_loss / count as f32
854        } else {
855            0.0
856        }
857    }
858}
859
860#[async_trait]
861impl EmbeddingModel for OntologyAwareEmbedding {
862    fn config(&self) -> &ModelConfig {
863        &self.config.base_config
864    }
865
866    fn model_id(&self) -> &Uuid {
867        &self.model_id
868    }
869
870    fn model_type(&self) -> &'static str {
871        "OntologyAware"
872    }
873
874    fn add_triple(&mut self, triple: Triple) -> Result<()> {
875        self.triples.push(triple);
876        Ok(())
877    }
878
879    async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
880        let start_time = std::time::Instant::now();
881
882        // Extract ontology constraints first
883        self.extract_ontology_constraints();
884
885        // Build property characteristics cache for enhanced constraint handling
886        self.ontology_constraints
887            .build_property_characteristics_cache();
888
889        // Build entity and relation vocabularies
890        let mut entity_set = HashSet::new();
891        let mut relation_set = HashSet::new();
892
893        for triple in &self.triples {
894            entity_set.insert(triple.subject.iri.clone());
895            entity_set.insert(triple.object.iri.clone());
896            relation_set.insert(triple.predicate.iri.clone());
897        }
898
899        // Create mappings
900        for (idx, entity) in entity_set.iter().enumerate() {
901            self.entity_to_idx.insert(entity.clone(), idx);
902        }
903
904        for (idx, relation) in relation_set.iter().enumerate() {
905            self.relation_to_idx.insert(relation.clone(), idx);
906        }
907
908        // Initialize embeddings
909        let dimensions = self.config.base_config.dimensions;
910        for entity in &entity_set {
911            let embedding = Array1::from_vec(
912                (0..dimensions)
913                    .map(|_| {
914                        use scirs2_core::random::{Random, Rng};
915                        let mut random = Random::default();
916                        (random.random::<f32>() - 0.5) * 0.1
917                    })
918                    .collect(),
919            );
920            self.entity_embeddings.insert(entity.clone(), embedding);
921        }
922
923        for relation in &relation_set {
924            let embedding = Array1::from_vec(
925                (0..dimensions)
926                    .map(|_| {
927                        use scirs2_core::random::{Random, Rng};
928                        let mut random = Random::default();
929                        (random.random::<f32>() - 0.5) * 0.1
930                    })
931                    .collect(),
932            );
933            self.relation_embeddings.insert(relation.clone(), embedding);
934        }
935
936        // Training loop with ontology constraints
937        let max_epochs = epochs.unwrap_or(self.config.base_config.max_epochs);
938        let learning_rate = self.config.base_config.learning_rate as f32;
939        let mut loss_history = Vec::new();
940
941        for epoch in 0..max_epochs {
942            let mut total_loss = 0.0;
943
944            // Standard TransE-style training
945            for triple in &self.triples {
946                if let (Some(h), Some(r), Some(t)) = (
947                    self.entity_embeddings.get(&triple.subject.iri).cloned(),
948                    self.relation_embeddings.get(&triple.predicate.iri).cloned(),
949                    self.entity_embeddings.get(&triple.object.iri).cloned(),
950                ) {
951                    // Compute score: ||h + r - t||
952                    let predicted = &h + &r;
953                    let error = &t - &predicted;
954                    let loss = error.dot(&error).sqrt();
955                    total_loss += loss;
956
957                    // Gradient updates
958                    let gradient_scale = learning_rate * 0.01;
959                    let h_grad = &error * gradient_scale;
960                    let r_grad = &error * gradient_scale;
961                    let t_grad = &error * (-gradient_scale);
962
963                    // Update embeddings
964                    if let Some(h_emb) = self.entity_embeddings.get_mut(&triple.subject.iri) {
965                        *h_emb += &h_grad;
966                    }
967                    if let Some(r_emb) = self.relation_embeddings.get_mut(&triple.predicate.iri) {
968                        *r_emb += &r_grad;
969                    }
970                    if let Some(t_emb) = self.entity_embeddings.get_mut(&triple.object.iri) {
971                        *t_emb += &t_grad;
972                    }
973                }
974            }
975
976            // Add ontology constraint losses
977            let hierarchy_loss = self.compute_hierarchy_loss();
978            let equivalence_loss = self.compute_equivalence_loss();
979            let disjoint_loss = self.compute_disjoint_loss();
980            let property_loss = self.compute_property_constraint_loss();
981            let contrastive_loss = self.compute_contrastive_loss();
982            let mutual_info_loss = self.compute_mutual_info_loss();
983            let property_chain_loss = self.compute_property_chain_loss();
984
985            total_loss += hierarchy_loss * self.config.hierarchy_weight;
986            total_loss += equivalence_loss * self.config.equivalence_weight;
987            total_loss += disjoint_loss * self.config.disjoint_weight;
988            total_loss += property_loss * self.config.property_constraint_weight;
989            total_loss += contrastive_loss * self.config.cross_modal_weight;
990            total_loss += mutual_info_loss * self.config.cross_modal_weight * 0.5;
991            total_loss += property_chain_loss * self.config.property_constraint_weight * 0.8;
992
993            loss_history.push(total_loss as f64);
994
995            // Normalize embeddings if configured
996            if self.config.normalize_for_hierarchy {
997                for embedding in self.entity_embeddings.values_mut() {
998                    let norm = embedding.dot(embedding).sqrt();
999                    if norm > 0.0 {
1000                        *embedding /= norm;
1001                    }
1002                }
1003            }
1004
1005            if epoch % 10 == 0 {
1006                tracing::info!(
1007                    "Epoch {}: total_loss={:.6}, hierarchy={:.6}, equiv={:.6}, disjoint={:.6}",
1008                    epoch,
1009                    total_loss,
1010                    hierarchy_loss,
1011                    equivalence_loss,
1012                    disjoint_loss
1013                );
1014            }
1015        }
1016
1017        let training_time = start_time.elapsed().as_secs_f64();
1018        self.is_trained = true;
1019
1020        // Update model stats
1021        self.model_stats.num_entities = entity_set.len();
1022        self.model_stats.num_relations = relation_set.len();
1023        self.model_stats.num_triples = self.triples.len();
1024        self.model_stats.is_trained = true;
1025        self.model_stats.last_training_time = Some(Utc::now());
1026
1027        // Update training stats
1028        self.training_stats = TrainingStats {
1029            epochs_completed: max_epochs,
1030            final_loss: loss_history.last().copied().unwrap_or(0.0),
1031            training_time_seconds: training_time,
1032            convergence_achieved: loss_history.last().copied().unwrap_or(0.0) < 0.01,
1033            loss_history,
1034        };
1035
1036        Ok(self.training_stats.clone())
1037    }
1038
1039    fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
1040        self.entity_embeddings
1041            .get(entity)
1042            .map(|arr| Vector::new(arr.to_vec()))
1043            .ok_or_else(|| anyhow!("Entity not found: {}", entity))
1044    }
1045
1046    fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
1047        self.relation_embeddings
1048            .get(relation)
1049            .map(|arr| Vector::new(arr.to_vec()))
1050            .ok_or_else(|| anyhow!("Relation not found: {}", relation))
1051    }
1052
1053    fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
1054        let h = self
1055            .entity_embeddings
1056            .get(subject)
1057            .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
1058        let r = self
1059            .relation_embeddings
1060            .get(predicate)
1061            .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
1062        let t = self
1063            .entity_embeddings
1064            .get(object)
1065            .ok_or_else(|| anyhow!("Object not found: {}", object))?;
1066
1067        // TransE scoring: ||h + r - t||
1068        let predicted = h + r;
1069        let distance = (&predicted - t).mapv(|x| x * x).sum().sqrt();
1070        Ok(-(distance as f64)) // Negative distance as higher scores are better
1071    }
1072
1073    fn predict_objects(
1074        &self,
1075        subject: &str,
1076        predicate: &str,
1077        k: usize,
1078    ) -> Result<Vec<(String, f64)>> {
1079        let h = self
1080            .entity_embeddings
1081            .get(subject)
1082            .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
1083        let r = self
1084            .relation_embeddings
1085            .get(predicate)
1086            .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
1087        let predicted = h + r;
1088
1089        let mut scores = Vec::new();
1090        for (entity, embedding) in &self.entity_embeddings {
1091            let distance = (&predicted - embedding).mapv(|x| x * x).sum().sqrt();
1092            scores.push((entity.clone(), -(distance as f64)));
1093        }
1094
1095        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
1096        scores.truncate(k);
1097        Ok(scores)
1098    }
1099
1100    fn predict_subjects(
1101        &self,
1102        predicate: &str,
1103        object: &str,
1104        k: usize,
1105    ) -> Result<Vec<(String, f64)>> {
1106        let r = self
1107            .relation_embeddings
1108            .get(predicate)
1109            .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
1110        let t = self
1111            .entity_embeddings
1112            .get(object)
1113            .ok_or_else(|| anyhow!("Object not found: {}", object))?;
1114        let target = t - r; // h = t - r
1115
1116        let mut scores = Vec::new();
1117        for (entity, embedding) in &self.entity_embeddings {
1118            let distance = (embedding - &target).mapv(|x| x * x).sum().sqrt();
1119            scores.push((entity.clone(), -(distance as f64)));
1120        }
1121
1122        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
1123        scores.truncate(k);
1124        Ok(scores)
1125    }
1126
1127    fn predict_relations(
1128        &self,
1129        subject: &str,
1130        object: &str,
1131        k: usize,
1132    ) -> Result<Vec<(String, f64)>> {
1133        let h = self
1134            .entity_embeddings
1135            .get(subject)
1136            .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
1137        let t = self
1138            .entity_embeddings
1139            .get(object)
1140            .ok_or_else(|| anyhow!("Object not found: {}", object))?;
1141        let target = t - h; // r = t - h
1142
1143        let mut scores = Vec::new();
1144        for (relation, embedding) in &self.relation_embeddings {
1145            let distance = (embedding - &target).mapv(|x| x * x).sum().sqrt();
1146            scores.push((relation.clone(), -(distance as f64)));
1147        }
1148
1149        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
1150        scores.truncate(k);
1151        Ok(scores)
1152    }
1153
1154    fn get_entities(&self) -> Vec<String> {
1155        self.entity_embeddings.keys().cloned().collect()
1156    }
1157
1158    fn get_relations(&self) -> Vec<String> {
1159        self.relation_embeddings.keys().cloned().collect()
1160    }
1161
1162    fn get_stats(&self) -> ModelStats {
1163        self.model_stats.clone()
1164    }
1165
1166    fn save(&self, path: &str) -> Result<()> {
1167        let serialized = serde_json::to_string_pretty(self)?;
1168        std::fs::write(path, serialized)?;
1169        Ok(())
1170    }
1171
1172    fn load(&mut self, path: &str) -> Result<()> {
1173        let content = std::fs::read_to_string(path)?;
1174        let loaded: OntologyAwareEmbedding = serde_json::from_str(&content)?;
1175        *self = loaded;
1176        Ok(())
1177    }
1178
1179    fn clear(&mut self) {
1180        self.entity_embeddings.clear();
1181        self.relation_embeddings.clear();
1182        self.entity_to_idx.clear();
1183        self.relation_to_idx.clear();
1184        self.triples.clear();
1185        self.ontology_constraints = OntologyConstraints::default();
1186        self.training_stats = TrainingStats::default();
1187        self.is_trained = false;
1188        self.model_stats.is_trained = false;
1189        self.model_stats.num_entities = 0;
1190        self.model_stats.num_relations = 0;
1191        self.model_stats.num_triples = 0;
1192    }
1193
1194    fn is_trained(&self) -> bool {
1195        self.is_trained
1196    }
1197
1198    async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
1199        Err(anyhow!(
1200            "Knowledge graph embedding model does not support text encoding"
1201        ))
1202    }
1203}
1204
1205#[cfg(test)]
1206mod tests {
1207    use super::*;
1208    use crate::NamedNode;
1209
1210    #[test]
1211    fn test_ontology_relation_from_iri() {
1212        assert_eq!(
1213            OntologyRelation::from_iri("http://www.w3.org/2000/01/rdf-schema#subClassOf"),
1214            Some(OntologyRelation::SubClassOf)
1215        );
1216        assert_eq!(
1217            OntologyRelation::from_iri("http://www.w3.org/2002/07/owl#equivalentClass"),
1218            Some(OntologyRelation::EquivalentClass)
1219        );
1220        assert_eq!(
1221            OntologyRelation::from_iri("http://example.org/custom"),
1222            None
1223        );
1224    }
1225
1226    #[test]
1227    fn test_ontology_constraint_extraction() {
1228        let config = OntologyAwareEmbedding::hierarchy_optimized_config(50);
1229        let mut model = OntologyAwareEmbedding::new(config);
1230
1231        // Create test triples with ontology relationships
1232        let triples = vec![
1233            Triple::new(
1234                NamedNode::new("http://example.org/Dog").unwrap(),
1235                NamedNode::new("http://www.w3.org/2000/01/rdf-schema#subClassOf").unwrap(),
1236                NamedNode::new("http://example.org/Animal").unwrap(),
1237            ),
1238            Triple::new(
1239                NamedNode::new("http://example.org/Cat").unwrap(),
1240                NamedNode::new("http://www.w3.org/2000/01/rdf-schema#subClassOf").unwrap(),
1241                NamedNode::new("http://example.org/Animal").unwrap(),
1242            ),
1243            Triple::new(
1244                NamedNode::new("http://example.org/Canine").unwrap(),
1245                NamedNode::new("http://www.w3.org/2002/07/owl#equivalentClass").unwrap(),
1246                NamedNode::new("http://example.org/Dog").unwrap(),
1247            ),
1248        ];
1249
1250        model.triples = triples;
1251        model.extract_ontology_constraints();
1252
1253        // Check class hierarchy extraction
1254        assert!(model
1255            .ontology_constraints
1256            .class_hierarchy
1257            .contains_key("http://example.org/Dog"));
1258        assert!(model
1259            .ontology_constraints
1260            .class_hierarchy
1261            .contains_key("http://example.org/Cat"));
1262
1263        // Check equivalent classes
1264        assert!(model
1265            .ontology_constraints
1266            .equivalent_classes
1267            .contains_key("http://example.org/Canine"));
1268        assert!(model
1269            .ontology_constraints
1270            .equivalent_classes
1271            .contains_key("http://example.org/Dog"));
1272    }
1273
1274    #[test]
1275    fn test_transitive_closure_computation() {
1276        let mut constraints = OntologyConstraints::default();
1277
1278        // A -> B -> C hierarchy
1279        constraints.class_hierarchy.insert("A".to_string(), {
1280            let mut set = HashSet::new();
1281            set.insert("B".to_string());
1282            set
1283        });
1284        constraints.class_hierarchy.insert("B".to_string(), {
1285            let mut set = HashSet::new();
1286            set.insert("C".to_string());
1287            set
1288        });
1289
1290        constraints.compute_transitive_closure(5);
1291
1292        // A should be a subclass of both B and C
1293        assert!(constraints.is_subclass_of("A", "B"));
1294        assert!(constraints.is_subclass_of("A", "C"));
1295        assert!(constraints.is_subclass_of("B", "C"));
1296        assert!(!constraints.is_subclass_of("C", "A"));
1297    }
1298
1299    #[test]
1300    fn test_ontology_aware_config_factory_methods() {
1301        let hierarchy_config = OntologyAwareEmbedding::hierarchy_optimized_config(100);
1302        assert_eq!(hierarchy_config.base_config.dimensions, 100);
1303        assert_eq!(hierarchy_config.hierarchy_weight, 2.0);
1304        assert!(hierarchy_config.use_transitive_closure);
1305
1306        let property_config = OntologyAwareEmbedding::property_optimized_config(100);
1307        assert_eq!(property_config.disjoint_weight, 2.0);
1308        assert_eq!(property_config.max_transitive_depth, 8);
1309    }
1310
1311    #[tokio::test]
1312    async fn test_ontology_aware_embedding_training() {
1313        let config = OntologyAwareEmbedding::hierarchy_optimized_config(32);
1314        let mut model = OntologyAwareEmbedding::new(config);
1315
1316        // Add triples using the trait method
1317        model
1318            .add_triple(Triple::new(
1319                NamedNode::new("http://example.org/Dog").unwrap(),
1320                NamedNode::new("http://www.w3.org/2000/01/rdf-schema#subClassOf").unwrap(),
1321                NamedNode::new("http://example.org/Animal").unwrap(),
1322            ))
1323            .unwrap();
1324
1325        model
1326            .add_triple(Triple::new(
1327                NamedNode::new("http://example.org/Fido").unwrap(),
1328                NamedNode::new("http://www.w3.org/1999/02/22-rdf-syntax-ns#type").unwrap(),
1329                NamedNode::new("http://example.org/Dog").unwrap(),
1330            ))
1331            .unwrap();
1332
1333        let result = model.train(Some(10)).await;
1334        assert!(result.is_ok());
1335
1336        // Check that embeddings were created
1337        assert!(model
1338            .entity_embeddings
1339            .contains_key("http://example.org/Dog"));
1340        assert!(model
1341            .entity_embeddings
1342            .contains_key("http://example.org/Animal"));
1343        assert!(model
1344            .entity_embeddings
1345            .contains_key("http://example.org/Fido"));
1346
1347        // Test embedding retrieval using the trait method
1348        let dog_embedding = model.get_entity_embedding("http://example.org/Dog");
1349        assert!(dog_embedding.is_ok());
1350        assert_eq!(dog_embedding.unwrap().dimensions, 32);
1351
1352        // Test that model is trained
1353        assert!(model.is_trained());
1354    }
1355}