use std::fmt;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use crate::primitives::Vector;
use super::grad_fn::GradFn;
use super::with_graph;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TensorId(u64);
impl TensorId {
pub fn new() -> Self {
static COUNTER: AtomicU64 = AtomicU64::new(0);
TensorId(COUNTER.fetch_add(1, Ordering::SeqCst))
}
}
impl Default for TensorId {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct Tensor {
data: Vector<f32>,
shape: Vec<usize>,
grad: Option<Box<Tensor>>,
requires_grad: bool,
is_leaf: bool,
grad_fn: Option<Arc<dyn GradFn>>,
id: TensorId,
}
impl Tensor {
#[must_use]
pub fn new(data: &[f32], shape: &[usize]) -> Self {
let expected_len: usize = shape.iter().product();
assert_eq!(
data.len(),
expected_len,
"Data length {} doesn't match shape {:?} (expected {})",
data.len(),
shape,
expected_len
);
Self {
data: Vector::from_slice(data),
shape: shape.to_vec(),
grad: None,
requires_grad: false,
is_leaf: true,
grad_fn: None,
id: TensorId::new(),
}
}
#[must_use]
pub fn from_vec(data: Vec<f32>, shape: &[usize]) -> Self {
let expected_len: usize = shape.iter().product();
assert_eq!(
data.len(),
expected_len,
"Data length {} doesn't match shape {:?} (expected {})",
data.len(),
shape,
expected_len
);
Self {
data: Vector::from_vec(data),
shape: shape.to_vec(),
grad: None,
requires_grad: false,
is_leaf: true,
grad_fn: None,
id: TensorId::new(),
}
}
#[must_use]
pub fn from_slice(data: &[f32]) -> Self {
Self::new(data, &[data.len()])
}
#[must_use]
pub fn zeros(shape: &[usize]) -> Self {
let len: usize = shape.iter().product();
Self::new(&vec![0.0; len], shape)
}
#[must_use]
pub fn ones(shape: &[usize]) -> Self {
let len: usize = shape.iter().product();
Self::new(&vec![1.0; len], shape)
}
#[must_use]
pub fn zeros_like(other: &Tensor) -> Self {
Self::zeros(&other.shape)
}
#[must_use]
pub fn ones_like(other: &Tensor) -> Self {
Self::ones(&other.shape)
}
#[must_use]
pub fn requires_grad(mut self) -> Self {
self.requires_grad = true;
self
}
pub fn requires_grad_(&mut self, requires: bool) -> &mut Self {
self.requires_grad = requires;
self
}
#[must_use]
pub fn requires_grad_enabled(&self) -> bool {
self.requires_grad
}
#[must_use]
pub fn is_leaf(&self) -> bool {
self.is_leaf
}
#[must_use]
pub fn id(&self) -> TensorId {
self.id
}
#[must_use]
pub fn shape(&self) -> &[usize] {
&self.shape
}
#[must_use]
pub fn numel(&self) -> usize {
self.shape.iter().product()
}
#[must_use]
pub fn ndim(&self) -> usize {
self.shape.len()
}
#[must_use]
pub fn data(&self) -> &[f32] {
self.data.as_slice()
}
pub fn data_mut(&mut self) -> &mut [f32] {
self.data.as_mut_slice()
}
#[must_use]
pub fn grad(&self) -> Option<&Tensor> {
self.grad.as_deref()
}
pub fn zero_grad_(&mut self) {
self.grad = None;
}
pub fn clear_grad(&mut self) {
self.grad = None;
}
pub(crate) fn accumulate_grad(&mut self, grad: Tensor) {
match &mut self.grad {
Some(existing) => {
let new_data: Vec<f32> = existing
.data()
.iter()
.zip(grad.data().iter())
.map(|(a, b)| a + b)
.collect();
**existing = Tensor::new(&new_data, &self.shape);
}
None => {
self.grad = Some(Box::new(grad));
}
}
}
pub(crate) fn set_grad_fn(&mut self, grad_fn: Arc<dyn GradFn>) {
self.grad_fn = Some(grad_fn);
self.is_leaf = false;
}
#[allow(dead_code)]
pub(crate) fn grad_fn(&self) -> Option<&Arc<dyn GradFn>> {
self.grad_fn.as_ref()
}
#[must_use]
pub fn detach(&self) -> Tensor {
Tensor {
data: self.data.clone(),
shape: self.shape.clone(),
grad: None,
requires_grad: false,
is_leaf: true,
grad_fn: None,
id: TensorId::new(),
}
}
#[must_use]
pub fn item(&self) -> f32 {
assert_eq!(
self.numel(),
1,
"item() only works on tensors with exactly 1 element, got {}",
self.numel()
);
self.data[0]
}
pub fn backward(&self) {
assert_eq!(
self.numel(),
1,
"backward() requires scalar output, got shape {:?}. Use backward_with_grad() instead.",
self.shape
);
self.backward_with_grad(Tensor::ones(&self.shape));
}
pub fn backward_with_grad(&self, grad_output: Tensor) {
with_graph(|graph| {
graph.backward(self.id, grad_output);
});
}
}
impl fmt::Debug for Tensor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Tensor")
.field("shape", &self.shape)
.field("requires_grad", &self.requires_grad)
.field("is_leaf", &self.is_leaf)
.field("has_grad", &self.grad.is_some())
.field("id", &self.id)
.finish_non_exhaustive()
}
}
#[cfg(test)]
#[path = "tensor_tests.rs"]
mod tests;