1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
use crate::tensor::{Tensor, TensorInternal};
use crate::error::OpError;
use crate::ndarray_ext::{NdArrayView, RawNdArrayView};
use crate::op;
use crate::variable::{VariableID, VariableNamespace};
use crate::{tensor_ops as T, Evaluator};
use crate::{Float, NdArray, VariableEnvironment};
use std::collections::{HashMap, HashSet};
use std::cell::{Ref, RefCell, RefMut};
use std::fmt;
use std::ops::Deref;
pub type TensorID = usize;
/// Graph represents a computation graph holding tensors inside.
///
/// NOTE:
/// You won't be using this struct directly because this is generally accessed via `Context::deref()`.
pub struct Graph<F: Float> {
pub(crate) node_set: RefCell<Vec<TensorInternal<F>>>,
pub(crate) variable2node: RefCell<HashMap<VariableID, TensorID>>,
}
pub const NUM_NODES_WARN: usize = 50_000;
pub const NUM_NODES_CRITICAL: usize = 500_000;
impl<'graph, F: Float> Graph<F> {
#[inline]
pub fn eval_tensors(
tensors: &[&Tensor<F>],
feeds: &HashMap<TensorID, &RawNdArrayView<F>>,
ctx: &Context<F>,
) -> Vec<Result<NdArray<F>, OpError>> {
// The original tensors we want to compute values for
let mut results = Vec::with_capacity(tensors.len());
// Early return if there are no tensors to evaluate
if tensors.is_empty() {
return results;
}
// Collect all nodes needed for evaluation in topological order
let mut eval_nodes = Vec::new();
let mut visited = HashSet::new();
// Helper function to collect nodes in topological order
fn collect_nodes_topo<F: Float>(
node_id: TensorID,
graph: &Graph<F>,
eval_nodes: &mut Vec<TensorID>,
visited: &mut HashSet<TensorID>,
) {
if visited.contains(&node_id) {
return;
}
// Mark as visited to avoid cycles
visited.insert(node_id);
// Get the node's dependencies (incoming nodes)
let incoming = graph.access_inner(node_id).incoming_nodes.clone();
// Process dependencies first (depth-first)
for incoming_node in &incoming {
collect_nodes_topo(incoming_node.id, graph, eval_nodes, visited);
}
// Add this node after its dependencies
eval_nodes.push(node_id);
}
// Collect nodes for all target tensors
for tensor in tensors {
collect_nodes_topo(tensor.id, ctx.as_graph(), &mut eval_nodes, &mut visited);
}
// Map to store computed values for each node
let mut computed_values: HashMap<TensorID, NdArray<F>> = HashMap::new();
// Add feed values to the computed values
for (&id, &feed_view) in feeds.iter() {
// Convert the RawNdArrayView back to a regular NdArrayView and then to owned array
unsafe {
let view: NdArrayView<F> = std::mem::transmute(feed_view.clone());
let owned_array = view.to_owned();
computed_values.insert(id, owned_array);
}
}
// Evaluate nodes in topological order
for node_id in eval_nodes {
// Skip if already computed (e.g., from feeds)
if computed_values.contains_key(&node_id) {
continue;
}
let node = ctx.as_graph().access_inner(node_id);
// If this is a variable node, fetch its data from the VariableEnvironment
if let Some(variable_id) = node.variable_id {
// Get the variable data from the environment
if let Some(var_array) = ctx.var_env_ref.get_array_by_id(variable_id) {
let borrowed_array = var_array.borrow();
let cloned_array = borrowed_array.clone();
computed_values.insert(node_id, cloned_array);
continue;
} else {
let err = OpError::RuntimeError(format!(
"Variable with ID {variable_id} not found in VariableEnvironment"
));
// If this is one of our target tensors, add an error to the result
for tensor in tensors {
if tensor.id == node_id {
results.push(Err(err.clone()));
}
}
continue;
}
}
// If this is a placeholder but no feed was provided, return an error
if node.placeholder_name.is_some() && !computed_values.contains_key(&node_id) {
let placeholder_name = node.placeholder_name.unwrap_or("<unnamed>");
let err = OpError::RuntimeError(format!(
"No feed value provided for placeholder '{placeholder_name}'"
));
// If this is one of our target tensors, add an error to the result
for tensor in tensors {
if tensor.id == node_id {
results.push(Err(err.clone()));
}
}
// Skip this node since we can't compute it
continue;
}
// Get inputs for this operation
let mut input_arrays = Vec::with_capacity(node.incoming_nodes.len());
// Collect input arrays from computed values
for input_node in &node.incoming_nodes {
if let Some(input_array) = computed_values.get(&input_node.id) {
input_arrays.push(input_array.clone());
} else {
// If an input wasn't computed, there's a bug in our topological sort
let err = OpError::RuntimeError(format!(
"Input node {} for node {} was not computed - possible cycle in graph",
input_node.id, node_id
));
// If this is one of our target tensors, add an error to the result
for tensor in tensors {
if tensor.id == node_id {
results.push(Err(err.clone()));
}
}
// Skip this node
continue;
}
}
// We no longer need a separate output_arrays variable
// Create compute context with cloned input arrays
let cloned_inputs = input_arrays.clone();
let mut compute_ctx = op::ComputeContext::with_inputs(cloned_inputs);
// Execute the operation
match node.get_op().compute(&mut compute_ctx) {
Ok(()) => {
// Operation succeeded, store the output
let outputs = compute_ctx.get_outputs();
if !outputs.is_empty() {
computed_values.insert(node_id, outputs[0].clone());
} else {
// Operation produced no output
let err = OpError::RuntimeError(format!(
"Operation {} did not produce any output",
node.get_op().name()
));
// If this is one of our target tensors, add an error to the result
for tensor in tensors {
if tensor.id == node_id {
results.push(Err(err.clone()));
}
}
}
}
Err(err) => {
// Operation failed
// If this is one of our target tensors, add an error to the result
for tensor in tensors {
if tensor.id == node_id {
results.push(Err(err.clone()));
}
}
}
}
}
// Collect results for the requested tensors
results.clear(); // Clear any error results added during evaluation
for tensor in tensors {
if let Some(value) = computed_values.get(&tensor.id) {
results.push(Ok(value.clone()));
} else {
results.push(Err(OpError::RuntimeError(format!(
"Failed to compute tensor {}",
tensor.id
))));
}
}
results
}
#[inline]
pub fn get_tensor_by_name(&self, name: &'static str) -> Option<TensorID> {
// Search through all tensors to find one with matching placeholder name
let nodes = self.node_set.borrow();
for (id, node) in nodes.iter().enumerate() {
if let Some(placeholder_name) = node.placeholder_name {
if placeholder_name == name {
return Some(id);
}
}
}
None
}
#[inline]
pub(crate) fn install(&'graph self, mut node: TensorInternal<F>) -> TensorID {
let mut inner = self.node_set.borrow_mut();
let id = inner.len();
if id == NUM_NODES_WARN {
eprintln!(
"Too many tensors in this graph: {NUM_NODES_WARN}. \
Use Graph::clear, or move the training loop out of the `run` block"
)
}
if id > NUM_NODES_CRITICAL {
panic!(
"Maximum graph size exceeded: {NUM_NODES_CRITICAL}. \
Use Graph::clear, or move the training loop out of the `run` block"
)
}
node.id = id;
inner.push(node);
id
}
#[inline(always)]
pub(crate) fn access_inner(&self, id: TensorID) -> Ref<TensorInternal<F>> {
let borrow = self.node_set.borrow();
Ref::map(borrow, |t| &t[id])
}
#[inline(always)]
pub(crate) fn access_inner_mut(&self, id: TensorID) -> RefMut<TensorInternal<F>> {
let borrow = self.node_set.borrow_mut();
RefMut::map(borrow, |t| &mut t[id])
}
#[inline(always)]
pub(crate) fn tensor(&'graph self, id: TensorID) -> Tensor<'graph, F> {
Tensor { id, graph: self }
}
#[inline]
pub(crate) fn topo_rank(&self, id: TensorID) -> usize {
self.node_set.borrow()[id].topo_rank
}
#[inline]
pub fn variable_by_id(&self, vid: VariableID) -> Tensor<F> {
let tid = {
let temp = self.variable2node.borrow();
temp.get(&vid).cloned()
};
if let Some(tid) = tid {
// use existing tensor
self.tensor(tid)
} else {
// allocate a new tensor
let allocated = Tensor::builder(self)
.set_variable(vid)
.build(crate::tensor_ops::basic_source_ops::Variable);
// register vid -> tid map
self.variable2node.borrow_mut().insert(vid, allocated.id);
allocated
}
}
}
impl<T: Float> fmt::Debug for Graph<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let set = &*self.node_set.borrow();
let mut buf = format!("graph size: {}\n", set.len());
for node in set {
buf += format!("{node}\n").as_str();
}
write!(f, "{buf}")
}
}
/// Creates and runs a computation graph.
///
/// See [Context].
#[allow(dead_code)]
pub fn run<F, FN, R>(f: FN) -> R
where
F: Float,
FN: FnOnce(&mut Context<F>) -> R,
{
let graph_internal = Graph {
node_set: RefCell::new(Vec::with_capacity(512)),
variable2node: RefCell::new(HashMap::new()),
};
let mut ctx = Context {
var_env_ref: &mut VariableEnvironment::new(),
graph: graph_internal,
};
f(&mut ctx)
}
/// Generates and runs a computation graph
///
/// Each time [run] is invoked, a new `Context` allocating a [Graph] is passed to the closure, in which tensors are generated and evaluated.
/// It's faster to understand if you see [Tensor]'s documentation.
///
/// In order to bind `Tensor`s to pre-defined variable arrays, use [VariableEnvironment::run] instead.
/// See [crate::variable]
pub struct Context<'env, F: Float> {
pub(crate) graph: Graph<F>,
pub(crate) var_env_ref: &'env VariableEnvironment<F>,
}
impl<'graph, 'env, F: Float> Context<'env, F> {
/// Get or create a variable namespace with the specified name.
///
/// Use `namespace_mut` for mutable operations such as variables registrations.
#[inline]
pub fn namespace(&'env self, namespace_id: &'static str) -> VariableNamespace<'env, F> {
self.var_env_ref.namespace(namespace_id)
}
/// Get or create the *default* variable namespace.
///
/// Use `namespace_mut` for mutable operations such as variables registrations.
#[inline]
pub fn default_namespace(&'env self) -> VariableNamespace<'env, F> {
self.var_env_ref.default_namespace()
}
/// Returns a reference to the current VariableEnvironment
#[inline]
pub fn env(&'graph self) -> &'env VariableEnvironment<F> {
self.var_env_ref
}
/// Creates an evaluator for the graph.
///
/// This method is used to evaluate tensors in the graph.
#[inline]
pub fn evaluator(&'graph self) -> Evaluator<'graph, 'graph, F> {
Evaluator::new(self)
}
/// Evaluates tensors in the graph.
///
/// This is an internal method used by tensor.eval()
#[inline]
pub fn eval(
&'graph self,
tensors: &[&Tensor<'graph, F>],
feeds: &HashMap<TensorID, RawNdArrayView<F>>,
_var_env: &'env VariableEnvironment<F>,
) -> Vec<Result<NdArray<F>, OpError>> {
// Create a temporary HashMap to store references
let temp_feeds: HashMap<TensorID, &RawNdArrayView<F>> =
feeds.iter().map(|(k, v)| (*k, v)).collect();
Graph::eval_tensors(tensors, &temp_feeds, self)
}
/// Removes all tensors in this graph.
///
/// Note that any tensors allocated prior to this method call are invalid.
#[inline]
pub fn clear(&mut self) {
self.graph.node_set.borrow_mut().clear();
self.graph.variable2node.borrow_mut().clear();
}
/// Clears the computation graph while preserving variable-to-tensor mappings.
///
/// This is useful for training loops where you want to reset the graph between
/// iterations but maintain references to variables. After calling this method:
/// - All tensor nodes are removed from the graph
/// - Variable references are preserved but will create new tensor nodes on next access
/// - Any existing `Tensor` handles become invalid
///
/// # Example
/// ```ignore
/// for epoch in 0..1000 {
/// env.run(|ctx| {
/// // ... forward pass and backward pass ...
/// ctx.evaluator().run();
///
/// // Clear graph for next iteration, keeping variable mappings
/// ctx.clear_graph();
/// });
/// }
/// ```
///
/// See also: `clear()` for complete graph reset including variable mappings.
#[inline]
pub fn clear_graph(&mut self) {
self.graph.node_set.borrow_mut().clear();
// Keep variable2node mapping - it will be repopulated on next variable access
// but the variable IDs in VariableEnvironment remain valid
self.graph.variable2node.borrow_mut().clear();
}
/// Returns the current number of tensor nodes in the graph.
///
/// This is useful for monitoring graph growth during training loops.
/// If this number grows unboundedly, consider using `clear_graph()` or
/// restructuring your training loop.
#[inline]
pub fn node_count(&self) -> usize {
self.graph.node_set.borrow().len()
}
/// Creates a placeholder tensor in a [Graph].
///
/// placeholder is a named tensor whose value can be specified when evaluating a computation graph.
/// You can designate the `shape` of the placeholder and `shape[i]` can be a positive
/// value or -1 which means an dim of arbitrary size.
///
/// Use `Evaluator::feed` and `Feeder::push` in order to assign ArrayViews to placeholders.
/// ```
/// use scirs2_autograd as ag;
/// use scirs2_core::ndarray::array;
///
/// ag::run(|ctx| {
/// // be aware that x1 and x3 represent the same value
/// let x1 = ctx.placeholder("x", &[-1, 2]);
/// let x2 = ctx.placeholder("y", &[-1, 2]);
/// let x3 = ctx.placeholder("x", &[-1, 2]);
/// let sum = x1 + x2 + x3;
///
/// let arr = &array![[1., 1.]].into_dyn();
///
/// let result = ctx.evaluator()
/// .push(&sum)
/// .feed("x", arr.view()) // feed for x1 and x3
/// .feed("y", arr.view()) // feed for x2
/// .feed(x2, arr.view()) // same as .feed("y", ...)
/// .run();
/// assert_eq!(result[0], Ok(arr + arr + arr));
/// });
/// ```
///
/// See also `tensor_ops::convert_to_tensor`.
#[inline]
pub fn placeholder(&'graph self, name: &'static str, shape: &[isize]) -> Tensor<'graph, F> {
// Check if a placeholder with this name already exists
if let Some(existing_id) = self.get_tensor_by_name(name) {
// Return the existing placeholder
return self.tensor(existing_id);
}
// Create a new placeholder tensor with the given name and shape
Tensor::builder(self)
.set_placeholder_name(name)
.set_knownshape(shape)
.build(T::basic_source_ops::Placeholder)
}
/// Creates a constant tensor from an ndarray.
///
/// This is a convenience method that wraps `tensor_ops::convert_to_tensor`.
/// Accepts arrays of any dimension and automatically converts them to dynamic dimensions.
///
/// # Example
/// ```
/// use scirs2_autograd as ag;
/// use scirs2_core::ndarray::array;
///
/// ag::run(|ctx| {
/// let c = ctx.constant(array![[1., 2.], [3., 4.]]);
/// // Use c in computations...
/// });
/// ```
#[inline]
pub fn constant<D>(&'graph self, arr: scirs2_core::ndarray::Array<F, D>) -> Tensor<'graph, F>
where
D: scirs2_core::ndarray::Dimension,
{
crate::tensor_ops::convert_to_tensor(arr, self)
}
}
#[allow(clippy::needless_lifetimes)]
impl<'env, F: Float> Deref for Context<'env, F> {
type Target = Graph<F>;
#[inline]
fn deref(&self) -> &Self::Target {
&self.graph
}
}
pub trait AsGraph<F: Float> {
fn as_graph(&self) -> &Graph<F>;
// Get a reference to the variable environment
fn env_ref(&self) -> &VariableEnvironment<F>;
// Get a reference to the context (if available)
fn context_ref(&self) -> Option<&Context<F>> {
None
}
// Get or create a variable tensor by ID
fn variable_by_id(&self, vid: VariableID) -> Tensor<F> {
self.as_graph().variable_by_id(vid)
}
}
impl<F: Float> AsGraph<F> for Graph<F> {
#[inline]
fn as_graph(&self) -> &Graph<F> {
self
}
// Return a reference to the current variable environment
// This is a simple placeholder implementation for AsGraph trait
#[inline]
fn env_ref(&self) -> &VariableEnvironment<F> {
// This should never be called in practice since we simplified the variable function
panic!("env_ref called on Graph, but Graph has no associated environment")
}
}
impl<F: Float> AsGraph<F> for Context<'_, F> {
#[inline]
fn as_graph(&self) -> &Graph<F> {
&self.graph
}
#[inline]
fn env_ref(&self) -> &VariableEnvironment<F> {
self.var_env_ref
}
#[inline]
fn context_ref(&self) -> Option<&Context<F>> {
Some(self)
}
}
impl<F: Float> Default for Graph<F> {
fn default() -> Self {
Self {
node_set: RefCell::new(Vec::new()),
variable2node: RefCell::new(HashMap::new()),
}
}
}
#[inline]
pub(crate) fn assert_same_graph<F: Float>(a: &impl AsGraph<F>, b: &impl AsGraph<F>) {
assert_eq!(
a.as_graph() as *const _,
b.as_graph() as *const _,
"Detected tensors belonging to different graphs"
);
}
#[test]
#[should_panic]
#[allow(dead_code)]
fn test_mixed_graph() {
VariableEnvironment::<f32>::new().run(|g| {
let a = T::zeros(&[1], g);
VariableEnvironment::<f32>::new().run(|g2| {
let b = T::zeros(&[1], g2);
let _ = a + b;
});
});
}