use super::element::Element;
use std::sync::Arc;
#[cfg(feature = "dynamic")]
#[derive(Debug, Clone)]
#[allow(dead_code)] pub(crate) enum ViewKind {
Contiguous { offset: usize },
Indexed(Vec<usize>),
}
#[cfg(feature = "dynamic")]
#[derive(Debug, Clone)]
pub(crate) struct DynamicTensor {
pub(crate) storage: Arc<Vec<Element>>,
pub(crate) shape: Vec<usize>,
pub(crate) len: usize, pub(crate) view: ViewKind,
}
#[cfg(feature = "dynamic")]
#[allow(dead_code)] impl DynamicTensor {
#[allow(dead_code)] pub(crate) fn from_vec(data: Vec<Element>, shape: Vec<usize>) -> Self {
let len = data.len();
DynamicTensor {
storage: Arc::new(data),
shape,
len,
view: ViewKind::Contiguous { offset: 0 },
}
}
pub(crate) fn get_flat(&self, flat: usize) -> Option<&Element> {
if flat >= self.len {
return None;
}
let storage_idx = match &self.view {
ViewKind::Contiguous { offset } => offset + flat,
ViewKind::Indexed(idxs) => idxs[flat],
};
self.storage.get(storage_idx)
}
pub(crate) fn is_unique(&self) -> bool {
Arc::strong_count(&self.storage) == 1
}
pub(crate) fn materialize(&mut self) {
match &self.view {
ViewKind::Contiguous { offset: 0 } if self.is_unique() => return,
_ => {}
}
let data: Vec<Element> = (0..self.len)
.map(|i| self.get_flat(i).cloned().unwrap_or(Element::None))
.collect();
self.storage = Arc::new(data);
self.view = ViewKind::Contiguous { offset: 0 };
}
pub(crate) fn slice_indices(
&self,
indices: Vec<usize>,
new_shape: Vec<usize>,
) -> DynamicTensor {
let new_len = indices.len();
let storage_indices: Vec<usize> = indices
.iter()
.map(|&i| match &self.view {
ViewKind::Contiguous { offset } => offset + i,
ViewKind::Indexed(idxs) => idxs[i],
})
.collect();
DynamicTensor {
storage: Arc::clone(&self.storage),
shape: new_shape,
len: new_len,
view: ViewKind::Indexed(storage_indices),
}
}
pub(crate) fn reshape(&self, new_shape: Vec<usize>) -> Option<DynamicTensor> {
let new_len: usize = if new_shape.is_empty() {
1
} else {
new_shape.iter().product()
};
if new_len != self.len {
return None;
}
Some(DynamicTensor {
storage: Arc::clone(&self.storage),
shape: new_shape,
len: new_len,
view: self.view.clone(),
})
}
pub(crate) fn to_vec(&self) -> Vec<Element> {
(0..self.len)
.map(|i| self.get_flat(i).cloned().unwrap_or(Element::None))
.collect()
}
}