impl Tensor {
#[must_use]
pub fn relu(&self) -> Tensor {
let data = trueno::blis::elementwise::relu_alloc(self.data());
let mut result = Tensor::from_vec(data, self.shape());
if is_grad_enabled() && self.requires_grad_enabled() {
result.requires_grad_(true);
let grad_fn = Arc::new(ReluBackward { x: self.clone() });
result.set_grad_fn(grad_fn.clone());
with_graph(|graph| {
graph.register_tensor(self.clone());
graph.record(result.id(), grad_fn, vec![self.id()]);
});
}
result
}
#[must_use]
pub fn sigmoid(&self) -> Tensor {
let src = self.data();
let n = src.len();
let mut data = vec![0.0f32; n];
for i in 0..n {
data[i] = 1.0 / (1.0 + (-src[i]).exp());
}
let mut result = Tensor::from_vec(data, self.shape());
if is_grad_enabled() && self.requires_grad_enabled() {
result.requires_grad_(true);
let grad_fn = Arc::new(SigmoidBackward {
output: result.clone(),
});
result.set_grad_fn(grad_fn.clone());
with_graph(|graph| {
graph.register_tensor(self.clone());
graph.record(result.id(), grad_fn, vec![self.id()]);
});
}
result
}
#[must_use]
pub fn tanh_(&self) -> Tensor {
let data: Vec<f32> = self.data().iter().map(|&a| a.tanh()).collect();
let mut result = Tensor::from_vec(data, self.shape());
if is_grad_enabled() && self.requires_grad_enabled() {
result.requires_grad_(true);
let grad_fn = Arc::new(TanhBackward {
output: result.clone(),
});
result.set_grad_fn(grad_fn.clone());
with_graph(|graph| {
graph.register_tensor(self.clone());
graph.record(result.id(), grad_fn, vec![self.id()]);
});
}
result
}
#[must_use]
pub fn leaky_relu(&self, negative_slope: f32) -> Tensor {
let src = self.data();
let n = src.len();
let mut data = vec![0.0f32; n];
for i in 0..n {
data[i] = if src[i] > 0.0 { src[i] } else { negative_slope * src[i] };
}
let mut result = Tensor::from_vec(data, self.shape());
if is_grad_enabled() && self.requires_grad_enabled() {
result.requires_grad_(true);
let grad_fn = Arc::new(LeakyReluBackward {
x: self.clone(),
negative_slope,
});
result.set_grad_fn(grad_fn.clone());
with_graph(|graph| {
graph.register_tensor(self.clone());
graph.record(result.id(), grad_fn, vec![self.id()]);
});
}
result
}
#[must_use]
pub fn gelu(&self) -> Tensor {
let sqrt_2_over_pi = (2.0_f32 / std::f32::consts::PI).sqrt();
let src = self.data();
let n = src.len();
let mut data = vec![0.0f32; n];
for i in 0..n {
let x = src[i];
let inner = sqrt_2_over_pi * (x + 0.044715 * x.powi(3));
data[i] = 0.5 * x * (1.0 + inner.tanh());
}
let mut result = Tensor::from_vec(data, self.shape());
if is_grad_enabled() && self.requires_grad_enabled() {
result.requires_grad_(true);
let grad_fn = Arc::new(GeluBackward { x: self.clone() });
result.set_grad_fn(grad_fn.clone());
with_graph(|graph| {
graph.register_tensor(self.clone());
graph.record(result.id(), grad_fn, vec![self.id()]);
});
}
result
}
#[must_use]
pub fn softmax(&self) -> Tensor {
let computed = crate::nn::functional::softmax(self, -1);
let mut result = Tensor::from_vec(computed.data().to_vec(), self.shape());
if is_grad_enabled() && self.requires_grad_enabled() {
result.requires_grad_(true);
let grad_fn = Arc::new(SoftmaxBackward {
output: result.clone(),
});
result.set_grad_fn(grad_fn.clone());
with_graph(|graph| {
graph.register_tensor(self.clone());
graph.record(result.id(), grad_fn, vec![self.id()]);
});
}
result
}
}
impl Tensor {
#[provable_contracts_macros::contract("matmul-kernel-v1", equation = "matmul")]
#[must_use]
pub fn matmul(&self, other: &Tensor) -> Tensor {
assert_eq!(self.ndim(), 2, "matmul requires 2D tensors");
assert_eq!(other.ndim(), 2, "matmul requires 2D tensors");
let (m, k1) = (self.shape()[0], self.shape()[1]);
let (k2, n) = (other.shape()[0], other.shape()[1]);
assert_eq!(k1, k2, "matmul dimension mismatch: {k1} vs {k2}");
let data = if m == 1 {
let mut c = vec![0.0f32; n];
trueno::blis::gemv::gemv(k1, n, self.data(), other.data(), &mut c);
c
} else {
let a_matrix = trueno::Matrix::from_vec(m, k1, self.data().to_vec())
.expect("valid matrix dimensions");
let b_matrix = trueno::Matrix::from_vec(k2, n, other.data().to_vec())
.expect("valid matrix dimensions");
let result_matrix = a_matrix.matmul(&b_matrix).expect("matmul should succeed");
result_matrix.as_slice().to_vec()
};
let mut result = Tensor::from_vec(data, &[m, n]);
if is_grad_enabled() && (self.requires_grad_enabled() || other.requires_grad_enabled()) {
result.requires_grad_(true);
let grad_fn = Arc::new(MatmulBackward {
x: self.clone(),
y: other.clone(),
});
result.set_grad_fn(grad_fn.clone());
with_graph(|graph| {
graph.register_tensor(self.clone());
graph.register_tensor(other.clone());
graph.record(result.id(), grad_fn, vec![self.id(), other.id()]);
});
}
result
}
#[must_use]
pub fn transpose(&self) -> Tensor {
assert_eq!(self.ndim(), 2, "transpose requires 2D tensor");
let (rows, cols) = (self.shape()[0], self.shape()[1]);
let src = self.data();
let mut data = vec![0.0; rows * cols];
trueno::blis::transpose::transpose(rows, cols, src, &mut data)
.expect("transpose: dimension mismatch (should be impossible)");
let mut result = Tensor::from_vec(data, &[cols, rows]);
if is_grad_enabled() && self.requires_grad_enabled() {
result.requires_grad_(true);
let grad_fn = Arc::new(TransposeBackward);
result.set_grad_fn(grad_fn.clone());
with_graph(|graph| {
graph.register_tensor(self.clone());
graph.record(result.id(), grad_fn, vec![self.id()]);
});
}
result
}
#[must_use]
pub fn broadcast_add(&self, other: &Tensor) -> Tensor {
assert_eq!(self.ndim(), 2, "broadcast_add requires 2D matrix");
assert_eq!(other.ndim(), 1, "broadcast_add requires 1D vector");
assert_eq!(
self.shape()[1],
other.shape()[0],
"Matrix columns {} must match vector length {}",
self.shape()[1],
other.shape()[0]
);
let (rows, cols) = (self.shape()[0], self.shape()[1]);
let mut data = vec![0.0; rows * cols];
for i in 0..rows {
for j in 0..cols {
data[i * cols + j] = self.data()[i * cols + j] + other.data()[j];
}
}
let mut result = Tensor::from_vec(data, self.shape());
if is_grad_enabled() && (self.requires_grad_enabled() || other.requires_grad_enabled()) {
result.requires_grad_(true);
let grad_fn = Arc::new(BroadcastAddBackward {
x_shape: self.shape().to_vec(),
y_shape: other.shape().to_vec(),
});
result.set_grad_fn(grad_fn.clone());
with_graph(|graph| {
graph.register_tensor(self.clone());
graph.register_tensor(other.clone());
graph.record(result.id(), grad_fn, vec![self.id(), other.id()]);
});
}
result
}
#[must_use]
pub fn view(&self, new_shape: &[usize]) -> Tensor {
let old_numel: usize = self.shape().iter().product();
let new_numel: usize = new_shape.iter().product();
assert_eq!(
old_numel, new_numel,
"view: number of elements must match ({old_numel} vs {new_numel})"
);
let mut result = Tensor::new(self.data(), new_shape);
if is_grad_enabled() && self.requires_grad_enabled() {
result.requires_grad_(true);
let grad_fn = Arc::new(ViewBackward {
input_shape: self.shape().to_vec(),
});
result.set_grad_fn(grad_fn.clone());
with_graph(|graph| {
graph.register_tensor(self.clone());
graph.record(result.id(), grad_fn, vec![self.id()]);
});
}
result
}
}
#[cfg(test)]
mod tests;