Skip to main content

flodl/
lib.rs

1//! flodl — a deep learning framework built on libtorch, from Rust.
2//!
3//! Stack: `flodl-sys` (C++ shim FFI) → `tensor` → `autograd` → `nn` → `graph`.
4//!
5//! ```ignore
6//! use flodl::*;
7//!
8//! // Build a model as a computation graph
9//! let model = FlowBuilder::from(Linear::new(4, 8)?)
10//!     .through(GELU)
11//!     .through(Linear::new(8, 2)?)
12//!     .build()?;
13//!
14//! // Forward pass
15//! let x = Variable::new(Tensor::randn(&[1, 4], Default::default())?, false);
16//! let target = Variable::new(Tensor::randn(&[1, 2], Default::default())?, false);
17//! let pred = model.forward(&x)?;
18//!
19//! // Backward + optimize
20//! let params = model.parameters();
21//! let mut optimizer = Adam::new(&params, 1e-3);
22//! let loss = mse_loss(&pred, &target)?;
23//! optimizer.zero_grad();
24//! loss.backward()?;
25//! optimizer.step()?;
26//! ```
27
28pub mod tensor;
29pub mod autograd;
30pub mod nn;
31pub mod graph;
32pub mod monitor;
33pub mod worker;
34
35/// Shorthand for building `Vec<Box<dyn Module>>` from a list of modules.
36/// Use with `split`, `gate`, and `switch` to avoid manual `Box::new()` wrapping.
37///
38/// ```ignore
39/// .split(modules![read_head(H), read_head(H)])
40/// .gate(router, modules![Linear::new(H, H)?, Linear::new(H, H)?])
41/// ```
42#[macro_export]
43macro_rules! modules {
44    ($($module:expr),* $(,)?) => {
45        vec![$(Box::new($module) as Box<dyn $crate::Module>),*]
46    };
47}
48
49pub use tensor::{cuda_available, cuda_device_count, cuda_memory_info, cuda_memory_info_idx, cuda_utilization, cuda_utilization_idx, cuda_device_name, cuda_device_name_idx, cuda_devices, DeviceInfo, set_current_cuda_device, current_cuda_device, cuda_synchronize, hardware_summary, set_cudnn_benchmark, malloc_trim, live_tensor_count, rss_kb, Device, DType, Result, Tensor, TensorError, TensorOptions};
50pub use autograd::{Variable, no_grad, is_grad_enabled, NoGradGuard, adaptive_avg_pool2d, grid_sample};
51pub use nn::{
52    Module, NamedInputModule,
53    Parameter, Buffer, Linear, Optimizer, Stateful, SGD, SGDBuilder, Adam, AdamBuilder, AdamW, AdamWBuilder,
54    save_checkpoint, load_checkpoint, save_checkpoint_file, load_checkpoint_file,
55    LoadReport,
56    GradScaler, cast_parameters,
57    Identity, ReLU, Sigmoid, Tanh, GELU, SiLU,
58    Dropout, Dropout2d, LayerNorm, Embedding, GRUCell, LSTMCell,
59    Conv2d, Conv2dBuilder, ConvTranspose2d, BatchNorm, BatchNorm2d,
60    mse_loss, cross_entropy_loss, bce_with_logits_loss, l1_loss, smooth_l1_loss, kl_div_loss,
61    clip_grad_norm, clip_grad_value,
62    Scheduler, StepDecay, CosineScheduler, WarmupScheduler, PlateauScheduler,
63    xavier_uniform, xavier_normal,
64    walk_modules, walk_modules_visited,
65};
66pub use graph::{
67    FlowBuilder, MergeOp, Graph, MapBuilder, Trend, TrendGroup,
68    Profile, NodeTiming, LevelTiming, format_duration,
69    SoftmaxRouter, SigmoidRouter, FixedSelector, ArgmaxSelector,
70    ThresholdHalt, LearnedHalt,
71    Reshape, StateAdd, Reduce, ModelSnapshot,
72};
73pub use worker::CpuWorker;