Skip to main content

oxirs_embed/
novel_arch_types.rs

1//! Configuration and state types for novel embedding architectures
2//!
3//! Contains every config struct, enum, parameter container and runtime state
4//! shared by the [`NovelArchitectureModel`](crate::novel_arch_impl::NovelArchitectureModel)
5//! implementation: graph transformers, neural ODEs, hyperbolic embeddings,
6//! geometric deep learning, quantum-inspired layers and continuous flows.
7
8use crate::{ModelConfig, TrainingStats};
9use scirs2_core::ndarray_ext::{Array1, Array2, Array3};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use uuid::Uuid;
13
14/// Configuration for novel architectures
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct NovelArchitectureConfig {
17    pub base_config: ModelConfig,
18    /// Architecture type
19    pub architecture: ArchitectureType,
20    /// Specialized parameters per architecture
21    pub architecture_params: ArchitectureParams,
22    /// Training dynamics configuration
23    pub dynamics_config: DynamicsConfig,
24    /// Geometric learning settings
25    pub geometric_config: GeometricConfig,
26}
27
28impl Default for NovelArchitectureConfig {
29    fn default() -> Self {
30        Self {
31            base_config: ModelConfig::default(),
32            architecture: ArchitectureType::GraphTransformer,
33            architecture_params: ArchitectureParams::default(),
34            dynamics_config: DynamicsConfig::default(),
35            geometric_config: GeometricConfig::default(),
36        }
37    }
38}
39
40/// Types of novel architectures
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub enum ArchitectureType {
43    /// Graph Transformer with structural attention
44    GraphTransformer,
45    /// Neural ODE for continuous dynamics
46    NeuralODE,
47    /// Hyperbolic embeddings for hierarchical structures
48    HyperbolicEmbedding,
49    /// Geometric deep learning on manifolds
50    GeometricDeepLearning,
51    /// Quantum-inspired embedding methods
52    QuantumInspired,
53    /// Continuous normalizing flows
54    ContinuousNormalizingFlow,
55}
56
57/// Architecture-specific parameters
58#[derive(Debug, Clone, Serialize, Deserialize, Default)]
59pub struct ArchitectureParams {
60    /// Graph Transformer parameters
61    pub transformer_params: GraphTransformerParams,
62    /// Neural ODE parameters
63    pub ode_params: NeuralODEParams,
64    /// Hyperbolic parameters
65    pub hyperbolic_params: HyperbolicParams,
66    /// Geometric parameters
67    pub geometric_params: GeometricParams,
68    /// Quantum parameters
69    pub quantum_params: QuantumParams,
70}
71
72/// Graph Transformer configuration
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct GraphTransformerParams {
75    /// Number of attention heads
76    pub num_heads: usize,
77    /// Number of transformer layers
78    pub num_layers: usize,
79    /// Attention dimension
80    pub attention_dim: usize,
81    /// Feed-forward dimension
82    pub ff_dim: usize,
83    /// Structural encoding dimension
84    pub structural_dim: usize,
85    /// Use positional encoding
86    pub use_positional_encoding: bool,
87    /// Attention mechanism
88    pub attention_mechanism: AttentionMechanism,
89    /// Structural bias type
90    pub structural_bias: StructuralBias,
91}
92
93impl Default for GraphTransformerParams {
94    fn default() -> Self {
95        Self {
96            num_heads: 8,
97            num_layers: 6,
98            attention_dim: 512,
99            ff_dim: 2048,
100            structural_dim: 128,
101            use_positional_encoding: true,
102            attention_mechanism: AttentionMechanism::SparseAttention,
103            structural_bias: StructuralBias::SpectralFeatures,
104        }
105    }
106}
107
108/// Attention mechanisms for Graph Transformers
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub enum AttentionMechanism {
111    /// Standard multi-head attention
112    MultiHeadAttention,
113    /// Sparse attention for large graphs
114    SparseAttention,
115    /// Linear attention for efficiency
116    LinearAttention,
117    /// Performer-style attention
118    PerformerAttention,
119    /// Graph-aware attention
120    GraphAwareAttention,
121}
122
123/// Structural bias types
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub enum StructuralBias {
126    /// Spectral features from graph Laplacian
127    SpectralFeatures,
128    /// Shortest path distances
129    ShortestPath,
130    /// Random walk features
131    RandomWalk,
132    /// Centrality measures
133    CentralityMeasures,
134    /// Graph motif features
135    GraphMotifs,
136}
137
138/// Neural ODE configuration
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct NeuralODEParams {
141    /// ODE solver type
142    pub solver_type: ODESolverType,
143    /// Integration time steps
144    pub time_steps: usize,
145    /// Tolerance for adaptive solvers
146    pub tolerance: f64,
147    /// Hidden dimensions for ODE function
148    pub hidden_dims: Vec<usize>,
149    /// Activation function
150    pub activation: ActivationType,
151    /// Adjoint method for backprop
152    pub use_adjoint: bool,
153    /// Regularization type
154    pub regularization: ODERegularization,
155}
156
157impl Default for NeuralODEParams {
158    fn default() -> Self {
159        Self {
160            solver_type: ODESolverType::DormandPrince,
161            time_steps: 100,
162            tolerance: 1e-6,
163            hidden_dims: vec![512, 256, 128],
164            activation: ActivationType::Swish,
165            use_adjoint: true,
166            regularization: ODERegularization::None,
167        }
168    }
169}
170
171/// ODE solver types
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub enum ODESolverType {
174    /// Euler method
175    Euler,
176    /// Runge-Kutta 4th order
177    RungeKutta4,
178    /// Dormand-Prince adaptive method
179    DormandPrince,
180    /// Adams-Bashforth
181    AdamsBashforth,
182    /// Implicit methods
183    BackwardEuler,
184}
185
186/// ODE regularization techniques
187#[derive(Debug, Clone, Serialize, Deserialize)]
188pub enum ODERegularization {
189    None,
190    /// Kinetic energy regularization
191    KineticEnergy,
192    /// Jacobian regularization
193    JacobianFrobenius,
194    /// Spectral normalization
195    SpectralNormalization,
196}
197
198/// Activation types for neural networks
199#[derive(Debug, Clone, Serialize, Deserialize)]
200pub enum ActivationType {
201    ReLU,
202    Swish,
203    Mish,
204    GELU,
205    ELU,
206    LeakyReLU,
207    Tanh,
208}
209
210/// Hyperbolic embedding configuration
211#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct HyperbolicParams {
213    /// Hyperbolic manifold type
214    pub manifold: HyperbolicManifold,
215    /// Curvature parameter
216    pub curvature: f64,
217    /// Manifold dimension
218    pub manifold_dim: usize,
219    /// Optimization method on manifold
220    pub optimizer: ManifoldOptimizer,
221    /// Distance function
222    pub distance_function: HyperbolicDistance,
223    /// Initialization strategy
224    pub initialization: HyperbolicInit,
225}
226
227impl Default for HyperbolicParams {
228    fn default() -> Self {
229        Self {
230            manifold: HyperbolicManifold::Poincare,
231            curvature: -1.0,
232            manifold_dim: 128,
233            optimizer: ManifoldOptimizer::RiemannianAdam,
234            distance_function: HyperbolicDistance::Poincare,
235            initialization: HyperbolicInit::RandomNormal,
236        }
237    }
238}
239
240/// Hyperbolic manifold types
241#[derive(Debug, Clone, Serialize, Deserialize)]
242pub enum HyperbolicManifold {
243    /// Poincaré ball model
244    Poincare,
245    /// Klein model
246    Klein,
247    /// Hyperboloid model
248    Hyperboloid,
249    /// Upper half-space model
250    UpperHalfSpace,
251}
252
253/// Manifold optimizers
254#[derive(Debug, Clone, Serialize, Deserialize)]
255pub enum ManifoldOptimizer {
256    /// Riemannian SGD
257    RiemannianSGD,
258    /// Riemannian Adam
259    RiemannianAdam,
260    /// Riemannian AdaGrad
261    RiemannianAdaGrad,
262    /// Exponential map based
263    ExponentialMap,
264}
265
266/// Hyperbolic distance functions
267#[derive(Debug, Clone, Serialize, Deserialize)]
268pub enum HyperbolicDistance {
269    /// Poincaré distance
270    Poincare,
271    /// Hyperbolic distance in hyperboloid model
272    Hyperboloid,
273    /// Geodesic distance
274    Geodesic,
275}
276
277/// Hyperbolic initialization strategies
278#[derive(Debug, Clone, Serialize, Deserialize)]
279pub enum HyperbolicInit {
280    /// Random normal initialization
281    RandomNormal,
282    /// Wrapped normal distribution
283    WrappedNormal,
284    /// Uniform on hyperbolic space
285    UniformHyperbolic,
286    /// Tree-based initialization
287    TreeBased,
288}
289
290/// Geometric deep learning parameters
291#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct GeometricParams {
293    /// Geometric space type
294    pub space_type: GeometricSpace,
295    /// Equivariance groups
296    pub equivariance_groups: Vec<EquivarianceGroup>,
297    /// Gauge equivariant layers
298    pub use_gauge_equivariance: bool,
299    /// Fiber bundle dimension
300    pub fiber_dim: usize,
301    /// Connection learning
302    pub learn_connection: bool,
303    /// Curvature regularization
304    pub curvature_regularization: f64,
305}
306
307impl Default for GeometricParams {
308    fn default() -> Self {
309        Self {
310            space_type: GeometricSpace::RiemannianManifold,
311            equivariance_groups: vec![EquivarianceGroup::SO3, EquivarianceGroup::SE3],
312            use_gauge_equivariance: true,
313            fiber_dim: 64,
314            learn_connection: true,
315            curvature_regularization: 0.01,
316        }
317    }
318}
319
320/// Geometric space types
321#[derive(Debug, Clone, Serialize, Deserialize)]
322pub enum GeometricSpace {
323    /// Riemannian manifolds
324    RiemannianManifold,
325    /// Lie groups
326    LieGroup,
327    /// Fiber bundles
328    FiberBundle,
329    /// Homogeneous spaces
330    HomogeneousSpace,
331    /// Simplicial complexes
332    SimplicialComplex,
333}
334
335/// Equivariance groups
336#[derive(Debug, Clone, Serialize, Deserialize)]
337pub enum EquivarianceGroup {
338    /// Special orthogonal group SO(3)
339    SO3,
340    /// Special Euclidean group SE(3)
341    SE3,
342    /// General linear group GL(n)
343    GLn,
344    /// Symmetric group
345    SymmetricGroup,
346    /// Lorentz group
347    LorentzGroup,
348}
349
350/// Quantum-inspired parameters
351#[derive(Debug, Clone, Serialize, Deserialize)]
352pub struct QuantumParams {
353    /// Number of qubits for quantum state
354    pub num_qubits: usize,
355    /// Quantum gate set
356    pub gate_set: QuantumGateSet,
357    /// Entanglement structure
358    pub entanglement: EntanglementStructure,
359    /// Measurement strategy
360    pub measurement: QuantumMeasurement,
361    /// Quantum noise model
362    pub noise_model: QuantumNoise,
363    /// Classical-quantum interface
364    pub hybrid_layers: bool,
365}
366
367impl Default for QuantumParams {
368    fn default() -> Self {
369        Self {
370            num_qubits: 10,
371            gate_set: QuantumGateSet::Universal,
372            entanglement: EntanglementStructure::Linear,
373            measurement: QuantumMeasurement::Computational,
374            noise_model: QuantumNoise::None,
375            hybrid_layers: true,
376        }
377    }
378}
379
380/// Quantum gate sets
381#[derive(Debug, Clone, Serialize, Deserialize)]
382pub enum QuantumGateSet {
383    /// Universal gate set
384    Universal,
385    /// Clifford gates
386    Clifford,
387    /// Variational gates
388    Variational,
389    /// Adiabatic evolution
390    Adiabatic,
391}
392
393/// Entanglement structures
394#[derive(Debug, Clone, Serialize, Deserialize)]
395pub enum EntanglementStructure {
396    /// Linear entanglement
397    Linear,
398    /// All-to-all entanglement
399    AllToAll,
400    /// Tree entanglement
401    Tree,
402    /// Hardware-efficient
403    HardwareEfficient,
404}
405
406/// Quantum measurement strategies
407#[derive(Debug, Clone, Serialize, Deserialize)]
408pub enum QuantumMeasurement {
409    /// Computational basis
410    Computational,
411    /// Pauli measurements
412    Pauli,
413    /// Quantum state tomography
414    Tomography,
415    /// Shadow measurements
416    Shadow,
417}
418
419/// Quantum noise models
420#[derive(Debug, Clone, Serialize, Deserialize)]
421pub enum QuantumNoise {
422    None,
423    /// Depolarizing noise
424    Depolarizing,
425    /// Amplitude damping
426    AmplitudeDamping,
427    /// Phase damping
428    PhaseDamping,
429    /// Realistic device noise
430    DeviceNoise,
431}
432
433/// Dynamics configuration for continuous models
434#[derive(Debug, Clone, Serialize, Deserialize)]
435pub struct DynamicsConfig {
436    /// Time evolution parameters
437    pub time_evolution: TimeEvolution,
438    /// Continuous flow type
439    pub flow_type: FlowType,
440    /// Integration scheme
441    pub integration_scheme: IntegrationScheme,
442    /// Stability constraints
443    pub stability_constraints: StabilityConstraints,
444}
445
446impl Default for DynamicsConfig {
447    fn default() -> Self {
448        Self {
449            time_evolution: TimeEvolution::default(),
450            flow_type: FlowType::NormalizingFlow,
451            integration_scheme: IntegrationScheme::AdaptiveRungeKutta,
452            stability_constraints: StabilityConstraints::default(),
453        }
454    }
455}
456
457/// Time evolution parameters
458#[derive(Debug, Clone, Serialize, Deserialize)]
459pub struct TimeEvolution {
460    /// Start time
461    pub t_start: f64,
462    /// End time
463    pub t_end: f64,
464    /// Time steps
465    pub time_steps: usize,
466    /// Adaptive time stepping
467    pub adaptive: bool,
468}
469
470impl Default for TimeEvolution {
471    fn default() -> Self {
472        Self {
473            t_start: 0.0,
474            t_end: 1.0,
475            time_steps: 100,
476            adaptive: true,
477        }
478    }
479}
480
481/// Flow types for continuous models
482#[derive(Debug, Clone, Serialize, Deserialize)]
483pub enum FlowType {
484    /// Normalizing flows
485    NormalizingFlow,
486    /// Continuous normalizing flows
487    ContinuousNormalizingFlow,
488    /// Neural flows
489    NeuralFlow,
490    /// Hamiltonian flows
491    HamiltonianFlow,
492}
493
494/// Integration schemes
495#[derive(Debug, Clone, Serialize, Deserialize)]
496pub enum IntegrationScheme {
497    /// Fixed-step Runge-Kutta
498    FixedRungeKutta,
499    /// Adaptive Runge-Kutta
500    AdaptiveRungeKutta,
501    /// Symplectic integrators
502    SymplecticIntegrator,
503    /// Implicit methods
504    ImplicitMethods,
505}
506
507/// Stability constraints
508#[derive(Debug, Clone, Serialize, Deserialize)]
509pub struct StabilityConstraints {
510    /// Maximum eigenvalue
511    pub max_eigenvalue: f64,
512    /// Lyapunov regularization
513    pub lyapunov_reg: f64,
514    /// Spectral normalization
515    pub spectral_norm: bool,
516}
517
518impl Default for StabilityConstraints {
519    fn default() -> Self {
520        Self {
521            max_eigenvalue: 1.0,
522            lyapunov_reg: 0.01,
523            spectral_norm: true,
524        }
525    }
526}
527
528/// Geometric configuration
529#[derive(Debug, Clone, Serialize, Deserialize, Default)]
530pub struct GeometricConfig {
531    /// Manifold learning parameters
532    pub manifold_learning: ManifoldLearning,
533    /// Curvature computation
534    pub curvature_computation: CurvatureComputation,
535    /// Parallel transport
536    pub parallel_transport: ParallelTransport,
537}
538
539/// Manifold learning configuration
540#[derive(Debug, Clone, Serialize, Deserialize)]
541pub struct ManifoldLearning {
542    /// Intrinsic dimension
543    pub intrinsic_dim: usize,
544    /// Neighborhood size
545    pub neighborhood_size: usize,
546    /// Embedding method
547    pub embedding_method: ManifoldMethod,
548}
549
550impl Default for ManifoldLearning {
551    fn default() -> Self {
552        Self {
553            intrinsic_dim: 64,
554            neighborhood_size: 10,
555            embedding_method: ManifoldMethod::Isomap,
556        }
557    }
558}
559
560/// Manifold embedding methods
561#[derive(Debug, Clone, Serialize, Deserialize)]
562pub enum ManifoldMethod {
563    /// Isomap
564    Isomap,
565    /// Locally Linear Embedding
566    LLE,
567    /// Laplacian Eigenmaps
568    LaplacianEigenmaps,
569    /// Diffusion Maps
570    DiffusionMaps,
571    /// t-SNE
572    TSNE,
573    /// UMAP
574    UMAP,
575}
576
577/// Curvature computation
578#[derive(Debug, Clone, Serialize, Deserialize)]
579pub struct CurvatureComputation {
580    /// Curvature type
581    pub curvature_type: CurvatureType,
582    /// Computation method
583    pub computation_method: CurvatureMethod,
584    /// Regularization
585    pub regularization: f64,
586}
587
588impl Default for CurvatureComputation {
589    fn default() -> Self {
590        Self {
591            curvature_type: CurvatureType::Ricci,
592            computation_method: CurvatureMethod::FormanRicci,
593            regularization: 0.01,
594        }
595    }
596}
597
598/// Curvature types
599#[derive(Debug, Clone, Serialize, Deserialize)]
600pub enum CurvatureType {
601    /// Gaussian curvature
602    Gaussian,
603    /// Mean curvature
604    Mean,
605    /// Ricci curvature
606    Ricci,
607    /// Scalar curvature
608    Scalar,
609    /// Sectional curvature
610    Sectional,
611}
612
613/// Curvature computation methods
614#[derive(Debug, Clone, Serialize, Deserialize)]
615pub enum CurvatureMethod {
616    /// Forman-Ricci curvature
617    FormanRicci,
618    /// Ollivier-Ricci curvature
619    OllivierRicci,
620    /// Discrete Gaussian curvature
621    DiscreteGaussian,
622    /// Graph-based methods
623    GraphBased,
624}
625
626/// Parallel transport configuration
627#[derive(Debug, Clone, Serialize, Deserialize)]
628pub struct ParallelTransport {
629    /// Transport method
630    pub method: TransportMethod,
631    /// Path discretization
632    pub path_steps: usize,
633    /// Tolerance
634    pub tolerance: f64,
635}
636
637impl Default for ParallelTransport {
638    fn default() -> Self {
639        Self {
640            method: TransportMethod::SchildLadder,
641            path_steps: 50,
642            tolerance: 1e-6,
643        }
644    }
645}
646
647/// Parallel transport methods
648#[derive(Debug, Clone, Serialize, Deserialize)]
649pub enum TransportMethod {
650    /// Schild's ladder
651    SchildLadder,
652    /// Pole ladder
653    PoleLadder,
654    /// Geodesic parallel transport
655    GeodesicTransport,
656    /// Discrete transport
657    DiscreteTransport,
658}
659
660/// Novel architecture embedding model
661#[derive(Debug, Clone)]
662pub struct NovelArchitectureModel {
663    pub config: NovelArchitectureConfig,
664    pub model_id: Uuid,
665    pub entities: HashMap<String, usize>,
666    pub relations: HashMap<String, usize>,
667    pub entity_embeddings: Array2<f64>,
668    pub relation_embeddings: Array2<f64>,
669    pub architecture_state: ArchitectureState,
670    pub training_stats: Option<TrainingStats>,
671    pub is_trained: bool,
672}
673
674/// Architecture-specific state
675#[derive(Debug, Clone)]
676pub struct ArchitectureState {
677    /// Graph transformer state
678    pub transformer_state: Option<GraphTransformerState>,
679    /// Neural ODE state
680    pub ode_state: Option<NeuralODEState>,
681    /// Hyperbolic state
682    pub hyperbolic_state: Option<HyperbolicState>,
683    /// Geometric state
684    pub geometric_state: Option<GeometricState>,
685    /// Quantum state
686    pub quantum_state: Option<QuantumState>,
687}
688
689/// Graph transformer state
690#[derive(Debug, Clone)]
691pub struct GraphTransformerState {
692    /// Attention weights
693    pub attention_weights: Array3<f64>,
694    /// Layer outputs
695    pub layer_outputs: Vec<Array2<f64>>,
696    /// Structural features
697    pub structural_features: Array2<f64>,
698    /// Position encodings
699    pub position_encodings: Option<Array2<f64>>,
700}
701
702/// Neural ODE state
703#[derive(Debug, Clone)]
704pub struct NeuralODEState {
705    /// Current time
706    pub current_time: f64,
707    /// State trajectory
708    pub trajectory: Vec<Array2<f64>>,
709    /// ODE function parameters
710    pub ode_params: Array2<f64>,
711    /// Integration statistics
712    pub integration_stats: IntegrationStats,
713}
714
715/// Integration statistics
716#[derive(Debug, Clone)]
717pub struct IntegrationStats {
718    pub steps_taken: usize,
719    pub function_evaluations: usize,
720    pub jacobian_evaluations: usize,
721    pub failed_steps: usize,
722    pub final_error: f64,
723}
724
725/// Hyperbolic state
726#[derive(Debug, Clone)]
727pub struct HyperbolicState {
728    /// Manifold embeddings
729    pub manifold_embeddings: Array2<f64>,
730    /// Curvature parameter
731    pub curvature: f64,
732    /// Tangent vectors
733    pub tangent_vectors: Array2<f64>,
734    /// Metric tensor
735    pub metric_tensor: Array3<f64>,
736}
737
738/// Geometric state
739#[derive(Debug, Clone)]
740pub struct GeometricState {
741    /// Connection coefficients
742    pub connection: Array3<f64>,
743    /// Curvature tensor
744    pub curvature_tensor: Array3<f64>,
745    /// Parallel transport maps
746    pub transport_maps: HashMap<String, Array2<f64>>,
747    /// Equivariance maps
748    pub equivariance_maps: Vec<Array2<f64>>,
749}
750
751/// Quantum state
752#[derive(Debug, Clone)]
753pub struct QuantumState {
754    /// Quantum state vector
755    pub state_vector: Array1<f64>,
756    /// Quantum gates
757    pub gates: Vec<Array2<f64>>,
758    /// Measurement outcomes
759    pub measurements: Vec<f64>,
760    /// Entanglement measures
761    pub entanglement: f64,
762}