#![no_std]
#![forbid(rustdoc::broken_intra_doc_links)]
#![forbid(rustdoc::private_intra_doc_links)]
#![forbid(missing_docs)]
#![forbid(rustdoc::missing_crate_level_docs)]
#![forbid(rustdoc::private_doc_tests)]
#![forbid(rustdoc::invalid_codeblock_attributes)]
#![forbid(rustdoc::invalid_html_tags)]
#![forbid(rustdoc::invalid_rust_codeblocks)]
#![forbid(rustdoc::bare_urls)]
#![forbid(rustdoc::unescaped_backticks)]
#![forbid(rustdoc::redundant_explicit_links)]
#[cfg(feature = "std")]
extern crate std;
mod interpreter;
use crate::interpreter::Interpreter;
extern crate alloc;
use alloc::{
collections::{BTreeMap, BTreeSet},
vec::Vec,
};
use core::ops::Range;
use std::cell::RefCell;
#[cfg(feature = "std")]
pub use zyx_core::io::save;
use zyx_core::{
backend::Backend,
node::Node,
runtime::Runtime,
scalar::Scalar,
shape::Shape,
tensor::Id,
tensor::{tensor, IntoTensor},
};
pub use zyx_core::{dtype::DType, error::ZyxError, tensor::Tensor};
pub struct CPU(RefCell<Runtime<Interpreter>>);
pub fn device() -> Result<CPU, ZyxError> {
Ok(CPU(RefCell::new(Runtime::new(Interpreter::new()))))
}
impl CPU {
#[must_use]
pub fn tensor<'a>(&'a self, data: impl IntoTensor<&'a Self>) -> Tensor<&'a Self> {
<&Self as Backend>::tensor(self, data).unwrap()
}
#[must_use]
pub fn randn(&self, shape: impl Into<Shape>, dtype: DType) -> Tensor<&Self> {
<&Self as Backend>::randn(self, shape, dtype).unwrap()
}
#[must_use]
pub fn uniform(&self, shape: impl Into<Shape>, range: Range<impl Scalar>) -> Tensor<&Self> {
<&Self as Backend>::uniform(self, shape, range).unwrap()
}
#[must_use]
pub fn full(&self, shape: impl Into<Shape>, value: impl Scalar) -> Tensor<&Self> {
<&Self as Backend>::full(self, shape, value).unwrap()
}
#[must_use]
pub fn zeros(&self, shape: impl Into<Shape>, dtype: DType) -> Tensor<&Self> {
<&Self as Backend>::zeros(self, shape, dtype).unwrap()
}
#[must_use]
pub fn ones(&self, shape: impl Into<Shape>, dtype: DType) -> Tensor<&Self> {
<&Self as Backend>::ones(self, shape, dtype).unwrap()
}
#[must_use]
pub fn eye(&self, n: usize, dtype: DType) -> Tensor<&Self> {
<&Self as Backend>::eye(self, n, dtype).unwrap()
}
#[must_use]
pub fn plot_graph<'a, B: Backend + 'a>(
&self,
tensors: impl IntoIterator<Item = &'a Tensor<B>>,
) -> alloc::string::String {
<&Self as Backend>::plot_graph(self, tensors)
}
#[cfg(feature = "std")]
pub fn load(&self, path: impl AsRef<std::path::Path>) -> Result<Vec<Tensor<&CPU>>, ZyxError> {
zyx_core::io::load(self, path)
}
}
impl Backend for &CPU {
fn plot_graph<'a, B: Backend + 'a>(
self,
tensors: impl IntoIterator<Item = &'a Tensor<B>>,
) -> alloc::string::String {
let ids: Vec<Id> = tensors.into_iter().map(|t| t.id()).collect();
self.0.borrow().plot_graph_dot(&ids)
}
fn randn(self, shape: impl Into<Shape>, dtype: DType) -> Result<Tensor<Self>, ZyxError> {
Ok(tensor(
self.0.borrow_mut().randn(shape.into(), dtype)?,
self,
))
}
fn uniform(
self,
shape: impl Into<Shape>,
range: Range<impl Scalar>,
) -> Result<Tensor<Self>, ZyxError> {
Ok(tensor(
self.0.borrow_mut().uniform(shape.into(), range)?,
self,
))
}
fn shape(self, x: Id) -> Shape {
self.0.borrow().shape(x).clone()
}
fn dtype(self, x: Id) -> DType {
self.0.borrow().dtype(x)
}
fn backward(self, x: Id, sources: &BTreeSet<Id>) -> Result<BTreeMap<Id, Id>, ZyxError> {
self.0.borrow_mut().backward(x, sources)
}
fn load<T: Scalar>(self, x: Id) -> Result<Vec<T>, ZyxError> {
self.0.borrow_mut().load(x)
}
fn store<T: Scalar, IT>(self, iter: IT) -> Result<Id, ZyxError>
where
IT: IntoIterator<Item = T>,
IT::IntoIter: ExactSizeIterator,
{
self.0.borrow_mut().store(iter)
}
fn push(self, node: Node) -> Result<Id, ZyxError> {
self.0.borrow_mut().push(node)
}
fn release(self, x: Id) -> Result<(), ZyxError> {
self.0.borrow_mut().release(x)
}
fn retain(self, x: Id) {
self.0.borrow_mut().retain(x);
}
}