pub struct Graph { /* private fields */ }
Expand description
The core graph datastructure.
This is a Directed Acyclic Graph (DAG) with values and their creating operations as nodes, and input operands as edges. The data structure is append-only, values cannot be removed and so will never become invalid.
This type implements Index<Value>
trait, so you can use graph[value]
to get information about the given value.
// create a new graph
let mut graph = Graph::new();
// define the inputs
let x = graph.input(shape![Size::BATCH, 4, 8, 8]);
// define constants
let w_data = vec![0.5; 4 * 4 * 3 * 3];
let w = graph.constant(shape![4, 4, 3, 3], w_data);
let b_data = vec![0.5; 4];
let b = graph.constant(shape![4, 1, 1], b_data);
// build operation graph
let y0 = graph.conv(x, w, 1, 1, 1, 1);
let y = graph.add(y0, b);
graph.output(y);
println!("{}", graph);
Results in the following output:
Graph {
check: 1504812640,
input_shapes: [Shape(B x 4 x 8 x 8)],
output_shapes: [Shape(B x 4 x 8 x 8)],
inputs: [Value(0)],
outputs: [Value(6)],
values: [
Value(0) = ValueInfo { shape: Shape(B x 4 x 8 x 8), operation: Input { index: 0 }, debug_id: "", non_output_uses: 1 },
Value(1) = ValueInfo { shape: Shape(4 x 4 x 3 x 3), operation: Constant { data: [..; 144] }, debug_id: "", non_output_uses: 1 },
Value(2) = ValueInfo { shape: Shape(4 x 1 x 1), operation: Constant { data: [0.5, 0.5, 0.5, 0.5] }, debug_id: "", non_output_uses: 1 },
Value(3) = ValueInfo { shape: Shape(B x 4 x 8 x 8), operation: Conv { input: Value(0), filter: Value(1), details: ConvDetails { batch_size: Size(B), input_channels: 4, output_channels: 4, input_h: 8, input_w: 8, kernel_h: 3, kernel_w: 3, stride_y: 1, stride_x: 1, padding_y: 1, padding_x: 1, output_h: 8, output_w: 8 } }, debug_id: "", non_output_uses: 1 },
Value(4) = ValueInfo { shape: Shape(1 x 4 x 1 x 1), operation: View { input: Value(2) }, debug_id: "", non_output_uses: 1 },
Value(5) = ValueInfo { shape: Shape(B x 4 x 8 x 8), operation: Broadcast { input: Value(4) }, debug_id: "", non_output_uses: 1 },
Value(6) = ValueInfo { shape: Shape(B x 4 x 8 x 8), operation: Binary { left: Value(3), right: Value(5), op: Add }, debug_id: "", non_output_uses: 0 },
],
}
Implementations§
source§impl Graph
impl Graph
pub fn new() -> Self
sourcepub fn values(&self) -> impl Iterator<Item = Value>
pub fn values(&self) -> impl Iterator<Item = Value>
Iterate over the values in this graph, in topological order, which means that nodes will only be visited after all of their inputs have been visited.
pub fn inputs(&self) -> &[Value]
pub fn input_shapes(&self) -> Vec<Shape>
pub fn outputs(&self) -> &[Value]
pub fn output_shapes(&self) -> Vec<Shape>
pub fn outputs_mut(&mut self) -> &mut Vec<Value>
pub fn is_const(&self, value: Value) -> bool
sourcepub fn is_const_filled_with(&self, value: Value, f: f32) -> bool
pub fn is_const_filled_with(&self, value: Value, f: f32) -> bool
Returns whether value
is effectively a constant with every element equal to f
.
sourcepub fn as_single_const(&self, value: Value) -> Option<f32>
pub fn as_single_const(&self, value: Value) -> Option<f32>
Returns Some(f)
if value
is effectively a constant with every element equal to f
.
sourcepub fn take_new_values(&mut self) -> Vec<Value>
pub fn take_new_values(&mut self) -> Vec<Value>
Return all newly crated values since the last call to take_new_values
.
sourcepub fn set_debug_id(&mut self, value: Value, id: String)
pub fn set_debug_id(&mut self, value: Value, id: String)
Equivalent to self[value].debug_id = id
,
but that would not work since there is intentionally no implementation of IndexMut
for Graph
.
pub fn scalar(&mut self, value: f32) -> Value
pub fn constant_tensor(&mut self, tensor: Tensor) -> Value
sourcepub fn view(&mut self, input: Value, new_shape: Shape) -> Value
pub fn view(&mut self, input: Value, new_shape: Shape) -> Value
View an existing value as a new shape.
sourcepub fn broadcast(&mut self, input: Value, new_shape: Shape) -> Value
pub fn broadcast(&mut self, input: Value, new_shape: Shape) -> Value
Broadcast the input
towards new_shape
.
Additional unit axes are are inserted at the front and unit axes are repeated as necessary.
pub fn repeat_unary(&mut self, input: Value, axis: usize, count: Size) -> Value
sourcepub fn flatten(&mut self, input: Value, start_axis: usize) -> Value
pub fn flatten(&mut self, input: Value, start_axis: usize) -> Value
View a value with a flattened shape.
All axis starting from start_axis
inclusive are flattened into a single axis.
sourcepub fn permute(&mut self, input: Value, permutation: Vec<usize>) -> Value
pub fn permute(&mut self, input: Value, permutation: Vec<usize>) -> Value
Change the order of axis in the shape.
sourcepub fn slice(&mut self, input: Value, axis: usize, range: SliceRange) -> Value
pub fn slice(&mut self, input: Value, axis: usize, range: SliceRange) -> Value
Slice a value along an axis.
sourcepub fn index(&mut self, input: Value, axis: usize, index: usize) -> Value
pub fn index(&mut self, input: Value, axis: usize, index: usize) -> Value
Index along a given axis. Similar to slice with a 1-sized interval except that the the resulting value doesn’t have the extra axis.
sourcepub fn repeat(&mut self, input: Value, axis: usize, count: Size) -> Value
pub fn repeat(&mut self, input: Value, axis: usize, count: Size) -> Value
Repeat input
along a given axis
, count
times.
This starts by emitting the entire tensor before repeating elements,
similar to torch.repeat
or numpy.tile
.
sourcepub fn repeat_interleave(
&mut self,
input: Value,
axis: usize,
count: Size
) -> Value
pub fn repeat_interleave( &mut self, input: Value, axis: usize, count: Size ) -> Value
Repeat elements of input
along a given axis
, count
times.
This starts by repeat each element before going to the next one,
similar to torch.repeat_interleave
or numpy.repeat
.
sourcepub fn gather(&mut self, input: Value, axis: usize, indices: Value) -> Value
pub fn gather(&mut self, input: Value, axis: usize, indices: Value) -> Value
Index input
along the given axis
with indices given by indices
.
The output
shape is the input
shape with axis
replaced by the shape of indices
.
sourcepub fn concat(
&mut self,
inputs: Vec<Value>,
axis: usize,
base_shape: Option<Shape>
) -> Value
pub fn concat( &mut self, inputs: Vec<Value>, axis: usize, base_shape: Option<Shape> ) -> Value
Concatenate inputs
along axis
.
base_shape
can be provided to allow the result shape to be inferred in case inputs
is empty.
sourcepub fn conv(
&mut self,
input: Value,
filter: Value,
stride_y: usize,
stride_x: usize,
padding_y: usize,
padding_x: usize
) -> Value
pub fn conv( &mut self, input: Value, filter: Value, stride_y: usize, stride_x: usize, padding_y: usize, padding_x: usize ) -> Value
Apply 2D convolution.
sourcepub fn linear(&mut self, input: Value, weight: Value) -> Value
pub fn linear(&mut self, input: Value, weight: Value) -> Value
Apply a linear transformation.
Input shape [b, Ci]
and weight shape [Co, Ci]
result in an output with shape [b, Co]
.
sourcepub fn mat_mul(&mut self, left: Value, right: Value) -> Value
pub fn mat_mul(&mut self, left: Value, right: Value) -> Value
General matrix multiply, with broadcasting.
- The last two axes should have shapes
[n, p]
and[p, m]
and will result in an output shape[n, m]
- The preceding axes are broadcast together and reappear in the output as-is.
sourcepub fn batched_mat_mul(&mut self, left: Value, right: Value) -> Value
pub fn batched_mat_mul(&mut self, left: Value, right: Value) -> Value
Batched matrix multiply, without any automatic broadcasting.
Inputs must have shapes [b, m, n]
, [b, n, p]
and the result has shape [b, m, p]
.
pub fn softmax(&mut self, input: Value, axis: usize) -> Value
pub fn layernorm(&mut self, input: Value, axis: usize, eps: f32) -> Value
sourcepub fn reduce(&mut self, input: Value, axes: Vec<usize>, op: ReduceOp) -> Value
pub fn reduce(&mut self, input: Value, axes: Vec<usize>, op: ReduceOp) -> Value
Reduce input
along the given axes
.
The result shape is the same as the input shape but without the reduces axes.
pub fn add(&mut self, left: Value, right: Value) -> Value
pub fn sub(&mut self, left: Value, right: Value) -> Value
pub fn mul(&mut self, left: Value, right: Value) -> Value
pub fn pow(&mut self, left: Value, right: Value) -> Value
pub fn unary(&mut self, op: UnaryOp, input: Value) -> Value
sourcepub fn binary(&mut self, op: BinaryOp, left: Value, right: Value) -> Value
pub fn binary(&mut self, op: BinaryOp, left: Value, right: Value) -> Value
Compute elementwise binary operation. Both inputs must have the same rank (or right must have rank 0), the right shape is broadcasted to the left shape.
sourcepub fn call(&mut self, graph: &Graph, inputs: &[Value]) -> Vec<Value>
pub fn call(&mut self, graph: &Graph, inputs: &[Value]) -> Vec<Value>
Computes the operations described by graph
on the given inputs.
This can be used to cleanly compose multiple graphs together.
sourcepub fn output_all(&mut self, values: &[Value])
pub fn output_all(&mut self, values: &[Value])
Register multiple values as output at once, in order.
Trait Implementations§
Auto Trait Implementations§
impl RefUnwindSafe for Graph
impl Send for Graph
impl Sync for Graph
impl Unpin for Graph
impl UnwindSafe for Graph
Blanket Implementations§
source§impl<S, D, Swp, Dwp, T> AdaptInto<D, Swp, Dwp, T> for Swhere
T: Real + Zero + Arithmetics + Clone,
Swp: WhitePoint<T>,
Dwp: WhitePoint<T>,
D: AdaptFrom<S, Swp, Dwp, T>,
impl<S, D, Swp, Dwp, T> AdaptInto<D, Swp, Dwp, T> for Swhere T: Real + Zero + Arithmetics + Clone, Swp: WhitePoint<T>, Dwp: WhitePoint<T>, D: AdaptFrom<S, Swp, Dwp, T>,
source§fn adapt_into_using<M>(self, method: M) -> Dwhere
M: TransformMatrix<T>,
fn adapt_into_using<M>(self, method: M) -> Dwhere M: TransformMatrix<T>,
source§fn adapt_into(self) -> D
fn adapt_into(self) -> D
source§impl<T, C> ArraysFrom<C> for Twhere
C: IntoArrays<T>,
impl<T, C> ArraysFrom<C> for Twhere C: IntoArrays<T>,
source§fn arrays_from(colors: C) -> T
fn arrays_from(colors: C) -> T
source§impl<T, C> ArraysInto<C> for Twhere
C: FromArrays<T>,
impl<T, C> ArraysInto<C> for Twhere C: FromArrays<T>,
source§fn arrays_into(self) -> C
fn arrays_into(self) -> C
source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere T: ?Sized,
source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
source§impl<T, C> ComponentsFrom<C> for Twhere
C: IntoComponents<T>,
impl<T, C> ComponentsFrom<C> for Twhere C: IntoComponents<T>,
source§fn components_from(colors: C) -> T
fn components_from(colors: C) -> T
source§impl<T> FromAngle<T> for T
impl<T> FromAngle<T> for T
source§fn from_angle(angle: T) -> T
fn from_angle(angle: T) -> T
angle
.source§impl<T, U> FromStimulus<U> for Twhere
U: IntoStimulus<T>,
impl<T, U> FromStimulus<U> for Twhere U: IntoStimulus<T>,
source§fn from_stimulus(other: U) -> T
fn from_stimulus(other: U) -> T
other
into Self
, while performing the appropriate scaling,
rounding and clamping.source§impl<T, U> IntoAngle<U> for Twhere
U: FromAngle<T>,
impl<T, U> IntoAngle<U> for Twhere U: FromAngle<T>,
source§fn into_angle(self) -> U
fn into_angle(self) -> U
T
.source§impl<T, U> IntoColor<U> for Twhere
U: FromColor<T>,
impl<T, U> IntoColor<U> for Twhere U: FromColor<T>,
source§fn into_color(self) -> U
fn into_color(self) -> U
source§impl<T, U> IntoColorUnclamped<U> for Twhere
U: FromColorUnclamped<T>,
impl<T, U> IntoColorUnclamped<U> for Twhere U: FromColorUnclamped<T>,
source§fn into_color_unclamped(self) -> U
fn into_color_unclamped(self) -> U
source§impl<T> IntoStimulus<T> for T
impl<T> IntoStimulus<T> for T
source§fn into_stimulus(self) -> T
fn into_stimulus(self) -> T
self
into T
, while performing the appropriate scaling,
rounding and clamping.source§impl<'a, T, C> TryComponentsInto<C> for Twhere
C: TryFromComponents<T>,
impl<'a, T, C> TryComponentsInto<C> for Twhere C: TryFromComponents<T>,
§type Error = <C as TryFromComponents<T>>::Error
type Error = <C as TryFromComponents<T>>::Error
try_into_colors
fails to cast.source§fn try_components_into(self) -> Result<C, <T as TryComponentsInto<C>>::Error>
fn try_components_into(self) -> Result<C, <T as TryComponentsInto<C>>::Error>
source§impl<T, U> TryIntoColor<U> for Twhere
U: TryFromColor<T>,
impl<T, U> TryIntoColor<U> for Twhere U: TryFromColor<T>,
source§fn try_into_color(self) -> Result<U, OutOfBounds<U>>
fn try_into_color(self) -> Result<U, OutOfBounds<U>>
OutOfBounds
error is returned which contains
the unclamped color. Read more