ghostflow_nn/
lib.rs

1//! GhostFlow Neural Network Layers
2//!
3//! High-level building blocks for neural networks.
4
5#![allow(dead_code)]
6
7pub mod module;
8pub mod linear;
9pub mod conv;
10pub mod norm;
11pub mod activation;
12pub mod dropout;
13pub mod loss;
14pub mod init;
15pub mod attention;
16pub mod transformer;
17pub mod embedding;
18pub mod pooling;
19pub mod rnn;
20pub mod quantization;
21pub mod distributed;
22pub mod serialization;
23pub mod onnx;
24pub mod inference;
25pub mod gnn;
26pub mod rl;
27pub mod federated;
28pub mod differential_privacy;
29pub mod adversarial;
30pub mod vision_transformer;
31pub mod bert;
32pub mod gpt;
33pub mod t5;
34pub mod diffusion;
35pub mod llama;
36pub mod clip;
37pub mod nerf;
38pub mod point_cloud;
39pub mod mesh;
40pub mod mixed_precision;
41pub mod gradient_checkpointing;
42pub mod lora;
43pub mod flash_attention;
44pub mod knowledge_distillation;
45pub mod zero_optimizer;
46pub mod ring_attention;
47pub mod mixture_of_experts;
48pub mod prompt_tuning;
49pub mod curriculum_learning;
50
51pub use module::Module;
52pub use linear::Linear;
53pub use conv::{Conv1d, Conv2d, Conv3d, TransposeConv2d};
54pub use norm::{BatchNorm1d, BatchNorm2d, LayerNorm, GroupNorm, InstanceNorm};
55pub use activation::*;
56pub use dropout::Dropout;
57pub use loss::*;
58pub use attention::{MultiHeadAttention, scaled_dot_product_attention};
59pub use transformer::{
60    TransformerEncoder, TransformerEncoderLayer,
61    TransformerDecoderLayer, FeedForward,
62    PositionalEncoding, RotaryEmbedding,
63};
64pub use embedding::Embedding;
65pub use pooling::*;
66pub use rnn::{LSTM, LSTMCell, GRU, GRUCell};
67pub use quantization::{
68    QuantizedTensor, QuantizationConfig, QuantizationScheme,
69    QuantizationAwareTraining, DynamicQuantization,
70};
71pub use distributed::{
72    DistributedConfig, DistributedBackend, DataParallel, ModelParallel,
73    GradientAccumulator, DistributedDataParallel, PipelineParallel,
74};
75pub use serialization::{
76    ModelCheckpoint, ModelMetadata, save_model, load_model,
77};
78pub use gnn::{
79    Graph, GCNLayer, GATLayer, GraphSAGELayer, MPNNLayer, AggregatorType,
80};
81pub use rl::{
82    ReplayBuffer, Experience, DQNAgent, QNetwork,
83    PolicyNetwork, REINFORCEAgent, ActorCriticAgent, ValueNetwork, PPOAgent,
84};
85pub use federated::{
86    FederatedClient, FederatedServer, AggregationStrategy,
87    SecureAggregation, DifferentialPrivacy,
88};
89pub use onnx::{
90    ONNXModel, ONNXNode, ONNXTensor, ONNXDataType, ONNXAttribute,
91    tensor_to_onnx, onnx_to_tensor,
92};
93pub use inference::{
94    InferenceConfig, InferenceOptimizer, InferenceSession,
95    BatchInference, warmup_model,
96};
97pub use differential_privacy::{
98    DPConfig, PrivacyAccountant, DPSGDOptimizer, PATEEnsemble, LocalDP,
99};
100pub use adversarial::{
101    AttackConfig, AttackType, AdversarialAttack, AdversarialTrainingConfig,
102    AdversarialTrainer, RandomizedSmoothing,
103};
104// pub use vision_transformer::{
105//     VisionTransformer, ViTConfig, PatchEmbedding,
106// };
107
108/// Prelude for convenient imports
109pub mod prelude {
110    pub use crate::{Module, Linear, Conv1d, Conv2d, Conv3d, TransposeConv2d};
111    pub use crate::{BatchNorm1d, BatchNorm2d, LayerNorm, GroupNorm, InstanceNorm, Dropout};
112    pub use crate::activation::*;
113    pub use crate::loss::*;
114    pub use crate::attention::MultiHeadAttention;
115    pub use crate::transformer::{TransformerEncoder, TransformerEncoderLayer};
116    pub use crate::embedding::Embedding;
117    pub use crate::rnn::{LSTM, GRU};
118}