use std::cell::RefCell;
use std::collections::HashSet;
use std::fmt;
use std::rc::Rc;
use scivex_core::Tensor;
use scivex_gpu::{GpuDevice, GpuTensor};
use crate::error::Result;
fn next_id() -> usize {
use std::sync::atomic::{AtomicUsize, Ordering};
static COUNTER: AtomicUsize = AtomicUsize::new(0);
COUNTER.fetch_add(1, Ordering::Relaxed)
}
type GpuGradFn = Box<dyn Fn(&GpuTensor) -> Vec<GpuTensor>>;
struct GpuNode {
data: GpuTensor,
grad: Option<GpuTensor>,
requires_grad: bool,
grad_fn: Option<GpuGradFn>,
parents: Vec<GpuVariable>,
id: usize,
}
#[derive(Clone)]
pub struct GpuVariable {
inner: Rc<RefCell<GpuNode>>,
}
impl GpuVariable {
pub fn new(data: GpuTensor, requires_grad: bool) -> Self {
Self {
inner: Rc::new(RefCell::new(GpuNode {
data,
grad: None,
requires_grad,
grad_fn: None,
parents: Vec::new(),
id: next_id(),
})),
}
}
pub(crate) fn from_op(data: GpuTensor, parents: Vec<GpuVariable>, grad_fn: GpuGradFn) -> Self {
Self {
inner: Rc::new(RefCell::new(GpuNode {
data,
grad: None,
requires_grad: true,
grad_fn: Some(grad_fn),
parents,
id: next_id(),
})),
}
}
pub fn data_cpu(&self) -> Result<Tensor<f32>> {
Ok(self.inner.borrow().data.to_tensor()?)
}
pub fn with_data<R>(&self, f: impl FnOnce(&GpuTensor) -> R) -> R {
let node = self.inner.borrow();
f(&node.data)
}
pub fn shape(&self) -> Vec<usize> {
self.inner.borrow().data.shape().to_vec()
}
pub fn numel(&self) -> usize {
self.inner.borrow().data.numel()
}
pub fn device(&self) -> GpuDevice {
self.inner.borrow().data.device().clone()
}
pub fn requires_grad(&self) -> bool {
self.inner.borrow().requires_grad
}
pub fn grad_cpu(&self) -> Option<Result<Tensor<f32>>> {
let node = self.inner.borrow();
node.grad.as_ref().map(|g| Ok(g.to_tensor()?))
}
fn id(&self) -> usize {
self.inner.borrow().id
}
pub fn zero_grad(&self) {
self.inner.borrow_mut().grad = None;
}
pub fn set_data(&self, data: GpuTensor) {
self.inner.borrow_mut().data = data;
}
pub fn detach(&self) -> Result<Self> {
let cpu = self.data_cpu()?;
let device = self.device();
let gpu = GpuTensor::from_tensor(&device, &cpu);
Ok(Self::new(gpu, false))
}
pub(crate) fn acc_grad(&self, g: GpuTensor) {
let mut node = self.inner.borrow_mut();
node.grad = Some(match node.grad.take() {
Some(existing) => {
scivex_gpu::ops::add(&existing, &g).expect("gradient shapes must match")
}
None => g,
});
}
pub fn backward(&self) {
let mut order = self.topo_sort();
order.reverse();
{
let node = self.inner.borrow();
let shape = node.data.shape().to_vec();
let device = node.data.device().clone();
drop(node);
let ones = scivex_gpu::ops::fill(&device, shape, 1.0).expect("fill for seed gradient");
self.acc_grad(ones);
}
for var in &order {
let (grad_fn, parents, grad_cpu) = {
let node = var.inner.borrow();
let gf = node.grad_fn.is_some();
let parents = node.parents.clone();
let grad_cpu = node.grad.as_ref().map(|g| {
let shape = g.shape().to_vec();
let device = g.device().clone();
let tensor = g.to_tensor().expect("gradient download");
(tensor, shape, device)
});
(gf, parents, grad_cpu)
};
if let (true, Some((grad_tensor, _grad_shape, grad_device))) = (grad_fn, grad_cpu) {
let grad_gpu = GpuTensor::from_tensor(&grad_device, &grad_tensor);
let node = var.inner.borrow();
let gf = node.grad_fn.as_ref().expect("checked is_some above");
let parent_grads = gf(&grad_gpu);
drop(node);
for (parent, pg) in parents.iter().zip(parent_grads) {
if parent.requires_grad() {
parent.acc_grad(pg);
}
}
}
}
}
fn topo_sort(&self) -> Vec<GpuVariable> {
let mut visited = HashSet::new();
let mut order = Vec::new();
let mut stack: Vec<(GpuVariable, bool)> = vec![(self.clone(), false)];
while let Some((var, processed)) = stack.pop() {
let vid = var.id();
if processed {
if !visited.contains(&vid) {
visited.insert(vid);
order.push(var);
}
continue;
}
if visited.contains(&vid) {
continue;
}
stack.push((var.clone(), true));
let node = var.inner.borrow();
for parent in &node.parents {
if !visited.contains(&parent.id()) {
stack.push((parent.clone(), false));
}
}
}
order
}
}
impl fmt::Debug for GpuVariable {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let node = self.inner.borrow();
f.debug_struct("GpuVariable")
.field("shape", &node.data.shape())
.field("requires_grad", &node.requires_grad)
.field("has_grad", &node.grad.is_some())
.finish()
}
}