use std::sync::atomic::{AtomicU64, Ordering};
use crate::buffer::Buffer;
use crate::device::Device;
use crate::dtypes::DType;
use crate::shape::Shape;
static NEXT_TENSOR_ID: AtomicU64 = AtomicU64::new(1);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TensorId(pub u64);
impl TensorId {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
TensorId(NEXT_TENSOR_ID.fetch_add(1, Ordering::Relaxed))
}
}
#[derive(Debug, Clone)]
pub struct Tensor {
pub id: TensorId,
pub name: Option<String>,
pub data: Buffer,
pub shape: Shape,
pub device: Device,
pub dtype: DType,
pub requires_grad: bool,
}
impl Tensor {
pub fn new(
data: Buffer,
shape: Shape,
device: Device,
dtype: DType,
requires_grad: bool,
) -> Self {
Self {
id: TensorId::new(),
name: None,
data,
shape,
device,
dtype,
requires_grad,
}
}
pub fn with_name(mut self, name: &str) -> Self {
self.name = Some(name.to_string());
self
}
pub fn transpose(&self) -> Self {
let new_shape = self.shape.transpose();
Self {
id: TensorId::new(), name: self.name.as_ref().map(|n| format!("{}_T", n)), data: self.data.clone(), shape: new_shape,
device: self.device,
dtype: self.dtype,
requires_grad: self.requires_grad,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_tensor_id_uniqueness() {
let id1 = TensorId::new();
let id2 = TensorId::new();
assert_ne!(id1.0, id2.0, "Sequential IDs must not overlap!");
}
#[test]
fn test_atomic_id_thread_safety() {
let mut handles = vec![];
for _ in 0..10 {
handles.push(thread::spawn(|| {
let mut local_ids = vec![];
for _ in 0..100 {
local_ids.push(TensorId::new().0);
}
local_ids
}));
}
let mut all_ids = vec![];
for handle in handles {
all_ids.extend(handle.join().unwrap());
}
all_ids.sort_unstable();
all_ids.dedup();
assert_eq!(all_ids.len(), 1000, "Race condition detected! Duplicate IDs generated.");
}
#[test]
fn test_tensor_explainability_name() {
let shape = Shape::new(vec![2, 2]);
let data = Buffer::new_cpu_zeros(4, DType::F32);
let t = Tensor::new(data, shape, Device::Cpu, DType::F32, true)
.with_name("attention_weights");
assert_eq!(t.name.unwrap(), "attention_weights");
}
#[test]
fn test_zero_copy_transpose_view() {
let shape = Shape::new(vec![3, 4]);
let data = Buffer::new_cpu_zeros(12, DType::F32);
let t1 = Tensor::new(data, shape, Device::Cpu, DType::F32, true).with_name("matrix");
let initial_arc_count = t1.data.strong_count().unwrap();
let t2 = t1.transpose();
assert_ne!(t1.id, t2.id);
assert_eq!(t2.shape.dims, vec![4, 3]);
assert_eq!(t2.name.unwrap(), "matrix_T");
assert_eq!(t2.data.strong_count().unwrap(), initial_arc_count + 1);
}
}