smelte_rs/lib.rs
1#![deny(missing_docs)]
2//! # What is smelte-rs ?
3//!
4//! Smelt is a ML library focusing on inference, small depedencies with as many optimizations
5//! as possible, and still be readable and easy to use.
6//!
7//! Keep unsafe usage limited and only for performance.
8//!
9//! # Running models
10//!
11//! Try running Bert on text classification example.
12//!
13//! ```bash
14//! # Download the model + tokenizer + config
15//! # This is a clone of https://huggingface.co/ProsusAI/finbert with safetensors support.
16//! curl https://huggingface.co/Narsil/finbert/resolve/main/model.safetensors -o model-Narsil-finbert.safetensors -L
17//! curl https://huggingface.co/Narsil/finbert/resolve/main/tokenizer.json -o tokenizer-Narsil-finbert.json -L
18//! curl https://huggingface.co/Narsil/finbert/resolve/main/config.json -o config-Narsil-finbert.json -L
19//!
20//! # Linux
21//! cargo run --example bert --release --features intel-mkl -- "This is a test" -n 3
22//!
23//! # M1
24//! cargo run --example bert --release -- "This is a test" -n 3
25//! ```
26//!
27//! # Why not use library X ?
28//!
29//! Many other libraries for ML out there, torch and tensorflow are great but
30//! are now extremely heavy with no option to statically link against.
31//! Libraries like ONNX are great too, but when an operator is missing out, it's
32//! really hard to work against.
33//!
34//! For low level libraries. [ggml](https://github.com/ggerganov/ggml) is a great
35//! library, no dependencies, extremely small binary size. It's actually an
36//! inspiration for this project ! But I'm not good enough a C++ programmer to hack it
37//! efficiently enough. Also it's hard to use outside of the intended scope, for
38//! instance when writing a webserver/API, or if we wanted to use CUDA as a backend.
39//!
40//! [dfdx](https://github.com/coreylowman/dfdx) is another super nice project.
41//! I drew inspiration from it too. The problem with dfdx was the typing system
42//! which while extremely powerful (compile time size checking) it was getting
43//! in the way of getting things done, and optimizing for it is not as trivial as
44//! it's harder to know what's going on.
45//!
46//! # The architecture of this library:
47//!
48//! - [cpu] is containing all the various precisions backend operations, tensor structs.
49//! This is your go-to if you want to code everything from scratch.
50//! - [nn] contains all the basic layers, and actual model implementations. Code should
51//! look closely like torch implementations.
52//! - [traits] Contains the glue that allows [nn] to be written independantly of [cpu]
53//! which should hopefully making using different precisions (or backends) quite easy.
54//!
55//!
56//! # How does the model look like:
57//!
58//! ```ignore
59//! pub struct BertClassifier<T: Tensor + TensorOps<T>> {
60//! bert: Bert<T>,
61//! pooler: BertPooler<T>,
62//! classifier: Linear<T>,
63//! }
64//!
65//! impl<T: Tensor + TensorOps<T>> BertClassifier<T> {
66//! pub fn new(bert: Bert<T>, pooler: BertPooler<T>, classifier: Linear<T>) -> Self {
67//! Self {
68//! bert,
69//! pooler,
70//! classifier,
71//! }
72//! }
73//! pub fn forward(&self, input_ids: &[usize], type_ids: &[usize]) -> Result<T, SmeltError> {
74//! let tensor = self.bert.forward(input_ids, type_ids)?;
75//! let tensor = self.pooler.forward(&tensor)?;
76//! let mut logits = self.classifier.forward(&tensor)?;
77//! T::softmax(&mut logits)?;
78//! Ok(logits)
79//! }
80//! }
81//! ```
82//!
83//! # What's the performance like ?
84//!
85//! On a relatively old computer (i7-4790 CPU) This gives ~40ms/token for GPT-2
86//! in full f32 precision.
87//! For comparison, on the same hardware `torch` gives ~47ms/token and ggml ~37ms.
88//!
89//! Current implementations does *not* use threading, nor precomputed gelu/exp
90//! nor f16 shortcuts that ggml can use (like for the softmax).
91//!
92//! So there is still lots of room for improvement, and most of the current performance
93//! comes from using `intel-mkl` library, which can be dropped once this implements
94//! the various ops from ggml (hopefully to get the full performance).
95
96/// The various CPU implementations
97pub mod cpu;
98
99/// The neural networks
100pub mod nn;
101
102/// The traits for generic implementations
103pub mod traits;
104
105/// Error linked to the tensor creation
106#[derive(Debug)]
107pub enum TensorError {
108 /// The arguments to the tensor creation are invalid, the shape doesn't match
109 /// the size of the buffer.
110 InvalidBuffer {
111 /// The size of the buffer sent
112 buffer_size: usize,
113 /// The shape of the tensor to create
114 shape: Vec<usize>,
115 },
116}
117
118/// Potential errors when using the library
119#[derive(Debug)]
120pub enum SmeltError {
121 /// The operation could not succeed because the shapes are not valid.
122 DimensionMismatch {
123 /// The shape that we should have seen
124 expected: Vec<usize>,
125 /// The shape that we received
126 got: Vec<usize>,
127 },
128 /// The tensor given has insufficient rank (rank 2 means a tensor that has a shape of length 2)
129 InsufficientRank {
130 /// The minimum rank that we expect
131 minimum_rank: usize,
132 },
133 /// The tensor given has not the expected rank (rank 2 means a tensor that has a shape of length 2)
134 InvalidRank {
135 /// The rank that we expect
136 expected_rank: usize,
137 },
138 /// The tensor given has not enough room for the operations
139 VectorTooSmall {
140 /// The minimum size that we expect
141 minimum: usize,
142 },
143
144 /// The select operation attempted to select out of the tensor
145 OutOfVocabulary {
146 /// The vocabulary size
147 vocab_size: usize,
148 /// culprit id
149 id: usize,
150 },
151
152 /// Some slices do not have the expected lengths
153 InvalidLength {
154 /// The size we expected
155 expected: usize,
156 /// The size we got
157 got: usize,
158 },
159}
160
161#[cfg(test)]
162mod tests {
163 pub(crate) fn simplify(data: &[f32]) -> Vec<f32> {
164 let precision = 3;
165 let m = 10.0 * 10.0f32.powf(precision as f32);
166 data.iter().map(|x| (x * m).round() / m).collect()
167 }
168
169 // fn assert_float_eq(left: &[f32], right: &[f32]) {
170 // assert_eq!(left.len(), right.len());
171
172 // left.iter().zip(right.iter()).for_each(|(l, r)| {
173 // assert!(
174 // (l - r).abs() / l.abs() < 1e-4,
175 // "{l} != {r}\n{left:?}\n{right:?}"
176 // );
177 // });
178 // }
179}