zyx_cpu/
lib.rs

1//! CPU only, pure rust backend for zyx
2//!
3//! Initialize backend.
4//! ```rust
5//! let dev = zyx_cpu::device()?;
6//! # Ok::<(), zyx_cpu::ZyxError>(())
7//! ```
8//!
9//! For README, quick tutorial and source code, please visit [https://www.github.com/zk4x/zyx].
10//!
11//! For more details, there is a [book](https://www.github.com/zk4x/zyx/tree/main/zyx-book).
12
13#![no_std]
14//#![forbid(unsafe_code)]
15#![forbid(rustdoc::broken_intra_doc_links)]
16#![forbid(rustdoc::private_intra_doc_links)]
17#![forbid(missing_docs)]
18#![forbid(rustdoc::missing_crate_level_docs)]
19//#![forbid(rustdoc::missing_doc_code_examples)]
20#![forbid(rustdoc::private_doc_tests)]
21#![forbid(rustdoc::invalid_codeblock_attributes)]
22#![forbid(rustdoc::invalid_html_tags)]
23#![forbid(rustdoc::invalid_rust_codeblocks)]
24#![forbid(rustdoc::bare_urls)]
25#![forbid(rustdoc::unescaped_backticks)]
26#![forbid(rustdoc::redundant_explicit_links)]
27
28#[cfg(feature = "std")]
29extern crate std;
30
31mod interpreter;
32use crate::interpreter::Interpreter;
33
34extern crate alloc;
35use alloc::{
36    collections::{BTreeMap, BTreeSet},
37    vec::Vec,
38};
39use core::ops::Range;
40use std::cell::RefCell;
41#[cfg(feature = "std")]
42pub use zyx_core::io::save;
43use zyx_core::{
44    backend::Backend,
45    node::Node,
46    runtime::Runtime,
47    scalar::Scalar,
48    shape::Shape,
49    tensor::Id,
50    tensor::{tensor, IntoTensor},
51};
52pub use zyx_core::{dtype::DType, error::ZyxError, tensor::Tensor};
53
54/// CPU backend
55pub struct CPU(RefCell<Runtime<Interpreter>>);
56
57/// Create new CPU backend
58pub fn device() -> Result<CPU, ZyxError> {
59    Ok(CPU(RefCell::new(Runtime::new(Interpreter::new()))))
60}
61
62impl CPU {
63    /// Create new tensor
64    #[must_use]
65    pub fn tensor<'a>(&'a self, data: impl IntoTensor<&'a Self>) -> Tensor<&'a Self> {
66        <&Self as Backend>::tensor(self, data).unwrap()
67    }
68
69    /// Create new tensor using values from standard normal distribution
70    #[must_use]
71    pub fn randn(&self, shape: impl Into<Shape>, dtype: DType) -> Tensor<&Self> {
72        <&Self as Backend>::randn(self, shape, dtype).unwrap()
73    }
74
75    /// Create new tensor using values from uniform distribution
76    #[must_use]
77    pub fn uniform(&self, shape: impl Into<Shape>, range: Range<impl Scalar>) -> Tensor<&Self> {
78        <&Self as Backend>::uniform(self, shape, range).unwrap()
79    }
80
81    /// Create new tensor by repeating single value
82    #[must_use]
83    pub fn full(&self, shape: impl Into<Shape>, value: impl Scalar) -> Tensor<&Self> {
84        <&Self as Backend>::full(self, shape, value).unwrap()
85    }
86
87    /// Create new tensor by repeating zeroes
88    #[must_use]
89    pub fn zeros(&self, shape: impl Into<Shape>, dtype: DType) -> Tensor<&Self> {
90        <&Self as Backend>::zeros(self, shape, dtype).unwrap()
91    }
92
93    /// Create new tensor by repeating ones
94    #[must_use]
95    pub fn ones(&self, shape: impl Into<Shape>, dtype: DType) -> Tensor<&Self> {
96        <&Self as Backend>::ones(self, shape, dtype).unwrap()
97    }
98
99    /// Create eye tensor
100    #[must_use]
101    pub fn eye(&self, n: usize, dtype: DType) -> Tensor<&Self> {
102        <&Self as Backend>::eye(self, n, dtype).unwrap()
103    }
104
105    /// Create graph of operations between tensors in dot format for visualization
106    #[must_use]
107    pub fn plot_graph<'a, B: Backend + 'a>(
108        &self,
109        tensors: impl IntoIterator<Item = &'a Tensor<B>>,
110    ) -> alloc::string::String {
111        <&Self as Backend>::plot_graph(self, tensors)
112    }
113
114    /// Load tensors from disk.
115    #[cfg(feature = "std")]
116    pub fn load(&self, path: impl AsRef<std::path::Path>) -> Result<Vec<Tensor<&CPU>>, ZyxError> {
117        zyx_core::io::load(self, path)
118    }
119}
120
121impl Backend for &CPU {
122    fn plot_graph<'a, B: Backend + 'a>(
123        self,
124        tensors: impl IntoIterator<Item = &'a Tensor<B>>,
125    ) -> alloc::string::String {
126        let ids: Vec<Id> = tensors.into_iter().map(|t| t.id()).collect();
127        self.0.borrow().plot_graph_dot(&ids)
128    }
129
130    fn randn(self, shape: impl Into<Shape>, dtype: DType) -> Result<Tensor<Self>, ZyxError> {
131        Ok(tensor(
132            self.0.borrow_mut().randn(shape.into(), dtype)?,
133            self,
134        ))
135    }
136
137    fn uniform(
138        self,
139        shape: impl Into<Shape>,
140        range: Range<impl Scalar>,
141    ) -> Result<Tensor<Self>, ZyxError> {
142        Ok(tensor(
143            self.0.borrow_mut().uniform(shape.into(), range)?,
144            self,
145        ))
146    }
147
148    fn shape(self, x: Id) -> Shape {
149        self.0.borrow().shape(x).clone()
150    }
151
152    fn dtype(self, x: Id) -> DType {
153        self.0.borrow().dtype(x)
154    }
155
156    fn backward(self, x: Id, sources: &BTreeSet<Id>) -> Result<BTreeMap<Id, Id>, ZyxError> {
157        self.0.borrow_mut().backward(x, sources)
158    }
159
160    fn load<T: Scalar>(self, x: Id) -> Result<Vec<T>, ZyxError> {
161        self.0.borrow_mut().load(x)
162    }
163
164    fn store<T: Scalar, IT>(self, iter: IT) -> Result<Id, ZyxError>
165    where
166        IT: IntoIterator<Item = T>,
167        IT::IntoIter: ExactSizeIterator,
168    {
169        self.0.borrow_mut().store(iter)
170    }
171
172    fn push(self, node: Node) -> Result<Id, ZyxError> {
173        self.0.borrow_mut().push(node)
174    }
175
176    fn release(self, x: Id) -> Result<(), ZyxError> {
177        self.0.borrow_mut().release(x)
178    }
179
180    fn retain(self, x: Id) {
181        self.0.borrow_mut().retain(x);
182    }
183}
184
185/*#[test]
186fn t0() -> Result<(), ZyxError> {
187    let dev = crate::device()?;
188    //let x = dev.tensor([[3, 2, 4], [4, 2, 3]]);
189    //crate::save([&x], "../x.safetensors")?;
190    let x = crate::load(&dev, "../x.safetensors")?.next().unwrap();
191    std::println!("{x}");
192    Ok(())
193}*/
194
195/*#[test]
196fn t0() -> Result<(), ZyxError> {
197    let dev = device()?;
198    let x = dev.randn([2, 3], DType::F32);
199    let y = dev.randn([2, 3], DType::F32);
200    let z = (&x + &y).exp() + &x;
201    let _grads = z.backward([&y]);
202    Ok(())
203}*/
204
205/*#[test]
206fn test_layer_norm() -> Result<(), ZyxError> {
207    let dev = device()?;
208    let x = dev.randn([2, 3], DType::F32);
209    let _n = x.shape()[-1];
210
211    //let z = (x - (x.sum(-1)/n).expand())/(((x - (x.sum(-1)/n).expand()).sum(-1)/n + 0.00001.expand()).sqrt()).expand();
212
213    //let x = x.dot(w);
214    //let x = a * (x - x.mean(-1))/(x.var(-1) + 0.00001).sqrt() + b;
215    //let x = x.tanh();
216    //let x = x.dropout(0.3);
217
218    Ok(())
219}*/
220
221/*#[test]
222fn t0() -> Result<(), ZyxError> {
223    let dev = device()?;
224    let Q = dev.tensor([[0.0, 0.75], [0.0, 0.0]]);
225    let E = dev.eye(2, DType::F32);
226    let p = dev.tensor([30000., 20000.]);
227    let d = 0.909090909090909090;
228    let R = dev.tensor([[0.0, 0.25], [0.8, 0.2]]);
229    let n = dev.tensor([0.0, 100.0]);
230    let inv = dev.tensor([[1.0, 0.681818], [0.0, 1.0]]);
231
232    let y = p.dot(inv).dot(R).dot(n);
233
234    std::println!("{y}");
235    //std::println!("{}", n.transpose());
236    panic!();
237    Ok(())
238}*/
239
240/*#[test]
241fn t5() -> Result<(), ZyxError> {
242    let dev = device()?;
243    //let mut x = dev.randn([1024, 1024], DType::F32);
244    let mut x = dev.tensor(3);
245    //let y = dev.randn([1024, 1024], DType::F32);
246    /*let x = x + 1;
247    let x = x + 1;
248    let x = x + 1;
249    let x = x + 1;*/
250    for i in 0..1000000 {
251        if i % 100000 == 0 {
252            std::println!("i: {i}");
253        }
254        x = x + 1;
255    }
256
257    std::println!("{x}");
258    panic!();
259    Ok(())
260}*/