use std::mem;
use std::cell::{Cell, RefCell};
use std::marker::PhantomData;
use {Alloc, ComputeDevice, Device, ErrorKind, Memory, Result, Synch};
use utility::Has;
#[derive(Debug)]
pub struct SharedTensor<T = f32> {
pub shape: Shape,
copies: RefCell<Vec<(ComputeDevice, Memory<T>)>>,
versions: u64Map,
phantom: PhantomData<T>,
}
impl<T> SharedTensor<T> where Device: Alloc<T> + Synch<T> {
pub fn new<A>(sh: A) -> Self where A: Into<Shape> {
let shape = sh.into();
let copies = RefCell::new(vec![]);
let versions = u64Map::new();
SharedTensor { shape, copies, versions, phantom: PhantomData }
}
pub fn with<H, I>(con: &H, sh: I, chunk: Vec<T>) -> Result<Self>
where H: Has<Device>,
I: Into<Shape>,
{
let shape = sh.into();
let device = con.get_ref();
let buffer = device.allocwrite(&shape, chunk)?;
let copies = RefCell::new(vec![(device.view(), buffer)]);
let versions = u64Map::with(1);
Ok(SharedTensor { shape, copies, versions, phantom: PhantomData })
}
pub fn alloc<H, I>(con: &H, sh: I) -> Result<Self>
where H: Has<Device>,
I: Into<Shape>
{
let shape = sh.into();
let device = con.get_ref();
let buffer = device.alloc(&shape)?;
let copies = RefCell::new(vec![(device.view(), buffer)]);
let versions = u64Map::with(1);
Ok(SharedTensor { shape, copies, versions, phantom: PhantomData })
}
pub fn dealloc<H>(&mut self, con: &H) -> Result<Memory<T>> where H: Has<Device> {
let device = con.get_ref();
let location = device.view();
match self.get_location_index(&location) {
Some(i) => {
let (_, memory) = self.copies.borrow_mut().remove(i);
let version = self.versions.get();
let mask = (1 << i) - 1;
let lower = version & mask;
let upper = (version >> 1) & (!mask);
self.versions.set(lower | upper);
Ok(memory)
},
_ => Err(ErrorKind::AllocatedMemoryNotFoundForDevice.into())
}
}
pub fn realloc<H, I>(&mut self, dev: &H, sh: I) -> Result
where H: Has<Device>,
I: Into<Shape>
{
unimplemented!()
}
pub fn reshape<I>(&mut self, sh: I) -> Result where I: Into<Shape> {
let shape = sh.into();
if shape.capacity() != self.shape.capacity() {
return Err(ErrorKind::InvalidReshapedTensorSize.into());
}
self.shape = shape;
Ok(())
}
pub fn capacity(&self) -> usize {
self.shape.capacity()
}
}
impl<T> SharedTensor<T> where Device: Alloc<T> + Synch<T> {
pub fn read<'shared, H>(&'shared self, dev: &H) -> Result<&'shared Memory<T>>
where H: Has<Device> {
let i = self.autosync(dev, false)?;
let borrowed_copies = self.copies.borrow();
let (_, ref buffer) = borrowed_copies[i];
let memory = unsafe { extend_lifetime::<'shared>(buffer) };
Ok(memory)
}
pub fn read_write<'shared, H>(&'shared mut self, dev: &H) -> Result<&'shared mut Memory<T>>
where H: Has<Device> {
let i = self.autosync(dev, true)?;
let mut borrowed_copies = self.copies.borrow_mut();
let (_, ref mut buffer) = borrowed_copies[i];
let memory = unsafe { extend_lifetime_mut::<'shared>(buffer) };
Ok(memory)
}
pub fn write<'shared, H>(&'shared mut self, con: &H) -> Result<&'shared mut Memory<T>>
where H: Has<Device> {
let i = self.get_or_create_location_index(con)?;
self.versions.set(1 << i);
let mut borrowed_copies = self.copies.borrow_mut();
let (_, ref mut buffer) = borrowed_copies[i];
let memory = unsafe { extend_lifetime_mut::<'shared>(buffer) };
Ok(memory)
}
}
impl<T> SharedTensor<T> where Device: Alloc<T> + Synch<T> {
fn get_location_index(&self, location: &ComputeDevice) -> Option<usize> {
for (i, l) in self.copies.borrow().iter().map(|&(ref l, _)| l).enumerate() {
if l.eq(location) {
return Some(i);
}
}
None
}
fn get_or_create_location_index<H>(&self, con: &H) -> Result<usize> where H: Has<Device> {
let device = con.get_ref();
let location = device.view();
if let Some(i) = self.get_location_index(&location) {
return Ok(i);
}
if self.copies.borrow().len() == u64Map::CAPACITY {
return Err(ErrorKind::BitMapCapacityExceeded.into());
}
let memory = device.alloc(&self.shape)?;
self.copies.borrow_mut().push((location, memory));
Ok(self.copies.borrow().len() - 1)
}
pub fn autosync<H>(&self, dev: &H, tick: bool) -> Result<usize> where H: Has<Device> {
if self.versions.empty() {
return Err(ErrorKind::UninitializedMemory.into());
}
let i = self.get_or_create_location_index(dev)?;
self.autosync_(i)?;
if tick {
self.versions.set(1 << i);
} else {
self.versions.insert(i);
}
Ok(i)
}
fn autosync_(&self, destination_index: usize) -> Result {
if self.versions.contains(destination_index) {
return Ok(());
}
let source_index = self.versions.latest() as usize;
assert_ne!(source_index, u64Map::CAPACITY);
assert_ne!(source_index, destination_index);
let mut borrowed_copies = self.copies.borrow_mut();
let (source, mut destination) = {
if source_index < destination_index {
let (left, right) = borrowed_copies.split_at_mut(destination_index);
(&left[source_index], &mut right[0])
} else {
let (left, right) = borrowed_copies.split_at_mut(source_index);
(&right[0], &mut left[destination_index])
}
};
match source.0.device().read(&source.1, &mut destination.0, &mut destination.1) {
Err(ref e) if e.kind() == ErrorKind::NoAvailableSynchronizationRouteFound => { },
ret @ _ => return ret,
}
destination.0.device().write(&mut destination.1, &source.0, &source.1)
}
}
#[derive(Clone, Debug)]
pub struct Shape {
pub capacity: usize,
rank: usize,
pub dims: Vec<usize>,
}
impl Shape {
pub fn capacity(&self) -> usize {
self.capacity
}
}
impl From<usize> for Shape {
fn from(n: usize) -> Shape {
[n].into()
}
}
impl From<[usize; 1]> for Shape {
fn from(array: [usize; 1]) -> Shape {
let capacity = array[0];
let rank = 1;
let dims = array.to_vec();
Shape { capacity, rank, dims }
}
}
impl From<[usize; 2]> for Shape {
fn from(array: [usize; 2]) -> Shape {
let capacity = array.iter().fold(1, |acc, &dims| acc * dims);
let rank = 2;
let dims = array.to_vec();
Shape { capacity, rank, dims }
}
}
impl From<[usize; 3]> for Shape {
fn from(array: [usize; 3]) -> Shape {
let capacity = array.iter().fold(1, |acc, &dims| acc * dims);
let rank = 3;
let dims = array.to_vec();
Shape { capacity, rank, dims }
}
}
#[allow(non_camel_case_types)]
#[derive(Debug)]
pub struct u64Map(Cell<u64>);
impl u64Map {
const CAPACITY: usize = 64;
fn new() -> u64Map {
u64Map::with(0)
}
fn with(n: u64) -> u64Map {
u64Map(Cell::new(n))
}
fn get(&self) -> u64 {
self.0.get()
}
fn set(&self, v: u64) {
self.0.set(v)
}
fn empty(&self) -> bool {
self.0.get() == 0
}
fn insert(&self, k: usize) {
self.0.set(self.0.get() | (1 << k))
}
fn contains(&self, k: usize) -> bool {
k < Self::CAPACITY && (self.0.get() & (1 << k) != 0)
}
fn latest(&self) -> u32 {
self.0.get().trailing_zeros()
}
}
unsafe fn extend_lifetime<'a, 'b, T>(t: &'a T) -> &'b T {
mem::transmute::<&'a T, &'b T>(t)
}
unsafe fn extend_lifetime_mut<'a, 'b, T>(t: &'a mut T) -> &'b mut T {
mem::transmute::<&'a mut T, &'b mut T>(t)
}