zyx/lib.rs
1#![doc = include_str!("../README.md")]
2#![forbid(rustdoc::broken_intra_doc_links)]
3#![forbid(rustdoc::private_intra_doc_links)]
4#![forbid(missing_docs)]
5#![forbid(rustdoc::missing_crate_level_docs)]
6#![forbid(rustdoc::private_doc_tests)]
7#![forbid(rustdoc::invalid_codeblock_attributes)]
8#![forbid(rustdoc::invalid_html_tags)]
9#![forbid(rustdoc::invalid_rust_codeblocks)]
10#![forbid(rustdoc::bare_urls)]
11#![forbid(rustdoc::unescaped_backticks)]
12#![forbid(rustdoc::redundant_explicit_links)]
13
14use std::{fs::File, path::Path};
15
16use crate::runtime::Runtime;
17
18mod dtype;
19mod index_map;
20mod mutex;
21#[cfg(feature = "py")]
22mod python_bindings;
23mod runtime;
24mod scalar;
25mod shape;
26mod tensor;
27
28pub use dtype::DType;
29pub use runtime::DeviceConfig;
30pub use runtime::ZyxError;
31pub use scalar::Scalar;
32pub use shape::IntoShape;
33pub use tensor::Tensor;
34
35// Works, but rust does not call drop on this when exiting the program, which causes all sorts of problems ...
36static RT: mutex::Mutex<Runtime, 1000000000> = mutex::Mutex::new(Runtime::new());
37//static RT: mutex::Mutex<Runtime> = mutex::Mutex::new(Runtime::new());
38
39/// Save tensors or modules
40pub trait TensorSave {
41 /// Save tensors or modules
42 fn save(self, path: impl AsRef<Path>) -> Result<(), ZyxError>;
43}
44
45impl<'a, I: IntoIterator<Item = &'a Tensor>> TensorSave for I {
46 fn save(self, path: impl AsRef<Path>) -> Result<(), ZyxError> {
47 use std::fmt::Write;
48 use std::io::Write as IOWrite;
49 let mut f = File::create(path)?;
50 let mut header = String::from("{");
51 let mut begin = 0;
52 let tensors: Vec<&Tensor> = self.into_iter().collect();
53 for tensor in &tensors {
54 let dtype = tensor.dtype();
55 //if let Some(label) = tensor.label() {
56 //write!(header, "\"{label}\":{{").unwrap();
57 //} else {
58 write!(header, "\"{}\":{{", tensor.id()).unwrap();
59 //}
60 write!(header, "\"dtype\":\"{}\",", dtype.safetensors()).unwrap();
61 let mut st_shape = format!("{:?}", tensor.shape());
62 st_shape.retain(|c| !c.is_whitespace());
63 write!(header, "\"shape\":{},", st_shape).unwrap();
64 let size = tensor.numel() * dtype.byte_size();
65 write!(header, "\"data_offsets\":[{},{}]", begin, begin + size).unwrap();
66 begin += size;
67 write!(header, "}},").unwrap();
68 }
69 header.pop();
70 write!(header, "}}").unwrap();
71 let header_bytes = header.as_bytes();
72 f.write_all(&(header_bytes.len() as u64).to_le_bytes())?;
73 f.write_all(header_bytes)?;
74 for tensor in tensors {
75 f.write_all(&tensor.to_le_bytes()?)?;
76 }
77 Ok(())
78 }
79}
80
81/*
82#[test]
83fn t0() {
84 let x = Tensor::from([[2, 3], [4, 5]]);
85 println!("{x}");
86 //assert_eq!(x, [[2, 3], [4, 5]]);
87}
88
89// Unary test
90#[test]
91fn t1() {
92 let x = Tensor::from([[2f32, 3.], [4., 5.]]).exp();
93 println!("{x}");
94 //assert_eq!(x, [[2, 3], [4, 5]]);
95}
96
97#[test]
98fn t2() {
99 //let x = Tensor::randn([2, 2], DType::F32).reshape(256).exp().expand([256, 4]);
100 let x = Tensor::from([[[2f32, 3.]], [[4., 5.]]])
101 .expand([2, 3, 2])
102 .exp()
103 .ln()
104 .reshape([2, 3, 2, 1]);
105 //let x = Tensor::from([[[[2f32], [3.]]], [[[4.], [5.]]]]).expand([2, 3, 2, 1]);
106 //println!("{x}");
107 let y = Tensor::from([[2f32, 3., 1.], [4., 3., 2.]])
108 .reshape([2, 3, 1, 1])
109 .expand([2, 3, 2, 1]);
110 //println!("{y}");
111 let z = (&x + &y).expand([2, 3, 2, 2]).sum([3, 0]);
112 let z = z.exp().ln().permute([1, 0]).sum(0);
113 //Tensor::plot_dot_graph([&x, &y, &z], "graph0");
114 //Tensor::realize([&x, &y, &z]);
115 //println!("{x}\n{y}\n{z}");
116 println!("{z}");
117
118 //let l0 = zyx_nn::Linear::new(1024, 1024, DType::F16);
119}
120
121#[cfg(feature = "rand")]
122#[test]
123#[should_panic]
124fn t3() {
125 let x = Tensor::randn([1024, 1024], DType::F32).expand([1024, 1024, 1024]);
126 Tensor::realize([&x]).unwrap();
127}
128
129#[cfg(feature = "rand")]
130#[test]
131fn t4() {
132 let x = Tensor::uniform([1024, 1024], 0f32..1f32);
133 let y = Tensor::uniform([1024, 1024], 0f32..1f32);
134 //let z = (x * y).sum(2);
135 for _ in 0..20 {
136 let z = x.dot(&y);
137 Tensor::realize([&z]).unwrap();
138 drop(z);
139 //Tensor::plot_graph([], &format!("graph{i}"));
140 }
141 //Tensor::plot_graph([], "graph0");
142 //Tensor::realize([&z]).unwrap();
143}
144
145#[test]
146fn t5() {
147 let x = Tensor::from([[2f32, 3.], [4., 5.]]);
148 let y = x.t();
149 let z = x.exp();
150 //Tensor::plot_dot_graph([&y, &z], "graph1");
151 Tensor::realize([&y, &z]).unwrap();
152 println!("{y}\n{z}");
153}
154
155#[cfg(feature = "rand")]
156#[test]
157fn t6() {
158 //let x = Tensor::from([[2, 3], [4, 5]]).pad_zeros([(1, 3)]);
159
160 let x = Tensor::randn([14, 16], DType::U8);
161 let x = x.get((.., 8..-2));
162 println!("{x}");
163}
164
165#[test]
166fn t7() {
167 let x = Tensor::from([[2, 3], [4, 5]]);
168 //let x = x.pad_zeros([(0, 1)]);
169 let x = x.pad_zeros([(4, 3), (1, 2)]);
170 //Tensor::plot_dot_graph([], "graph0");
171 println!("{x}")
172}
173
174#[test]
175fn t8() {
176 let x = Tensor::ones([2, 3], DType::F32);
177 println!("{x}");
178}
179
180#[test]
181fn t9() {
182 let mut x = Tensor::ones([1024, 1024], DType::F32);
183 let y = Tensor::ones([1024, 1024], DType::F32);
184 for _ in 0..10 {
185 x = x.dot(&y);
186 }
187 println!("{x}");
188}
189
190#[test]
191fn t_10() {
192 let x = Tensor::eye(8, DType::I32);
193 println!("{x}");
194}
195
196#[test]
197fn t_11() {
198 let x = Tensor::from([[2, 3, 1], [3, 4, 1]]);
199 let y = Tensor::from([[2, 3], [2, 1], [4, 1]]);
200 let x = x.dot(y);
201 //let x = x.reshape([2, 1, 3]) * y.t().reshape([1, 2, 3]);
202 //let x = x.sum(2);
203 println!("{x}");
204}
205
206#[test]
207fn t_12() {
208 let mut x = Tensor::from([2, 3, 1]);
209 let w = Tensor::from([[2, 3, 2], [2, 1, 1], [4, 1, 4]]);
210 let b = Tensor::from([2, 3, 5]);
211 for _ in 0..10 {
212 x = x.dot(&w) + &b;
213 Tensor::realize([&x]).unwrap();
214 }
215 println!("{x}");
216}
217
218#[test]
219fn t_14() {
220 let mut x = Tensor::from([[2, 3, 1], [2, 4, 1]]);
221 x = x.repeat([2, 4, 1]);
222 println!("{x}");
223}
224
225#[test]
226fn t_15() {
227 let mut x = Tensor::from([[2, 3, 1], [2, 4, 1]]);
228 for _ in 0..10 {
229 x = &x + &x;
230 println!("{x}");
231 //Tensor::plot_graph([], &format!("graph{i}"));
232 Tensor::realize([&x]).unwrap();
233 }
234 println!("{x}");
235}
236
237#[test]
238fn t_16() {
239 let mut x = Tensor::from([[2, 3, 1], [2, 4, 1]]);
240 let y = Tensor::from([[5, 6, 9], [4, 2, 0]]);
241 let _z = x.exp2() + &y;
242 x = -x * &y;
243 Tensor::plot_graph([], "graph0");
244 Tensor::realize([&x]).unwrap();
245 Tensor::plot_graph([], "graph1");
246}
247
248#[test]
249fn t_17() {
250 let mut x = Tensor::from([[2, 3, 1], [2, 4, 1]]);
251 println!("{x}");
252 x = x.sum([]);
253 println!("{x}");
254}
255
256#[test]
257fn t_18() {
258 let mut x = Tensor::from([[2, 3, 1], [2, 4, 1]]);
259 let y = Tensor::from([[2, 3], [1, 2], [4, 1]]);
260 x = x.dot(y).pad_zeros([(2, 1)]);
261 println!("{x}");
262}
263*/
264
265/*#[test]
266fn t_15() {
267 let mut x = Tensor::from([[2, 3, 1], [2, 4, 1]]);
268 for _ in 0..10 {
269 x = &x + &x;
270 //println!("{x}");
271 //Tensor::plot_graph([], &format!("graph{i}"));
272 Tensor::realize([&x]).unwrap();
273 }
274 println!("{x}");
275}
276
277#[test]
278fn t_12() {
279 let mut x = Tensor::from([2, 3, 1]);
280 let w = Tensor::from([[2, 3, 2], [2, 1, 1], [4, 1, 4]]);
281 let b = Tensor::from([2, 3, 5]);
282 for _ in 0..10 {
283 x = x.dot(&w) + &b;
284 //Tensor::realize([&x]).unwrap();
285 }
286 println!("{x}");
287}*/
288
289/*#[test]
290fn t1() {
291 use crate::DType;
292 let x = Tensor::from([0f32, 5., 1.]);
293 let y = Tensor::rand([3, 5], DType::F32);
294 let a = x.dot(y);
295 let x = Tensor::from([0f32, 5., 1.]);
296 let y = Tensor::rand([3, 5], DType::F32);
297 let b = x.dot(y);
298 println!("{a}, {b}");
299}*/
300
301#[test]
302fn t2() {
303 let x = Tensor::from([4, 2, 3]);
304 let y = Tensor::from([4, 2, 3]);
305 let a = x + y;
306 println!("{a}");
307 drop(a);
308 let x = Tensor::from([4, 2, 3]);
309 let y = Tensor::from([4, 2, 3]);
310 let b = x + y;
311 println!("{b}");
312}
313
314#[test]
315fn t3() {
316 let x = Tensor::from([[2, 3, 1], [2, 1, 4]]);
317 let tensors = x.split([2, 1], 1).unwrap();
318 for t in tensors {
319 println!("{t}");
320 }
321}
322
323#[cfg(feature = "rand")]
324#[test]
325fn t4() {
326 //let x = Tensor::uniform([16, 8], 0f32..1f32).unwrap();
327 //let y = Tensor::uniform([8, 8], 0f32..1f32).unwrap();
328 let x = Tensor::rand([1024, 1024], DType::F32).unwrap();
329 let y = Tensor::rand([1024, 1024], DType::F32).unwrap();
330 for _ in 0..20 {
331 let z = x.dot(&y).unwrap();
332 //Tensor::plot_graph([], "graph0");
333 Tensor::realize([&z]).unwrap();
334 //Tensor::plot_graph([], &format!("graph0"));
335 //println!("{z}");
336 drop(z);
337 //Tensor::plot_graph([], "graph1");
338 //Tensor::plot_graph([], &format!("graph"));
339 }
340 //Tensor::plot_graph([], "graph0");
341 //Tensor::realize([&z]).unwrap();
342}
343
344#[test]
345fn t5() {
346 let x = Tensor::from([[2, 3, 1], [2, 1, 4]]);
347 println!("{}", x.get((.., 2..3)).unwrap());
348}
349
350/*#[test]
351fn t6() {
352 let x = Tensor::from([[2, 3, 1], [2, 1, 4]]);
353 let handle = std::thread::spawn(|| {
354 let y = Tensor::from([2, 3]);
355 println!("{y}");
356 });
357 println!("{x}");
358 handle.join().unwrap();
359}*/