1#![allow(dead_code)]
122
123pub mod acceleration;
124pub mod adaptive_learning;
125pub mod advanced_profiler;
126#[cfg(feature = "api-server")]
127pub mod api;
128pub mod application_tasks;
129pub mod batch_processing;
130pub mod biological_computing;
131pub mod biomedical_embeddings;
132pub mod caching;
133pub mod causal_representation_learning;
134pub mod cloud_integration;
135pub mod compression;
136pub mod consciousness_aware_embeddings;
137pub mod continual_learning;
139pub mod cross_domain_transfer;
140pub mod cross_module_performance;
141pub mod delta;
142pub mod diffusion_embeddings;
143pub mod enterprise_knowledge;
144pub mod evaluation;
145pub mod federated_learning;
146pub mod gpu_acceleration;
147pub mod graphql_api;
148pub mod inference;
149pub mod integration;
150pub mod mamba_attention;
151pub mod model_registry;
152pub mod models;
153pub mod monitoring;
154pub mod multimodal;
155pub mod neural_symbolic_integration;
156pub mod neuro_evolution;
157pub mod novel_architectures;
158pub mod persistence;
159pub mod quantum_circuits;
160pub mod real_time_fine_tuning;
161pub mod real_time_optimization;
162pub mod research_networks;
163pub mod training;
165pub mod utils;
166pub mod vision_language_graph;
167
168pub use oxirs_vec::Vector as VecVector;
170
171pub use adaptive_learning::{
173 AdaptationMetrics, AdaptationStrategy, AdaptiveLearningConfig, AdaptiveLearningSystem,
174 QualityFeedback,
175};
176
177use anyhow::Result;
178use chrono::{DateTime, Utc};
179use serde::{Deserialize, Serialize};
180use std::collections::HashMap;
181use std::ops::{Add, Sub};
182use uuid::Uuid;
183
184#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
187pub struct Vector {
188 pub values: Vec<f32>,
189 pub dimensions: usize,
190 #[serde(skip)]
191 inner: Option<VecVector>,
192}
193
194impl Vector {
195 pub fn new(values: Vec<f32>) -> Self {
196 let dimensions = values.len();
197 Self {
198 values,
199 dimensions,
200 inner: None,
201 }
202 }
203
204 fn get_inner(&self) -> VecVector {
206 if let Some(ref inner) = self.inner {
208 inner.clone()
209 } else {
210 VecVector::new(self.values.clone())
211 }
212 }
213
214 fn sync_internal(&mut self) {
216 self.dimensions = self.values.len();
217 self.inner = None; }
219
220 pub fn from_array1(array: &scirs2_core::ndarray_ext::Array1<f32>) -> Self {
222 Self::new(array.to_vec())
223 }
224
225 pub fn to_array1(&self) -> scirs2_core::ndarray_ext::Array1<f32> {
227 scirs2_core::ndarray_ext::Array1::from_vec(self.values.clone())
228 }
229
230 pub fn mapv<F>(&self, f: F) -> Self
232 where
233 F: Fn(f32) -> f32,
234 {
235 Self::new(self.values.iter().copied().map(f).collect())
236 }
237
238 pub fn sum(&self) -> f32 {
240 self.values.iter().sum()
241 }
242
243 pub fn sqrt(&self) -> f32 {
245 self.sum().sqrt()
246 }
247
248 pub fn inner(&self) -> VecVector {
250 self.get_inner()
251 }
252
253 pub fn into_inner(self) -> VecVector {
255 self.inner.unwrap_or_else(|| VecVector::new(self.values))
256 }
257
258 pub fn from_vec_vector(vec_vector: VecVector) -> Self {
260 let values = vec_vector.as_f32().to_vec();
261 let dimensions = values.len();
262 Self {
263 values,
264 dimensions,
265 inner: Some(vec_vector),
266 }
267 }
268
269 pub fn with_capacity(capacity: usize) -> Self {
271 Self {
272 values: Vec::with_capacity(capacity),
273 dimensions: 0,
274 inner: None,
275 }
276 }
277
278 pub fn extend_optimized(&mut self, other_values: &[f32]) {
280 self.values.reserve(other_values.len());
282 self.values.extend_from_slice(other_values);
283 self.sync_internal();
284 }
285
286 pub fn shrink_to_fit(&mut self) {
288 self.values.shrink_to_fit();
289 self.sync_internal();
290 }
291
292 pub fn memory_usage(&self) -> usize {
294 self.values.capacity() * std::mem::size_of::<f32>() + std::mem::size_of::<Self>()
295 }
296}
297
298impl Add for &Vector {
300 type Output = Vector;
301
302 fn add(self, other: &Vector) -> Vector {
303 if let (Some(self_inner), Some(other_inner)) = (&self.inner, &other.inner) {
305 if let Ok(result) = self_inner.add(other_inner) {
306 return Vector::from_vec_vector(result);
307 }
308 }
309 assert_eq!(
311 self.values.len(),
312 other.values.len(),
313 "Vector dimensions must match"
314 );
315 let result_values: Vec<f32> = self
316 .values
317 .iter()
318 .zip(other.values.iter())
319 .map(|(a, b)| a + b)
320 .collect();
321 Vector::new(result_values)
322 }
323}
324
325impl Sub for &Vector {
326 type Output = Vector;
327
328 fn sub(self, other: &Vector) -> Vector {
329 if let (Some(self_inner), Some(other_inner)) = (&self.inner, &other.inner) {
331 if let Ok(result) = self_inner.subtract(other_inner) {
332 return Vector::from_vec_vector(result);
333 }
334 }
335 assert_eq!(
337 self.values.len(),
338 other.values.len(),
339 "Vector dimensions must match"
340 );
341 let result_values: Vec<f32> = self
342 .values
343 .iter()
344 .zip(other.values.iter())
345 .map(|(a, b)| a - b)
346 .collect();
347 Vector::new(result_values)
348 }
349}
350
351#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
353pub struct Triple {
354 pub subject: NamedNode,
355 pub predicate: NamedNode,
356 pub object: NamedNode,
357}
358
359impl Triple {
360 pub fn new(subject: NamedNode, predicate: NamedNode, object: NamedNode) -> Self {
361 Self {
362 subject,
363 predicate,
364 object,
365 }
366 }
367}
368
369#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
371pub struct NamedNode {
372 pub iri: String,
373}
374
375impl NamedNode {
376 pub fn new(iri: &str) -> Result<Self> {
377 Ok(Self {
378 iri: iri.to_string(),
379 })
380 }
381}
382
383impl std::fmt::Display for NamedNode {
384 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
385 write!(f, "{}", self.iri)
386 }
387}
388
389#[derive(Debug, Clone, Serialize, Deserialize)]
391pub struct ModelConfig {
392 pub dimensions: usize,
393 pub learning_rate: f64,
394 pub l2_reg: f64,
395 pub max_epochs: usize,
396 pub batch_size: usize,
397 pub negative_samples: usize,
398 pub seed: Option<u64>,
399 pub use_gpu: bool,
400 pub model_params: HashMap<String, f64>,
401}
402
403impl Default for ModelConfig {
404 fn default() -> Self {
405 Self {
406 dimensions: 100,
407 learning_rate: 0.01,
408 l2_reg: 0.0001,
409 max_epochs: 1000,
410 batch_size: 1000,
411 negative_samples: 10,
412 seed: None,
413 use_gpu: false,
414 model_params: HashMap::new(),
415 }
416 }
417}
418
419impl ModelConfig {
420 pub fn with_dimensions(mut self, dimensions: usize) -> Self {
421 self.dimensions = dimensions;
422 self
423 }
424
425 pub fn with_learning_rate(mut self, learning_rate: f64) -> Self {
426 self.learning_rate = learning_rate;
427 self
428 }
429
430 pub fn with_max_epochs(mut self, max_epochs: usize) -> Self {
431 self.max_epochs = max_epochs;
432 self
433 }
434
435 pub fn with_seed(mut self, seed: u64) -> Self {
436 self.seed = Some(seed);
437 self
438 }
439
440 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
441 self.batch_size = batch_size;
442 self
443 }
444}
445
446#[derive(Debug, Clone, Serialize, Deserialize)]
448pub struct TrainingStats {
449 pub epochs_completed: usize,
450 pub final_loss: f64,
451 pub training_time_seconds: f64,
452 pub convergence_achieved: bool,
453 pub loss_history: Vec<f64>,
454}
455
456#[derive(Debug, Clone, Serialize, Deserialize)]
458pub struct ModelStats {
459 pub num_entities: usize,
460 pub num_relations: usize,
461 pub num_triples: usize,
462 pub dimensions: usize,
463 pub is_trained: bool,
464 pub model_type: String,
465 pub creation_time: DateTime<Utc>,
466 pub last_training_time: Option<DateTime<Utc>>,
467}
468
469impl Default for ModelStats {
470 fn default() -> Self {
471 Self {
472 num_entities: 0,
473 num_relations: 0,
474 num_triples: 0,
475 dimensions: 0,
476 is_trained: false,
477 model_type: "unknown".to_string(),
478 creation_time: Utc::now(),
479 last_training_time: None,
480 }
481 }
482}
483
484#[derive(Debug, thiserror::Error)]
486pub enum EmbeddingError {
487 #[error("Model not trained")]
488 ModelNotTrained,
489 #[error("Entity not found: {entity}")]
490 EntityNotFound { entity: String },
491 #[error("Relation not found: {relation}")]
492 RelationNotFound { relation: String },
493 #[error("Other error: {0}")]
494 Other(#[from] anyhow::Error),
495}
496
497#[async_trait::async_trait]
499pub trait EmbeddingModel: Send + Sync {
500 fn config(&self) -> &ModelConfig;
501 fn model_id(&self) -> &Uuid;
502 fn model_type(&self) -> &'static str;
503 fn add_triple(&mut self, triple: Triple) -> Result<()>;
504 async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats>;
505 fn get_entity_embedding(&self, entity: &str) -> Result<Vector>;
506 fn getrelation_embedding(&self, relation: &str) -> Result<Vector>;
507 fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64>;
508 fn predict_objects(
509 &self,
510 subject: &str,
511 predicate: &str,
512 k: usize,
513 ) -> Result<Vec<(String, f64)>>;
514 fn predict_subjects(
515 &self,
516 predicate: &str,
517 object: &str,
518 k: usize,
519 ) -> Result<Vec<(String, f64)>>;
520 fn predict_relations(
521 &self,
522 subject: &str,
523 object: &str,
524 k: usize,
525 ) -> Result<Vec<(String, f64)>>;
526 fn get_entities(&self) -> Vec<String>;
527 fn get_relations(&self) -> Vec<String>;
528 fn get_stats(&self) -> ModelStats;
529 fn save(&self, path: &str) -> Result<()>;
530 fn load(&mut self, path: &str) -> Result<()>;
531 fn clear(&mut self);
532 fn is_trained(&self) -> bool;
533
534 async fn encode(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
536}
537
538pub use acceleration::{AdaptiveEmbeddingAccelerator, GpuEmbeddingAccelerator};
540#[cfg(feature = "api-server")]
541pub use api::{start_server, ApiConfig, ApiState};
542pub use batch_processing::{
543 BatchJob, BatchProcessingConfig, BatchProcessingManager, BatchProcessingResult,
544 BatchProcessingStats, IncrementalConfig, JobProgress, JobStatus, OutputFormat,
545 PartitioningStrategy, RetryConfig,
546};
547pub use biomedical_embeddings::{
548 BiomedicalEmbedding, BiomedicalEmbeddingConfig, BiomedicalEntityType, BiomedicalRelationType,
549 FineTuningConfig, PreprocessingRule, SpecializedTextConfig, SpecializedTextEmbedding,
550 SpecializedTextModel,
551};
552pub use caching::{CacheConfig, CacheManager, CachedEmbeddingModel};
553pub use causal_representation_learning::{
554 CausalDiscoveryAlgorithm, CausalDiscoveryConfig, CausalGraph, CausalRepresentationConfig,
555 CausalRepresentationModel, ConstraintSettings, CounterfactualConfig, CounterfactualQuery,
556 DisentanglementConfig, DisentanglementMethod, ExplanationType, IndependenceTest,
557 InterventionConfig, ScoreSettings, StructuralCausalModelConfig,
558};
559pub use cloud_integration::{
560 AWSSageMakerService, AutoScalingConfig, AzureMLService, BackupConfig, CloudIntegrationConfig,
561 CloudIntegrationManager, CloudProvider, CloudService, ClusterStatus, CostEstimate,
562 CostOptimizationResult, CostOptimizationStrategy, DeploymentConfig, DeploymentResult,
563 DeploymentStatus, EndpointInfo, FunctionInvocationResult, GPUClusterConfig, GPUClusterResult,
564 LifecyclePolicy, OptimizationAction, PerformanceTier, ReplicationType,
565 ServerlessDeploymentResult, ServerlessFunctionConfig, ServerlessStatus, StorageConfig,
566 StorageResult, StorageStatus, StorageType,
567};
568pub use compression::{
569 CompressedModel, CompressionStats, CompressionTarget, DistillationConfig,
570 ModelCompressionManager, NASConfig, OptimizationTarget, PruningConfig, PruningMethod,
571 QuantizationConfig, QuantizationMethod,
572};
573pub use consciousness_aware_embeddings::{
574 AttentionMechanism, ConsciousnessAwareEmbedding, ConsciousnessInsights, ConsciousnessLevel,
575 MetaCognition, WorkingMemory,
576};
577pub use continual_learning::{
581 ArchitectureConfig, BoundaryDetection, ConsolidationConfig, ContinualLearningConfig,
582 ContinualLearningModel, MemoryConfig, MemoryType, MemoryUpdateStrategy, RegularizationConfig,
583 ReplayConfig, ReplayMethod, TaskConfig, TaskDetection, TaskSwitching,
584};
585pub use cross_module_performance::{
586 CoordinatorConfig, CrossModulePerformanceCoordinator, GlobalPerformanceMetrics, ModuleMetrics,
587 ModulePerformanceMonitor, OptimizationCache, PerformanceSnapshot, PredictivePerformanceEngine,
588 ResourceAllocator, ResourceTracker,
589};
590pub use delta::{
591 ChangeRecord, ChangeStatistics, ChangeType, DeltaConfig, DeltaManager, DeltaResult, DeltaStats,
592 IncrementalStrategy,
593};
594pub use enterprise_knowledge::{
595 BehaviorMetrics, CareerPredictions, Category, CategoryHierarchy, CategoryPerformance,
596 ColdStartStrategy, CommunicationFrequency, CommunicationPreferences, CustomerEmbedding,
597 CustomerPreferences, CustomerRatings, CustomerSegment, Department, DepartmentPerformance,
598 EmployeeEmbedding, EnterpriseConfig, EnterpriseKnowledgeAnalyzer, EnterpriseMetrics,
599 ExperienceLevel, FeatureType, MarketAnalysis, OrganizationalStructure,
600 PerformanceMetrics as EnterprisePerformanceMetrics, ProductAvailability, ProductEmbedding,
601 ProductFeature, ProductRecommendation, Project, ProjectOutcome, ProjectParticipation,
602 ProjectPerformance, ProjectStatus, Purchase, PurchaseChannel, RecommendationConfig,
603 RecommendationEngine, RecommendationEngineType, RecommendationPerformance,
604 RecommendationReason, SalesMetrics, Skill, SkillCategory, Team, TeamPerformance,
605};
606pub use evaluation::{
607 QueryAnsweringEvaluator, QueryEvaluationConfig, QueryEvaluationResults, QueryMetric,
608 QueryResult, QueryTemplate, QueryType, ReasoningChain, ReasoningEvaluationConfig,
609 ReasoningEvaluationResults, ReasoningRule, ReasoningStep, ReasoningTaskEvaluator,
610 ReasoningType, TypeSpecificResults,
611};
612pub use federated_learning::{
613 AggregationEngine, AggregationStrategy, AuthenticationConfig, AuthenticationMethod,
614 CertificateConfig, ClippingMechanisms, ClippingMethod, CommunicationConfig,
615 CommunicationManager, CommunicationProtocol, CompressionAlgorithm, CompressionConfig,
616 CompressionEngine, ConvergenceMetrics, ConvergenceStatus, DataSelectionStrategy,
617 DataStatistics, EncryptionScheme, FederatedConfig, FederatedCoordinator,
618 FederatedEmbeddingModel, FederatedMessage, FederatedRound, FederationStats, GlobalModelState,
619 HardwareAccelerator, KeyManager, LocalModelState, LocalTrainingStats, LocalUpdate,
620 MetaLearningConfig, NoiseGenerator, NoiseMechanism, OutlierAction, OutlierDetection,
621 OutlierDetectionMethod, Participant, ParticipantCapabilities, ParticipantStatus,
622 PersonalizationConfig, PersonalizationStrategy, PrivacyAccountant, PrivacyConfig,
623 PrivacyEngine, PrivacyMetrics, PrivacyParams, RoundMetrics, RoundStatus, SecurityConfig,
624 SecurityFeature, SecurityManager, TrainingConfig, VerificationEngine, VerificationMechanism,
625 VerificationResult, WeightingScheme,
626};
627pub use gpu_acceleration::{
628 GpuAccelerationConfig, GpuAccelerationManager, GpuMemoryPool, GpuPerformanceStats,
629 MixedPrecisionProcessor, MultiStreamProcessor, TensorCache,
630};
631pub use graphql_api::{
632 create_schema, BatchEmbeddingInput, BatchEmbeddingResult, BatchStatus, DistanceMetric,
633 EmbeddingFormat, EmbeddingQueryInput, EmbeddingResult, EmbeddingSchema, GraphQLContext,
634 ModelInfo, ModelType, SimilarityResult, SimilaritySearchInput,
635};
636pub use models::{
637 AggregationType, ComplEx, DistMult, GNNConfig, GNNEmbedding, GNNType, PoolingStrategy, RotatE,
638 TransE, TransformerConfig, TransformerEmbedding, TransformerType,
639};
640pub use monitoring::{
641 Alert, AlertSeverity, AlertThresholds, AlertType, CacheMetrics, ConsoleAlertHandler,
642 DriftMetrics, ErrorEvent, ErrorMetrics, ErrorSeverity, LatencyMetrics, MonitoringConfig,
643 PerformanceMetrics as MonitoringPerformanceMetrics, PerformanceMonitor, QualityAssessment,
644 QualityMetrics, ResourceMetrics, SlackAlertHandler, ThroughputMetrics,
645};
646pub use multimodal::{
647 AlignmentNetwork, AlignmentObjective, ContrastiveConfig, CrossDomainConfig, CrossModalConfig,
648 KGEncoder, MultiModalEmbedding, MultiModalStats, TextEncoder,
649};
650pub use neural_symbolic_integration::{
651 ConstraintSatisfactionConfig, ConstraintType, KnowledgeIntegrationConfig, KnowledgeRule,
652 LogicIntegrationConfig, LogicProgrammingConfig, LogicalFormula, NeuralSymbolicConfig,
653 NeuralSymbolicModel, NeuroSymbolicArchitectureConfig, OntologicalConfig, ReasoningEngine,
654 RuleBasedConfig, SymbolicReasoningConfig,
655};
656pub use novel_architectures::{
657 ActivationType, ArchitectureParams, ArchitectureState, ArchitectureType, CurvatureComputation,
658 CurvatureMethod, CurvatureType, DynamicsConfig, EntanglementStructure, EquivarianceGroup,
659 FlowType, GeometricConfig, GeometricParams, GeometricSpace, GeometricState,
660 GraphTransformerParams, GraphTransformerState, HyperbolicDistance, HyperbolicInit,
661 HyperbolicManifold, HyperbolicParams, HyperbolicState, IntegrationScheme, IntegrationStats,
662 ManifoldLearning, ManifoldMethod, ManifoldOptimizer, NeuralODEParams, NeuralODEState,
663 NovelArchitectureConfig, NovelArchitectureModel, ODERegularization, ODESolverType,
664 ParallelTransport, QuantumGateSet, QuantumMeasurement, QuantumNoise, QuantumParams,
665 QuantumState, StabilityConstraints, StructuralBias, TimeEvolution, TransportMethod,
666};
667pub use quantum_circuits::{
668 Complex, MeasurementStrategy, QNNLayerType, QuantumApproximateOptimization, QuantumCircuit,
669 QuantumGate, QuantumNeuralNetwork, QuantumNeuralNetworkLayer, QuantumSimulator,
670 VariationalQuantumEigensolver,
671};
672pub use research_networks::{
673 AuthorEmbedding, Citation, CitationNetwork, CitationType, Collaboration, CollaborationNetwork,
674 NetworkMetrics, PaperSection, PublicationEmbedding, PublicationType, ResearchCommunity,
675 ResearchNetworkAnalyzer, ResearchNetworkConfig, TopicModel, TopicModelingConfig,
676};
677pub use vision_language_graph::{
678 AggregationFunction, CNNConfig, CrossAttentionConfig, DomainAdaptationConfig,
679 DomainAdaptationMethod, EpisodeConfig, FewShotConfig, FewShotMethod, FusionStrategy,
680 GraphArchitecture, GraphEncoder, GraphEncoderConfig, JointTrainingConfig, LanguageArchitecture,
681 LanguageEncoder, LanguageEncoderConfig, LanguageTransformerConfig, MetaLearner,
682 ModalityEncoding, MultiModalTransformer, MultiModalTransformerConfig, NormalizationType,
683 PoolingType, PositionEncodingType, ReadoutFunction, TaskCategory, TaskSpecificParams,
684 TrainingObjective, TransferLearningConfig, TransferStrategy, ViTConfig, VisionArchitecture,
685 VisionEncoder, VisionEncoderConfig, VisionLanguageGraphConfig, VisionLanguageGraphModel,
686 VisionLanguageGraphStats, ZeroShotConfig, ZeroShotMethod,
687};
688
689#[cfg(feature = "tucker")]
690pub use models::TuckER;
691
692#[cfg(feature = "quatd")]
693pub use models::QuatD;
694
695pub use crate::model_registry::{
697 ModelRegistry, ModelVersion, ResourceAllocation as ModelResourceAllocation,
698};
699
700pub mod quick_start {
714 use super::*;
715 use crate::models::TransE;
716
717 pub fn create_simple_transe_model() -> TransE {
719 let config = ModelConfig::default()
720 .with_dimensions(128)
721 .with_learning_rate(0.01)
722 .with_max_epochs(100);
723 TransE::new(config)
724 }
725
726 pub fn create_biomedical_model() -> BiomedicalEmbedding {
728 let config = BiomedicalEmbeddingConfig::default();
729 BiomedicalEmbedding::new(config)
730 }
731
732 pub fn parse_triple_from_string(triple_str: &str) -> Result<Triple> {
734 let parts: Vec<&str> = triple_str.split_whitespace().collect();
735 if parts.len() != 3 {
736 return Err(anyhow::anyhow!(
737 "Triple must have exactly 3 parts separated by spaces"
738 ));
739 }
740
741 let expand_uri = |s: &str| -> String {
743 if s.starts_with("http://") || s.starts_with("https://") {
744 s.to_string()
745 } else {
746 format!("http://example.org/{s}")
747 }
748 };
749
750 Ok(Triple::new(
751 NamedNode::new(&expand_uri(parts[0]))?,
752 NamedNode::new(&expand_uri(parts[1]))?,
753 NamedNode::new(&expand_uri(parts[2]))?,
754 ))
755 }
756
757 pub fn add_triples_from_strings<T: EmbeddingModel>(
759 model: &mut T,
760 triple_strings: &[&str],
761 ) -> Result<usize> {
762 let mut count = 0;
763 for triple_str in triple_strings {
764 let triple = parse_triple_from_string(triple_str)?;
765 model.add_triple(triple)?;
766 count += 1;
767 }
768 Ok(count)
769 }
770
771 pub fn cosine_similarity(a: &[f64], b: &[f64]) -> Result<f64> {
773 if a.len() != b.len() {
774 return Err(anyhow::anyhow!(
775 "Vector dimensions don't match: {} vs {}",
776 a.len(),
777 b.len()
778 ));
779 }
780
781 let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
782 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
783 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
784
785 if norm_a == 0.0 || norm_b == 0.0 {
786 return Ok(0.0);
787 }
788
789 Ok(dot_product / (norm_a * norm_b))
790 }
791
792 pub fn generate_sample_kg_data(
794 num_entities: usize,
795 num_relations: usize,
796 ) -> Vec<(String, String, String)> {
797 #[allow(unused_imports)]
798 use scirs2_core::random::{Random, Rng};
799
800 let mut random = Random::default();
801 let mut triples = Vec::new();
802
803 let entities: Vec<String> = (0..num_entities)
804 .map(|i| format!("http://example.org/entity_{i}"))
805 .collect();
806
807 let relations: Vec<String> = (0..num_relations)
808 .map(|i| format!("http://example.org/relation_{i}"))
809 .collect();
810
811 for _ in 0..(num_entities * 2) {
813 let subject_idx = random.random_range(0, entities.len());
814 let relation_idx = random.random_range(0, relations.len());
815 let object_idx = random.random_range(0, entities.len());
816
817 let subject = entities[subject_idx].clone();
818 let relation = relations[relation_idx].clone();
819 let object = entities[object_idx].clone();
820
821 if subject != object {
822 triples.push((subject, relation, object));
823 }
824 }
825
826 triples
827 }
828
829 pub fn quick_performance_test<F>(
831 name: &str,
832 iterations: usize,
833 operation: F,
834 ) -> std::time::Duration
835 where
836 F: Fn(),
837 {
838 let start = std::time::Instant::now();
839 for _ in 0..iterations {
840 operation();
841 }
842 let duration = start.elapsed();
843
844 println!(
845 "Performance test '{name}': {iterations} iterations in {duration:?} ({:.2} ops/sec)",
846 iterations as f64 / duration.as_secs_f64()
847 );
848
849 duration
850 }
851
852 pub async fn quick_revolutionary_performance_test<F, Fut>(
885 name: &str,
886 iterations: usize,
887 async_operation: F,
888 ) -> std::time::Duration
889 where
890 F: Fn() -> Fut,
891 Fut: std::future::Future<Output = ()>,
892 {
893 let start = std::time::Instant::now();
894 for _ in 0..iterations {
895 async_operation().await;
896 }
897 let duration = start.elapsed();
898
899 println!(
900 "Revolutionary performance test '{name}': {iterations} iterations in {duration:?} ({:.2} ops/sec)",
901 iterations as f64 / duration.as_secs_f64()
902 );
903
904 duration
905 }
906}
907
908#[cfg(test)]
909mod quick_start_tests {
910 use super::*;
911 use crate::quick_start::*;
912
913 #[test]
914 fn test_create_simple_transe_model() {
915 let model = create_simple_transe_model();
916 let config = model.config();
917 assert_eq!(config.dimensions, 128);
918 assert_eq!(config.learning_rate, 0.01);
919 assert_eq!(config.max_epochs, 100);
920 }
921
922 #[test]
923 fn test_parse_triple_from_string() {
924 let triple_str = "http://example.org/alice http://example.org/knows http://example.org/bob";
925 let triple = parse_triple_from_string(triple_str).unwrap();
926 assert_eq!(triple.subject.iri, "http://example.org/alice");
927 assert_eq!(triple.predicate.iri, "http://example.org/knows");
928 assert_eq!(triple.object.iri, "http://example.org/bob");
929 }
930
931 #[test]
932 fn test_parse_triple_from_string_invalid() {
933 let triple_str = "http://example.org/alice http://example.org/knows";
934 let result = parse_triple_from_string(triple_str);
935 assert!(result.is_err());
936 }
937
938 #[test]
939 fn test_add_triples_from_strings() {
940 let mut model = create_simple_transe_model();
941 let triple_strings = [
942 "http://example.org/alice http://example.org/knows http://example.org/bob",
943 "http://example.org/bob http://example.org/likes http://example.org/music",
944 ];
945
946 let count = add_triples_from_strings(&mut model, &triple_strings).unwrap();
947 assert_eq!(count, 2);
948 }
949
950 #[test]
951 fn test_cosine_similarity() {
952 let a = vec![1.0, 0.0, 0.0];
953 let b = vec![1.0, 0.0, 0.0];
954 let similarity = cosine_similarity(&a, &b).unwrap();
955 assert!((similarity - 1.0).abs() < 1e-10);
956
957 let c = vec![0.0, 1.0, 0.0];
958 let similarity2 = cosine_similarity(&a, &c).unwrap();
959 assert!((similarity2 - 0.0).abs() < 1e-10);
960
961 let d = vec![1.0, 0.0];
963 assert!(cosine_similarity(&a, &d).is_err());
964 }
965
966 #[test]
967 fn test_generate_sample_kg_data() {
968 let triples = generate_sample_kg_data(5, 3);
969 assert!(!triples.is_empty());
970
971 for (subject, relation, object) in &triples {
973 assert!(subject.starts_with("http://example.org/entity_"));
974 assert!(relation.starts_with("http://example.org/relation_"));
975 assert!(object.starts_with("http://example.org/entity_"));
976 assert_ne!(subject, object); }
978 }
979
980 #[test]
981 fn test_quick_performance_test() {
982 let duration = quick_performance_test("test_operation", 100, || {
983 let _sum: i32 = (1..10).sum();
985 });
986
987 assert!(duration.as_nanos() > 0);
988 }
989}