1#![allow(dead_code)]
122
123pub mod ab_testing;
124pub mod acceleration;
125pub mod adaptive_learning;
126pub mod advanced_profiler;
127pub mod alignment;
128#[cfg(feature = "api-server")]
129pub mod api;
130pub mod application_tasks;
131pub mod batch_processing;
132pub mod biomedical_embeddings;
133pub mod caching;
134pub mod causal_representation_learning;
135pub mod cloud_integration;
136pub mod clustering;
137pub mod community_detection;
138pub mod compression;
139pub mod contextual;
140pub mod continual_learning;
141pub mod cross_domain_transfer;
142pub mod cross_module_performance;
143pub mod delta;
144pub mod diffusion_embeddings;
145pub mod distributed_training;
146pub mod embed_compression;
147pub mod enterprise_knowledge;
148pub mod entity_linking;
149pub mod evaluation;
150pub mod federated_learning;
151pub mod fine_tuning;
152#[cfg(feature = "gpu")]
153pub mod gpu_acceleration;
154pub mod graph_models;
155pub mod graphql_api;
156pub mod inference;
157pub mod integration;
158pub mod interpretability;
159pub mod kg_completion;
160pub mod link_prediction;
161pub mod mamba_attention;
162pub mod mixed_precision;
163pub mod model_registry;
164pub mod model_selection;
165pub mod models;
166pub mod monitoring;
167pub mod multimodal;
168pub mod neural_symbolic_integration;
169pub mod neuro_evolution;
170pub mod novel_architectures;
171pub mod performance_profiler;
172pub mod persistence;
173pub mod quantization;
174pub mod real_time_fine_tuning;
175pub mod real_time_optimization;
176pub mod research_networks;
177pub mod sparql_extension;
179pub mod storage_backend;
180pub mod temporal_embeddings;
181pub mod training;
182pub mod training_online;
183pub mod utils;
184pub mod validation;
185pub mod vector_search;
186pub mod vision_language_graph;
187pub mod visualization;
188pub mod contrastive_learning;
190
191pub mod procrustes_alignment;
193
194pub mod embedding_cache;
196
197pub mod dimensionality_reducer;
199
200pub mod pca_reducer;
202
203pub mod fine_tuner;
205
206pub mod vector_store;
208
209pub mod cross_encoder;
211
212pub mod projection_layer;
214pub use projection_layer::{ActivationFn, InitMethod, ProjectionLayer, ProjectionMatrix};
215
216pub mod embedding_store;
218
219pub mod tokenizer;
221
222pub mod embedding_aggregator;
224
225pub mod reranker;
227
228pub mod index_optimizer;
230
231pub mod batch_encoder;
233
234pub mod embedding_compressor;
236
237pub use oxirs_vec::Vector as VecVector;
239
240pub use adaptive_learning::{
242 AdaptationMetrics, AdaptationStrategy, AdaptiveLearningConfig, AdaptiveLearningSystem,
243 QualityFeedback,
244};
245
246use anyhow::Result;
247use chrono::{DateTime, Utc};
248use serde::{Deserialize, Serialize};
249use std::collections::HashMap;
250use std::ops::{Add, Sub};
251use uuid::Uuid;
252
253#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
256pub struct Vector {
257 pub values: Vec<f32>,
258 pub dimensions: usize,
259 #[serde(skip)]
260 inner: Option<VecVector>,
261}
262
263impl Vector {
264 pub fn new(values: Vec<f32>) -> Self {
265 let dimensions = values.len();
266 Self {
267 values,
268 dimensions,
269 inner: None,
270 }
271 }
272
273 fn get_inner(&self) -> VecVector {
275 if let Some(ref inner) = self.inner {
277 inner.clone()
278 } else {
279 VecVector::new(self.values.clone())
280 }
281 }
282
283 fn sync_internal(&mut self) {
285 self.dimensions = self.values.len();
286 self.inner = None; }
288
289 pub fn from_array1(array: &scirs2_core::ndarray_ext::Array1<f32>) -> Self {
291 Self::new(array.to_vec())
292 }
293
294 pub fn to_array1(&self) -> scirs2_core::ndarray_ext::Array1<f32> {
296 scirs2_core::ndarray_ext::Array1::from_vec(self.values.clone())
297 }
298
299 pub fn mapv<F>(&self, f: F) -> Self
301 where
302 F: Fn(f32) -> f32,
303 {
304 Self::new(self.values.iter().copied().map(f).collect())
305 }
306
307 pub fn sum(&self) -> f32 {
309 self.values.iter().sum()
310 }
311
312 pub fn sqrt(&self) -> f32 {
314 self.sum().sqrt()
315 }
316
317 pub fn inner(&self) -> VecVector {
319 self.get_inner()
320 }
321
322 pub fn into_inner(self) -> VecVector {
324 self.inner.unwrap_or_else(|| VecVector::new(self.values))
325 }
326
327 pub fn from_vec_vector(vec_vector: VecVector) -> Self {
329 let values = vec_vector.as_f32().to_vec();
330 let dimensions = values.len();
331 Self {
332 values,
333 dimensions,
334 inner: Some(vec_vector),
335 }
336 }
337
338 pub fn with_capacity(capacity: usize) -> Self {
340 Self {
341 values: Vec::with_capacity(capacity),
342 dimensions: 0,
343 inner: None,
344 }
345 }
346
347 pub fn extend_optimized(&mut self, other_values: &[f32]) {
349 self.values.reserve(other_values.len());
351 self.values.extend_from_slice(other_values);
352 self.sync_internal();
353 }
354
355 pub fn shrink_to_fit(&mut self) {
357 self.values.shrink_to_fit();
358 self.sync_internal();
359 }
360
361 pub fn memory_usage(&self) -> usize {
363 self.values.capacity() * std::mem::size_of::<f32>() + std::mem::size_of::<Self>()
364 }
365}
366
367impl Add for &Vector {
369 type Output = Vector;
370
371 fn add(self, other: &Vector) -> Vector {
372 if let (Some(self_inner), Some(other_inner)) = (&self.inner, &other.inner) {
374 if let Ok(result) = self_inner.add(other_inner) {
375 return Vector::from_vec_vector(result);
376 }
377 }
378 assert_eq!(
380 self.values.len(),
381 other.values.len(),
382 "Vector dimensions must match"
383 );
384 let result_values: Vec<f32> = self
385 .values
386 .iter()
387 .zip(other.values.iter())
388 .map(|(a, b)| a + b)
389 .collect();
390 Vector::new(result_values)
391 }
392}
393
394impl Sub for &Vector {
395 type Output = Vector;
396
397 fn sub(self, other: &Vector) -> Vector {
398 if let (Some(self_inner), Some(other_inner)) = (&self.inner, &other.inner) {
400 if let Ok(result) = self_inner.subtract(other_inner) {
401 return Vector::from_vec_vector(result);
402 }
403 }
404 assert_eq!(
406 self.values.len(),
407 other.values.len(),
408 "Vector dimensions must match"
409 );
410 let result_values: Vec<f32> = self
411 .values
412 .iter()
413 .zip(other.values.iter())
414 .map(|(a, b)| a - b)
415 .collect();
416 Vector::new(result_values)
417 }
418}
419
420#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
422pub struct Triple {
423 pub subject: NamedNode,
424 pub predicate: NamedNode,
425 pub object: NamedNode,
426}
427
428impl Triple {
429 pub fn new(subject: NamedNode, predicate: NamedNode, object: NamedNode) -> Self {
430 Self {
431 subject,
432 predicate,
433 object,
434 }
435 }
436}
437
438#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
440pub struct NamedNode {
441 pub iri: String,
442}
443
444impl NamedNode {
445 pub fn new(iri: &str) -> Result<Self> {
446 Ok(Self {
447 iri: iri.to_string(),
448 })
449 }
450}
451
452impl std::fmt::Display for NamedNode {
453 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
454 write!(f, "{}", self.iri)
455 }
456}
457
458#[derive(Debug, Clone, Serialize, Deserialize)]
460pub struct ModelConfig {
461 pub dimensions: usize,
462 pub learning_rate: f64,
463 pub l2_reg: f64,
464 pub max_epochs: usize,
465 pub batch_size: usize,
466 pub negative_samples: usize,
467 pub seed: Option<u64>,
468 pub use_gpu: bool,
469 pub model_params: HashMap<String, f64>,
470}
471
472impl Default for ModelConfig {
473 fn default() -> Self {
474 Self {
475 dimensions: 100,
476 learning_rate: 0.01,
477 l2_reg: 0.0001,
478 max_epochs: 1000,
479 batch_size: 1000,
480 negative_samples: 10,
481 seed: None,
482 use_gpu: false,
483 model_params: HashMap::new(),
484 }
485 }
486}
487
488impl ModelConfig {
489 pub fn with_dimensions(mut self, dimensions: usize) -> Self {
490 self.dimensions = dimensions;
491 self
492 }
493
494 pub fn with_learning_rate(mut self, learning_rate: f64) -> Self {
495 self.learning_rate = learning_rate;
496 self
497 }
498
499 pub fn with_max_epochs(mut self, max_epochs: usize) -> Self {
500 self.max_epochs = max_epochs;
501 self
502 }
503
504 pub fn with_seed(mut self, seed: u64) -> Self {
505 self.seed = Some(seed);
506 self
507 }
508
509 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
510 self.batch_size = batch_size;
511 self
512 }
513}
514
515#[derive(Debug, Clone, Serialize, Deserialize)]
517pub struct TrainingStats {
518 pub epochs_completed: usize,
519 pub final_loss: f64,
520 pub training_time_seconds: f64,
521 pub convergence_achieved: bool,
522 pub loss_history: Vec<f64>,
523}
524
525#[derive(Debug, Clone, Serialize, Deserialize)]
527pub struct ModelStats {
528 pub num_entities: usize,
529 pub num_relations: usize,
530 pub num_triples: usize,
531 pub dimensions: usize,
532 pub is_trained: bool,
533 pub model_type: String,
534 pub creation_time: DateTime<Utc>,
535 pub last_training_time: Option<DateTime<Utc>>,
536}
537
538impl Default for ModelStats {
539 fn default() -> Self {
540 Self {
541 num_entities: 0,
542 num_relations: 0,
543 num_triples: 0,
544 dimensions: 0,
545 is_trained: false,
546 model_type: "unknown".to_string(),
547 creation_time: Utc::now(),
548 last_training_time: None,
549 }
550 }
551}
552
553#[derive(Debug, thiserror::Error)]
555pub enum EmbeddingError {
556 #[error("Model not trained")]
557 ModelNotTrained,
558 #[error("Entity not found: {entity}")]
559 EntityNotFound { entity: String },
560 #[error("Relation not found: {relation}")]
561 RelationNotFound { relation: String },
562 #[error("Other error: {0}")]
563 Other(#[from] anyhow::Error),
564}
565
566#[async_trait::async_trait]
568pub trait EmbeddingModel: Send + Sync {
569 fn config(&self) -> &ModelConfig;
570 fn model_id(&self) -> &Uuid;
571 fn model_type(&self) -> &'static str;
572 fn add_triple(&mut self, triple: Triple) -> Result<()>;
573 async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats>;
574 fn get_entity_embedding(&self, entity: &str) -> Result<Vector>;
575 fn get_relation_embedding(&self, relation: &str) -> Result<Vector>;
576 fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64>;
577 fn predict_objects(
578 &self,
579 subject: &str,
580 predicate: &str,
581 k: usize,
582 ) -> Result<Vec<(String, f64)>>;
583 fn predict_subjects(
584 &self,
585 predicate: &str,
586 object: &str,
587 k: usize,
588 ) -> Result<Vec<(String, f64)>>;
589 fn predict_relations(
590 &self,
591 subject: &str,
592 object: &str,
593 k: usize,
594 ) -> Result<Vec<(String, f64)>>;
595 fn get_entities(&self) -> Vec<String>;
596 fn get_relations(&self) -> Vec<String>;
597 fn get_stats(&self) -> ModelStats;
598 fn save(&self, path: &str) -> Result<()>;
599 fn load(&mut self, path: &str) -> Result<()>;
600 fn clear(&mut self);
601 fn is_trained(&self) -> bool;
602
603 async fn encode(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
605}
606
607pub use acceleration::{AdaptiveEmbeddingAccelerator, GpuEmbeddingAccelerator};
609#[cfg(feature = "api-server")]
610pub use api::{start_server, ApiConfig, ApiState};
611pub use batch_processing::{
612 BatchJob, BatchProcessingConfig, BatchProcessingManager, BatchProcessingResult,
613 BatchProcessingStats, IncrementalConfig, JobProgress, JobStatus, OutputFormat,
614 PartitioningStrategy, RetryConfig,
615};
616pub use biomedical_embeddings::{
617 BiomedicalEmbedding, BiomedicalEmbeddingConfig, BiomedicalEntityType, BiomedicalRelationType,
618 FineTuningConfig, PreprocessingRule, SpecializedTextConfig, SpecializedTextEmbedding,
619 SpecializedTextModel,
620};
621pub use caching::{CacheConfig, CacheManager, CachedEmbeddingModel};
622pub use causal_representation_learning::{
623 CausalDiscoveryAlgorithm, CausalDiscoveryConfig, CausalGraph, CausalRepresentationConfig,
624 CausalRepresentationModel, ConstraintSettings, CounterfactualConfig, CounterfactualQuery,
625 DisentanglementConfig, DisentanglementMethod, ExplanationType, IndependenceTest,
626 InterventionConfig, ScoreSettings, StructuralCausalModelConfig,
627};
628pub use cloud_integration::{
629 AWSSageMakerService, AutoScalingConfig, AzureMLService, BackupConfig, CloudIntegrationConfig,
630 CloudIntegrationManager, CloudProvider, CloudService, ClusterStatus, CostEstimate,
631 CostOptimizationResult, CostOptimizationStrategy, DeploymentConfig, DeploymentResult,
632 DeploymentStatus, EndpointInfo, FunctionInvocationResult, GPUClusterConfig, GPUClusterResult,
633 LifecyclePolicy, OptimizationAction, PerformanceTier, ReplicationType,
634 ServerlessDeploymentResult, ServerlessFunctionConfig, ServerlessStatus, StorageConfig,
635 StorageResult, StorageStatus, StorageType,
636};
637pub use compression::{
638 CompressedModel, CompressionStats, CompressionTarget, DistillationConfig,
639 ModelCompressionManager, NASConfig, OptimizationTarget, PruningConfig, PruningMethod,
640 QuantizationConfig, QuantizationMethod,
641};
642pub use continual_learning::{
646 ArchitectureConfig, BoundaryDetection, ConsolidationConfig, ContinualLearningConfig,
647 ContinualLearningModel, MemoryConfig, MemoryType, MemoryUpdateStrategy, RegularizationConfig,
648 ReplayConfig, ReplayMethod, TaskConfig, TaskDetection, TaskSwitching,
649};
650pub use cross_module_performance::{
651 CoordinatorConfig, CrossModulePerformanceCoordinator, GlobalPerformanceMetrics, ModuleMetrics,
652 ModulePerformanceMonitor, OptimizationCache, PerformanceSnapshot, PredictivePerformanceEngine,
653 ResourceAllocator, ResourceTracker,
654};
655pub use delta::{
656 ChangeRecord, ChangeStatistics, ChangeType, DeltaConfig, DeltaManager, DeltaResult, DeltaStats,
657 IncrementalStrategy,
658};
659pub use enterprise_knowledge::{
660 BehaviorMetrics, CareerPredictions, Category, CategoryHierarchy, CategoryPerformance,
661 ColdStartStrategy, CommunicationFrequency, CommunicationPreferences, CustomerEmbedding,
662 CustomerPreferences, CustomerRatings, CustomerSegment, Department, DepartmentPerformance,
663 EmployeeEmbedding, EnterpriseConfig, EnterpriseKnowledgeAnalyzer, EnterpriseMetrics,
664 ExperienceLevel, FeatureType, MarketAnalysis, OrganizationalStructure,
665 PerformanceMetrics as EnterprisePerformanceMetrics, ProductAvailability, ProductEmbedding,
666 ProductFeature, ProductRecommendation, Project, ProjectOutcome, ProjectParticipation,
667 ProjectPerformance, ProjectStatus, Purchase, PurchaseChannel, RecommendationConfig,
668 RecommendationEngine, RecommendationEngineType, RecommendationPerformance,
669 RecommendationReason, SalesMetrics, Skill, SkillCategory, Team, TeamPerformance,
670};
671pub use evaluation::{
672 AnalogicalReasoningBenchmark, AnalogyQuad, EmbeddingClusteringMetrics, EmbeddingEvaluator,
673 QueryAnsweringEvaluator, QueryEvaluationConfig, QueryEvaluationResults, QueryMetric,
674 QueryResult, QueryTemplate, QueryType, ReasoningChain, ReasoningEvaluationConfig,
675 ReasoningEvaluationResults, ReasoningRule, ReasoningStep, ReasoningTaskEvaluator,
676 ReasoningType, TypeSpecificResults,
677};
678pub use federated_learning::{
679 AggregationEngine, AggregationStrategy, AuthenticationConfig, AuthenticationMethod,
680 CertificateConfig, ClippingMechanisms, ClippingMethod, CommunicationConfig,
681 CommunicationManager, CommunicationProtocol, CompressionAlgorithm, CompressionConfig,
682 CompressionEngine, ConvergenceMetrics, ConvergenceStatus, DataSelectionStrategy,
683 DataStatistics, EncryptionScheme, FederatedConfig, FederatedCoordinator,
684 FederatedEmbeddingModel, FederatedMessage, FederatedRound, FederationStats, GlobalModelState,
685 HardwareAccelerator, KeyManager, LocalModelState, LocalTrainingStats, LocalUpdate,
686 MetaLearningConfig, NoiseGenerator, NoiseMechanism, OutlierAction, OutlierDetection,
687 OutlierDetectionMethod, Participant, ParticipantCapabilities, ParticipantStatus,
688 PersonalizationConfig, PersonalizationStrategy, PrivacyAccountant, PrivacyConfig,
689 PrivacyEngine, PrivacyMetrics, PrivacyParams, RoundMetrics, RoundStatus, SecurityConfig,
690 SecurityFeature, SecurityManager, TrainingConfig, VerificationEngine, VerificationMechanism,
691 VerificationResult, WeightingScheme,
692};
693#[cfg(feature = "gpu")]
694pub use gpu_acceleration::{
695 GpuAccelerationConfig, GpuAccelerationManager, GpuMemoryPool, GpuPerformanceStats,
696 MixedPrecisionProcessor, MultiStreamProcessor, TensorCache,
697};
698pub use graphql_api::{
699 create_schema, BatchEmbeddingInput, BatchEmbeddingResult, BatchStatus, DistanceMetric,
700 EmbeddingFormat, EmbeddingQueryInput, EmbeddingResult, EmbeddingSchema, GraphQLContext,
701 ModelInfo, ModelType, SimilarityResult, SimilaritySearchInput,
702};
703pub use kg_completion::{BatchedTrainingLoop, KgCompletionTask, NegativeSampler, TrainingBatch};
704pub use models::{
705 AggregationType, ComplEx, DistMult, GNNConfig, GNNEmbedding, GNNType, HoLE, HoLEConfig,
706 PoolingStrategy, RotatE, TransE, TransformerConfig, TransformerEmbedding, TransformerType,
707};
708
709pub use contextual::{
710 AccessibilityPreferences, ComplexityLevel, ContextualConfig, ContextualEmbeddingModel,
711 DomainContext, EmbeddingContext, PerformanceRequirements, PriorityLevel, PrivacySettings,
712 QueryContext, QueryType as ContextualQueryType, ResponseFormat, TaskConstraints, TaskContext,
713 TaskType, UserContext, UserHistory, UserPreferences,
714};
715pub use distributed_training::{
716 AggregationMethod, AllReduceStrategy, CommunicationBackend, DataParallelTrainer,
717 DistributedEmbeddingTrainer, DistributedStrategy, DistributedTrainingConfig,
718 DistributedTrainingCoordinator, DistributedTrainingSample, DistributedTrainingStats,
719 FaultToleranceConfig, GradientAggregator, GradientCompressor, ModelUpdate, SparseGradient,
720 WorkerInfo, WorkerStatus, WorkerUpdate,
721};
722#[cfg(feature = "conve")]
723pub use models::{ConvE, ConvEConfig};
724pub use monitoring::{
725 Alert, AlertSeverity, AlertThresholds, AlertType, CacheMetrics, ConsoleAlertHandler,
726 DriftMetrics, ErrorEvent, ErrorMetrics, ErrorSeverity, LatencyMetrics, MonitoringConfig,
727 PerformanceMetrics as MonitoringPerformanceMetrics, PerformanceMonitor, QualityAssessment,
728 QualityMetrics, ResourceMetrics, SlackAlertHandler, ThroughputMetrics,
729};
730pub use multimodal::{
731 AlignmentNetwork, AlignmentObjective, ContrastiveConfig, CrossDomainConfig, CrossModalConfig,
732 KGEncoder, MultiModalEmbedding, MultiModalStats, TextEncoder,
733};
734pub use neural_symbolic_integration::{
735 ConstraintSatisfactionConfig, ConstraintType, KnowledgeIntegrationConfig, KnowledgeRule,
736 LogicIntegrationConfig, LogicProgrammingConfig, LogicalFormula, NeuralSymbolicConfig,
737 NeuralSymbolicModel, NeuroSymbolicArchitectureConfig, OntologicalConfig, ReasoningEngine,
738 RuleBasedConfig, SymbolicReasoningConfig,
739};
740pub use novel_architectures::{
741 ActivationType, ArchitectureParams, ArchitectureState, ArchitectureType, CurvatureComputation,
742 CurvatureMethod, CurvatureType, DynamicsConfig, EntanglementStructure, EquivarianceGroup,
743 FlowType, GeometricConfig, GeometricParams, GeometricSpace, GeometricState,
744 GraphTransformerParams, GraphTransformerState, HyperbolicDistance, HyperbolicInit,
745 HyperbolicManifold, HyperbolicParams, HyperbolicState, IntegrationScheme, IntegrationStats,
746 ManifoldLearning, ManifoldMethod, ManifoldOptimizer, NeuralODEParams, NeuralODEState,
747 NovelArchitectureConfig, NovelArchitectureModel, ODERegularization, ODESolverType,
748 ParallelTransport, QuantumGateSet, QuantumMeasurement, QuantumNoise, QuantumParams,
749 QuantumState, StabilityConstraints, StructuralBias, TimeEvolution, TransportMethod,
750};
751pub use research_networks::{
752 AuthorEmbedding, Citation, CitationNetwork, CitationType, Collaboration, CollaborationNetwork,
753 NetworkMetrics, PaperSection, PublicationEmbedding, PublicationType, ResearchCommunity,
754 ResearchNetworkAnalyzer, ResearchNetworkConfig, TopicModel, TopicModelingConfig,
755};
756pub use sparql_extension::{
757 ExpandedQuery, Expansion, ExpansionType, QueryStatistics as SparqlQueryStatistics,
758 SparqlExtension, SparqlExtensionConfig,
759};
760pub use storage_backend::{
761 DiskBackend, EmbeddingMetadata, EmbeddingVersion, MemoryBackend, StorageBackend,
762 StorageBackendConfig, StorageBackendManager, StorageBackendType, StorageStats,
763};
764pub use temporal_embeddings::{
765 TemporalEmbeddingConfig, TemporalEmbeddingModel, TemporalEvent, TemporalForecast,
766 TemporalGranularity, TemporalScope, TemporalStats, TemporalTriple,
767};
768pub use vision_language_graph::{
769 AggregationFunction, CNNConfig, CrossAttentionConfig, DomainAdaptationConfig,
770 DomainAdaptationMethod, EpisodeConfig, FewShotConfig, FewShotMethod, FusionStrategy,
771 GraphArchitecture, GraphEncoder, GraphEncoderConfig, JointTrainingConfig, LanguageArchitecture,
772 LanguageEncoder, LanguageEncoderConfig, LanguageTransformerConfig, MetaLearner,
773 ModalityEncoding, MultiModalTransformer, MultiModalTransformerConfig, NormalizationType,
774 PoolingType, PositionEncodingType, ReadoutFunction, TaskCategory, TaskSpecificParams,
775 TrainingObjective, TransferLearningConfig, TransferStrategy, ViTConfig, VisionArchitecture,
776 VisionEncoder, VisionEncoderConfig, VisionLanguageGraphConfig, VisionLanguageGraphModel,
777 VisionLanguageGraphStats, ZeroShotConfig, ZeroShotMethod,
778};
779
780#[cfg(feature = "tucker")]
781pub use models::TuckER;
782
783#[cfg(feature = "quatd")]
784pub use models::QuatD;
785
786pub use crate::model_registry::{
788 ModelRegistry, ModelVersion, ResourceAllocation as ModelResourceAllocation,
789};
790
791pub use crate::model_selection::{
793 DatasetCharacteristics, MemoryRequirement, ModelComparison, ModelComparisonEntry,
794 ModelRecommendation, ModelSelector, ModelType as SelectionModelType, TrainingTime, UseCaseType,
795};
796
797pub use crate::performance_profiler::{
799 OperationStats, OperationTimer, OperationType, PerformanceProfiler, PerformanceReport,
800};
801
802pub mod quick_start {
816 use super::*;
817 use crate::models::TransE;
818
819 pub fn create_simple_transe_model() -> TransE {
821 let config = ModelConfig::default()
822 .with_dimensions(128)
823 .with_learning_rate(0.01)
824 .with_max_epochs(100);
825 TransE::new(config)
826 }
827
828 pub fn create_biomedical_model() -> BiomedicalEmbedding {
830 let config = BiomedicalEmbeddingConfig::default();
831 BiomedicalEmbedding::new(config)
832 }
833
834 pub fn parse_triple_from_string(triple_str: &str) -> Result<Triple> {
836 let parts: Vec<&str> = triple_str.split_whitespace().collect();
837 if parts.len() != 3 {
838 return Err(anyhow::anyhow!(
839 "Triple must have exactly 3 parts separated by spaces"
840 ));
841 }
842
843 let expand_uri = |s: &str| -> String {
845 if s.starts_with("http://") || s.starts_with("https://") {
846 s.to_string()
847 } else {
848 format!("http://example.org/{s}")
849 }
850 };
851
852 Ok(Triple::new(
853 NamedNode::new(&expand_uri(parts[0]))?,
854 NamedNode::new(&expand_uri(parts[1]))?,
855 NamedNode::new(&expand_uri(parts[2]))?,
856 ))
857 }
858
859 pub fn add_triples_from_strings<T: EmbeddingModel>(
861 model: &mut T,
862 triple_strings: &[&str],
863 ) -> Result<usize> {
864 let mut count = 0;
865 for triple_str in triple_strings {
866 let triple = parse_triple_from_string(triple_str)?;
867 model.add_triple(triple)?;
868 count += 1;
869 }
870 Ok(count)
871 }
872
873 pub fn cosine_similarity(a: &[f64], b: &[f64]) -> Result<f64> {
875 if a.len() != b.len() {
876 return Err(anyhow::anyhow!(
877 "Vector dimensions don't match: {} vs {}",
878 a.len(),
879 b.len()
880 ));
881 }
882
883 let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
884 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
885 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
886
887 if norm_a == 0.0 || norm_b == 0.0 {
888 return Ok(0.0);
889 }
890
891 Ok(dot_product / (norm_a * norm_b))
892 }
893
894 pub fn generate_sample_kg_data(
896 num_entities: usize,
897 num_relations: usize,
898 ) -> Vec<(String, String, String)> {
899 #[allow(unused_imports)]
900 use scirs2_core::random::{Random, Rng};
901
902 let mut random = Random::default();
903 let mut triples = Vec::new();
904
905 let entities: Vec<String> = (0..num_entities)
906 .map(|i| format!("http://example.org/entity_{i}"))
907 .collect();
908
909 let relations: Vec<String> = (0..num_relations)
910 .map(|i| format!("http://example.org/relation_{i}"))
911 .collect();
912
913 for _ in 0..(num_entities * 2) {
915 let subject_idx = random.random_range(0..entities.len());
916 let relation_idx = random.random_range(0..relations.len());
917 let object_idx = random.random_range(0..entities.len());
918
919 let subject = entities[subject_idx].clone();
920 let relation = relations[relation_idx].clone();
921 let object = entities[object_idx].clone();
922
923 if subject != object {
924 triples.push((subject, relation, object));
925 }
926 }
927
928 triples
929 }
930
931 pub fn quick_performance_test<F>(
933 name: &str,
934 iterations: usize,
935 operation: F,
936 ) -> std::time::Duration
937 where
938 F: Fn(),
939 {
940 let start = std::time::Instant::now();
941 for _ in 0..iterations {
942 operation();
943 }
944 let duration = start.elapsed();
945
946 println!(
947 "Performance test '{name}': {iterations} iterations in {duration:?} ({:.2} ops/sec)",
948 iterations as f64 / duration.as_secs_f64()
949 );
950
951 duration
952 }
953
954 pub async fn quick_revolutionary_performance_test<F, Fut>(
987 name: &str,
988 iterations: usize,
989 async_operation: F,
990 ) -> std::time::Duration
991 where
992 F: Fn() -> Fut,
993 Fut: std::future::Future<Output = ()>,
994 {
995 let start = std::time::Instant::now();
996 for _ in 0..iterations {
997 async_operation().await;
998 }
999 let duration = start.elapsed();
1000
1001 println!(
1002 "Revolutionary performance test '{name}': {iterations} iterations in {duration:?} ({:.2} ops/sec)",
1003 iterations as f64 / duration.as_secs_f64()
1004 );
1005
1006 duration
1007 }
1008}
1009
1010#[cfg(test)]
1011mod quick_start_tests {
1012 use super::*;
1013 use crate::quick_start::*;
1014
1015 #[test]
1016 fn test_create_simple_transe_model() {
1017 let model = create_simple_transe_model();
1018 let config = model.config();
1019 assert_eq!(config.dimensions, 128);
1020 assert_eq!(config.learning_rate, 0.01);
1021 assert_eq!(config.max_epochs, 100);
1022 }
1023
1024 #[test]
1025 fn test_parse_triple_from_string() {
1026 let triple_str = "http://example.org/alice http://example.org/knows http://example.org/bob";
1027 let triple = parse_triple_from_string(triple_str).unwrap();
1028 assert_eq!(triple.subject.iri, "http://example.org/alice");
1029 assert_eq!(triple.predicate.iri, "http://example.org/knows");
1030 assert_eq!(triple.object.iri, "http://example.org/bob");
1031 }
1032
1033 #[test]
1034 fn test_parse_triple_from_string_invalid() {
1035 let triple_str = "http://example.org/alice http://example.org/knows";
1036 let result = parse_triple_from_string(triple_str);
1037 assert!(result.is_err());
1038 }
1039
1040 #[test]
1041 fn test_add_triples_from_strings() {
1042 let mut model = create_simple_transe_model();
1043 let triple_strings = [
1044 "http://example.org/alice http://example.org/knows http://example.org/bob",
1045 "http://example.org/bob http://example.org/likes http://example.org/music",
1046 ];
1047
1048 let count = add_triples_from_strings(&mut model, &triple_strings).unwrap();
1049 assert_eq!(count, 2);
1050 }
1051
1052 #[test]
1053 fn test_cosine_similarity() {
1054 let a = vec![1.0, 0.0, 0.0];
1055 let b = vec![1.0, 0.0, 0.0];
1056 let similarity = cosine_similarity(&a, &b).unwrap();
1057 assert!((similarity - 1.0).abs() < 1e-10);
1058
1059 let c = vec![0.0, 1.0, 0.0];
1060 let similarity2 = cosine_similarity(&a, &c).unwrap();
1061 assert!((similarity2 - 0.0).abs() < 1e-10);
1062
1063 let d = vec![1.0, 0.0];
1065 assert!(cosine_similarity(&a, &d).is_err());
1066 }
1067
1068 #[test]
1069 fn test_generate_sample_kg_data() {
1070 let triples = generate_sample_kg_data(5, 3);
1071 assert!(!triples.is_empty());
1072
1073 for (subject, relation, object) in &triples {
1075 assert!(subject.starts_with("http://example.org/entity_"));
1076 assert!(relation.starts_with("http://example.org/relation_"));
1077 assert!(object.starts_with("http://example.org/entity_"));
1078 assert_ne!(subject, object); }
1080 }
1081
1082 #[test]
1083 fn test_quick_performance_test() {
1084 let duration = quick_performance_test("test_operation", 100, || {
1085 let _sum: i32 = (1..10).sum();
1087 });
1088
1089 let _nanos = duration.as_nanos();
1092 }
1093}