1#![allow(dead_code)]
122
123pub mod acceleration;
124pub mod adaptive_learning;
125pub mod advanced_profiler;
126#[cfg(feature = "api-server")]
127pub mod api;
128pub mod application_tasks;
129pub mod batch_processing;
130pub mod biological_computing;
131pub mod biomedical_embeddings;
132pub mod caching;
133pub mod causal_representation_learning;
134pub mod cloud_integration;
135pub mod clustering;
136pub mod community_detection;
137pub mod compression;
138pub mod consciousness_aware_embeddings;
139pub mod contextual;
140pub mod continual_learning;
141pub mod cross_domain_transfer;
142pub mod cross_module_performance;
143pub mod delta;
144pub mod diffusion_embeddings;
145pub mod distributed_training;
146pub mod enterprise_knowledge;
147pub mod entity_linking;
148pub mod evaluation;
149pub mod federated_learning;
150pub mod fine_tuning;
151pub mod gpu_acceleration;
152pub mod graphql_api;
153pub mod inference;
154pub mod integration;
155pub mod interpretability;
156pub mod link_prediction;
157pub mod mamba_attention;
158pub mod mixed_precision;
159pub mod model_registry;
160pub mod model_selection;
161pub mod models;
162pub mod monitoring;
163pub mod multimodal;
164pub mod neural_symbolic_integration;
165pub mod neuro_evolution;
166pub mod novel_architectures;
167pub mod performance_profiler;
168pub mod persistence;
169pub mod quantization;
170pub mod quantum_circuits;
171pub mod real_time_fine_tuning;
172pub mod real_time_optimization;
173pub mod research_networks;
174pub mod sparql_extension;
176pub mod storage_backend;
177pub mod temporal_embeddings;
178pub mod training;
179pub mod utils;
180pub mod vector_search;
181pub mod vision_language_graph;
182pub mod visualization;
183
184pub use oxirs_vec::Vector as VecVector;
186
187pub use adaptive_learning::{
189 AdaptationMetrics, AdaptationStrategy, AdaptiveLearningConfig, AdaptiveLearningSystem,
190 QualityFeedback,
191};
192
193use anyhow::Result;
194use chrono::{DateTime, Utc};
195use serde::{Deserialize, Serialize};
196use std::collections::HashMap;
197use std::ops::{Add, Sub};
198use uuid::Uuid;
199
200#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
203pub struct Vector {
204 pub values: Vec<f32>,
205 pub dimensions: usize,
206 #[serde(skip)]
207 inner: Option<VecVector>,
208}
209
210impl Vector {
211 pub fn new(values: Vec<f32>) -> Self {
212 let dimensions = values.len();
213 Self {
214 values,
215 dimensions,
216 inner: None,
217 }
218 }
219
220 fn get_inner(&self) -> VecVector {
222 if let Some(ref inner) = self.inner {
224 inner.clone()
225 } else {
226 VecVector::new(self.values.clone())
227 }
228 }
229
230 fn sync_internal(&mut self) {
232 self.dimensions = self.values.len();
233 self.inner = None; }
235
236 pub fn from_array1(array: &scirs2_core::ndarray_ext::Array1<f32>) -> Self {
238 Self::new(array.to_vec())
239 }
240
241 pub fn to_array1(&self) -> scirs2_core::ndarray_ext::Array1<f32> {
243 scirs2_core::ndarray_ext::Array1::from_vec(self.values.clone())
244 }
245
246 pub fn mapv<F>(&self, f: F) -> Self
248 where
249 F: Fn(f32) -> f32,
250 {
251 Self::new(self.values.iter().copied().map(f).collect())
252 }
253
254 pub fn sum(&self) -> f32 {
256 self.values.iter().sum()
257 }
258
259 pub fn sqrt(&self) -> f32 {
261 self.sum().sqrt()
262 }
263
264 pub fn inner(&self) -> VecVector {
266 self.get_inner()
267 }
268
269 pub fn into_inner(self) -> VecVector {
271 self.inner.unwrap_or_else(|| VecVector::new(self.values))
272 }
273
274 pub fn from_vec_vector(vec_vector: VecVector) -> Self {
276 let values = vec_vector.as_f32().to_vec();
277 let dimensions = values.len();
278 Self {
279 values,
280 dimensions,
281 inner: Some(vec_vector),
282 }
283 }
284
285 pub fn with_capacity(capacity: usize) -> Self {
287 Self {
288 values: Vec::with_capacity(capacity),
289 dimensions: 0,
290 inner: None,
291 }
292 }
293
294 pub fn extend_optimized(&mut self, other_values: &[f32]) {
296 self.values.reserve(other_values.len());
298 self.values.extend_from_slice(other_values);
299 self.sync_internal();
300 }
301
302 pub fn shrink_to_fit(&mut self) {
304 self.values.shrink_to_fit();
305 self.sync_internal();
306 }
307
308 pub fn memory_usage(&self) -> usize {
310 self.values.capacity() * std::mem::size_of::<f32>() + std::mem::size_of::<Self>()
311 }
312}
313
314impl Add for &Vector {
316 type Output = Vector;
317
318 fn add(self, other: &Vector) -> Vector {
319 if let (Some(self_inner), Some(other_inner)) = (&self.inner, &other.inner) {
321 if let Ok(result) = self_inner.add(other_inner) {
322 return Vector::from_vec_vector(result);
323 }
324 }
325 assert_eq!(
327 self.values.len(),
328 other.values.len(),
329 "Vector dimensions must match"
330 );
331 let result_values: Vec<f32> = self
332 .values
333 .iter()
334 .zip(other.values.iter())
335 .map(|(a, b)| a + b)
336 .collect();
337 Vector::new(result_values)
338 }
339}
340
341impl Sub for &Vector {
342 type Output = Vector;
343
344 fn sub(self, other: &Vector) -> Vector {
345 if let (Some(self_inner), Some(other_inner)) = (&self.inner, &other.inner) {
347 if let Ok(result) = self_inner.subtract(other_inner) {
348 return Vector::from_vec_vector(result);
349 }
350 }
351 assert_eq!(
353 self.values.len(),
354 other.values.len(),
355 "Vector dimensions must match"
356 );
357 let result_values: Vec<f32> = self
358 .values
359 .iter()
360 .zip(other.values.iter())
361 .map(|(a, b)| a - b)
362 .collect();
363 Vector::new(result_values)
364 }
365}
366
367#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
369pub struct Triple {
370 pub subject: NamedNode,
371 pub predicate: NamedNode,
372 pub object: NamedNode,
373}
374
375impl Triple {
376 pub fn new(subject: NamedNode, predicate: NamedNode, object: NamedNode) -> Self {
377 Self {
378 subject,
379 predicate,
380 object,
381 }
382 }
383}
384
385#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
387pub struct NamedNode {
388 pub iri: String,
389}
390
391impl NamedNode {
392 pub fn new(iri: &str) -> Result<Self> {
393 Ok(Self {
394 iri: iri.to_string(),
395 })
396 }
397}
398
399impl std::fmt::Display for NamedNode {
400 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
401 write!(f, "{}", self.iri)
402 }
403}
404
405#[derive(Debug, Clone, Serialize, Deserialize)]
407pub struct ModelConfig {
408 pub dimensions: usize,
409 pub learning_rate: f64,
410 pub l2_reg: f64,
411 pub max_epochs: usize,
412 pub batch_size: usize,
413 pub negative_samples: usize,
414 pub seed: Option<u64>,
415 pub use_gpu: bool,
416 pub model_params: HashMap<String, f64>,
417}
418
419impl Default for ModelConfig {
420 fn default() -> Self {
421 Self {
422 dimensions: 100,
423 learning_rate: 0.01,
424 l2_reg: 0.0001,
425 max_epochs: 1000,
426 batch_size: 1000,
427 negative_samples: 10,
428 seed: None,
429 use_gpu: false,
430 model_params: HashMap::new(),
431 }
432 }
433}
434
435impl ModelConfig {
436 pub fn with_dimensions(mut self, dimensions: usize) -> Self {
437 self.dimensions = dimensions;
438 self
439 }
440
441 pub fn with_learning_rate(mut self, learning_rate: f64) -> Self {
442 self.learning_rate = learning_rate;
443 self
444 }
445
446 pub fn with_max_epochs(mut self, max_epochs: usize) -> Self {
447 self.max_epochs = max_epochs;
448 self
449 }
450
451 pub fn with_seed(mut self, seed: u64) -> Self {
452 self.seed = Some(seed);
453 self
454 }
455
456 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
457 self.batch_size = batch_size;
458 self
459 }
460}
461
462#[derive(Debug, Clone, Serialize, Deserialize)]
464pub struct TrainingStats {
465 pub epochs_completed: usize,
466 pub final_loss: f64,
467 pub training_time_seconds: f64,
468 pub convergence_achieved: bool,
469 pub loss_history: Vec<f64>,
470}
471
472#[derive(Debug, Clone, Serialize, Deserialize)]
474pub struct ModelStats {
475 pub num_entities: usize,
476 pub num_relations: usize,
477 pub num_triples: usize,
478 pub dimensions: usize,
479 pub is_trained: bool,
480 pub model_type: String,
481 pub creation_time: DateTime<Utc>,
482 pub last_training_time: Option<DateTime<Utc>>,
483}
484
485impl Default for ModelStats {
486 fn default() -> Self {
487 Self {
488 num_entities: 0,
489 num_relations: 0,
490 num_triples: 0,
491 dimensions: 0,
492 is_trained: false,
493 model_type: "unknown".to_string(),
494 creation_time: Utc::now(),
495 last_training_time: None,
496 }
497 }
498}
499
500#[derive(Debug, thiserror::Error)]
502pub enum EmbeddingError {
503 #[error("Model not trained")]
504 ModelNotTrained,
505 #[error("Entity not found: {entity}")]
506 EntityNotFound { entity: String },
507 #[error("Relation not found: {relation}")]
508 RelationNotFound { relation: String },
509 #[error("Other error: {0}")]
510 Other(#[from] anyhow::Error),
511}
512
513#[async_trait::async_trait]
515pub trait EmbeddingModel: Send + Sync {
516 fn config(&self) -> &ModelConfig;
517 fn model_id(&self) -> &Uuid;
518 fn model_type(&self) -> &'static str;
519 fn add_triple(&mut self, triple: Triple) -> Result<()>;
520 async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats>;
521 fn get_entity_embedding(&self, entity: &str) -> Result<Vector>;
522 fn get_relation_embedding(&self, relation: &str) -> Result<Vector>;
523 fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64>;
524 fn predict_objects(
525 &self,
526 subject: &str,
527 predicate: &str,
528 k: usize,
529 ) -> Result<Vec<(String, f64)>>;
530 fn predict_subjects(
531 &self,
532 predicate: &str,
533 object: &str,
534 k: usize,
535 ) -> Result<Vec<(String, f64)>>;
536 fn predict_relations(
537 &self,
538 subject: &str,
539 object: &str,
540 k: usize,
541 ) -> Result<Vec<(String, f64)>>;
542 fn get_entities(&self) -> Vec<String>;
543 fn get_relations(&self) -> Vec<String>;
544 fn get_stats(&self) -> ModelStats;
545 fn save(&self, path: &str) -> Result<()>;
546 fn load(&mut self, path: &str) -> Result<()>;
547 fn clear(&mut self);
548 fn is_trained(&self) -> bool;
549
550 async fn encode(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
552}
553
554pub use acceleration::{AdaptiveEmbeddingAccelerator, GpuEmbeddingAccelerator};
556#[cfg(feature = "api-server")]
557pub use api::{start_server, ApiConfig, ApiState};
558pub use batch_processing::{
559 BatchJob, BatchProcessingConfig, BatchProcessingManager, BatchProcessingResult,
560 BatchProcessingStats, IncrementalConfig, JobProgress, JobStatus, OutputFormat,
561 PartitioningStrategy, RetryConfig,
562};
563pub use biomedical_embeddings::{
564 BiomedicalEmbedding, BiomedicalEmbeddingConfig, BiomedicalEntityType, BiomedicalRelationType,
565 FineTuningConfig, PreprocessingRule, SpecializedTextConfig, SpecializedTextEmbedding,
566 SpecializedTextModel,
567};
568pub use caching::{CacheConfig, CacheManager, CachedEmbeddingModel};
569pub use causal_representation_learning::{
570 CausalDiscoveryAlgorithm, CausalDiscoveryConfig, CausalGraph, CausalRepresentationConfig,
571 CausalRepresentationModel, ConstraintSettings, CounterfactualConfig, CounterfactualQuery,
572 DisentanglementConfig, DisentanglementMethod, ExplanationType, IndependenceTest,
573 InterventionConfig, ScoreSettings, StructuralCausalModelConfig,
574};
575pub use cloud_integration::{
576 AWSSageMakerService, AutoScalingConfig, AzureMLService, BackupConfig, CloudIntegrationConfig,
577 CloudIntegrationManager, CloudProvider, CloudService, ClusterStatus, CostEstimate,
578 CostOptimizationResult, CostOptimizationStrategy, DeploymentConfig, DeploymentResult,
579 DeploymentStatus, EndpointInfo, FunctionInvocationResult, GPUClusterConfig, GPUClusterResult,
580 LifecyclePolicy, OptimizationAction, PerformanceTier, ReplicationType,
581 ServerlessDeploymentResult, ServerlessFunctionConfig, ServerlessStatus, StorageConfig,
582 StorageResult, StorageStatus, StorageType,
583};
584pub use compression::{
585 CompressedModel, CompressionStats, CompressionTarget, DistillationConfig,
586 ModelCompressionManager, NASConfig, OptimizationTarget, PruningConfig, PruningMethod,
587 QuantizationConfig, QuantizationMethod,
588};
589pub use consciousness_aware_embeddings::{
590 AttentionMechanism, ConsciousnessAwareEmbedding, ConsciousnessInsights, ConsciousnessLevel,
591 MetaCognition, WorkingMemory,
592};
593pub use continual_learning::{
597 ArchitectureConfig, BoundaryDetection, ConsolidationConfig, ContinualLearningConfig,
598 ContinualLearningModel, MemoryConfig, MemoryType, MemoryUpdateStrategy, RegularizationConfig,
599 ReplayConfig, ReplayMethod, TaskConfig, TaskDetection, TaskSwitching,
600};
601pub use cross_module_performance::{
602 CoordinatorConfig, CrossModulePerformanceCoordinator, GlobalPerformanceMetrics, ModuleMetrics,
603 ModulePerformanceMonitor, OptimizationCache, PerformanceSnapshot, PredictivePerformanceEngine,
604 ResourceAllocator, ResourceTracker,
605};
606pub use delta::{
607 ChangeRecord, ChangeStatistics, ChangeType, DeltaConfig, DeltaManager, DeltaResult, DeltaStats,
608 IncrementalStrategy,
609};
610pub use enterprise_knowledge::{
611 BehaviorMetrics, CareerPredictions, Category, CategoryHierarchy, CategoryPerformance,
612 ColdStartStrategy, CommunicationFrequency, CommunicationPreferences, CustomerEmbedding,
613 CustomerPreferences, CustomerRatings, CustomerSegment, Department, DepartmentPerformance,
614 EmployeeEmbedding, EnterpriseConfig, EnterpriseKnowledgeAnalyzer, EnterpriseMetrics,
615 ExperienceLevel, FeatureType, MarketAnalysis, OrganizationalStructure,
616 PerformanceMetrics as EnterprisePerformanceMetrics, ProductAvailability, ProductEmbedding,
617 ProductFeature, ProductRecommendation, Project, ProjectOutcome, ProjectParticipation,
618 ProjectPerformance, ProjectStatus, Purchase, PurchaseChannel, RecommendationConfig,
619 RecommendationEngine, RecommendationEngineType, RecommendationPerformance,
620 RecommendationReason, SalesMetrics, Skill, SkillCategory, Team, TeamPerformance,
621};
622pub use evaluation::{
623 QueryAnsweringEvaluator, QueryEvaluationConfig, QueryEvaluationResults, QueryMetric,
624 QueryResult, QueryTemplate, QueryType, ReasoningChain, ReasoningEvaluationConfig,
625 ReasoningEvaluationResults, ReasoningRule, ReasoningStep, ReasoningTaskEvaluator,
626 ReasoningType, TypeSpecificResults,
627};
628pub use federated_learning::{
629 AggregationEngine, AggregationStrategy, AuthenticationConfig, AuthenticationMethod,
630 CertificateConfig, ClippingMechanisms, ClippingMethod, CommunicationConfig,
631 CommunicationManager, CommunicationProtocol, CompressionAlgorithm, CompressionConfig,
632 CompressionEngine, ConvergenceMetrics, ConvergenceStatus, DataSelectionStrategy,
633 DataStatistics, EncryptionScheme, FederatedConfig, FederatedCoordinator,
634 FederatedEmbeddingModel, FederatedMessage, FederatedRound, FederationStats, GlobalModelState,
635 HardwareAccelerator, KeyManager, LocalModelState, LocalTrainingStats, LocalUpdate,
636 MetaLearningConfig, NoiseGenerator, NoiseMechanism, OutlierAction, OutlierDetection,
637 OutlierDetectionMethod, Participant, ParticipantCapabilities, ParticipantStatus,
638 PersonalizationConfig, PersonalizationStrategy, PrivacyAccountant, PrivacyConfig,
639 PrivacyEngine, PrivacyMetrics, PrivacyParams, RoundMetrics, RoundStatus, SecurityConfig,
640 SecurityFeature, SecurityManager, TrainingConfig, VerificationEngine, VerificationMechanism,
641 VerificationResult, WeightingScheme,
642};
643pub use gpu_acceleration::{
644 GpuAccelerationConfig, GpuAccelerationManager, GpuMemoryPool, GpuPerformanceStats,
645 MixedPrecisionProcessor, MultiStreamProcessor, TensorCache,
646};
647pub use graphql_api::{
648 create_schema, BatchEmbeddingInput, BatchEmbeddingResult, BatchStatus, DistanceMetric,
649 EmbeddingFormat, EmbeddingQueryInput, EmbeddingResult, EmbeddingSchema, GraphQLContext,
650 ModelInfo, ModelType, SimilarityResult, SimilaritySearchInput,
651};
652pub use models::{
653 AggregationType, ComplEx, DistMult, GNNConfig, GNNEmbedding, GNNType, HoLE, HoLEConfig,
654 PoolingStrategy, RotatE, TransE, TransformerConfig, TransformerEmbedding, TransformerType,
655};
656
657pub use contextual::{
658 AccessibilityPreferences, ComplexityLevel, ContextualConfig, ContextualEmbeddingModel,
659 DomainContext, EmbeddingContext, PerformanceRequirements, PriorityLevel, PrivacySettings,
660 QueryContext, QueryType as ContextualQueryType, ResponseFormat, TaskConstraints, TaskContext,
661 TaskType, UserContext, UserHistory, UserPreferences,
662};
663pub use distributed_training::{
664 AggregationMethod, CommunicationBackend, DistributedEmbeddingTrainer, DistributedStrategy,
665 DistributedTrainingConfig, DistributedTrainingCoordinator, DistributedTrainingStats,
666 FaultToleranceConfig, WorkerInfo, WorkerStatus,
667};
668#[cfg(feature = "conve")]
669pub use models::{ConvE, ConvEConfig};
670pub use monitoring::{
671 Alert, AlertSeverity, AlertThresholds, AlertType, CacheMetrics, ConsoleAlertHandler,
672 DriftMetrics, ErrorEvent, ErrorMetrics, ErrorSeverity, LatencyMetrics, MonitoringConfig,
673 PerformanceMetrics as MonitoringPerformanceMetrics, PerformanceMonitor, QualityAssessment,
674 QualityMetrics, ResourceMetrics, SlackAlertHandler, ThroughputMetrics,
675};
676pub use multimodal::{
677 AlignmentNetwork, AlignmentObjective, ContrastiveConfig, CrossDomainConfig, CrossModalConfig,
678 KGEncoder, MultiModalEmbedding, MultiModalStats, TextEncoder,
679};
680pub use neural_symbolic_integration::{
681 ConstraintSatisfactionConfig, ConstraintType, KnowledgeIntegrationConfig, KnowledgeRule,
682 LogicIntegrationConfig, LogicProgrammingConfig, LogicalFormula, NeuralSymbolicConfig,
683 NeuralSymbolicModel, NeuroSymbolicArchitectureConfig, OntologicalConfig, ReasoningEngine,
684 RuleBasedConfig, SymbolicReasoningConfig,
685};
686pub use novel_architectures::{
687 ActivationType, ArchitectureParams, ArchitectureState, ArchitectureType, CurvatureComputation,
688 CurvatureMethod, CurvatureType, DynamicsConfig, EntanglementStructure, EquivarianceGroup,
689 FlowType, GeometricConfig, GeometricParams, GeometricSpace, GeometricState,
690 GraphTransformerParams, GraphTransformerState, HyperbolicDistance, HyperbolicInit,
691 HyperbolicManifold, HyperbolicParams, HyperbolicState, IntegrationScheme, IntegrationStats,
692 ManifoldLearning, ManifoldMethod, ManifoldOptimizer, NeuralODEParams, NeuralODEState,
693 NovelArchitectureConfig, NovelArchitectureModel, ODERegularization, ODESolverType,
694 ParallelTransport, QuantumGateSet, QuantumMeasurement, QuantumNoise, QuantumParams,
695 QuantumState, StabilityConstraints, StructuralBias, TimeEvolution, TransportMethod,
696};
697pub use quantum_circuits::{
698 Complex, MeasurementStrategy, QNNLayerType, QuantumApproximateOptimization, QuantumCircuit,
699 QuantumGate, QuantumNeuralNetwork, QuantumNeuralNetworkLayer, QuantumSimulator,
700 VariationalQuantumEigensolver,
701};
702pub use research_networks::{
703 AuthorEmbedding, Citation, CitationNetwork, CitationType, Collaboration, CollaborationNetwork,
704 NetworkMetrics, PaperSection, PublicationEmbedding, PublicationType, ResearchCommunity,
705 ResearchNetworkAnalyzer, ResearchNetworkConfig, TopicModel, TopicModelingConfig,
706};
707pub use sparql_extension::{
708 ExpandedQuery, Expansion, ExpansionType, QueryStatistics as SparqlQueryStatistics,
709 SparqlExtension, SparqlExtensionConfig,
710};
711pub use storage_backend::{
712 DiskBackend, EmbeddingMetadata, EmbeddingVersion, MemoryBackend, StorageBackend,
713 StorageBackendConfig, StorageBackendManager, StorageBackendType, StorageStats,
714};
715pub use temporal_embeddings::{
716 TemporalEmbeddingConfig, TemporalEmbeddingModel, TemporalEvent, TemporalForecast,
717 TemporalGranularity, TemporalScope, TemporalStats, TemporalTriple,
718};
719pub use vision_language_graph::{
720 AggregationFunction, CNNConfig, CrossAttentionConfig, DomainAdaptationConfig,
721 DomainAdaptationMethod, EpisodeConfig, FewShotConfig, FewShotMethod, FusionStrategy,
722 GraphArchitecture, GraphEncoder, GraphEncoderConfig, JointTrainingConfig, LanguageArchitecture,
723 LanguageEncoder, LanguageEncoderConfig, LanguageTransformerConfig, MetaLearner,
724 ModalityEncoding, MultiModalTransformer, MultiModalTransformerConfig, NormalizationType,
725 PoolingType, PositionEncodingType, ReadoutFunction, TaskCategory, TaskSpecificParams,
726 TrainingObjective, TransferLearningConfig, TransferStrategy, ViTConfig, VisionArchitecture,
727 VisionEncoder, VisionEncoderConfig, VisionLanguageGraphConfig, VisionLanguageGraphModel,
728 VisionLanguageGraphStats, ZeroShotConfig, ZeroShotMethod,
729};
730
731#[cfg(feature = "tucker")]
732pub use models::TuckER;
733
734#[cfg(feature = "quatd")]
735pub use models::QuatD;
736
737pub use crate::model_registry::{
739 ModelRegistry, ModelVersion, ResourceAllocation as ModelResourceAllocation,
740};
741
742pub use crate::model_selection::{
744 DatasetCharacteristics, MemoryRequirement, ModelComparison, ModelComparisonEntry,
745 ModelRecommendation, ModelSelector, ModelType as SelectionModelType, TrainingTime, UseCaseType,
746};
747
748pub use crate::performance_profiler::{
750 OperationStats, OperationTimer, OperationType, PerformanceProfiler, PerformanceReport,
751};
752
753pub mod quick_start {
767 use super::*;
768 use crate::models::TransE;
769
770 pub fn create_simple_transe_model() -> TransE {
772 let config = ModelConfig::default()
773 .with_dimensions(128)
774 .with_learning_rate(0.01)
775 .with_max_epochs(100);
776 TransE::new(config)
777 }
778
779 pub fn create_biomedical_model() -> BiomedicalEmbedding {
781 let config = BiomedicalEmbeddingConfig::default();
782 BiomedicalEmbedding::new(config)
783 }
784
785 pub fn parse_triple_from_string(triple_str: &str) -> Result<Triple> {
787 let parts: Vec<&str> = triple_str.split_whitespace().collect();
788 if parts.len() != 3 {
789 return Err(anyhow::anyhow!(
790 "Triple must have exactly 3 parts separated by spaces"
791 ));
792 }
793
794 let expand_uri = |s: &str| -> String {
796 if s.starts_with("http://") || s.starts_with("https://") {
797 s.to_string()
798 } else {
799 format!("http://example.org/{s}")
800 }
801 };
802
803 Ok(Triple::new(
804 NamedNode::new(&expand_uri(parts[0]))?,
805 NamedNode::new(&expand_uri(parts[1]))?,
806 NamedNode::new(&expand_uri(parts[2]))?,
807 ))
808 }
809
810 pub fn add_triples_from_strings<T: EmbeddingModel>(
812 model: &mut T,
813 triple_strings: &[&str],
814 ) -> Result<usize> {
815 let mut count = 0;
816 for triple_str in triple_strings {
817 let triple = parse_triple_from_string(triple_str)?;
818 model.add_triple(triple)?;
819 count += 1;
820 }
821 Ok(count)
822 }
823
824 pub fn cosine_similarity(a: &[f64], b: &[f64]) -> Result<f64> {
826 if a.len() != b.len() {
827 return Err(anyhow::anyhow!(
828 "Vector dimensions don't match: {} vs {}",
829 a.len(),
830 b.len()
831 ));
832 }
833
834 let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
835 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
836 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
837
838 if norm_a == 0.0 || norm_b == 0.0 {
839 return Ok(0.0);
840 }
841
842 Ok(dot_product / (norm_a * norm_b))
843 }
844
845 pub fn generate_sample_kg_data(
847 num_entities: usize,
848 num_relations: usize,
849 ) -> Vec<(String, String, String)> {
850 #[allow(unused_imports)]
851 use scirs2_core::random::{Random, Rng};
852
853 let mut random = Random::default();
854 let mut triples = Vec::new();
855
856 let entities: Vec<String> = (0..num_entities)
857 .map(|i| format!("http://example.org/entity_{i}"))
858 .collect();
859
860 let relations: Vec<String> = (0..num_relations)
861 .map(|i| format!("http://example.org/relation_{i}"))
862 .collect();
863
864 for _ in 0..(num_entities * 2) {
866 let subject_idx = random.random_range(0..entities.len());
867 let relation_idx = random.random_range(0..relations.len());
868 let object_idx = random.random_range(0..entities.len());
869
870 let subject = entities[subject_idx].clone();
871 let relation = relations[relation_idx].clone();
872 let object = entities[object_idx].clone();
873
874 if subject != object {
875 triples.push((subject, relation, object));
876 }
877 }
878
879 triples
880 }
881
882 pub fn quick_performance_test<F>(
884 name: &str,
885 iterations: usize,
886 operation: F,
887 ) -> std::time::Duration
888 where
889 F: Fn(),
890 {
891 let start = std::time::Instant::now();
892 for _ in 0..iterations {
893 operation();
894 }
895 let duration = start.elapsed();
896
897 println!(
898 "Performance test '{name}': {iterations} iterations in {duration:?} ({:.2} ops/sec)",
899 iterations as f64 / duration.as_secs_f64()
900 );
901
902 duration
903 }
904
905 pub async fn quick_revolutionary_performance_test<F, Fut>(
938 name: &str,
939 iterations: usize,
940 async_operation: F,
941 ) -> std::time::Duration
942 where
943 F: Fn() -> Fut,
944 Fut: std::future::Future<Output = ()>,
945 {
946 let start = std::time::Instant::now();
947 for _ in 0..iterations {
948 async_operation().await;
949 }
950 let duration = start.elapsed();
951
952 println!(
953 "Revolutionary performance test '{name}': {iterations} iterations in {duration:?} ({:.2} ops/sec)",
954 iterations as f64 / duration.as_secs_f64()
955 );
956
957 duration
958 }
959}
960
961#[cfg(test)]
962mod quick_start_tests {
963 use super::*;
964 use crate::quick_start::*;
965
966 #[test]
967 fn test_create_simple_transe_model() {
968 let model = create_simple_transe_model();
969 let config = model.config();
970 assert_eq!(config.dimensions, 128);
971 assert_eq!(config.learning_rate, 0.01);
972 assert_eq!(config.max_epochs, 100);
973 }
974
975 #[test]
976 fn test_parse_triple_from_string() {
977 let triple_str = "http://example.org/alice http://example.org/knows http://example.org/bob";
978 let triple = parse_triple_from_string(triple_str).unwrap();
979 assert_eq!(triple.subject.iri, "http://example.org/alice");
980 assert_eq!(triple.predicate.iri, "http://example.org/knows");
981 assert_eq!(triple.object.iri, "http://example.org/bob");
982 }
983
984 #[test]
985 fn test_parse_triple_from_string_invalid() {
986 let triple_str = "http://example.org/alice http://example.org/knows";
987 let result = parse_triple_from_string(triple_str);
988 assert!(result.is_err());
989 }
990
991 #[test]
992 fn test_add_triples_from_strings() {
993 let mut model = create_simple_transe_model();
994 let triple_strings = [
995 "http://example.org/alice http://example.org/knows http://example.org/bob",
996 "http://example.org/bob http://example.org/likes http://example.org/music",
997 ];
998
999 let count = add_triples_from_strings(&mut model, &triple_strings).unwrap();
1000 assert_eq!(count, 2);
1001 }
1002
1003 #[test]
1004 fn test_cosine_similarity() {
1005 let a = vec![1.0, 0.0, 0.0];
1006 let b = vec![1.0, 0.0, 0.0];
1007 let similarity = cosine_similarity(&a, &b).unwrap();
1008 assert!((similarity - 1.0).abs() < 1e-10);
1009
1010 let c = vec![0.0, 1.0, 0.0];
1011 let similarity2 = cosine_similarity(&a, &c).unwrap();
1012 assert!((similarity2 - 0.0).abs() < 1e-10);
1013
1014 let d = vec![1.0, 0.0];
1016 assert!(cosine_similarity(&a, &d).is_err());
1017 }
1018
1019 #[test]
1020 fn test_generate_sample_kg_data() {
1021 let triples = generate_sample_kg_data(5, 3);
1022 assert!(!triples.is_empty());
1023
1024 for (subject, relation, object) in &triples {
1026 assert!(subject.starts_with("http://example.org/entity_"));
1027 assert!(relation.starts_with("http://example.org/relation_"));
1028 assert!(object.starts_with("http://example.org/entity_"));
1029 assert_ne!(subject, object); }
1031 }
1032
1033 #[test]
1034 fn test_quick_performance_test() {
1035 let duration = quick_performance_test("test_operation", 100, || {
1036 let _sum: i32 = (1..10).sum();
1038 });
1039
1040 let _nanos = duration.as_nanos();
1043 }
1044}