1#![allow(dead_code)]
122
123pub mod ab_testing;
124pub mod acceleration;
125pub mod adaptive_learning;
126pub mod advanced_profiler;
127pub mod alignment;
128#[cfg(feature = "api-server")]
129pub mod api;
130pub mod application_tasks;
131pub mod batch_processing;
132pub mod biomedical_embeddings;
133pub mod caching;
134pub mod causal_representation_learning;
135pub mod cloud_integration;
136pub mod clustering;
137pub mod community_detection;
138pub mod compression;
139pub mod contextual;
140pub mod continual_learning;
141pub mod cross_domain_transfer;
142pub mod cross_module_performance;
143pub mod delta;
144pub mod diffusion_embeddings;
145pub mod distributed_training;
146pub mod embed_compression;
147pub mod enterprise_knowledge;
148pub mod entity_linking;
149pub mod evaluation;
150pub mod federated_learning;
151pub mod fine_tuning;
152#[cfg(feature = "gpu")]
153pub mod gpu_acceleration;
154pub mod graph_models;
155pub mod graphql_api;
156pub mod inference;
157pub mod integration;
158pub mod interpretability;
159pub mod kg_completion;
160pub mod link_prediction;
161pub mod mamba_attention;
162pub mod mixed_precision;
163pub mod model_registry;
164pub mod model_selection;
165pub mod models;
166pub mod monitoring;
167pub mod multimodal;
168pub mod neural_symbolic_integration;
169pub mod neuro_evolution;
170pub mod novel_architectures;
171pub mod performance_profiler;
172pub mod persistence;
173pub mod quantization;
174pub mod real_time_fine_tuning;
175pub mod real_time_optimization;
176pub mod research_networks;
177pub mod sparql_extension;
179pub mod storage_backend;
180pub mod temporal_embeddings;
181pub mod training;
182pub mod training_online;
183pub mod utils;
184pub mod validation;
185pub mod vector_search;
186pub mod vision_language_graph;
187pub mod visualization;
188pub mod contrastive_learning;
190
191pub mod procrustes_alignment;
193
194pub mod embedding_cache;
196
197pub mod dimensionality_reducer;
199
200pub mod pca_reducer;
202
203pub mod fine_tuner;
205
206pub mod vector_store;
208
209pub mod cross_encoder;
211
212pub mod projection_layer;
214pub use projection_layer::{ActivationFn, InitMethod, ProjectionLayer, ProjectionMatrix};
215
216pub mod embedding_store;
218
219pub mod tokenizer;
221
222pub mod embedding_aggregator;
224
225pub mod reranker;
227
228pub mod index_optimizer;
230
231pub mod batch_encoder;
233
234pub mod ensemble;
236
237pub mod embedding_compressor;
239
240pub mod model_zoo;
242pub use model_zoo::{sha256_hex, ModelManifest, ModelZoo, ModelZooError, ModelZooLoader};
243
244pub use oxirs_vec::Vector as VecVector;
246
247pub use adaptive_learning::{
249 AdaptationMetrics, AdaptationStrategy, AdaptiveLearningConfig, AdaptiveLearningSystem,
250 QualityFeedback,
251};
252
253use anyhow::Result;
254use chrono::{DateTime, Utc};
255use serde::{Deserialize, Serialize};
256use std::collections::HashMap;
257use std::ops::{Add, Sub};
258use uuid::Uuid;
259
260#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
263pub struct Vector {
264 pub values: Vec<f32>,
265 pub dimensions: usize,
266 #[serde(skip)]
267 inner: Option<VecVector>,
268}
269
270impl Vector {
271 pub fn new(values: Vec<f32>) -> Self {
272 let dimensions = values.len();
273 Self {
274 values,
275 dimensions,
276 inner: None,
277 }
278 }
279
280 fn get_inner(&self) -> VecVector {
282 if let Some(ref inner) = self.inner {
284 inner.clone()
285 } else {
286 VecVector::new(self.values.clone())
287 }
288 }
289
290 fn sync_internal(&mut self) {
292 self.dimensions = self.values.len();
293 self.inner = None; }
295
296 pub fn from_array1(array: &scirs2_core::ndarray_ext::Array1<f32>) -> Self {
298 Self::new(array.to_vec())
299 }
300
301 pub fn to_array1(&self) -> scirs2_core::ndarray_ext::Array1<f32> {
303 scirs2_core::ndarray_ext::Array1::from_vec(self.values.clone())
304 }
305
306 pub fn mapv<F>(&self, f: F) -> Self
308 where
309 F: Fn(f32) -> f32,
310 {
311 Self::new(self.values.iter().copied().map(f).collect())
312 }
313
314 pub fn sum(&self) -> f32 {
316 self.values.iter().sum()
317 }
318
319 pub fn sqrt(&self) -> f32 {
321 self.sum().sqrt()
322 }
323
324 pub fn inner(&self) -> VecVector {
326 self.get_inner()
327 }
328
329 pub fn into_inner(self) -> VecVector {
331 self.inner.unwrap_or_else(|| VecVector::new(self.values))
332 }
333
334 pub fn from_vec_vector(vec_vector: VecVector) -> Self {
336 let values = vec_vector.as_f32().to_vec();
337 let dimensions = values.len();
338 Self {
339 values,
340 dimensions,
341 inner: Some(vec_vector),
342 }
343 }
344
345 pub fn with_capacity(capacity: usize) -> Self {
347 Self {
348 values: Vec::with_capacity(capacity),
349 dimensions: 0,
350 inner: None,
351 }
352 }
353
354 pub fn extend_optimized(&mut self, other_values: &[f32]) {
356 self.values.reserve(other_values.len());
358 self.values.extend_from_slice(other_values);
359 self.sync_internal();
360 }
361
362 pub fn shrink_to_fit(&mut self) {
364 self.values.shrink_to_fit();
365 self.sync_internal();
366 }
367
368 pub fn memory_usage(&self) -> usize {
370 self.values.capacity() * std::mem::size_of::<f32>() + std::mem::size_of::<Self>()
371 }
372}
373
374impl Add for &Vector {
376 type Output = Vector;
377
378 fn add(self, other: &Vector) -> Vector {
379 if let (Some(self_inner), Some(other_inner)) = (&self.inner, &other.inner) {
381 if let Ok(result) = self_inner.add(other_inner) {
382 return Vector::from_vec_vector(result);
383 }
384 }
385 assert_eq!(
387 self.values.len(),
388 other.values.len(),
389 "Vector dimensions must match"
390 );
391 let result_values: Vec<f32> = self
392 .values
393 .iter()
394 .zip(other.values.iter())
395 .map(|(a, b)| a + b)
396 .collect();
397 Vector::new(result_values)
398 }
399}
400
401impl Sub for &Vector {
402 type Output = Vector;
403
404 fn sub(self, other: &Vector) -> Vector {
405 if let (Some(self_inner), Some(other_inner)) = (&self.inner, &other.inner) {
407 if let Ok(result) = self_inner.subtract(other_inner) {
408 return Vector::from_vec_vector(result);
409 }
410 }
411 assert_eq!(
413 self.values.len(),
414 other.values.len(),
415 "Vector dimensions must match"
416 );
417 let result_values: Vec<f32> = self
418 .values
419 .iter()
420 .zip(other.values.iter())
421 .map(|(a, b)| a - b)
422 .collect();
423 Vector::new(result_values)
424 }
425}
426
427#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
429pub struct Triple {
430 pub subject: NamedNode,
431 pub predicate: NamedNode,
432 pub object: NamedNode,
433}
434
435impl Triple {
436 pub fn new(subject: NamedNode, predicate: NamedNode, object: NamedNode) -> Self {
437 Self {
438 subject,
439 predicate,
440 object,
441 }
442 }
443}
444
445#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
447pub struct NamedNode {
448 pub iri: String,
449}
450
451impl NamedNode {
452 pub fn new(iri: &str) -> Result<Self> {
453 Ok(Self {
454 iri: iri.to_string(),
455 })
456 }
457}
458
459impl std::fmt::Display for NamedNode {
460 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
461 write!(f, "{}", self.iri)
462 }
463}
464
465#[derive(Debug, Clone, Serialize, Deserialize)]
467pub struct ModelConfig {
468 pub dimensions: usize,
469 pub learning_rate: f64,
470 pub l2_reg: f64,
471 pub max_epochs: usize,
472 pub batch_size: usize,
473 pub negative_samples: usize,
474 pub seed: Option<u64>,
475 pub use_gpu: bool,
476 pub model_params: HashMap<String, f64>,
477}
478
479impl Default for ModelConfig {
480 fn default() -> Self {
481 Self {
482 dimensions: 100,
483 learning_rate: 0.01,
484 l2_reg: 0.0001,
485 max_epochs: 1000,
486 batch_size: 1000,
487 negative_samples: 10,
488 seed: None,
489 use_gpu: false,
490 model_params: HashMap::new(),
491 }
492 }
493}
494
495impl ModelConfig {
496 pub fn with_dimensions(mut self, dimensions: usize) -> Self {
497 self.dimensions = dimensions;
498 self
499 }
500
501 pub fn with_learning_rate(mut self, learning_rate: f64) -> Self {
502 self.learning_rate = learning_rate;
503 self
504 }
505
506 pub fn with_max_epochs(mut self, max_epochs: usize) -> Self {
507 self.max_epochs = max_epochs;
508 self
509 }
510
511 pub fn with_seed(mut self, seed: u64) -> Self {
512 self.seed = Some(seed);
513 self
514 }
515
516 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
517 self.batch_size = batch_size;
518 self
519 }
520}
521
522#[derive(Debug, Clone, Serialize, Deserialize)]
524pub struct TrainingStats {
525 pub epochs_completed: usize,
526 pub final_loss: f64,
527 pub training_time_seconds: f64,
528 pub convergence_achieved: bool,
529 pub loss_history: Vec<f64>,
530}
531
532#[derive(Debug, Clone, Serialize, Deserialize)]
534pub struct ModelStats {
535 pub num_entities: usize,
536 pub num_relations: usize,
537 pub num_triples: usize,
538 pub dimensions: usize,
539 pub is_trained: bool,
540 pub model_type: String,
541 pub creation_time: DateTime<Utc>,
542 pub last_training_time: Option<DateTime<Utc>>,
543}
544
545impl Default for ModelStats {
546 fn default() -> Self {
547 Self {
548 num_entities: 0,
549 num_relations: 0,
550 num_triples: 0,
551 dimensions: 0,
552 is_trained: false,
553 model_type: "unknown".to_string(),
554 creation_time: Utc::now(),
555 last_training_time: None,
556 }
557 }
558}
559
560#[derive(Debug, thiserror::Error)]
562pub enum EmbeddingError {
563 #[error("Model not trained")]
564 ModelNotTrained,
565 #[error("Entity not found: {entity}")]
566 EntityNotFound { entity: String },
567 #[error("Relation not found: {relation}")]
568 RelationNotFound { relation: String },
569 #[error("Other error: {0}")]
570 Other(#[from] anyhow::Error),
571}
572
573#[async_trait::async_trait]
575pub trait EmbeddingModel: Send + Sync {
576 fn config(&self) -> &ModelConfig;
577 fn model_id(&self) -> &Uuid;
578 fn model_type(&self) -> &'static str;
579 fn add_triple(&mut self, triple: Triple) -> Result<()>;
580 async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats>;
581 fn get_entity_embedding(&self, entity: &str) -> Result<Vector>;
582 fn get_relation_embedding(&self, relation: &str) -> Result<Vector>;
583 fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64>;
584 fn predict_objects(
585 &self,
586 subject: &str,
587 predicate: &str,
588 k: usize,
589 ) -> Result<Vec<(String, f64)>>;
590 fn predict_subjects(
591 &self,
592 predicate: &str,
593 object: &str,
594 k: usize,
595 ) -> Result<Vec<(String, f64)>>;
596 fn predict_relations(
597 &self,
598 subject: &str,
599 object: &str,
600 k: usize,
601 ) -> Result<Vec<(String, f64)>>;
602 fn get_entities(&self) -> Vec<String>;
603 fn get_relations(&self) -> Vec<String>;
604 fn get_stats(&self) -> ModelStats;
605 fn save(&self, path: &str) -> Result<()>;
606 fn load(&mut self, path: &str) -> Result<()>;
607 fn clear(&mut self);
608 fn is_trained(&self) -> bool;
609
610 async fn encode(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
612}
613
614pub use acceleration::{AdaptiveEmbeddingAccelerator, GpuEmbeddingAccelerator};
616#[cfg(feature = "api-server")]
617pub use api::{start_server, ApiConfig, ApiState};
618pub use batch_processing::{
619 BatchJob, BatchProcessingConfig, BatchProcessingManager, BatchProcessingResult,
620 BatchProcessingStats, IncrementalConfig, JobProgress, JobStatus, OutputFormat,
621 PartitioningStrategy, RetryConfig,
622};
623pub use biomedical_embeddings::{
624 BiomedicalEmbedding, BiomedicalEmbeddingConfig, BiomedicalEntityType, BiomedicalRelationType,
625 FineTuningConfig, PreprocessingRule, SpecializedTextConfig, SpecializedTextEmbedding,
626 SpecializedTextModel,
627};
628pub use caching::{CacheConfig, CacheManager, CachedEmbeddingModel};
629pub use causal_representation_learning::{
630 CausalDiscoveryAlgorithm, CausalDiscoveryConfig, CausalGraph, CausalRepresentationConfig,
631 CausalRepresentationModel, ConstraintSettings, CounterfactualConfig, CounterfactualQuery,
632 DisentanglementConfig, DisentanglementMethod, ExplanationType, IndependenceTest,
633 InterventionConfig, ScoreSettings, StructuralCausalModelConfig,
634};
635pub use cloud_integration::{
636 AWSSageMakerService, AutoScalingConfig, AzureMLService, BackupConfig, CloudIntegrationConfig,
637 CloudIntegrationManager, CloudProvider, CloudService, ClusterStatus, CostEstimate,
638 CostOptimizationResult, CostOptimizationStrategy, DeploymentConfig, DeploymentResult,
639 DeploymentStatus, EndpointInfo, FunctionInvocationResult, GPUClusterConfig, GPUClusterResult,
640 LifecyclePolicy, OptimizationAction, PerformanceTier, ReplicationType,
641 ServerlessDeploymentResult, ServerlessFunctionConfig, ServerlessStatus, StorageConfig,
642 StorageResult, StorageStatus, StorageType,
643};
644pub use compression::{
645 CompressedModel, CompressionStats, CompressionTarget, DistillationConfig,
646 ModelCompressionManager, NASConfig, OptimizationTarget, PruningConfig, PruningMethod,
647 QuantizationConfig, QuantizationMethod,
648};
649pub use continual_learning::{
653 ArchitectureConfig, BoundaryDetection, ConsolidationConfig, ContinualLearningConfig,
654 ContinualLearningModel, MemoryConfig, MemoryType, MemoryUpdateStrategy, RegularizationConfig,
655 ReplayConfig, ReplayMethod, TaskConfig, TaskDetection, TaskSwitching,
656};
657pub use cross_module_performance::{
658 CoordinatorConfig, CrossModulePerformanceCoordinator, GlobalPerformanceMetrics, ModuleMetrics,
659 ModulePerformanceMonitor, OptimizationCache, PerformanceSnapshot, PredictivePerformanceEngine,
660 ResourceAllocator, ResourceTracker,
661};
662pub use delta::{
663 ChangeRecord, ChangeStatistics, ChangeType, DeltaConfig, DeltaManager, DeltaResult, DeltaStats,
664 IncrementalStrategy,
665};
666pub use enterprise_knowledge::{
667 BehaviorMetrics, CareerPredictions, Category, CategoryHierarchy, CategoryPerformance,
668 ColdStartStrategy, CommunicationFrequency, CommunicationPreferences, CustomerEmbedding,
669 CustomerPreferences, CustomerRatings, CustomerSegment, Department, DepartmentPerformance,
670 EmployeeEmbedding, EnterpriseConfig, EnterpriseKnowledgeAnalyzer, EnterpriseMetrics,
671 ExperienceLevel, FeatureType, MarketAnalysis, OrganizationalStructure,
672 PerformanceMetrics as EnterprisePerformanceMetrics, ProductAvailability, ProductEmbedding,
673 ProductFeature, ProductRecommendation, Project, ProjectOutcome, ProjectParticipation,
674 ProjectPerformance, ProjectStatus, Purchase, PurchaseChannel, RecommendationConfig,
675 RecommendationEngine, RecommendationEngineType, RecommendationPerformance,
676 RecommendationReason, SalesMetrics, Skill, SkillCategory, Team, TeamPerformance,
677};
678pub use evaluation::{
679 AnalogicalReasoningBenchmark, AnalogyQuad, EmbeddingClusteringMetrics, EmbeddingEvaluator,
680 QueryAnsweringEvaluator, QueryEvaluationConfig, QueryEvaluationResults, QueryMetric,
681 QueryResult, QueryTemplate, QueryType, ReasoningChain, ReasoningEvaluationConfig,
682 ReasoningEvaluationResults, ReasoningRule, ReasoningStep, ReasoningTaskEvaluator,
683 ReasoningType, TypeSpecificResults,
684};
685pub use federated_learning::{
686 AggregationEngine, AggregationStrategy, AuthenticationConfig, AuthenticationMethod,
687 CertificateConfig, ClippingMechanisms, ClippingMethod, CommunicationConfig,
688 CommunicationManager, CommunicationProtocol, CompressionAlgorithm, CompressionConfig,
689 CompressionEngine, ConvergenceMetrics, ConvergenceStatus, DataSelectionStrategy,
690 DataStatistics, EncryptionScheme, FederatedConfig, FederatedCoordinator,
691 FederatedEmbeddingModel, FederatedMessage, FederatedRound, FederationStats, GlobalModelState,
692 HardwareAccelerator, KeyManager, LocalModelState, LocalTrainingStats, LocalUpdate,
693 MetaLearningConfig, NoiseGenerator, NoiseMechanism, OutlierAction, OutlierDetection,
694 OutlierDetectionMethod, Participant, ParticipantCapabilities, ParticipantStatus,
695 PersonalizationConfig, PersonalizationStrategy, PrivacyAccountant, PrivacyConfig,
696 PrivacyEngine, PrivacyMetrics, PrivacyParams, RoundMetrics, RoundStatus, SecurityConfig,
697 SecurityFeature, SecurityManager, TrainingConfig, VerificationEngine, VerificationMechanism,
698 VerificationResult, WeightingScheme,
699};
700#[cfg(feature = "gpu")]
701pub use gpu_acceleration::{
702 GpuAccelerationConfig, GpuAccelerationManager, GpuMemoryPool, GpuPerformanceStats,
703 MixedPrecisionProcessor, MultiStreamProcessor, TensorCache,
704};
705pub use graphql_api::{
706 create_schema, BatchEmbeddingInput, BatchEmbeddingResult, BatchStatus, DistanceMetric,
707 EmbeddingFormat, EmbeddingQueryInput, EmbeddingResult, EmbeddingSchema, GraphQLContext,
708 ModelInfo, ModelType, SimilarityResult, SimilaritySearchInput,
709};
710pub use kg_completion::{BatchedTrainingLoop, KgCompletionTask, NegativeSampler, TrainingBatch};
711pub use models::{
712 AggregationType, ComplEx, DistMult, GNNConfig, GNNEmbedding, GNNType, HoLE, HoLEConfig,
713 PoolingStrategy, RotatE, TransE, TransformerConfig, TransformerEmbedding, TransformerType,
714};
715
716pub use contextual::{
717 AccessibilityPreferences, ComplexityLevel, ContextualConfig, ContextualEmbeddingModel,
718 DomainContext, EmbeddingContext, PerformanceRequirements, PriorityLevel, PrivacySettings,
719 QueryContext, QueryType as ContextualQueryType, ResponseFormat, TaskConstraints, TaskContext,
720 TaskType, UserContext, UserHistory, UserPreferences,
721};
722pub use distributed_training::{
723 AggregationMethod, AllReduceStrategy, CommunicationBackend, DataParallelTrainer,
724 DistributedEmbeddingTrainer, DistributedStrategy, DistributedTrainingConfig,
725 DistributedTrainingCoordinator, DistributedTrainingSample, DistributedTrainingStats,
726 FaultToleranceConfig, GradientAggregator, GradientCompressor, ModelShardManager, ModelUpdate,
727 ParameterServer, ParameterServerConfig, ParameterServerStats, ShardAssignment, ShardSnapshot,
728 ShardingStrategy, SparseGradient, TripleSample, UpdateMode, Worker, WorkerConfig, WorkerInfo,
729 WorkerLoss, WorkerStatus, WorkerUpdate,
730};
731#[cfg(feature = "conve")]
732pub use models::{ConvE, ConvEConfig};
733pub use monitoring::{
734 Alert, AlertSeverity, AlertThresholds, AlertType, CacheMetrics, ConsoleAlertHandler,
735 DriftMetrics, ErrorEvent, ErrorMetrics, ErrorSeverity, LatencyMetrics, MonitoringConfig,
736 PerformanceMetrics as MonitoringPerformanceMetrics, PerformanceMonitor, QualityAssessment,
737 QualityMetrics, ResourceMetrics, SlackAlertHandler, ThroughputMetrics,
738};
739pub use multimodal::{
740 AlignmentNetwork, AlignmentObjective, ContrastiveConfig, CrossDomainConfig, CrossModalConfig,
741 KGEncoder, MultiModalEmbedding, MultiModalStats, TextEncoder,
742};
743pub use neural_symbolic_integration::{
744 ConstraintSatisfactionConfig, ConstraintType, KnowledgeIntegrationConfig, KnowledgeRule,
745 LogicIntegrationConfig, LogicProgrammingConfig, LogicalFormula, NeuralSymbolicConfig,
746 NeuralSymbolicModel, NeuroSymbolicArchitectureConfig, OntologicalConfig, ReasoningEngine,
747 RuleBasedConfig, SymbolicReasoningConfig,
748};
749pub use novel_architectures::{
750 ActivationType, ArchitectureParams, ArchitectureState, ArchitectureType, CurvatureComputation,
751 CurvatureMethod, CurvatureType, DynamicsConfig, EntanglementStructure, EquivarianceGroup,
752 FlowType, GeometricConfig, GeometricParams, GeometricSpace, GeometricState,
753 GraphTransformerParams, GraphTransformerState, HyperbolicDistance, HyperbolicInit,
754 HyperbolicManifold, HyperbolicParams, HyperbolicState, IntegrationScheme, IntegrationStats,
755 ManifoldLearning, ManifoldMethod, ManifoldOptimizer, NeuralODEParams, NeuralODEState,
756 NovelArchitectureConfig, NovelArchitectureModel, ODERegularization, ODESolverType,
757 ParallelTransport, QuantumGateSet, QuantumMeasurement, QuantumNoise, QuantumParams,
758 QuantumState, StabilityConstraints, StructuralBias, TimeEvolution, TransportMethod,
759};
760pub use research_networks::{
761 AuthorEmbedding, Citation, CitationNetwork, CitationType, Collaboration, CollaborationNetwork,
762 NetworkMetrics, PaperSection, PublicationEmbedding, PublicationType, ResearchCommunity,
763 ResearchNetworkAnalyzer, ResearchNetworkConfig, TopicModel, TopicModelingConfig,
764};
765pub use sparql_extension::{
766 ExpandedQuery, Expansion, ExpansionType, QueryStatistics as SparqlQueryStatistics,
767 SparqlExtension, SparqlExtensionConfig,
768};
769pub use storage_backend::{
770 DiskBackend, EmbeddingMetadata, EmbeddingVersion, MemoryBackend, StorageBackend,
771 StorageBackendConfig, StorageBackendManager, StorageBackendType, StorageStats,
772};
773pub use temporal_embeddings::{
774 TemporalEmbeddingConfig, TemporalEmbeddingModel, TemporalEvent, TemporalForecast,
775 TemporalGranularity, TemporalScope, TemporalStats, TemporalTriple,
776};
777pub use vision_language_graph::{
778 AggregationFunction, CNNConfig, CrossAttentionConfig, DomainAdaptationConfig,
779 DomainAdaptationMethod, EpisodeConfig, FewShotConfig, FewShotMethod, FusionStrategy,
780 GraphArchitecture, GraphEncoder, GraphEncoderConfig, JointTrainingConfig, LanguageArchitecture,
781 LanguageEncoder, LanguageEncoderConfig, LanguageTransformerConfig, MetaLearner,
782 ModalityEncoding, MultiModalTransformer, MultiModalTransformerConfig, NormalizationType,
783 PoolingType, PositionEncodingType, ReadoutFunction, TaskCategory, TaskSpecificParams,
784 TrainingObjective, TransferLearningConfig, TransferStrategy, ViTConfig, VisionArchitecture,
785 VisionEncoder, VisionEncoderConfig, VisionLanguageGraphConfig, VisionLanguageGraphModel,
786 VisionLanguageGraphStats, ZeroShotConfig, ZeroShotMethod,
787};
788
789#[cfg(feature = "tucker")]
790pub use models::TuckER;
791
792#[cfg(feature = "quatd")]
793pub use models::QuatD;
794
795pub use crate::model_registry::{
797 ModelRegistry, ModelVersion, ResourceAllocation as ModelResourceAllocation,
798};
799
800pub use crate::model_selection::{
802 DatasetCharacteristics, MemoryRequirement, ModelComparison, ModelComparisonEntry,
803 ModelRecommendation, ModelSelector, ModelType as SelectionModelType, TrainingTime, UseCaseType,
804};
805
806pub use crate::performance_profiler::{
808 OperationStats, OperationTimer, OperationType, PerformanceProfiler, PerformanceReport,
809};
810
811pub mod quick_start {
825 use super::*;
826 use crate::models::TransE;
827
828 pub fn create_simple_transe_model() -> TransE {
830 let config = ModelConfig::default()
831 .with_dimensions(128)
832 .with_learning_rate(0.01)
833 .with_max_epochs(100);
834 TransE::new(config)
835 }
836
837 pub fn create_biomedical_model() -> BiomedicalEmbedding {
839 let config = BiomedicalEmbeddingConfig::default();
840 BiomedicalEmbedding::new(config)
841 }
842
843 pub fn parse_triple_from_string(triple_str: &str) -> Result<Triple> {
845 let parts: Vec<&str> = triple_str.split_whitespace().collect();
846 if parts.len() != 3 {
847 return Err(anyhow::anyhow!(
848 "Triple must have exactly 3 parts separated by spaces"
849 ));
850 }
851
852 let expand_uri = |s: &str| -> String {
854 if s.starts_with("http://") || s.starts_with("https://") {
855 s.to_string()
856 } else {
857 format!("http://example.org/{s}")
858 }
859 };
860
861 Ok(Triple::new(
862 NamedNode::new(&expand_uri(parts[0]))?,
863 NamedNode::new(&expand_uri(parts[1]))?,
864 NamedNode::new(&expand_uri(parts[2]))?,
865 ))
866 }
867
868 pub fn add_triples_from_strings<T: EmbeddingModel>(
870 model: &mut T,
871 triple_strings: &[&str],
872 ) -> Result<usize> {
873 let mut count = 0;
874 for triple_str in triple_strings {
875 let triple = parse_triple_from_string(triple_str)?;
876 model.add_triple(triple)?;
877 count += 1;
878 }
879 Ok(count)
880 }
881
882 pub fn cosine_similarity(a: &[f64], b: &[f64]) -> Result<f64> {
884 if a.len() != b.len() {
885 return Err(anyhow::anyhow!(
886 "Vector dimensions don't match: {} vs {}",
887 a.len(),
888 b.len()
889 ));
890 }
891
892 let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
893 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
894 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
895
896 if norm_a == 0.0 || norm_b == 0.0 {
897 return Ok(0.0);
898 }
899
900 Ok(dot_product / (norm_a * norm_b))
901 }
902
903 pub fn generate_sample_kg_data(
905 num_entities: usize,
906 num_relations: usize,
907 ) -> Vec<(String, String, String)> {
908 #[allow(unused_imports)]
909 use scirs2_core::random::{Random, RngExt};
910
911 let mut random = Random::default();
912 let mut triples = Vec::new();
913
914 let entities: Vec<String> = (0..num_entities)
915 .map(|i| format!("http://example.org/entity_{i}"))
916 .collect();
917
918 let relations: Vec<String> = (0..num_relations)
919 .map(|i| format!("http://example.org/relation_{i}"))
920 .collect();
921
922 for _ in 0..(num_entities * 2) {
924 let subject_idx = random.random_range(0..entities.len());
925 let relation_idx = random.random_range(0..relations.len());
926 let object_idx = random.random_range(0..entities.len());
927
928 let subject = entities[subject_idx].clone();
929 let relation = relations[relation_idx].clone();
930 let object = entities[object_idx].clone();
931
932 if subject != object {
933 triples.push((subject, relation, object));
934 }
935 }
936
937 triples
938 }
939
940 pub fn quick_performance_test<F>(
942 name: &str,
943 iterations: usize,
944 operation: F,
945 ) -> std::time::Duration
946 where
947 F: Fn(),
948 {
949 let start = std::time::Instant::now();
950 for _ in 0..iterations {
951 operation();
952 }
953 let duration = start.elapsed();
954
955 println!(
956 "Performance test '{name}': {iterations} iterations in {duration:?} ({:.2} ops/sec)",
957 iterations as f64 / duration.as_secs_f64()
958 );
959
960 duration
961 }
962
963 pub async fn quick_revolutionary_performance_test<F, Fut>(
996 name: &str,
997 iterations: usize,
998 async_operation: F,
999 ) -> std::time::Duration
1000 where
1001 F: Fn() -> Fut,
1002 Fut: std::future::Future<Output = ()>,
1003 {
1004 let start = std::time::Instant::now();
1005 for _ in 0..iterations {
1006 async_operation().await;
1007 }
1008 let duration = start.elapsed();
1009
1010 println!(
1011 "Revolutionary performance test '{name}': {iterations} iterations in {duration:?} ({:.2} ops/sec)",
1012 iterations as f64 / duration.as_secs_f64()
1013 );
1014
1015 duration
1016 }
1017}
1018
1019#[cfg(test)]
1020mod quick_start_tests {
1021 use super::*;
1022 use crate::quick_start::*;
1023
1024 #[test]
1025 fn test_create_simple_transe_model() {
1026 let model = create_simple_transe_model();
1027 let config = model.config();
1028 assert_eq!(config.dimensions, 128);
1029 assert_eq!(config.learning_rate, 0.01);
1030 assert_eq!(config.max_epochs, 100);
1031 }
1032
1033 #[test]
1034 fn test_parse_triple_from_string() {
1035 let triple_str = "http://example.org/alice http://example.org/knows http://example.org/bob";
1036 let triple = parse_triple_from_string(triple_str).expect("should succeed");
1037 assert_eq!(triple.subject.iri, "http://example.org/alice");
1038 assert_eq!(triple.predicate.iri, "http://example.org/knows");
1039 assert_eq!(triple.object.iri, "http://example.org/bob");
1040 }
1041
1042 #[test]
1043 fn test_parse_triple_from_string_invalid() {
1044 let triple_str = "http://example.org/alice http://example.org/knows";
1045 let result = parse_triple_from_string(triple_str);
1046 assert!(result.is_err());
1047 }
1048
1049 #[test]
1050 fn test_add_triples_from_strings() {
1051 let mut model = create_simple_transe_model();
1052 let triple_strings = [
1053 "http://example.org/alice http://example.org/knows http://example.org/bob",
1054 "http://example.org/bob http://example.org/likes http://example.org/music",
1055 ];
1056
1057 let count = add_triples_from_strings(&mut model, &triple_strings).expect("should succeed");
1058 assert_eq!(count, 2);
1059 }
1060
1061 #[test]
1062 fn test_cosine_similarity() {
1063 let a = vec![1.0, 0.0, 0.0];
1064 let b = vec![1.0, 0.0, 0.0];
1065 let similarity = cosine_similarity(&a, &b).expect("should succeed");
1066 assert!((similarity - 1.0).abs() < 1e-10);
1067
1068 let c = vec![0.0, 1.0, 0.0];
1069 let similarity2 = cosine_similarity(&a, &c).expect("should succeed");
1070 assert!((similarity2 - 0.0).abs() < 1e-10);
1071
1072 let d = vec![1.0, 0.0];
1074 assert!(cosine_similarity(&a, &d).is_err());
1075 }
1076
1077 #[test]
1078 fn test_generate_sample_kg_data() {
1079 let triples = generate_sample_kg_data(5, 3);
1080 assert!(!triples.is_empty());
1081
1082 for (subject, relation, object) in &triples {
1084 assert!(subject.starts_with("http://example.org/entity_"));
1085 assert!(relation.starts_with("http://example.org/relation_"));
1086 assert!(object.starts_with("http://example.org/entity_"));
1087 assert_ne!(subject, object); }
1089 }
1090
1091 #[test]
1092 fn test_quick_performance_test() {
1093 let duration = quick_performance_test("test_operation", 100, || {
1094 let _sum: i32 = (1..10).sum();
1096 });
1097
1098 let _nanos = duration.as_nanos();
1101 }
1102}