use crate::{
graph::Graph,
op::{self, Function},
prelude::Data,
shape::*,
tensor::Tensor,
};
use std::marker::PhantomData;
use std::{fmt::Debug, path::Path};
use petgraph::graph::NodeIndex;
#[derive(Clone, Copy)]
pub struct GraphTensor<S: Shape> {
pub id: NodeIndex,
pub graph_ref: *mut Graph,
pub(crate) _phantom: PhantomData<S>,
pub shape: ShapeTracker,
}
impl<S: Shape> GraphTensor<S> {
pub fn from_id(id: NodeIndex, shape: ShapeTracker, graph_ref: *mut Graph) -> Self {
Self {
id,
graph_ref,
shape,
_phantom: Default::default(),
}
}
pub fn keep(self) -> Self {
self.graph().keep_tensors(self.id);
self
}
pub fn retrieve(self) -> Self {
self.keep();
self.graph().retrieve_tensors(self.id);
self
}
pub fn drop(&self) {
self.graph().drop_tensors(self.id);
}
#[allow(clippy::mut_from_ref)]
pub fn graph(&self) -> &mut Graph {
unsafe { self.graph_ref.as_mut().unwrap() }
}
pub fn set_dyn<T: Data + Clone>(self, data: T, shape: &[usize]) -> Self {
assert_eq!(
S::realized_shape().len(),
shape.len(),
"Number of dimensions don't match!"
);
for (d, s) in S::realized_shape().iter().zip(shape.iter()) {
if let Some(c) = d.to_symbols().pop() {
self.graph().dyn_map.insert(c, *s);
}
}
let node = self
.graph()
.graph
.node_weight_mut(self.id)
.unwrap()
.as_any_mut()
.downcast_mut::<Function>()
.unwrap();
node.1 = Box::new(move |_| {
vec![Tensor {
data: Box::new(data.clone()),
}]
});
self
}
pub fn set_name(&self, name: &str) {
let node = self
.graph()
.graph
.node_weight_mut(self.id)
.unwrap()
.as_any_mut()
.downcast_mut::<Function>()
.unwrap();
node.0 = name.to_string();
}
pub fn print<T: ToString>(&self, message: T) {
let id = self
.graph()
.add_op(op::Print(message.to_string()))
.input(self.id, 0, self.shape)
.finish();
self.graph().no_delete.insert(id);
}
pub fn diff<T: AsRef<Path>>(&self, file: T, threshold: f32) {
let id = self
.graph()
.add_op(op::Diff(file.as_ref().into(), threshold))
.input(self.id, 0, self.shape)
.finish();
self.graph().no_delete.insert(id);
}
pub fn no_shape(self) -> GraphTensor<()> {
GraphTensor::from_id(self.id, self.shape, self.graph_ref)
}
pub fn data(&self) -> Vec<f32> {
let mut st = self.shape;
st.resolve_global_dyn_dims(&self.graph().dyn_map);
let tensor = self.graph().get_tensor_ref(self.id, 0).unwrap();
let orig_data = tensor.data.as_any().downcast_ref::<Vec<f32>>().unwrap();
let mut data = vec![0.; st.n_elements().to_usize().unwrap()];
let ind = st.index_expression();
let val = st.valid_expression();
#[allow(unused_mut)]
for (i, mut r) in data.iter_mut().enumerate() {
if val.exec_single_var(i) != 0 {
*r = orig_data[ind.exec_single_var(i)];
}
}
data
}
}
impl<S: ConstShape> GraphTensor<S> {
pub fn set<T: Data + Clone, D: ToData<S, T>>(self, data: D) -> Self {
let node = self
.graph()
.graph
.node_weight_mut(self.id)
.unwrap()
.as_any_mut()
.downcast_mut::<Function>()
.unwrap();
let data = data.to_data_vec();
node.1 = Box::new(move |_| vec![Tensor::new(data.clone())]);
self
}
pub fn set_deferred(self, loader: impl Fn() -> Vec<f32> + 'static) -> Self {
let node = self
.graph()
.graph
.node_weight_mut(self.id)
.unwrap()
.as_any_mut()
.downcast_mut::<Function>()
.unwrap();
node.1 = Box::new(move |_| {
vec![Tensor {
data: Box::new(loader()),
}]
});
self
}
}
fn pretty_print_tensor_recursive(
f: &mut std::fmt::Formatter<'_>,
data: &[f32],
shape: &[usize],
level: usize,
) -> std::fmt::Result {
if shape.is_empty() {
return Ok(());
}
let indent = " ".repeat(level);
if shape.len() == 1 {
write!(f, "{}[", indent)?;
if data.len() > 10 {
for (i, value) in data.iter().take(5).enumerate() {
write!(f, "{:.6}", value)?;
if i < data.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, "..., ")?;
for (i, value) in data.iter().skip(data.len() - 5).enumerate() {
write!(f, "{:.6}", value)?;
if i < data.len() - 1 {
write!(f, ", ")?;
}
}
} else {
for (i, value) in data.iter().enumerate() {
write!(f, "{:.6}", value)?;
if i < data.len() - 1 {
write!(f, ", ")?;
}
}
}
write!(f, "]")?; } else {
writeln!(f, "{indent}[")?;
let stride = shape[1..].iter().product();
if data.len() / stride > 10 {
for (i, chunk) in data.chunks(stride).take(5).enumerate() {
pretty_print_tensor_recursive(f, chunk, &shape[1..], level + 1)?;
if i < shape[0] - 1 {
writeln!(f, ",")?; }
}
writeln!(f, "{indent} ..., ")?;
for (i, chunk) in data
.chunks(stride)
.skip(data.len() / stride - 5)
.enumerate()
{
pretty_print_tensor_recursive(f, chunk, &shape[1..], level + 1)?;
if i < shape[0] - 1 {
writeln!(f, ",")?; }
}
} else {
for (i, chunk) in data.chunks(stride).enumerate() {
pretty_print_tensor_recursive(f, chunk, &shape[1..], level + 1)?;
if i < shape[0] - 1 {
writeln!(f, ",")?; }
}
}
writeln!(f)?; write!(f, "{indent}]")?; }
if level == 0 {
writeln!(f)?;
}
Ok(())
}
impl<S: Shape> Debug for GraphTensor<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let data = self.data();
let shape = self
.shape
.shape()
.iter()
.map(|expr| expr.exec(&self.graph().dyn_map).unwrap())
.collect::<Vec<_>>();
writeln!(f, "Tensor with Shape: {:?}", shape)?;
pretty_print_tensor_recursive(f, &data, &shape, 0)
}
}
pub trait MarkTensors {
fn keep(&self);
fn retrieve(&self);
fn drop(&self);
fn set_dyn<T: Data + Clone>(&self, data: T, shape: &[usize]);
}
impl<S: Shape> MarkTensors for GraphTensor<S> {
fn keep(&self) {
GraphTensor::keep(*self);
}
fn retrieve(&self) {
GraphTensor::retrieve(*self);
}
fn drop(&self) {
GraphTensor::drop(self);
}
fn set_dyn<T: Data + Clone>(&self, data: T, shape: &[usize]) {
GraphTensor::set_dyn(*self, data, shape);
}
}
impl<S: MarkTensors> MarkTensors for Vec<S> {
fn keep(&self) {
for t in self {
t.keep();
}
}
fn retrieve(&self) {
for t in self {
t.retrieve();
}
}
fn drop(&self) {
for t in self {
t.drop();
}
}
fn set_dyn<T: Data + Clone>(&self, data: T, shape: &[usize]) {
for t in self {
t.set_dyn(data.clone(), shape);
}
}
}
impl<S: MarkTensors> MarkTensors for &[S] {
fn keep(&self) {
for t in *self {
t.keep();
}
}
fn retrieve(&self) {
for t in *self {
t.retrieve();
}
}
fn drop(&self) {
for t in *self {
t.drop();
}
}
fn set_dyn<T: Data + Clone>(&self, data: T, shape: &[usize]) {
for t in *self {
t.set_dyn(data.clone(), shape);
}
}
}
macro_rules! tuple_impls {
([$($name:ident),+] , [$($idx:tt),+]) => {
impl<
$($name:
MarkTensors, )+
> MarkTensors for ($($name,)+) {
fn keep(&self) {
$(self.$idx.keep();)+
}
fn retrieve(&self) {
$(self.$idx.retrieve();)+
}
fn drop(&self) {
$(self.$idx.drop();)+
}
fn set_dyn<T: Data + Clone>(&self, data: T, shape: &[usize]) {
$(self.$idx.set_dyn(data.clone(), shape);)+
}
}
};
}
tuple_impls!([M1], [0]);
tuple_impls!([M1, M2], [0, 1]);
tuple_impls!([M1, M2, M3], [0, 1, 2]);
tuple_impls!([M1, M2, M3, M4], [0, 1, 2, 3]);
tuple_impls!([M1, M2, M3, M4, M5], [0, 1, 2, 3, 4]);
tuple_impls!([M1, M2, M3, M4, M5, M6], [0, 1, 2, 3, 4, 5]);
tuple_impls!([M1, M2, M3, M4, M5, M6, M7], [0, 1, 2, 3, 4, 5, 6]);
tuple_impls!([M1, M2, M3, M4, M5, M6, M7, M8], [0, 1, 2, 3, 4, 5, 6, 7]);
tuple_impls!(
[M1, M2, M3, M4, M5, M6, M7, M8, M9],
[0, 1, 2, 3, 4, 5, 6, 7, 8]
);
tuple_impls!(
[M1, M2, M3, M4, M5, M6, M7, M8, M9, M10],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
);
pub trait ToData<S: Shape, T> {
fn to_data_vec(self) -> T;
}
impl<S: Shape> ToData<S, Vec<f32>> for Vec<f32> {
fn to_data_vec(self) -> Vec<f32> {
self
}
}
impl<const A: usize> ToData<(Const<A>,), Vec<f32>> for [f32; A] {
fn to_data_vec(self) -> Vec<f32> {
self.to_vec()
}
}
impl<const A: usize, const B: usize> ToData<(Const<A>, Const<B>), Vec<f32>> for [[f32; B]; A] {
fn to_data_vec(self) -> Vec<f32> {
self.into_iter().flat_map(|i| i.to_vec()).collect()
}
}
impl<const A: usize, const B: usize, const C: usize>
ToData<(Const<A>, Const<B>, Const<C>), Vec<f32>> for [[[f32; C]; B]; A]
{
fn to_data_vec(self) -> Vec<f32> {
self.into_iter()
.flat_map(|i| i.into_iter().flat_map(|i| i.to_vec()))
.collect()
}
}
impl<const A: usize, const B: usize, const C: usize, const D: usize>
ToData<(Const<A>, Const<B>, Const<C>, Const<D>), Vec<f32>> for [[[[f32; D]; C]; B]; A]
{
fn to_data_vec(self) -> Vec<f32> {
self.into_iter()
.flat_map(|i| {
i.into_iter()
.flat_map(|i| i.into_iter().flat_map(|i| i.to_vec()))
})
.collect()
}
}
impl<const A: usize, const B: usize, const C: usize, const D: usize, const E: usize>
ToData<(Const<A>, Const<B>, Const<C>, Const<D>), Vec<f32>> for [[[[[f32; E]; D]; C]; B]; A]
{
fn to_data_vec(self) -> Vec<f32> {
self.into_iter()
.flat_map(|i| {
i.into_iter().flat_map(|i| {
i.into_iter()
.flat_map(|i| i.into_iter().flat_map(|i| i.to_vec()))
})
})
.collect()
}
}