autograd/lib.rs
1//! Differentiable operations and tensors backed by [ndarray](https://github.com/rust-ndarray/ndarray).
2//!
3//! ## Motivation
4//! Machine learning is one of the field where Rust lagging behind other languages.
5//! The aim of this crate is to show that Rust has the capability to implement efficient and full-featured dataflow graph naturally.
6//! Moreover, the core of this crate is quite small compared to others (due to being implemented in pure Rust and ndarray),
7//! therefore it might be reasonable for those who are not familiar with how this kind of library works.
8//!
9//! ## Features
10//! ### Lazy, lightweight tensor evaluation
11//! Computation graphs are created on the fly (a.k.a. *define-by-run*), but are not evaluated until `eval` is called.
12//! This mechanism balances better performance and flexibility.
13//!
14//! ```rust
15//! use autograd as ag;
16//!
17//! ag::with(|g: &mut ag::Graph<_>| {
18//! let a: ag::Tensor<f32> = g.ones(&[60]);
19//! let b: ag::Tensor<f32> = g.ones(&[24]);
20//! let c: ag::Tensor<f32> = g.reshape(a, &[3, 4, 5]);
21//! let d: ag::Tensor<f32> = g.reshape(b, &[4, 3, 2]);
22//! let e: ag::Tensor<f32> = g.tensordot(c, d, &[1, 0], &[0, 1]);
23//! let result: ag::ndarray::Array<_, _> = e.eval(&[]).unwrap(); // Getting `ndarray::Array` here.
24//! });
25//! ```
26//!
27//! ### Reverse-mode automatic differentiation
28//! There are a lot of [built-in operations](https://docs.rs/autograd/1.0.0/autograd/struct.Graph.html)
29//! that support *higher-order* derivatives, and
30//! you can also [define your own differentiable ops](https://docs.rs/autograd/1.0.0/autograd/op/trait.Op.html) with ndarrays easily.
31//!
32//! Here we are just computing partial derivatives of `z = 2x^2 + 3y + 1`.
33//!
34//! ```rust
35//! use autograd as ag;
36//!
37//! # fn main() {
38//! ag::with(|g: &mut ag::Graph<_>| {
39//! let x = g.placeholder(&[]);
40//! let y = g.placeholder(&[]);
41//! let z = 2.*x*x + 3.*y + 1.;
42//!
43//! // dz/dy
44//! let gy = &g.grad(&[z], &[y])[0];
45//! println!("{:?}", gy.eval(&[])); // => Ok(3.)
46//!
47//! // dz/dx (requires to fill the placeholder `x`)
48//! let gx = &g.grad(&[z], &[x])[0];
49//! let feed = ag::ndarray::arr0(2.);
50//! println!("{:?}", gx.eval(&[x.given(feed.view())])); // => Ok(8.)
51//! // ddz/dx (differentiates `z` again)
52//! let ggx = &g.grad(&[gx], &[x])[0];
53//! println!("{:?}", ggx.eval(&[])); // => Ok(4.)
54//! });
55//! # }
56//! ```
57//!
58//! ### Neural networks
59//! This crate has various low-level features inspired by tensorflow/theano to train neural networks.
60//! Since computation graphs require only bare minimum of heap allocations, the overhead is small, even for complex networks.
61//! ```rust
62//! // This is a softmax regression for MNIST digits classification with Adam.
63//! // This achieves 0.918 test accuracy after 3 epochs (0.11 sec/epoch on 2.7GHz Intel Core i5).
64//! use autograd::{self as ag, Graph, optimizers::adam, ndarray_ext as arr, tensor::Variable};
65//!
66//! let rng = ag::ndarray_ext::ArrayRng::<f32>::default();
67//! let w_arr = arr::into_shared(rng.glorot_uniform(&[28 * 28, 10]));
68//! let b_arr = arr::into_shared(arr::zeros(&[1, 10]));
69//! let adam_state = adam::AdamState::new(&[&w_arr, &b_arr]);
70//!
71//! let max_epoch = 3;
72//!
73//! for epoch in 0..max_epoch {
74//! ag::with(|g| {
75//! let w = g.variable(w_arr.clone());
76//! let b = g.variable(b_arr.clone());
77//! let x = g.placeholder(&[-1, 28*28]);
78//! let y = g.placeholder(&[-1]);
79//! let z = g.matmul(x, w) + b;
80//! let mean_loss = g.reduce_mean(g.sparse_softmax_cross_entropy(z, &y), &[0], false);
81//! let grads = &g.grad(&[&mean_loss], &[w, b]);
82//! let update_ops: &[ag::Tensor<f32>] =
83//! &adam::Adam::default().compute_updates(&[w, b], grads, &adam_state, g);
84//!
85//! // let batch_size = 200isize;
86//! // let num_samples = x_train.shape()[0];
87//! // let num_batches = num_samples / batch_size as usize;
88//! // for i in get_permutation(num_batches) {
89//! // let i = i as isize * batch_size;
90//! // let x_batch = x_train.slice(s![i..i + batch_size, ..]).into_dyn();
91//! // let y_batch = y_train.slice(s![i..i + batch_size, ..]).into_dyn();
92//! // g.eval(update_ops, &[x.given(x_batch), y.given(y_batch)]);
93//! // }
94//! });
95//! }
96//! ```
97//!
98//! ### Hooks
99//! You can register hooks on `ag::Tensor` objects for debugging.
100//!
101//! ```rust
102//! use autograd as ag;
103//!
104//! ag::with(|g| {
105//! let a: ag::Tensor<f32> = g.zeros(&[4, 2]).show();
106//! let b: ag::Tensor<f32> = g.ones(&[2, 3]).show_shape();
107//! let c = g.matmul(a, b).show_with("MatMul:");
108//!
109//! c.eval(&[]);
110//! // [[0.0, 0.0],
111//! // [0.0, 0.0],
112//! // [0.0, 0.0],
113//! // [0.0, 0.0]] shape=[4, 2], strides=[2, 1], layout=C (0x1)
114//! //
115//! // [2, 3]
116//! //
117//! // MatMul:
118//! // [[0.0, 0.0, 0.0],
119//! // [0.0, 0.0, 0.0],
120//! // [0.0, 0.0, 0.0],
121//! // [0.0, 0.0, 0.0]] shape=[4, 3], strides=[3, 1], layout=C (0x1), dynamic ndim=2
122//! });
123//! ```
124//!
125
126#[allow(unused_imports)]
127// Expose to prevent version conflict
128#[macro_use(s)]
129/// re-exported for convenience and version-compatibility
130pub extern crate ndarray;
131#[cfg(feature = "mkl")]
132extern crate intel_mkl_src;
133extern crate libc;
134#[cfg(not(feature = "mkl"))]
135extern crate matrixmultiply;
136extern crate num;
137extern crate num_traits;
138/// re-exported for convenience and version-compatibility
139pub extern crate rand;
140extern crate rand_distr;
141extern crate rayon;
142extern crate rustc_hash;
143pub(crate) extern crate smallvec;
144
145mod gradient;
146pub(crate) mod graph;
147mod hook;
148pub mod ndarray_ext;
149pub mod op;
150mod ops;
151pub mod optimizers;
152mod runtime;
153pub mod tensor;
154pub mod test_helper;
155
156use rustc_hash::FxHasher;
157use std::any::TypeId;
158use std::collections::{HashMap, HashSet};
159use std::fmt;
160use std::hash::BuildHasherDefault;
161
162pub(crate) type FxHashMap<K, V> = HashMap<K, V, BuildHasherDefault<FxHasher>>;
163pub(crate) type FxHashSet<K> = HashSet<K, BuildHasherDefault<FxHasher>>;
164
165/// Primitive type in this crate, which is actually a decorated `num_traits::Float`.
166pub trait Float:
167 num_traits::Float
168 + num_traits::NumAssignOps
169 + Copy
170 + Send
171 + Sync
172 + fmt::Display
173 + fmt::Debug
174 + Sized
175 + 'static
176{
177}
178
179#[doc(hidden)]
180/// Internal trait.
181pub trait Int:
182 num::Integer
183 + num_traits::NumAssignOps
184 + num_traits::ToPrimitive
185 + Copy
186 + Send
187 + fmt::Display
188 + Sized
189 + 'static
190{
191}
192
193impl<T> Float for T where
194 T: num::Float
195 + num_traits::NumAssignOps
196 + Copy
197 + Send
198 + Sync
199 + fmt::Display
200 + fmt::Debug
201 + Sized
202 + 'static
203{
204}
205
206impl<T> Int for T where
207 T: num::Integer
208 + num_traits::NumAssignOps
209 + num_traits::ToPrimitive
210 + Copy
211 + Send
212 + Sync
213 + fmt::Display
214 + Sized
215 + 'static
216{
217}
218
219#[inline(always)]
220/// Return `true` if `A` and `B` are the same type
221pub(crate) fn same_type<A: 'static, B: 'static>() -> bool {
222 TypeId::of::<A>() == TypeId::of::<B>()
223}
224
225pub use crate::ndarray_ext::array_gen;
226
227pub use crate::ndarray_ext::{NdArray, NdArrayView, NdArrayViewMut};
228
229pub use crate::runtime::{Eval, Feed};
230
231pub use crate::tensor::Tensor;
232
233pub(crate) use crate::ndarray_ext::ArrRepr;
234
235pub use crate::graph::{run, with, Graph};
236
237/// Error during tensor's evaluation.
238#[derive(Debug, PartialEq)]
239pub enum EvalError {
240 /// Error during `Op`'s computation.
241 OpError(op::OpError),
242 /// A value of tensor is empty.
243 ///
244 /// For example, compute results of inplace ops (e.g. optimizers) are not available
245 /// and are represented as `Empty`.
246 Empty,
247}
248
249impl std::error::Error for EvalError {}
250
251impl fmt::Display for EvalError {
252 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
253 match self {
254 EvalError::OpError(e) => e.fmt(f),
255 EvalError::Empty => write!(f, "Empty return value from a stateful op"),
256 }
257 }
258}