Skip to main content

Crate flodl

Crate flodl 

Source
Expand description

flodl — a deep learning framework built on libtorch, from Rust.

Stack: flodl-sys (C++ shim FFI) → tensorautogradnngraph.

use flodl::*;

// Build a model as a computation graph
let model = FlowBuilder::from(Linear::new(4, 8)?)
    .through(GELU)
    .through(Linear::new(8, 2)?)
    .build()?;

// Forward pass
let x = Variable::new(Tensor::randn(&[1, 4], Default::default())?, false);
let target = Variable::new(Tensor::randn(&[1, 2], Default::default())?, false);
let pred = model.forward(&x)?;

// Backward + optimize
let params = model.parameters();
let mut optimizer = Adam::new(&params, 1e-3);
let loss = mse_loss(&pred, &target)?;
optimizer.zero_grad();
loss.backward()?;
optimizer.step()?;

Re-exports§

pub use tensor::cuda_available;
pub use tensor::cuda_device_count;
pub use tensor::cuda_memory_info;
pub use tensor::cuda_memory_info_idx;
pub use tensor::cuda_utilization;
pub use tensor::cuda_utilization_idx;
pub use tensor::cuda_device_name;
pub use tensor::cuda_device_name_idx;
pub use tensor::cuda_devices;
pub use tensor::DeviceInfo;
pub use tensor::set_current_cuda_device;
pub use tensor::current_cuda_device;
pub use tensor::cuda_synchronize;
pub use tensor::hardware_summary;
pub use tensor::set_cudnn_benchmark;
pub use tensor::malloc_trim;
pub use tensor::live_tensor_count;
pub use tensor::rss_kb;
pub use tensor::Device;
pub use tensor::DType;
pub use tensor::Result;
pub use tensor::Tensor;
pub use tensor::TensorError;
pub use tensor::TensorOptions;
pub use autograd::Variable;
pub use autograd::no_grad;
pub use autograd::is_grad_enabled;
pub use autograd::NoGradGuard;
pub use autograd::adaptive_avg_pool2d;
pub use autograd::grid_sample;
pub use nn::Module;
pub use nn::NamedInputModule;
pub use nn::Parameter;
pub use nn::Buffer;
pub use nn::Linear;
pub use nn::Optimizer;
pub use nn::Stateful;
pub use nn::SGD;
pub use nn::SGDBuilder;
pub use nn::Adam;
pub use nn::AdamBuilder;
pub use nn::AdamW;
pub use nn::AdamWBuilder;
pub use nn::save_checkpoint;
pub use nn::load_checkpoint;
pub use nn::save_checkpoint_file;
pub use nn::load_checkpoint_file;
pub use nn::LoadReport;
pub use nn::GradScaler;
pub use nn::cast_parameters;
pub use nn::Identity;
pub use nn::ReLU;
pub use nn::Sigmoid;
pub use nn::Tanh;
pub use nn::GELU;
pub use nn::SiLU;
pub use nn::Dropout;
pub use nn::Dropout2d;
pub use nn::LayerNorm;
pub use nn::Embedding;
pub use nn::GRUCell;
pub use nn::LSTMCell;
pub use nn::Conv2d;
pub use nn::Conv2dBuilder;
pub use nn::ConvTranspose2d;
pub use nn::BatchNorm;
pub use nn::BatchNorm2d;
pub use nn::mse_loss;
pub use nn::cross_entropy_loss;
pub use nn::bce_with_logits_loss;
pub use nn::l1_loss;
pub use nn::smooth_l1_loss;
pub use nn::kl_div_loss;
pub use nn::clip_grad_norm;
pub use nn::clip_grad_value;
pub use nn::Scheduler;
pub use nn::StepDecay;
pub use nn::CosineScheduler;
pub use nn::WarmupScheduler;
pub use nn::PlateauScheduler;
pub use nn::xavier_uniform;
pub use nn::xavier_normal;
pub use nn::walk_modules;
pub use nn::walk_modules_visited;
pub use graph::FlowBuilder;
pub use graph::MergeOp;
pub use graph::Graph;
pub use graph::MapBuilder;
pub use graph::Trend;
pub use graph::TrendGroup;
pub use graph::Profile;
pub use graph::NodeTiming;
pub use graph::LevelTiming;
pub use graph::format_duration;
pub use graph::SoftmaxRouter;
pub use graph::SigmoidRouter;
pub use graph::FixedSelector;
pub use graph::ArgmaxSelector;
pub use graph::ThresholdHalt;
pub use graph::LearnedHalt;
pub use graph::Reshape;
pub use graph::StateAdd;
pub use graph::Reduce;
pub use graph::ModelSnapshot;
pub use worker::CpuWorker;

Modules§

autograd
Reverse-mode automatic differentiation backed by libtorch.
graph
Computation graph: fluent builder, parallel execution, observation, profiling, and visualization.
monitor
Training monitor with human-readable ETA, resource tracking, and live dashboard.
nn
Neural network modules, losses, optimizers, and training utilities.
tensor
Tensor — immutable, chainable wrapper around a libtorch tensor.
worker
Background CPU work queue.

Macros§

modules
Shorthand for building Vec<Box<dyn Module>> from a list of modules. Use with split, gate, and switch to avoid manual Box::new() wrapping.