1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
// Copyright (C) 2025 zk4x
// SPDX-License-Identifier: LGPL-3.0-only
#![allow(unused)]
use crate::{
RT, Set,
tensor::{Tensor, TensorId},
};
pub struct StaticGraph {
inputs: Set<TensorId>,
outputs: Set<TensorId>,
graph: Vec<GraphOp>,
}
impl Drop for StaticGraph {
fn drop(&mut self) {
let mut rt = RT.lock();
for &tid in self.inputs.union(&self.outputs) {
rt.release(tid);
}
}
}
impl StaticGraph {
/// Create new static graph using inputs and outputs.
/// Inputs are tensors that can be changed during each forward pass.
/// Outputs are tensors that get realized during forward pass.
pub fn new(inputs: impl IntoIterator<Item = Tensor>, outputs: impl IntoIterator<Item = Tensor>) -> Self {
// TODO keep order of inputs and resolve the fact, that input IDs can change, so there needs to be some
// perhaps some interior mutability to keep the graph valid.
// But actually we don't need to do that. We only need to work on the level of buffer IDs.
// The inputs need to be realized before passing them through the compiler and forward pass
// and we only need to map buffers correctly, once we are compiled down to kernels, only buffers matter,
// not tensors.
let inputs: Set<TensorId> = inputs.into_iter().map(|t| t.id).collect();
let outputs: Set<TensorId> = outputs.into_iter().map(|t| t.id).collect();
let mut rt = RT.lock();
for &tid in inputs.union(&outputs) {
rt.retain(tid);
}
let graph = rt.compile_graph(&inputs, &outputs);
Self { inputs, outputs, graph }
}
/// Launch the graph with given inputs.
#[allow(clippy::needless_pass_by_value)]
pub fn forward(&mut self, inputs: impl IntoIterator<Item = Tensor>) {
let _ = inputs;
todo!()
}
}
pub enum GraphOp {
MemoryAllocate,
MemoryFree,
MemoryCopy,
KernelLaunch,
}