autograph
A machine learning library for Rust.
To use autograph in your crate, add it as a dependency in Cargo.toml:
[dependencies]
autograph = { git = https://github.com/charles-r-earp/autograph }
Requirements
- Rust https://www.rust-lang.org/
- A device (typically a gpu) with drivers for a supported API:
- Vulkan (All platforms) https://www.vulkan.org/
- Metal (MacOS / iOS) https://developer.apple.com/metal/
- DX12 (Windows) https://docs.microsoft.com/windows/win32/directx
Tests
- To check that you have a valid device, run
cargo test device_new --features device_tests. - Run all the tests with
cargo test --features "full device_tests".
Custom Shader Code
You can write your own shaders and execute them with autograph.
// shader/src/lib.rs
// Declare the push constants. Use `#[repr(C)]` to ensure that fields
// are not reordered.
/// Computes `y' = `a` + `b`
///
/// `threads` can be up to 3 dimensions (x, y, z). This is the size of the `WorkGroup`. Generally
/// this should be a multiple of the hardware specific size, NVidia refers to this as the
/// `warp size`, which for NVidia is often 32 but sometimes 64. For AMD this is generally 64. 64
/// is a good default. Note that autograph will automatically choose the number of work groups to
/// execute given the global size, so it is not necessary for the function submitting the shader
/// to know the work group size.
///
/// # Note
/// autograph does check the size of the push constants, as well as the mutability of buffers. It
/// DOES NOT check their types. For example, a buffer can be declared like `&[u32]` but bound to a
/// `Slice<u8>`.
// main.rs
/// Adds `a` to `b`.
async
See the Hello Compute example.
Machine Learning
KMeans
// Create the device.
let device = new?;
// Create the dataset.
let iris = new;
// The flower dimensions are the inputs to the model.
let x_array = iris.dimensions;
// Select only Petal Length + Petal Height
// These are the primary dimensions and it makes plotting easier.
let x_array = x_array.slice;
// Create the KMeans model.
let kmeans = new
.into_device
.await?;
// For small datasets, we can load the entire dataset into the device.
// For larger datasets, the data can be streamed as an iterator.
let x = from
.into_device
// Note that despite the await this will resolve immediately.
// Host -> Device transfers are batched with other operations
// asynchronously on the device thread.
.await?;
// Construct a trainer.
let mut trainer = from;
// Intialize the model (KMeans++).
// Here we provide an iterator of n iterators, such that the trainer can
// visit the data n times. In this case, once for each centroid.
trainer.init?;
// Train the model (1 epoch).
trainer.train?;
// Get the model back.
let kmeans = from;
// Get the trained centroids.
// For multiple reads, batch them by getting the futures first.
let centroids_fut = kmeans.centroids
// The centroids are in a FloatArcTensor, which can either be f32 or bf16.
// This will convert to f32 if necessary.
.?
.read;
// Get the predicted classes.
let pred = kmeans.predict?
.into_dimensionality?
.read
// Here we wait on all previous operations, including centroids_fut.
.await?;
// This will resolve immediately.
let centroids = centroids_fut.await?;
// Get the flower classes from the dataset.
let classes = iris.classes.map;
// Plot the results to "plot.png".
// Note that since KMeans is an unsupervised method the predicted classes will be arbitrary and
// not align to the order of the true classes (ie the colors won't be the same in the plot).
plot?;
See the KMeans Iris example.
Neural Networks
See the Neural Network MNIST example.
Developement Platforms
- Ubuntu 18.04 | (Vulkan) NVidia GeForce GTX 1060 with Max-Q Design
- Wondows 10 Home | (Vulkan + DX12) AMD RX 580 / (DX12) Microsoft Basic Render Driver.
Shaders are tested on Github Actions:
- Windows Server 2019 | (DX12) Microsoft Basic Render Driver.
Metal
Shaders are untested on Metal / Apple platforms. If you have problems, please create an issue!
License
Dual-licensed to be compatible with the Rust project.
Licensed under the Apache License, Version 2.0 http://www.apache.org/licenses/LICENSE-2.0 or the MIT license http://opensource.org/licenses/MIT, at your option. This file may not be copied, modified, or distributed except according to those terms.
Contribution
Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions.