1use crate::{
12 embeddings::{EmbeddableContent, EmbeddingManager, EmbeddingStrategy},
13 kg_embeddings::KGEmbeddingModel,
14 Vector,
15};
16use anyhow::{anyhow, Result};
17use parking_lot::RwLock;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::sync::Arc;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct RdfContentConfig {
25 pub enable_uri_embeddings: bool,
27 pub enable_property_aggregation: bool,
29 pub enable_multi_language: bool,
31 pub enable_temporal_encoding: bool,
33 pub max_path_length: usize,
35 pub default_language: String,
37 pub context_window_size: usize,
39 pub component_weights: ComponentWeights,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ComponentWeights {
45 pub uri_weight: f32,
46 pub label_weight: f32,
47 pub description_weight: f32,
48 pub property_weight: f32,
49 pub context_weight: f32,
50 pub temporal_weight: f32,
51}
52
53impl Default for ComponentWeights {
54 fn default() -> Self {
55 Self {
56 uri_weight: 0.1,
57 label_weight: 0.3,
58 description_weight: 0.3,
59 property_weight: 0.2,
60 context_weight: 0.05,
61 temporal_weight: 0.05,
62 }
63 }
64}
65
66impl Default for RdfContentConfig {
67 fn default() -> Self {
68 Self {
69 enable_uri_embeddings: true,
70 enable_property_aggregation: true,
71 enable_multi_language: true,
72 enable_temporal_encoding: false,
73 max_path_length: 3,
74 default_language: "en".to_string(),
75 context_window_size: 5,
76 component_weights: ComponentWeights::default(),
77 }
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct RdfEntity {
84 pub uri: String,
85 pub labels: HashMap<String, String>, pub descriptions: HashMap<String, String>, pub properties: HashMap<String, Vec<RdfValue>>,
88 pub types: Vec<String>,
89 pub context: Option<RdfContext>,
90 pub temporal_info: Option<TemporalInfo>,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub enum RdfValue {
96 IRI(String),
97 Literal(String, Option<String>), LangString(String, String), Boolean(bool),
100 Integer(i64),
101 Float(f64),
102 Date(String),
103 DateTime(String),
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct RdfContext {
109 pub graph_uri: Option<String>,
110 pub neighbors: Vec<String>, pub subgraph_signature: Option<String>, pub semantic_distance: HashMap<String, f32>, }
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct TemporalInfo {
118 pub valid_from: Option<String>,
119 pub valid_to: Option<String>,
120 pub created_at: Option<String>,
121 pub modified_at: Option<String>,
122 pub version: Option<String>,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct PropertyPath {
128 pub path: Vec<String>, pub direction: Vec<PathDirection>, pub constraints: Vec<PathConstraint>,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub enum PathDirection {
135 Forward,
136 Backward,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub enum PathConstraint {
141 TypeFilter(String),
142 PropertyFilter(String, RdfValue),
143 LanguageFilter(String),
144}
145
146pub struct RdfContentProcessor {
148 config: RdfContentConfig,
149 embedding_manager: Arc<RwLock<EmbeddingManager>>,
150 kg_embeddings: Option<Box<dyn KGEmbeddingModel>>,
151 entity_cache: HashMap<String, Vector>,
152 relationship_cache: HashMap<String, Vector>,
153 property_aggregator: PropertyAggregator,
154 multi_language_processor: MultiLanguageProcessor,
155}
156
157impl RdfContentProcessor {
158 pub fn new(config: RdfContentConfig, embedding_strategy: EmbeddingStrategy) -> Result<Self> {
160 let embedding_manager = Arc::new(RwLock::new(EmbeddingManager::new(
161 embedding_strategy,
162 1000,
163 )?));
164
165 Ok(Self {
166 config,
167 embedding_manager,
168 kg_embeddings: None,
169 entity_cache: HashMap::new(),
170 relationship_cache: HashMap::new(),
171 property_aggregator: PropertyAggregator::new(),
172 multi_language_processor: MultiLanguageProcessor::new(),
173 })
174 }
175
176 pub fn generate_entity_embedding(&mut self, entity: &RdfEntity) -> Result<Vector> {
178 if let Some(cached) = self.entity_cache.get(&entity.uri) {
180 return Ok(cached.clone());
181 }
182
183 let mut embedding_components = Vec::new();
184 let weights = &self.config.component_weights;
185
186 if self.config.enable_uri_embeddings {
188 let uri_embedding = self.generate_uri_embedding(&entity.uri)?;
189 embedding_components.push((uri_embedding, weights.uri_weight));
190 }
191
192 if !entity.labels.is_empty() {
194 let label_embedding = self.generate_label_embedding(&entity.labels)?;
195 embedding_components.push((label_embedding, weights.label_weight));
196 }
197
198 if !entity.descriptions.is_empty() {
200 let desc_embedding = self.generate_description_embedding(&entity.descriptions)?;
201 embedding_components.push((desc_embedding, weights.description_weight));
202 }
203
204 if self.config.enable_property_aggregation && !entity.properties.is_empty() {
206 let prop_embedding = self
207 .property_aggregator
208 .aggregate_properties(&entity.properties)?;
209 embedding_components.push((prop_embedding, weights.property_weight));
210 }
211
212 if let Some(context) = &entity.context {
214 let context_embedding = self.generate_context_embedding(context)?;
215 embedding_components.push((context_embedding, weights.context_weight));
216 }
217
218 if self.config.enable_temporal_encoding {
220 if let Some(temporal) = &entity.temporal_info {
221 let temporal_embedding = self.generate_temporal_embedding(temporal)?;
222 embedding_components.push((temporal_embedding, weights.temporal_weight));
223 }
224 }
225
226 let final_embedding = self.combine_embeddings(embedding_components)?;
228
229 self.entity_cache
231 .insert(entity.uri.clone(), final_embedding.clone());
232
233 Ok(final_embedding)
234 }
235
236 pub fn generate_property_path_embedding(&mut self, path: &PropertyPath) -> Result<Vector> {
238 let path_key = format!("{path:?}");
239
240 if let Some(cached) = self.relationship_cache.get(&path_key) {
242 return Ok(cached.clone());
243 }
244
245 let mut path_embeddings = Vec::new();
246
247 for (i, property) in path.path.iter().enumerate() {
249 let mut prop_text = property.clone();
250
251 match path.direction.get(i) {
253 Some(PathDirection::Forward) => prop_text.push_str(" ->"),
254 Some(PathDirection::Backward) => prop_text.push_str(" <-"),
255 None => {}
256 }
257
258 for constraint in &path.constraints {
260 match constraint {
261 PathConstraint::TypeFilter(type_uri) => {
262 prop_text.push_str(&format!(" [type:{type_uri}]"));
263 }
264 PathConstraint::PropertyFilter(prop, value) => {
265 prop_text.push_str(&format!(" [{prop}:{value:?}]"));
266 }
267 PathConstraint::LanguageFilter(lang) => {
268 prop_text.push_str(&format!(" [@{lang}]"));
269 }
270 }
271 }
272
273 let content = EmbeddableContent::Text(prop_text);
274 let prop_embedding = self.embedding_manager.write().get_embedding(&content)?;
275 path_embeddings.push(prop_embedding);
276 }
277
278 let path_embedding = self.combine_path_embeddings(path_embeddings)?;
280
281 self.relationship_cache
283 .insert(path_key, path_embedding.clone());
284
285 Ok(path_embedding)
286 }
287
288 pub fn generate_subgraph_embedding(&mut self, entities: &[RdfEntity]) -> Result<Vector> {
290 if entities.is_empty() {
291 return Err(anyhow!("Cannot generate embedding for empty subgraph"));
292 }
293
294 let mut entity_embeddings = Vec::new();
295
296 for entity in entities {
298 let embedding = self.generate_entity_embedding(entity)?;
299 entity_embeddings.push(embedding);
300 }
301
302 self.aggregate_subgraph_embeddings(entity_embeddings)
304 }
305
306 fn generate_uri_embedding(&self, uri: &str) -> Result<Vector> {
308 let components = self.decompose_uri(uri);
310 let text_content = components.join(" ");
311
312 let content = EmbeddableContent::Text(text_content);
313 self.embedding_manager.write().get_embedding(&content)
314 }
315
316 fn generate_label_embedding(&self, labels: &HashMap<String, String>) -> Result<Vector> {
318 let preferred_lang = &self.config.default_language;
319
320 let label_text = if let Some(preferred_label) = labels.get(preferred_lang) {
322 preferred_label.clone()
323 } else if let Some((_, first_label)) = labels.iter().next() {
324 first_label.clone()
325 } else {
326 return Err(anyhow!("No labels available"));
327 };
328
329 let content = EmbeddableContent::Text(label_text);
330 self.embedding_manager.write().get_embedding(&content)
331 }
332
333 fn generate_description_embedding(
335 &self,
336 descriptions: &HashMap<String, String>,
337 ) -> Result<Vector> {
338 let preferred_lang = &self.config.default_language;
339
340 let desc_text = if let Some(preferred_desc) = descriptions.get(preferred_lang) {
341 preferred_desc.clone()
342 } else if let Some((_, first_desc)) = descriptions.iter().next() {
343 first_desc.clone()
344 } else {
345 return Err(anyhow!("No descriptions available"));
346 };
347
348 let content = EmbeddableContent::Text(desc_text);
349 self.embedding_manager.write().get_embedding(&content)
350 }
351
352 fn generate_context_embedding(&self, context: &RdfContext) -> Result<Vector> {
354 let mut context_text = String::new();
355
356 if let Some(graph_uri) = &context.graph_uri {
357 context_text.push_str(&format!("graph:{graph_uri} "));
358 }
359
360 if !context.neighbors.is_empty() {
362 context_text.push_str("neighbors:");
363 for neighbor in &context.neighbors {
364 context_text.push_str(&format!(" {neighbor}"));
365 }
366 }
367
368 if context_text.is_empty() {
369 return Ok(Vector::new(vec![0.0; 384])); }
372
373 let content = EmbeddableContent::Text(context_text);
374 self.embedding_manager.write().get_embedding(&content)
375 }
376
377 fn generate_temporal_embedding(&self, temporal: &TemporalInfo) -> Result<Vector> {
379 let mut temporal_text = String::new();
380
381 if let Some(valid_from) = &temporal.valid_from {
382 temporal_text.push_str(&format!("from:{valid_from} "));
383 }
384
385 if let Some(valid_to) = &temporal.valid_to {
386 temporal_text.push_str(&format!("to:{valid_to} "));
387 }
388
389 if let Some(created) = &temporal.created_at {
390 temporal_text.push_str(&format!("created:{created} "));
391 }
392
393 if temporal_text.is_empty() {
394 return Ok(Vector::new(vec![0.0; 384])); }
397
398 let content = EmbeddableContent::Text(temporal_text);
399 self.embedding_manager.write().get_embedding(&content)
400 }
401
402 fn combine_embeddings(&self, embeddings: Vec<(Vector, f32)>) -> Result<Vector> {
404 if embeddings.is_empty() {
405 return Err(anyhow!("No embeddings to combine"));
406 }
407
408 let dimensions = embeddings[0].0.dimensions;
409 let mut combined = vec![0.0; dimensions];
410 let mut total_weight = 0.0;
411
412 for (embedding, weight) in embeddings {
413 if embedding.dimensions != dimensions {
414 return Err(anyhow!("Dimension mismatch in embedding combination"));
415 }
416
417 let values = embedding.as_f32();
418 for (i, value) in values.iter().enumerate() {
419 combined[i] += value * weight;
420 }
421 total_weight += weight;
422 }
423
424 if total_weight > 0.0 {
426 for value in &mut combined {
427 *value /= total_weight;
428 }
429 }
430
431 Ok(Vector::new(combined))
432 }
433
434 fn combine_path_embeddings(&self, embeddings: Vec<Vector>) -> Result<Vector> {
436 if embeddings.is_empty() {
437 return Err(anyhow!("No path embeddings to combine"));
438 }
439
440 let dimensions = embeddings[0].dimensions;
441 let mut combined = vec![0.0; dimensions];
442
443 for (i, embedding) in embeddings.iter().enumerate() {
445 let position_weight = 1.0 / (i as f32 + 1.0); let values = embedding.as_f32();
447
448 for (j, value) in values.iter().enumerate() {
449 combined[j] += value * position_weight;
450 }
451 }
452
453 let total_positions = embeddings.len() as f32;
455 for value in &mut combined {
456 *value /= total_positions;
457 }
458
459 Ok(Vector::new(combined))
460 }
461
462 fn aggregate_subgraph_embeddings(&self, embeddings: Vec<Vector>) -> Result<Vector> {
464 if embeddings.is_empty() {
465 return Err(anyhow!("No subgraph embeddings to aggregate"));
466 }
467
468 let dimensions = embeddings[0].dimensions;
470 let mut centroid = vec![0.0; dimensions];
471
472 for embedding in &embeddings {
473 let values = embedding.as_f32();
474 for (i, value) in values.iter().enumerate() {
475 centroid[i] += value;
476 }
477 }
478
479 let count = embeddings.len() as f32;
480 for value in &mut centroid {
481 *value /= count;
482 }
483
484 Ok(Vector::new(centroid))
485 }
486
487 fn decompose_uri(&self, uri: &str) -> Vec<String> {
489 let mut components = Vec::new();
490
491 if let Some(domain_start) = uri.find("://") {
493 if let Some(domain_end) = uri[domain_start + 3..].find('/') {
494 let domain = &uri[domain_start + 3..domain_start + 3 + domain_end];
495 components.push(domain.to_string());
496 }
497 }
498
499 if let Some(path_start) = uri.rfind('/') {
501 let fragment = &uri[path_start + 1..];
502 if !fragment.is_empty() {
503 components.extend(self.split_identifier(fragment));
505 }
506 }
507
508 if let Some(fragment_start) = uri.find('#') {
510 let fragment = &uri[fragment_start + 1..];
511 if !fragment.is_empty() {
512 components.extend(self.split_identifier(fragment));
513 }
514 }
515
516 components
517 }
518
519 fn split_identifier(&self, identifier: &str) -> Vec<String> {
521 let mut words = Vec::new();
522 let mut current_word = String::new();
523
524 for ch in identifier.chars() {
525 if ch.is_uppercase() && !current_word.is_empty() {
526 words.push(current_word.to_lowercase());
527 current_word = ch.to_string();
528 } else if ch == '_' || ch == '-' {
529 if !current_word.is_empty() {
530 words.push(current_word.to_lowercase());
531 current_word.clear();
532 }
533 } else {
534 current_word.push(ch);
535 }
536 }
537
538 if !current_word.is_empty() {
539 words.push(current_word.to_lowercase());
540 }
541
542 words
543 }
544
545 pub fn clear_cache(&mut self) {
547 self.entity_cache.clear();
548 self.relationship_cache.clear();
549 }
550
551 pub fn cache_stats(&self) -> (usize, usize) {
553 (self.entity_cache.len(), self.relationship_cache.len())
554 }
555}
556
557pub struct PropertyAggregator {
559 aggregation_strategy: AggregationStrategy,
560}
561
562#[derive(Debug, Clone)]
563pub enum AggregationStrategy {
564 Mean,
565 WeightedMean,
566 Attention,
567 Concatenation,
568}
569
570impl PropertyAggregator {
571 pub fn new() -> Self {
572 Self {
573 aggregation_strategy: AggregationStrategy::WeightedMean,
574 }
575 }
576
577 pub fn aggregate_properties(
578 &self,
579 properties: &HashMap<String, Vec<RdfValue>>,
580 ) -> Result<Vector> {
581 let mut property_embeddings = Vec::new();
582
583 for (property_uri, values) in properties {
584 let mut property_text = property_uri.clone();
585
586 for value in values {
588 match value {
589 RdfValue::IRI(iri) => property_text.push_str(&format!(" {iri}")),
590 RdfValue::Literal(lit, _) => property_text.push_str(&format!(" {lit}")),
591 RdfValue::LangString(lit, _) => property_text.push_str(&format!(" {lit}")),
592 RdfValue::Boolean(b) => property_text.push_str(&format!(" {b}")),
593 RdfValue::Integer(i) => property_text.push_str(&format!(" {i}")),
594 RdfValue::Float(f) => property_text.push_str(&format!(" {f}")),
595 RdfValue::Date(d) | RdfValue::DateTime(d) => {
596 property_text.push_str(&format!(" {d}"))
597 }
598 }
599 }
600
601 let embedding = self.create_simple_embedding(&property_text);
604 property_embeddings.push(embedding);
605 }
606
607 if property_embeddings.is_empty() {
608 return Ok(Vector::new(vec![0.0; 384])); }
610
611 match self.aggregation_strategy {
613 AggregationStrategy::Mean => self.mean_aggregation(property_embeddings),
614 AggregationStrategy::WeightedMean => {
615 self.weighted_mean_aggregation(property_embeddings)
616 }
617 _ => self.mean_aggregation(property_embeddings), }
619 }
620
621 fn mean_aggregation(&self, embeddings: Vec<Vector>) -> Result<Vector> {
622 if embeddings.is_empty() {
623 return Err(anyhow!("No embeddings to aggregate"));
624 }
625
626 let dimensions = embeddings[0].dimensions;
627 let mut mean = vec![0.0; dimensions];
628
629 for embedding in &embeddings {
630 let values = embedding.as_f32();
631 for (i, value) in values.iter().enumerate() {
632 mean[i] += value;
633 }
634 }
635
636 let count = embeddings.len() as f32;
637 for value in &mut mean {
638 *value /= count;
639 }
640
641 Ok(Vector::new(mean))
642 }
643
644 fn weighted_mean_aggregation(&self, embeddings: Vec<Vector>) -> Result<Vector> {
645 self.mean_aggregation(embeddings)
648 }
649
650 fn create_simple_embedding(&self, text: &str) -> Vector {
651 use std::collections::hash_map::DefaultHasher;
653 use std::hash::{Hash, Hasher};
654
655 let mut hasher = DefaultHasher::new();
656 text.hash(&mut hasher);
657 let hash = hasher.finish();
658
659 let mut values = Vec::with_capacity(384);
660 let mut seed = hash;
661
662 for _ in 0..384 {
663 seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
664 let normalized = (seed as f32) / (u64::MAX as f32);
665 values.push((normalized - 0.5) * 2.0);
666 }
667
668 Vector::new(values)
669 }
670}
671
672impl Default for PropertyAggregator {
673 fn default() -> Self {
674 Self::new()
675 }
676}
677
678pub struct MultiLanguageProcessor {
680 language_weights: HashMap<String, f32>,
681 fallback_language: String,
682}
683
684impl MultiLanguageProcessor {
685 pub fn new() -> Self {
686 let mut language_weights = HashMap::new();
687 language_weights.insert("en".to_string(), 1.0);
688 language_weights.insert("es".to_string(), 0.8);
689 language_weights.insert("fr".to_string(), 0.8);
690 language_weights.insert("de".to_string(), 0.8);
691 language_weights.insert("zh".to_string(), 0.7);
692 language_weights.insert("ja".to_string(), 0.7);
693
694 Self {
695 language_weights,
696 fallback_language: "en".to_string(),
697 }
698 }
699
700 pub fn get_preferred_text(&self, texts: &HashMap<String, String>) -> Option<String> {
701 let mut sorted_langs: Vec<_> = self.language_weights.iter().collect();
703 sorted_langs.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap_or(std::cmp::Ordering::Equal));
704
705 for (lang, _) in sorted_langs {
707 if let Some(text) = texts.get(lang) {
708 return Some(text.clone());
709 }
710 }
711
712 texts.values().next().cloned()
714 }
715
716 pub fn get_language_weight(&self, language: &str) -> f32 {
717 self.language_weights.get(language).copied().unwrap_or(0.5)
718 }
719}
720
721impl Default for MultiLanguageProcessor {
722 fn default() -> Self {
723 Self::new()
724 }
725}
726
727#[cfg(test)]
728mod tests {
729 use super::*;
730 use crate::embeddings::EmbeddingStrategy;
731
732 #[test]
733 fn test_rdf_entity_creation() {
734 let mut labels = HashMap::new();
735 labels.insert("en".to_string(), "Person".to_string());
736 labels.insert("es".to_string(), "Persona".to_string());
737
738 let entity = RdfEntity {
739 uri: "http://example.org/Person".to_string(),
740 labels,
741 descriptions: HashMap::new(),
742 properties: HashMap::new(),
743 types: vec!["http://www.w3.org/2000/01/rdf-schema#Class".to_string()],
744 context: None,
745 temporal_info: None,
746 };
747
748 assert_eq!(entity.uri, "http://example.org/Person");
749 assert_eq!(entity.labels.len(), 2);
750 }
751
752 #[test]
753 fn test_property_path() {
754 let path = PropertyPath {
755 path: vec![
756 "http://example.org/knows".to_string(),
757 "http://example.org/worksAt".to_string(),
758 ],
759 direction: vec![PathDirection::Forward, PathDirection::Forward],
760 constraints: vec![PathConstraint::TypeFilter(
761 "http://example.org/Person".to_string(),
762 )],
763 };
764
765 assert_eq!(path.path.len(), 2);
766 assert_eq!(path.direction.len(), 2);
767 assert_eq!(path.constraints.len(), 1);
768 }
769
770 #[test]
771 fn test_uri_decomposition() {
772 let config = RdfContentConfig::default();
773 let processor = RdfContentProcessor::new(config, EmbeddingStrategy::TfIdf).unwrap();
774
775 let components = processor.decompose_uri("http://example.org/ontology/PersonClass");
776 assert!(components.contains(&"example.org".to_string()));
777 assert!(components.contains(&"person".to_string()));
778 assert!(components.contains(&"class".to_string()));
779 }
780
781 #[test]
782 fn test_identifier_splitting() {
783 let config = RdfContentConfig::default();
784 let processor = RdfContentProcessor::new(config, EmbeddingStrategy::TfIdf).unwrap();
785
786 let words = processor.split_identifier("PersonClass");
787 assert_eq!(words, vec!["person", "class"]);
788
789 let words = processor.split_identifier("person_class");
790 assert_eq!(words, vec!["person", "class"]);
791
792 let words = processor.split_identifier("person-class");
793 assert_eq!(words, vec!["person", "class"]);
794 }
795
796 #[test]
797 fn test_multi_language_processor() {
798 let processor = MultiLanguageProcessor::new();
799
800 let mut texts = HashMap::new();
801 texts.insert("en".to_string(), "Hello".to_string());
802 texts.insert("es".to_string(), "Hola".to_string());
803
804 let preferred = processor.get_preferred_text(&texts);
805 assert_eq!(preferred, Some("Hello".to_string())); }
807}