autograd/
lib.rs

1//! Differentiable operations and tensors backed by [ndarray](https://github.com/rust-ndarray/ndarray).
2//!
3//! ## Motivation
4//! Machine learning is one of the field where Rust lagging behind other languages.
5//! The aim of this crate is to show that Rust has the capability to implement efficient and full-featured dataflow graph naturally.
6//! Moreover, the core of this crate is quite small compared to others (due to being implemented in pure Rust and ndarray),
7//! therefore it might be reasonable for those who are not familiar with how this kind of library works.
8//!
9//! ## Features
10//! ### Lazy, lightweight tensor evaluation
11//! Computation graphs are created on the fly (a.k.a. *define-by-run*), but are not evaluated until `eval` is called.
12//! This mechanism balances better performance and flexibility.
13//!
14//! ```rust
15//! use autograd as ag;
16//!
17//! ag::with(|g: &mut ag::Graph<_>| {
18//!     let a: ag::Tensor<f32> = g.ones(&[60]);
19//!     let b: ag::Tensor<f32> = g.ones(&[24]);
20//!     let c: ag::Tensor<f32> = g.reshape(a, &[3, 4, 5]);
21//!     let d: ag::Tensor<f32> = g.reshape(b, &[4, 3, 2]);
22//!     let e: ag::Tensor<f32> = g.tensordot(c, d, &[1, 0], &[0, 1]);
23//!     let result: ag::ndarray::Array<_, _> = e.eval(&[]).unwrap();  // Getting `ndarray::Array` here.
24//! });
25//! ```
26//!
27//! ### Reverse-mode automatic differentiation
28//! There are a lot of [built-in operations](https://docs.rs/autograd/1.0.0/autograd/struct.Graph.html)
29//! that support *higher-order* derivatives, and
30//! you can also [define your own differentiable ops](https://docs.rs/autograd/1.0.0/autograd/op/trait.Op.html) with ndarrays easily.
31//!
32//! Here we are just computing partial derivatives of `z = 2x^2 + 3y + 1`.
33//!
34//! ```rust
35//! use autograd as ag;
36//!
37//! # fn main() {
38//! ag::with(|g: &mut ag::Graph<_>| {
39//!     let x = g.placeholder(&[]);
40//!     let y = g.placeholder(&[]);
41//!     let z = 2.*x*x + 3.*y + 1.;
42//!
43//!     // dz/dy
44//!     let gy = &g.grad(&[z], &[y])[0];
45//!     println!("{:?}", gy.eval(&[]));   // => Ok(3.)
46//!
47//!     // dz/dx (requires to fill the placeholder `x`)
48//!     let gx = &g.grad(&[z], &[x])[0];
49//!     let feed = ag::ndarray::arr0(2.);
50//!     println!("{:?}", gx.eval(&[x.given(feed.view())]));  // => Ok(8.)
51//!     // ddz/dx (differentiates `z` again)
52//!     let ggx = &g.grad(&[gx], &[x])[0];
53//!     println!("{:?}", ggx.eval(&[]));  // => Ok(4.)
54//! });
55//! # }
56//! ```
57//!
58//! ### Neural networks
59//! This crate has various low-level features inspired by tensorflow/theano to train neural networks.
60//! Since computation graphs require only bare minimum of heap allocations, the overhead is small, even for complex networks.
61//! ```rust
62//! // This is a softmax regression for MNIST digits classification with Adam.
63//! // This achieves 0.918 test accuracy after 3 epochs (0.11 sec/epoch on 2.7GHz Intel Core i5).
64//! use autograd::{self as ag, Graph, optimizers::adam, ndarray_ext as arr, tensor::Variable};
65//!
66//! let rng = ag::ndarray_ext::ArrayRng::<f32>::default();
67//! let w_arr = arr::into_shared(rng.glorot_uniform(&[28 * 28, 10]));
68//! let b_arr = arr::into_shared(arr::zeros(&[1, 10]));
69//! let adam_state = adam::AdamState::new(&[&w_arr, &b_arr]);
70//!
71//! let max_epoch = 3;
72//!
73//! for epoch in 0..max_epoch {
74//!     ag::with(|g| {
75//!         let w = g.variable(w_arr.clone());
76//!         let b = g.variable(b_arr.clone());
77//!         let x = g.placeholder(&[-1, 28*28]);
78//!         let y = g.placeholder(&[-1]);
79//!         let z = g.matmul(x, w) + b;
80//!         let mean_loss = g.reduce_mean(g.sparse_softmax_cross_entropy(z, &y), &[0], false);
81//!         let grads = &g.grad(&[&mean_loss], &[w, b]);
82//!         let update_ops: &[ag::Tensor<f32>] =
83//!             &adam::Adam::default().compute_updates(&[w, b], grads, &adam_state, g);
84//!
85//!         // let batch_size = 200isize;
86//!         // let num_samples = x_train.shape()[0];
87//!         // let num_batches = num_samples / batch_size as usize;
88//!         // for i in get_permutation(num_batches) {
89//!         //     let i = i as isize * batch_size;
90//!         //     let x_batch = x_train.slice(s![i..i + batch_size, ..]).into_dyn();
91//!         //     let y_batch = y_train.slice(s![i..i + batch_size, ..]).into_dyn();
92//!         //     g.eval(update_ops, &[x.given(x_batch), y.given(y_batch)]);
93//!         // }
94//!     });
95//! }
96//! ```
97//!
98//! ### Hooks
99//! You can register hooks on `ag::Tensor` objects for debugging.
100//!
101//! ```rust
102//! use autograd as ag;
103//!
104//! ag::with(|g| {
105//!     let a: ag::Tensor<f32> = g.zeros(&[4, 2]).show();
106//!     let b: ag::Tensor<f32> = g.ones(&[2, 3]).show_shape();
107//!     let c = g.matmul(a, b).show_with("MatMul:");
108//!
109//!     c.eval(&[]);
110//!     // [[0.0, 0.0],
111//!     // [0.0, 0.0],
112//!     // [0.0, 0.0],
113//!     // [0.0, 0.0]] shape=[4, 2], strides=[2, 1], layout=C (0x1)
114//!     //
115//!     // [2, 3]
116//!     //
117//!     // MatMul:
118//!     //  [[0.0, 0.0, 0.0],
119//!     //  [0.0, 0.0, 0.0],
120//!     //  [0.0, 0.0, 0.0],
121//!     //  [0.0, 0.0, 0.0]] shape=[4, 3], strides=[3, 1], layout=C (0x1), dynamic ndim=2
122//! });
123//! ```
124//!
125
126#[allow(unused_imports)]
127// Expose to prevent version conflict
128#[macro_use(s)]
129/// re-exported for convenience and version-compatibility
130pub extern crate ndarray;
131#[cfg(feature = "mkl")]
132extern crate intel_mkl_src;
133extern crate libc;
134#[cfg(not(feature = "mkl"))]
135extern crate matrixmultiply;
136extern crate num;
137extern crate num_traits;
138/// re-exported for convenience and version-compatibility
139pub extern crate rand;
140extern crate rand_distr;
141extern crate rayon;
142extern crate rustc_hash;
143pub(crate) extern crate smallvec;
144
145mod gradient;
146pub(crate) mod graph;
147mod hook;
148pub mod ndarray_ext;
149pub mod op;
150mod ops;
151pub mod optimizers;
152mod runtime;
153pub mod tensor;
154pub mod test_helper;
155
156use rustc_hash::FxHasher;
157use std::any::TypeId;
158use std::collections::{HashMap, HashSet};
159use std::fmt;
160use std::hash::BuildHasherDefault;
161
162pub(crate) type FxHashMap<K, V> = HashMap<K, V, BuildHasherDefault<FxHasher>>;
163pub(crate) type FxHashSet<K> = HashSet<K, BuildHasherDefault<FxHasher>>;
164
165/// Primitive type in this crate, which is actually a decorated `num_traits::Float`.
166pub trait Float:
167    num_traits::Float
168    + num_traits::NumAssignOps
169    + Copy
170    + Send
171    + Sync
172    + fmt::Display
173    + fmt::Debug
174    + Sized
175    + 'static
176{
177}
178
179#[doc(hidden)]
180/// Internal trait.
181pub trait Int:
182    num::Integer
183    + num_traits::NumAssignOps
184    + num_traits::ToPrimitive
185    + Copy
186    + Send
187    + fmt::Display
188    + Sized
189    + 'static
190{
191}
192
193impl<T> Float for T where
194    T: num::Float
195        + num_traits::NumAssignOps
196        + Copy
197        + Send
198        + Sync
199        + fmt::Display
200        + fmt::Debug
201        + Sized
202        + 'static
203{
204}
205
206impl<T> Int for T where
207    T: num::Integer
208        + num_traits::NumAssignOps
209        + num_traits::ToPrimitive
210        + Copy
211        + Send
212        + Sync
213        + fmt::Display
214        + Sized
215        + 'static
216{
217}
218
219#[inline(always)]
220/// Return `true` if `A` and `B` are the same type
221pub(crate) fn same_type<A: 'static, B: 'static>() -> bool {
222    TypeId::of::<A>() == TypeId::of::<B>()
223}
224
225pub use crate::ndarray_ext::array_gen;
226
227pub use crate::ndarray_ext::{NdArray, NdArrayView, NdArrayViewMut};
228
229pub use crate::runtime::{Eval, Feed};
230
231pub use crate::tensor::Tensor;
232
233pub(crate) use crate::ndarray_ext::ArrRepr;
234
235pub use crate::graph::{run, with, Graph};
236
237/// Error during tensor's evaluation.
238#[derive(Debug, PartialEq)]
239pub enum EvalError {
240    /// Error during `Op`'s computation.
241    OpError(op::OpError),
242    /// A value of tensor is empty.
243    ///
244    /// For example, compute results of inplace ops (e.g. optimizers) are not available
245    /// and are represented as `Empty`.
246    Empty,
247}
248
249impl std::error::Error for EvalError {}
250
251impl fmt::Display for EvalError {
252    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
253        match self {
254            EvalError::OpError(e) => e.fmt(f),
255            EvalError::Empty => write!(f, "Empty return value from a stateful op"),
256        }
257    }
258}