use linear_map::LinearMap;
use device::{IDevice, DeviceType};
use memory::MemoryType;
use std::marker::PhantomData;
use std::{fmt, mem, error};
pub type TensorDesc = Vec<usize>;
#[derive(Debug)]
pub struct SharedTensor<T> {
desc: TensorDesc,
latest_location: DeviceType,
latest_copy: MemoryType,
copies: LinearMap<DeviceType, MemoryType>,
phantom: PhantomData<T>,
}
pub trait ITensorDesc {
fn rank(&self) -> usize;
fn size(&self) -> usize;
fn dims(&self) -> &Vec<usize>;
fn dims_i32(&self) -> Vec<i32>;
fn default_stride(&self) -> Vec<usize> {
let mut strides: Vec<usize> = Vec::with_capacity(self.rank());
let dim_length = self.dims().len();
match dim_length {
0 => strides,
1 => {
strides.push(1);
strides
},
_ => {
let imp_dims = &self.dims()[1..dim_length];
for (i, _) in imp_dims.iter().enumerate() {
strides.push(imp_dims[i..imp_dims.len()].iter().fold(1, |prod, &x| prod * x))
}
strides.push(1);
strides
}
}
}
fn default_stride_i32(&self) -> Vec<i32> {
self.default_stride().iter().map(|&e| e as i32).collect()
}
}
pub trait IntoTensorDesc {
fn into(&self) -> TensorDesc;
}
impl IntoTensorDesc for () {
fn into(&self) -> TensorDesc {
Vec::with_capacity(1)
}
}
impl IntoTensorDesc for usize {
fn into(&self) -> TensorDesc {
vec![*self]
}
}
impl IntoTensorDesc for u32 {
fn into(&self) -> TensorDesc {
vec![*self as usize]
}
}
impl IntoTensorDesc for isize {
fn into(&self) -> TensorDesc {
vec![*self as usize]
}
}
impl IntoTensorDesc for i32 {
fn into(&self) -> TensorDesc {
vec![*self as usize]
}
}
impl IntoTensorDesc for Vec<usize> {
fn into(&self) -> TensorDesc {
self.clone()
}
}
impl<'a> IntoTensorDesc for &'a [usize] {
fn into(&self) -> TensorDesc {
From::from(self.to_owned())
}
}
impl IntoTensorDesc for (usize, usize) {
fn into(&self) -> TensorDesc {
vec![self.0, self.1]
}
}
impl IntoTensorDesc for (usize, usize, usize) {
fn into(&self) -> TensorDesc {
vec![self.0, self.1, self.2]
}
}
impl IntoTensorDesc for (usize, usize, usize, usize) {
fn into(&self) -> TensorDesc {
vec![self.0, self.1, self.2, self.3]
}
}
impl IntoTensorDesc for (usize, usize, usize, usize, usize) {
fn into(&self) -> TensorDesc {
vec![self.0, self.1, self.2, self.3, self.4]
}
}
impl IntoTensorDesc for (usize, usize, usize, usize, usize, usize) {
fn into(&self) -> TensorDesc {
vec![self.0, self.1, self.2, self.3, self.4, self.5]
}
}
impl ITensorDesc for TensorDesc {
fn rank(&self) -> usize {
self.len()
}
fn size(&self) -> usize {
match self.rank() {
0 => 1,
_ => self.iter().fold(1, |s, &a| s * a)
}
}
fn dims(&self) -> &Vec<usize> {
self
}
fn dims_i32(&self) -> Vec<i32> {
self.iter().map(|&e| e as i32).collect()
}
}
impl<T> SharedTensor<T> {
pub fn new<D: IntoTensorDesc>(dev: &DeviceType, desc: &D) -> Result<SharedTensor<T>, Error> {
let copies = LinearMap::<DeviceType, MemoryType>::new();
let copy = try!(Self::alloc_on_device(dev, desc));
let tensor_desc: TensorDesc = desc.into();
Ok(SharedTensor {
desc: tensor_desc,
latest_location: dev.clone(),
latest_copy: copy,
copies: copies,
phantom: PhantomData,
})
}
pub fn reshape<D: IntoTensorDesc>(&mut self, desc: &D) -> Result<(), Error> {
let new_desc: TensorDesc = desc.into();
if new_desc.size() == self.desc().size() {
self.desc = new_desc;
Ok(())
} else {
Err(Error::InvalidShape("Size of the provided shape is not equal to the old shape."))
}
}
pub fn resize<D: IntoTensorDesc>(&mut self, desc: &D) -> Result<(), Error> {
self.copies.clear();
self.latest_copy = try!(Self::alloc_on_device(self.latest_device(), desc));
let new_desc: TensorDesc = desc.into();
self.desc = new_desc;
Ok(())
}
fn alloc_on_device<D: IntoTensorDesc>(dev: &DeviceType, desc: &D) -> Result<MemoryType, Error> {
let tensor_desc: TensorDesc = desc.into();
let alloc_size = Self::mem_size(tensor_desc.size());
let copy = match *dev {
#[cfg(feature = "native")]
DeviceType::Native(ref cpu) => MemoryType::Native(try!(cpu.alloc_memory(alloc_size))),
#[cfg(feature = "opencl")]
DeviceType::OpenCL(ref context) => MemoryType::OpenCL(try!(context.alloc_memory(alloc_size))),
#[cfg(feature = "cuda")]
DeviceType::Cuda(ref context) => MemoryType::Cuda(try!(context.alloc_memory(alloc_size))),
};
Ok(copy)
}
pub fn sync(&mut self, destination: &DeviceType) -> Result<(), Error> {
if &self.latest_location != destination {
let latest = self.latest_location.clone();
try!(self.sync_from_to(&latest, &destination));
let mut swap_location = destination.clone();
let mut swap_copy = try!(self.copies.remove(destination).ok_or(Error::MissingDestination("Tensor does not hold a copy on destination device.")));
mem::swap(&mut self.latest_location, &mut swap_location);
mem::swap(&mut self.latest_copy, &mut swap_copy);
self.copies.insert(swap_location, swap_copy);
}
Ok(())
}
pub fn get(&self, device: &DeviceType) -> Option<&MemoryType> {
if &self.latest_location == device {
return Some(&self.latest_copy)
}
self.copies.get(device)
}
pub fn get_mut(&mut self, device: &DeviceType) -> Option<&mut MemoryType> {
if &self.latest_location == device {
return Some(&mut self.latest_copy)
}
self.copies.get_mut(device)
}
fn sync_from_to(&mut self, source: &DeviceType, destination: &DeviceType) -> Result<(), Error> {
if source != destination {
match self.copies.get_mut(destination) {
Some(mut destination_copy) => {
match destination {
#[cfg(feature = "native")]
&DeviceType::Native(ref cpu) => {
match destination_copy.as_mut_native() {
Some(ref mut mem) => try!(cpu.sync_in(&self.latest_location, &self.latest_copy, mem)),
None => return Err(Error::InvalidMemory("Expected Native Memory (FlatBox)"))
}
},
#[cfg(feature = "cuda")]
&DeviceType::Cuda(ref context) => {
match destination_copy.as_mut_cuda() {
Some(ref mut mem) => try!(context.sync_in(&self.latest_location, &self.latest_copy, mem)),
None => return Err(Error::InvalidMemory("Expected CUDA Memory."))
}
},
#[cfg(feature = "opencl")]
&DeviceType::OpenCL(ref context) => {
match destination_copy.as_mut_opencl() {
Some(ref mut mem) => try!(context.sync_in(&self.latest_location, &self.latest_copy, mem)),
None => return Err(Error::InvalidMemory("Expected OpenCL Memory."))
}
}
}
Ok(())
},
None => Err(Error::MissingDestination("Tensor does not hold a copy on destination device."))
}
} else {
Ok(())
}
}
pub fn remove_copy(&mut self, destination: &DeviceType) -> Result<(MemoryType), Error> {
if &self.latest_location == destination {
let first = self.copies.keys().nth(0).unwrap().clone();
try!(self.sync(&first));
}
match self.copies.remove(destination) {
Some(destination_cpy) => Ok(destination_cpy),
None => Err(Error::MissingDestination("Tensor does not hold a copy on destination device."))
}
}
fn return_copy(&mut self, dest: &DeviceType, dest_mem: MemoryType) {
self.copies.insert(dest.clone(), dest_mem);
}
pub fn add_device(&mut self, device: &DeviceType) -> Result<&mut Self, Error> {
if &self.latest_location == device {
return Err(Error::InvalidMemoryAllocation("Tensor already tracks memory for this device. No memory allocation."))
}
match self.copies.get(device) {
Some(_) => Err(Error::InvalidMemoryAllocation("Tensor already tracks memory for this device. No memory allocation.")),
None => {
let copy: MemoryType;
match *device {
#[cfg(feature = "native")]
DeviceType::Native(ref cpu) => copy = MemoryType::Native(try!(cpu.alloc_memory(Self::mem_size(self.capacity())))),
#[cfg(feature = "opencl")]
DeviceType::OpenCL(ref context) => copy = MemoryType::OpenCL(try!(context.alloc_memory(Self::mem_size(self.capacity())))),
#[cfg(feature = "cuda")]
DeviceType::Cuda(ref context) => copy = MemoryType::Cuda(try!(context.alloc_memory(Self::mem_size(self.capacity())))),
};
self.copies.insert(device.clone(), copy);
Ok(self)
}
}
}
pub fn latest_device(&self) -> &DeviceType {
&self.latest_location
}
pub fn capacity(&self) -> usize {
self.desc.size()
}
pub fn desc(&self) -> &TensorDesc {
&self.desc
}
pub fn mem_size(capacity: usize) -> usize {
mem::size_of::<T>() * capacity
}
}
#[derive(Debug, Copy, Clone)]
pub enum Error {
MissingSource(&'static str),
MissingDestination(&'static str),
InvalidMemory(&'static str),
InvalidMemoryAllocation(&'static str),
InvalidRemove(&'static str),
MemoryAllocationError(::device::Error),
MemorySynchronizationError(::device::Error),
InvalidShape(&'static str)
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Error::MissingSource(ref err) => write!(f, "{:?}", err),
Error::MissingDestination(ref err) => write!(f, "{:?}", err),
Error::InvalidMemory(ref err) => write!(f, "{:?}", err),
Error::InvalidMemoryAllocation(ref err) => write!(f, "{:?}", err),
Error::InvalidRemove(ref err) => write!(f, "{:?}", err),
Error::MemoryAllocationError(ref err) => write!(f, "{}", err),
Error::MemorySynchronizationError(ref err) => write!(f, "{}", err),
Error::InvalidShape(ref err) => write!(f, "{}", err),
}
}
}
impl error::Error for Error {
fn description(&self) -> &str {
match *self {
Error::MissingSource(ref err) => err,
Error::MissingDestination(ref err) => err,
Error::InvalidMemory(ref err) => err,
Error::InvalidMemoryAllocation(ref err) => err,
Error::InvalidRemove(ref err) => err,
Error::MemoryAllocationError(ref err) => err.description(),
Error::MemorySynchronizationError(ref err) => err.description(),
Error::InvalidShape(ref err) => err,
}
}
fn cause(&self) -> Option<&error::Error> {
match *self {
Error::MissingSource(_) => None,
Error::MissingDestination(_) => None,
Error::InvalidMemory(_) => None,
Error::InvalidMemoryAllocation(_) => None,
Error::InvalidRemove(_) => None,
Error::MemoryAllocationError(ref err) => Some(err),
Error::MemorySynchronizationError(ref err) => Some(err),
Error::InvalidShape(_) => None,
}
}
}
impl From<Error> for ::error::Error {
fn from(err: Error) -> ::error::Error {
::error::Error::Tensor(err)
}
}