zyx/
lib.rs

1#![doc = include_str!("../README.md")]
2#![forbid(rustdoc::broken_intra_doc_links)]
3#![forbid(rustdoc::private_intra_doc_links)]
4#![forbid(missing_docs)]
5#![forbid(rustdoc::missing_crate_level_docs)]
6#![forbid(rustdoc::private_doc_tests)]
7#![forbid(rustdoc::invalid_codeblock_attributes)]
8#![forbid(rustdoc::invalid_html_tags)]
9#![forbid(rustdoc::invalid_rust_codeblocks)]
10#![forbid(rustdoc::bare_urls)]
11#![forbid(rustdoc::unescaped_backticks)]
12#![forbid(rustdoc::redundant_explicit_links)]
13
14use std::{fs::File, path::Path};
15
16use crate::runtime::Runtime;
17
18mod dtype;
19mod index_map;
20mod mutex;
21#[cfg(feature = "py")]
22mod python_bindings;
23mod runtime;
24mod scalar;
25mod shape;
26mod tensor;
27
28pub use dtype::DType;
29pub use runtime::DeviceConfig;
30pub use runtime::ZyxError;
31pub use scalar::Scalar;
32pub use shape::IntoShape;
33pub use tensor::Tensor;
34
35// Works, but rust does not call drop on this when exiting the program, which causes all sorts of problems ...
36static RT: mutex::Mutex<Runtime, 1000000000> = mutex::Mutex::new(Runtime::new());
37//static RT: mutex::Mutex<Runtime> = mutex::Mutex::new(Runtime::new());
38
39/// Save tensors or modules
40pub trait TensorSave {
41    /// Save tensors or modules
42    fn save(self, path: impl AsRef<Path>) -> Result<(), ZyxError>;
43}
44
45impl<'a, I: IntoIterator<Item = &'a Tensor>> TensorSave for I {
46    fn save(self, path: impl AsRef<Path>) -> Result<(), ZyxError> {
47        use std::fmt::Write;
48        use std::io::Write as IOWrite;
49        let mut f = File::create(path)?;
50        let mut header = String::from("{");
51        let mut begin = 0;
52        let tensors: Vec<&Tensor> = self.into_iter().collect();
53        for tensor in &tensors {
54            let dtype = tensor.dtype();
55            //if let Some(label) = tensor.label() {
56            //write!(header, "\"{label}\":{{").unwrap();
57            //} else {
58            write!(header, "\"{}\":{{", tensor.id()).unwrap();
59            //}
60            write!(header, "\"dtype\":\"{}\",", dtype.safetensors()).unwrap();
61            let mut st_shape = format!("{:?}", tensor.shape());
62            st_shape.retain(|c| !c.is_whitespace());
63            write!(header, "\"shape\":{},", st_shape).unwrap();
64            let size = tensor.numel() * dtype.byte_size();
65            write!(header, "\"data_offsets\":[{},{}]", begin, begin + size).unwrap();
66            begin += size;
67            write!(header, "}},").unwrap();
68        }
69        header.pop();
70        write!(header, "}}").unwrap();
71        let header_bytes = header.as_bytes();
72        f.write_all(&(header_bytes.len() as u64).to_le_bytes())?;
73        f.write_all(header_bytes)?;
74        for tensor in tensors {
75            f.write_all(&tensor.to_le_bytes()?)?;
76        }
77        Ok(())
78    }
79}
80
81/*
82#[test]
83fn t0() {
84    let x = Tensor::from([[2, 3], [4, 5]]);
85    println!("{x}");
86    //assert_eq!(x, [[2, 3], [4, 5]]);
87}
88
89// Unary test
90#[test]
91fn t1() {
92    let x = Tensor::from([[2f32, 3.], [4., 5.]]).exp();
93    println!("{x}");
94    //assert_eq!(x, [[2, 3], [4, 5]]);
95}
96
97#[test]
98fn t2() {
99    //let x = Tensor::randn([2, 2], DType::F32).reshape(256).exp().expand([256, 4]);
100    let x = Tensor::from([[[2f32, 3.]], [[4., 5.]]])
101        .expand([2, 3, 2])
102        .exp()
103        .ln()
104        .reshape([2, 3, 2, 1]);
105    //let x = Tensor::from([[[[2f32], [3.]]], [[[4.], [5.]]]]).expand([2, 3, 2, 1]);
106    //println!("{x}");
107    let y = Tensor::from([[2f32, 3., 1.], [4., 3., 2.]])
108        .reshape([2, 3, 1, 1])
109        .expand([2, 3, 2, 1]);
110    //println!("{y}");
111    let z = (&x + &y).expand([2, 3, 2, 2]).sum([3, 0]);
112    let z = z.exp().ln().permute([1, 0]).sum(0);
113    //Tensor::plot_dot_graph([&x, &y, &z], "graph0");
114    //Tensor::realize([&x, &y, &z]);
115    //println!("{x}\n{y}\n{z}");
116    println!("{z}");
117
118    //let l0 = zyx_nn::Linear::new(1024, 1024, DType::F16);
119}
120
121#[cfg(feature = "rand")]
122#[test]
123#[should_panic]
124fn t3() {
125    let x = Tensor::randn([1024, 1024], DType::F32).expand([1024, 1024, 1024]);
126    Tensor::realize([&x]).unwrap();
127}
128
129#[cfg(feature = "rand")]
130#[test]
131fn t4() {
132    let x = Tensor::uniform([1024, 1024], 0f32..1f32);
133    let y = Tensor::uniform([1024, 1024], 0f32..1f32);
134    //let z = (x * y).sum(2);
135    for _ in 0..20 {
136        let z = x.dot(&y);
137        Tensor::realize([&z]).unwrap();
138        drop(z);
139        //Tensor::plot_graph([], &format!("graph{i}"));
140    }
141    //Tensor::plot_graph([], "graph0");
142    //Tensor::realize([&z]).unwrap();
143}
144
145#[test]
146fn t5() {
147    let x = Tensor::from([[2f32, 3.], [4., 5.]]);
148    let y = x.t();
149    let z = x.exp();
150    //Tensor::plot_dot_graph([&y, &z], "graph1");
151    Tensor::realize([&y, &z]).unwrap();
152    println!("{y}\n{z}");
153}
154
155#[cfg(feature = "rand")]
156#[test]
157fn t6() {
158    //let x = Tensor::from([[2, 3], [4, 5]]).pad_zeros([(1, 3)]);
159
160    let x = Tensor::randn([14, 16], DType::U8);
161    let x = x.get((.., 8..-2));
162    println!("{x}");
163}
164
165#[test]
166fn t7() {
167    let x = Tensor::from([[2, 3], [4, 5]]);
168    //let x = x.pad_zeros([(0, 1)]);
169    let x = x.pad_zeros([(4, 3), (1, 2)]);
170    //Tensor::plot_dot_graph([], "graph0");
171    println!("{x}")
172}
173
174#[test]
175fn t8() {
176    let x = Tensor::ones([2, 3], DType::F32);
177    println!("{x}");
178}
179
180#[test]
181fn t9() {
182    let mut x = Tensor::ones([1024, 1024], DType::F32);
183    let y = Tensor::ones([1024, 1024], DType::F32);
184    for _ in 0..10 {
185        x = x.dot(&y);
186    }
187    println!("{x}");
188}
189
190#[test]
191fn t_10() {
192    let x = Tensor::eye(8, DType::I32);
193    println!("{x}");
194}
195
196#[test]
197fn t_11() {
198    let x = Tensor::from([[2, 3, 1], [3, 4, 1]]);
199    let y = Tensor::from([[2, 3], [2, 1], [4, 1]]);
200    let x = x.dot(y);
201    //let x = x.reshape([2, 1, 3]) * y.t().reshape([1, 2, 3]);
202    //let x = x.sum(2);
203    println!("{x}");
204}
205
206#[test]
207fn t_12() {
208    let mut x = Tensor::from([2, 3, 1]);
209    let w = Tensor::from([[2, 3, 2], [2, 1, 1], [4, 1, 4]]);
210    let b = Tensor::from([2, 3, 5]);
211    for _ in 0..10 {
212        x = x.dot(&w) + &b;
213        Tensor::realize([&x]).unwrap();
214    }
215    println!("{x}");
216}
217
218#[test]
219fn t_14() {
220    let mut x = Tensor::from([[2, 3, 1], [2, 4, 1]]);
221    x = x.repeat([2, 4, 1]);
222    println!("{x}");
223}
224
225#[test]
226fn t_15() {
227    let mut x = Tensor::from([[2, 3, 1], [2, 4, 1]]);
228    for _ in 0..10 {
229        x = &x + &x;
230        println!("{x}");
231        //Tensor::plot_graph([], &format!("graph{i}"));
232        Tensor::realize([&x]).unwrap();
233    }
234    println!("{x}");
235}
236
237#[test]
238fn t_16() {
239    let mut x = Tensor::from([[2, 3, 1], [2, 4, 1]]);
240    let y = Tensor::from([[5, 6, 9], [4, 2, 0]]);
241    let _z = x.exp2() + &y;
242    x = -x * &y;
243    Tensor::plot_graph([], "graph0");
244    Tensor::realize([&x]).unwrap();
245    Tensor::plot_graph([], "graph1");
246}
247
248#[test]
249fn t_17() {
250    let mut x = Tensor::from([[2, 3, 1], [2, 4, 1]]);
251    println!("{x}");
252    x = x.sum([]);
253    println!("{x}");
254}
255
256#[test]
257fn t_18() {
258    let mut x = Tensor::from([[2, 3, 1], [2, 4, 1]]);
259    let y = Tensor::from([[2, 3], [1, 2], [4, 1]]);
260    x = x.dot(y).pad_zeros([(2, 1)]);
261    println!("{x}");
262}
263*/
264
265/*#[test]
266fn t_15() {
267    let mut x = Tensor::from([[2, 3, 1], [2, 4, 1]]);
268    for _ in 0..10 {
269        x = &x + &x;
270        //println!("{x}");
271        //Tensor::plot_graph([], &format!("graph{i}"));
272        Tensor::realize([&x]).unwrap();
273    }
274    println!("{x}");
275}
276
277#[test]
278fn t_12() {
279    let mut x = Tensor::from([2, 3, 1]);
280    let w = Tensor::from([[2, 3, 2], [2, 1, 1], [4, 1, 4]]);
281    let b = Tensor::from([2, 3, 5]);
282    for _ in 0..10 {
283        x = x.dot(&w) + &b;
284        //Tensor::realize([&x]).unwrap();
285    }
286    println!("{x}");
287}*/
288
289/*#[test]
290fn t1() {
291    use crate::DType;
292    let x = Tensor::from([0f32, 5., 1.]);
293    let y = Tensor::rand([3, 5], DType::F32);
294    let a = x.dot(y);
295    let x = Tensor::from([0f32, 5., 1.]);
296    let y = Tensor::rand([3, 5], DType::F32);
297    let b = x.dot(y);
298    println!("{a}, {b}");
299}*/
300
301#[test]
302fn t2() {
303    let x = Tensor::from([4, 2, 3]);
304    let y = Tensor::from([4, 2, 3]);
305    let a = x + y;
306    println!("{a}");
307    drop(a);
308    let x = Tensor::from([4, 2, 3]);
309    let y = Tensor::from([4, 2, 3]);
310    let b = x + y;
311    println!("{b}");
312}
313
314#[test]
315fn t3() {
316    let x = Tensor::from([[2, 3, 1], [2, 1, 4]]);
317    let tensors = x.split([2, 1], 1).unwrap();
318    for t in tensors {
319        println!("{t}");
320    }
321}
322
323#[cfg(feature = "rand")]
324#[test]
325fn t4() {
326    //let x = Tensor::uniform([16, 8], 0f32..1f32).unwrap();
327    //let y = Tensor::uniform([8, 8], 0f32..1f32).unwrap();
328    let x = Tensor::rand([1024, 1024], DType::F32).unwrap();
329    let y = Tensor::rand([1024, 1024], DType::F32).unwrap();
330    for _ in 0..20 {
331        let z = x.dot(&y).unwrap();
332        //Tensor::plot_graph([], "graph0");
333        Tensor::realize([&z]).unwrap();
334        //Tensor::plot_graph([], &format!("graph0"));
335        //println!("{z}");
336        drop(z);
337        //Tensor::plot_graph([], "graph1");
338        //Tensor::plot_graph([], &format!("graph"));
339    }
340    //Tensor::plot_graph([], "graph0");
341    //Tensor::realize([&z]).unwrap();
342}
343
344#[test]
345fn t5() {
346    let x = Tensor::from([[2, 3, 1], [2, 1, 4]]);
347    println!("{}", x.get((.., 2..3)).unwrap());
348}
349
350/*#[test]
351fn t6() {
352    let x = Tensor::from([[2, 3, 1], [2, 1, 4]]);
353    let handle = std::thread::spawn(|| {
354        let y = Tensor::from([2, 3]);
355        println!("{y}");
356    });
357    println!("{x}");
358    handle.join().unwrap();
359}*/