#[cfg(not(feature = "no-std"))]
use crate::Ident;
use core::cell::{Ref, RefMut};
#[cfg(feature = "opt-cache")]
use crate::{CacheReturn, DeviceError};
pub use add_graph::*;
pub use node::*;
mod add_graph;
mod node;
#[cfg(not(feature = "no-std"))]
mod graph_struct;
#[cfg(not(feature = "no-std"))]
pub use graph_struct::*;
pub trait NodeIdx {
#[inline]
fn idx(nodes: &[Node]) -> usize {
nodes.len()
}
}
#[derive(Debug, Default)]
pub struct GlobalCount;
#[cfg(feature = "no-std")]
impl NodeIdx for GlobalCount {
}
#[cfg(feature = "no-std")]
pub struct Graph<IdxFrom: NodeIdx> {
_p: core::marker::PhantomData<IdxFrom>
}
#[cfg(feature = "no-std")]
impl<IdxFrom: NodeIdx> Graph<IdxFrom> {
#[inline]
pub fn add_leaf(&mut self, _len: usize) -> Node {
unimplemented!("Not available in no-std mode")
}
#[inline]
pub fn add_node(&mut self, _len: usize, _lhs_idx: usize, _rhs_idx: usize) -> Node {
unimplemented!("Not available in no-std mode")
}
}
#[cfg(not(feature = "no-std"))]
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct CacheTrace {
pub cache_idx: usize,
pub use_cache_idx: Vec<Ident>,
}
pub trait GraphReturn<IdxFrom: NodeIdx = GlobalCount> {
fn graph(&self) -> Ref<Graph<IdxFrom>>;
fn graph_mut(&self) -> RefMut<Graph<IdxFrom>>;
}
#[cfg(feature = "opt-cache")]
pub trait GraphOpt {
fn optimize(&self) -> crate::Result<()>
where
Self: GraphReturn + CacheReturn + crate::PtrConv,
{
let mut cache = self.cache_mut();
for trace in self.graph().cache_traces() {
for node in &trace.use_cache_idx {
let ptr = cache
.nodes
.get(&trace.use_cache_idx[0])
.ok_or(DeviceError::GraphOptimization)?
.clone();
cache.nodes.insert(*node, ptr);
}
}
Ok(())
}
}
#[cfg(feature = "opt-cache")]
impl<D: GraphReturn> GraphOpt for D {}
#[cfg(not(feature = "no-std"))]
#[cfg(test)]
mod tests {
use crate::{set_count, CacheTrace, Graph, Ident, Node, NodeCount};
#[test]
fn test_is_leaf() {
let mut graph = Graph::<NodeCount>::new();
let node = graph.add_leaf(0);
assert!(node.is_leaf());
let node = graph.add_node(10, 1, 2);
assert!(!node.is_leaf());
}
#[test]
fn test_cache_trace() {
unsafe { set_count(0) };
let mut graph = Graph::<NodeCount>::new();
let a = graph.add_leaf(10);
let b = graph.add_leaf(10);
let c = graph.add_node(10, a.idx, b.idx);
let d = graph.add_node(10, c.idx, c.idx);
let _e = graph.add_node(10, d.idx, b.idx);
let trace = graph.trace_cache_path_raw(&c);
assert_eq!(
vec![
Node {
idx: 3,
deps: [2, 2],
len: 10,
},
Node {
idx: 4,
deps: [3, 1],
len: 10,
}
],
trace
);
let _traces = graph.cache_traces();
}
#[test]
fn test_no_cache_trace() {
unsafe { set_count(0) };
let mut graph = Graph::<NodeCount>::new();
let a = graph.add_leaf(10);
let b = graph.add_leaf(10);
let c = graph.add_node(10, a.idx, b.idx);
let d = graph.add_node(10, c.idx, c.idx);
let _e = graph.add_node(10, d.idx, b.idx);
let _f = graph.add_node(10, c.idx, b.idx);
let trace = graph.trace_cache_path_raw(&c);
assert_eq!(Vec::<Node>::new(), trace);
}
#[test]
fn test_cache_trace_2() {
unsafe { set_count(0) };
let mut graph = Graph::<NodeCount>::new();
let a = graph.add_leaf(10); let b = graph.add_leaf(10); let u = graph.add_leaf(10);
let c = graph.add_node(10, a.idx, b.idx);
let _z = graph.add_node(10, a.idx, u.idx);
let d = graph.add_node(10, c.idx, c.idx);
let _e = graph.add_node(10, d.idx, b.idx);
let trace = graph.trace_cache_path_raw(&c);
assert_eq!(
vec![
Node {
idx: 5,
deps: [3, 3],
len: 10,
},
Node {
idx: 6,
deps: [5, 1],
len: 10,
}
],
trace
);
}
#[test]
fn test_cache_trace_break_not_anymore() {
unsafe { set_count(0) };
let mut graph = Graph::<NodeCount>::new();
let a = graph.add_leaf(10);
let b = graph.add_leaf(10);
let c = graph.add_node(10, a.idx, b.idx);
let d = graph.add_node(10, c.idx, c.idx);
let _u = graph.add_node(10, d.idx, a.idx);
let _e = graph.add_node(10, d.idx, b.idx);
println!("traces: {:?}", graph.cache_traces());
let trace = graph.trace_cache_path_raw(&c);
println!("c_trace: {:?}", trace);
assert!(graph.is_path_optimizable(&c));
assert!(!graph.is_path_optimizable(&d));
}
#[test]
fn test_trace_all() {
unsafe { set_count(0) };
let mut graph = Graph::<NodeCount>::new();
let a = graph.add_leaf(10);
let b = graph.add_leaf(10);
let c = graph.add_node(10, a.idx, b.idx);
let d = graph.add_node(10, c.idx, c.idx);
let _e = graph.add_node(10, d.idx, b.idx);
let traces = graph.cache_traces();
assert_eq!(traces.len(), 1);
assert_eq!(
CacheTrace {
cache_idx: 2,
use_cache_idx: vec![Ident { idx: 3, len: 10 }, Ident { idx: 4, len: 10 },],
},
traces[0]
);
}
#[test]
fn test_leafed_diff_len_trace() {
unsafe { set_count(0) };
let mut graph = Graph::<NodeCount>::new();
let a = graph.add_leaf(10);
let _b = graph.add_node(10, a.idx, a.idx);
let _z = graph.add_leaf(10);
let _z = graph.add_leaf(10);
let c = graph.add_node(12, a.idx, a.idx);
let d = graph.add_node(12, c.idx, c.idx);
let _e = graph.add_node(12, d.idx, a.idx);
let traces = graph.cache_traces();
assert_eq!(
[CacheTrace {
cache_idx: 4,
use_cache_idx: vec![Ident { idx: 5, len: 12 }, Ident { idx: 6, len: 12 },],
}],
&*traces
);
}
#[test]
fn test_cache_trace_neural_net() {
unsafe { set_count(0) };
let mut graph = Graph::<NodeCount>::new();
let inputs = graph.add_leaf(100 * 10);
let targets = graph.add_leaf(100);
let w1 = graph.add_leaf(10 * 64);
let b1 = graph.add_leaf(64);
let w2 = graph.add_leaf(64 * 64);
let b2 = graph.add_leaf(64);
let w3 = graph.add_leaf(64 * 64);
let b3 = graph.add_leaf(64);
let w4 = graph.add_leaf(64 * 1);
let b4 = graph.add_leaf(1);
let a1 = graph.add_node(100 * 64, inputs.idx, w1.idx);
let a2 = graph.add_node(100 * 64, a1.idx, b1.idx);
let a2 = graph.add_node(100 * 64, a2.idx, a2.idx);
let a3 = graph.add_node(100 * 64, a2.idx, w2.idx);
let a4 = graph.add_node(100 * 64, a3.idx, b2.idx);
let a4 = graph.add_node(100 * 64, a4.idx, a4.idx);
let a5 = graph.add_node(100 * 64, a4.idx, w3.idx);
let a6 = graph.add_node(100 * 64, a5.idx, b3.idx);
let a6 = graph.add_node(100 * 64, a6.idx, a6.idx);
let a7 = graph.add_node(100 * 1, a6.idx, w4.idx);
let a8 = graph.add_node(100 * 1, a7.idx, b4.idx);
let _loss = graph.add_node(100, a8.idx, targets.idx);
let traces = graph.cache_traces();
assert_eq!(
traces,
[
CacheTrace {
cache_idx: 10,
use_cache_idx: vec![
Ident { idx: 11, len: 6400 },
Ident { idx: 12, len: 6400 },
Ident { idx: 13, len: 6400 },
Ident { idx: 14, len: 6400 },
Ident { idx: 15, len: 6400 },
Ident { idx: 16, len: 6400 },
Ident { idx: 17, len: 6400 },
Ident { idx: 18, len: 6400 }
]
},
CacheTrace {
cache_idx: 19,
use_cache_idx: vec![
Ident { idx: 20, len: 100 },
Ident { idx: 21, len: 100 },
]
}
]
)
}
#[test]
fn test_cache_trace_d() {
unsafe { set_count(0) };
let mut graph = Graph::<NodeCount>::new();
let a = graph.add_leaf(10);
let b = graph.add_leaf(10);
let c = graph.add_node(10, a.idx, b.idx);
let d = graph.add_node(10, c.idx, c.idx);
let _u = graph.add_node(10, a.idx, d.idx);
let trace = graph.trace_cache_path_raw(&c);
assert_eq!(
vec![
Node {
idx: 3,
deps: [2, 2],
len: 10,
},
Node {
idx: 4,
deps: [0, 3],
len: 10,
}
],
trace
);
assert!(graph.is_path_optimizable(&c));
assert!(graph.is_path_optimizable(&d));
let trace = graph.cache_traces();
assert_eq!(
trace,
[CacheTrace {
cache_idx: 2,
use_cache_idx: vec![Ident { idx: 3, len: 10 }, Ident { idx: 4, len: 10 }]
}]
);
}
#[cfg(feature = "cpu")]
#[cfg(feature = "opt-cache")]
#[test]
fn test_from_retrieve() {
use crate::{Buffer, Device, GraphReturn, CPU};
let device = CPU::new();
let w1 = Buffer::from((&device, [1; 10 * 64]));
let b1 = Buffer::from((&device, [1; 64]));
let w2 = Buffer::from((&device, [1; 64 * 64]));
let b2 = Buffer::from((&device, [1; 64]));
let w3 = Buffer::from((&device, [1; 64 * 64]));
let b3 = Buffer::from((&device, [1; 64]));
let w4 = Buffer::from((&device, [1; 64 * 1]));
let b4 = Buffer::from((&device, [1; 1]));
let inputs = Buffer::from((&device, [1; 10 * 100]));
let targets = Buffer::from((&device, [2; 100]));
let a1 = device.retrieve::<i32, ()>(100 * 64, (&inputs, &w1));
let a2 = device.retrieve::<i32, ()>(100 * 64, (&a1, &b1));
let a2 = device.retrieve::<i32, ()>(100 * 64, (&a2, &a2));
let a3 = device.retrieve::<i32, ()>(100 * 64, (&a2, &w2));
let a4 = device.retrieve::<i32, ()>(100 * 64, (&a3, &b2));
let a4 = device.retrieve::<i32, ()>(100 * 64, (&a4, &a4));
let a5 = device.retrieve::<i32, ()>(100 * 64, (&a4, &w3));
let a6 = device.retrieve::<i32, ()>(100 * 64, (&a5, &b3));
let a6 = device.retrieve::<i32, ()>(100 * 64, (&a6, &a6));
let a7 = device.retrieve::<i32, ()>(100 * 1, (&a6, &w4));
let a8 = device.retrieve::<i32, ()>(100 * 1, (&a7, &b4));
let _loss = device.retrieve::<i32, ()>(100, (&a8, &targets));
let cts = device.graph().cache_traces();
assert_eq!(
cts,
[
CacheTrace {
cache_idx: 10,
use_cache_idx: vec![
Ident { idx: 11, len: 6400 },
Ident { idx: 12, len: 6400 },
Ident { idx: 13, len: 6400 },
Ident { idx: 14, len: 6400 },
Ident { idx: 15, len: 6400 },
Ident { idx: 16, len: 6400 },
Ident { idx: 17, len: 6400 },
Ident { idx: 18, len: 6400 }
]
},
CacheTrace {
cache_idx: 19,
use_cache_idx: vec![
Ident { idx: 20, len: 100 },
Ident { idx: 21, len: 100 },
]
}
]
)
}
#[test]
fn test_no_cache_trace_in_graph() {
let mut graph = Graph::<NodeCount>::new();
let a = graph.add_leaf(10);
let b = graph.add_leaf(10);
let c = graph.add_node(10, a.idx, b.idx);
let trace = graph.trace_cache_path_raw(&c);
graph.cache_traces();
assert_eq!(Vec::<Node>::new(), trace);
}
#[test]
fn test_multiple_traces() {
unsafe { set_count(0) };
let mut graph = Graph::<NodeCount>::new();
let a = graph.add_leaf(10);
let _b = graph.add_node(10, a.idx, a.idx);
let _z = graph.add_leaf(10);
let _z = graph.add_leaf(10);
let c = graph.add_node(12, a.idx, a.idx);
let d = graph.add_node(12, c.idx, c.idx);
let _e = graph.add_node(12, d.idx, a.idx);
let f = graph.add_node(10, _b.idx, _z.idx);
let _g = graph.add_node(10, f.idx, _z.idx);
let traces = graph.cache_traces();
assert_eq!(
[
CacheTrace {
cache_idx: 1,
use_cache_idx: vec![Ident { idx: 7, len: 10 }, Ident { idx: 8, len: 10 }]
},
CacheTrace {
cache_idx: 4,
use_cache_idx: vec![Ident { idx: 5, len: 12 }, Ident { idx: 6, len: 12 }]
}
],
&*traces
);
}
}