1use std::collections::HashMap;
4use crate::tensor::DenseTensor;
5use crate::tensor::traits::{TensorBase, TensorOps};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub struct OpId(pub usize);
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub struct TensorId(pub usize);
14
15#[derive(Debug, Clone)]
17pub struct OpRef {
18 pub id: OpId,
20 pub op_type: OpType,
22 pub inputs: Vec<TensorId>,
24 pub output: TensorId,
26}
27
28#[derive(Debug, Clone)]
30pub enum OpType {
31 Add,
33 Sub,
35 Mul,
37 Div,
39 MatMul,
41 Transpose,
43 Sum,
45 Mean,
47 ReLU,
49 GELU,
51 Sigmoid,
53 Tanh,
55 SiLU,
57 Softmax,
59 LayerNorm,
61 RMSNorm,
63 Linear,
65 Embedding,
67 RoPE,
69 ScaledDotProduct,
71}
72
73#[derive(Debug, Clone)]
75pub struct OpNode {
76 pub id: OpId,
78 pub op_type: OpType,
80 pub inputs: Vec<TensorId>,
82 pub output: TensorId,
84}
85
86#[derive(Debug, Clone)]
88pub struct DataEdge {
89 pub from: OpId,
91 pub to: OpId,
93 pub tensor_id: TensorId,
95}
96
97#[derive(Debug, Clone)]
99pub struct Checkpoint {
100 pub tensors: HashMap<TensorId, DenseTensor>,
102}
103
104#[derive(Debug, Default, Clone)]
106pub struct ComputeGraph {
107 nodes: Vec<OpNode>,
109 edges: Vec<DataEdge>,
111 gradients: HashMap<TensorId, DenseTensor>,
113 values: HashMap<TensorId, DenseTensor>,
115 checkpoint: Option<Checkpoint>,
117 next_op_id: usize,
119 next_tensor_id: usize,
121 recording: bool,
123}
124
125impl ComputeGraph {
126 pub fn new() -> Self {
128 Self {
129 nodes: Vec::new(),
130 edges: Vec::new(),
131 gradients: HashMap::new(),
132 values: HashMap::new(),
133 checkpoint: None,
134 next_op_id: 0,
135 next_tensor_id: 0,
136 recording: true,
137 }
138 }
139
140 pub fn next_op_id(&mut self) -> OpId {
142 let id = OpId(self.next_op_id);
143 self.next_op_id += 1;
144 id
145 }
146
147 pub fn next_tensor_id(&mut self) -> TensorId {
149 let id = TensorId(self.next_tensor_id);
150 self.next_tensor_id += 1;
151 id
152 }
153
154 pub fn record_op(&mut self, op_type: OpType, inputs: &[TensorId], output: TensorId) {
156 if !self.recording {
157 return;
158 }
159
160 let op_id = self.next_op_id();
161 let node = OpNode {
162 id: op_id,
163 op_type: op_type.clone(),
164 inputs: inputs.to_vec(),
165 output,
166 };
167 self.nodes.push(node);
168
169 for &input_id in inputs {
171 if let Some(producer_op) = self.nodes.iter().rev().find(|n| {
173 n.output == input_id
174 }) {
175 let edge = DataEdge {
176 from: producer_op.id,
177 to: op_id,
178 tensor_id: input_id,
179 };
180 self.edges.push(edge);
181 }
182 }
183 }
184
185 pub fn store_value(&mut self, tensor_id: TensorId, value: DenseTensor) {
187 self.values.insert(tensor_id, value);
188 }
189
190 pub fn get_value(&self, tensor_id: TensorId) -> Option<&DenseTensor> {
192 self.values.get(&tensor_id)
193 }
194
195 pub fn get_value_mut(&mut self, tensor_id: TensorId) -> Option<&mut DenseTensor> {
197 self.values.get_mut(&tensor_id)
198 }
199
200 pub fn store_gradient(&mut self, tensor_id: TensorId, gradient: DenseTensor) {
202 self.gradients.insert(tensor_id, gradient);
203 }
204
205 pub fn get_gradient(&self, tensor_id: TensorId) -> Option<&DenseTensor> {
207 self.gradients.get(&tensor_id)
208 }
209
210 pub fn backward(&mut self, loss: TensorId) -> HashMap<TensorId, DenseTensor> {
218 if let Some(loss_tensor) = self.values.get(&loss) {
220 let shape = loss_tensor.shape().to_vec();
221 let ones = DenseTensor::ones(shape);
222 self.gradients.insert(loss, ones);
223 }
224
225 let topo_order = self.topological_sort();
227
228 for op_id in topo_order.into_iter().rev() {
230 let (node_op_type, node_inputs, node_output) = if let Some(node) = self.nodes.iter().find(|n| n.id == op_id) {
232 (node.op_type.clone(), node.inputs.clone(), node.output)
233 } else {
234 continue;
235 };
236
237 let grad_output = self.gradients.get(&node_output).cloned();
238
239 if let Some(grad) = grad_output {
240 let input_grads = self.compute_gradients(&node_op_type, &node_inputs, &grad);
242
243 for (i, &input_id) in node_inputs.iter().enumerate() {
245 if let Some(input_grad) = input_grads.get(&i) {
246 self.accumulate_gradient(input_id, input_grad.clone());
247 }
248 }
249 }
250 }
251
252 self.gradients.clone()
253 }
254
255 fn compute_gradients(
257 &self,
258 op_type: &OpType,
259 inputs: &[TensorId],
260 grad_output: &DenseTensor,
261 ) -> HashMap<usize, DenseTensor> {
262 let mut grads = HashMap::new();
263
264 match op_type {
265 OpType::Add => {
266 for (i, _) in inputs.iter().enumerate() {
268 grads.insert(i, grad_output.clone());
269 }
270 }
271 OpType::Sub => {
272 for (i, _) in inputs.iter().enumerate() {
274 if i == 0 {
275 grads.insert(i, grad_output.clone());
276 } else {
277 grads.insert(i, grad_output.neg());
278 }
279 }
280 }
281 OpType::Mul => {
282 if inputs.len() >= 2 {
284 if let (Some(x), Some(y)) = (
285 self.values.get(&inputs[0]),
286 self.values.get(&inputs[1]),
287 ) {
288 grads.insert(0, grad_output.mul(y));
289 grads.insert(1, grad_output.mul(x));
290 }
291 }
292 }
293 OpType::MatMul => {
294 if inputs.len() >= 2 {
296 if let (Some(x), Some(w)) = (
297 self.values.get(&inputs[0]),
298 self.values.get(&inputs[1]),
299 ) {
300 let w_t = w.transpose(None);
302 let grad_x = grad_output.matmul(&w_t);
303 grads.insert(0, grad_x);
304
305 let x_t = x.transpose(None);
307 let grad_w = x_t.matmul(grad_output);
308 grads.insert(1, grad_w);
309 }
310 }
311 }
312 OpType::ReLU => {
313 if let Some(x) = inputs.first().and_then(|id| self.values.get(id)) {
315 let mask = x.gt(0.0);
316 let grad = grad_output.mul(&mask);
317 grads.insert(0, grad);
318 }
319 }
320 OpType::GELU => {
321 if let Some(x) = inputs.first().and_then(|id| self.values.get(id)) {
323 let gelu_grad = x.gelu_derivative();
324 let grad = grad_output.mul(&gelu_grad);
325 grads.insert(0, grad);
326 }
327 }
328 OpType::Softmax => {
329 if let Some(softmax_out) = inputs.first().and_then(|id| self.values.get(id)) {
331 let sum_grad_dot_s = grad_output.mul(softmax_out).sum(None);
332 let ones = DenseTensor::ones(softmax_out.shape().to_vec());
333 let ones_scaled = ones.scale(sum_grad_dot_s.data()[0]);
334 let diff = grad_output.sub(&ones_scaled);
335 let grad = softmax_out.mul(&diff);
336 grads.insert(0, grad);
337 }
338 }
339 OpType::Transpose => {
340 if !inputs.is_empty() {
342 grads.insert(0, grad_output.transpose(None));
343 }
344 }
345 OpType::LayerNorm | OpType::RMSNorm => {
346 if inputs.first().and_then(|id| self.values.get(id)).is_some() {
348 grads.insert(0, grad_output.clone());
350 }
351 }
352 _ => {
353 for (i, _) in inputs.iter().enumerate() {
355 grads.insert(i, grad_output.clone());
356 }
357 }
358 }
359
360 grads
361 }
362
363 pub fn accumulate_gradient(&mut self, tensor_id: TensorId, gradient: DenseTensor) {
365 self.gradients
366 .entry(tensor_id)
367 .and_modify(|existing| {
368 *existing = existing.add(&gradient);
369 })
370 .or_insert(gradient);
371 }
372
373 pub fn topological_sort(&self) -> Vec<OpId> {
375 let mut result = Vec::new();
376 let mut visited = std::collections::HashSet::new();
377
378 fn visit(
379 node: &OpNode,
380 nodes: &[OpNode],
381 visited: &mut std::collections::HashSet<OpId>,
382 result: &mut Vec<OpId>,
383 ) {
384 if visited.contains(&node.id) {
385 return;
386 }
387 visited.insert(node.id);
388
389 for &input_id in &node.inputs {
391 if let Some(producer) = nodes.iter().find(|n| n.output == input_id) {
392 visit(producer, nodes, visited, result);
393 }
394 }
395
396 result.push(node.id);
397 }
398
399 for node in &self.nodes {
400 visit(node, &self.nodes, &mut visited, &mut result);
401 }
402
403 result
404 }
405
406 pub fn clear(&mut self) {
408 self.nodes.clear();
409 self.edges.clear();
410 self.gradients.clear();
411 self.values.clear();
412 self.checkpoint = None;
413 }
414
415 pub fn set_recording(&mut self, recording: bool) {
417 self.recording = recording;
418 }
419
420 pub fn is_recording(&self) -> bool {
422 self.recording
423 }
424
425 pub fn num_ops(&self) -> usize {
427 self.nodes.len()
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434
435 #[test]
436 fn test_compute_graph_basic() {
437 let mut graph = ComputeGraph::new();
438
439 let x_id = graph.next_tensor_id();
441 let w_id = graph.next_tensor_id();
442
443 let x = DenseTensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]);
444 let w = DenseTensor::new(vec![0.1, 0.2, 0.3], vec![3, 1]);
445
446 graph.store_value(x_id, x);
447 graph.store_value(w_id, w);
448
449 let out_id = graph.next_tensor_id();
451 graph.record_op(OpType::MatMul, &[x_id, w_id], out_id);
452
453 if let (Some(x), Some(w)) = (graph.get_value(x_id), graph.get_value(w_id)) {
455 let out = x.matmul(w);
456 graph.store_value(out_id, out);
457 }
458
459 assert_eq!(graph.num_ops(), 1);
460 }
461
462 #[test]
463 fn test_topological_sort() {
464 let mut graph = ComputeGraph::new();
465
466 let x_id = graph.next_tensor_id();
468 let w_id = graph.next_tensor_id();
469 let matmul_out = graph.next_tensor_id();
470 let relu_out = graph.next_tensor_id();
471
472 graph.store_value(x_id, DenseTensor::new(vec![1.0, 2.0], vec![1, 2]));
473 graph.store_value(w_id, DenseTensor::new(vec![0.1, 0.2], vec![2, 1]));
474
475 graph.record_op(OpType::MatMul, &[x_id, w_id], matmul_out);
476 graph.record_op(OpType::ReLU, &[matmul_out], relu_out);
477
478 let order = graph.topological_sort();
479 assert_eq!(order.len(), 2);
480 assert!(order.iter().position(|&id| {
482 graph.nodes.iter().any(|n| n.id == id && matches!(n.op_type, OpType::MatMul))
483 }).unwrap() < order.iter().position(|&id| {
484 graph.nodes.iter().any(|n| n.id == id && matches!(n.op_type, OpType::ReLU))
485 }).unwrap());
486 }
487}