1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
use crate::autograd::Variable;
use crate::tensor::{Result, TensorError};
use super::trend::{Trend, TrendGroup};
use super::Graph;
/// Reduction strategy for non-scalar tagged outputs in collect_with().
pub enum Reduce {
/// Arithmetic mean over all elements.
Mean,
/// Sum of all elements.
Sum,
/// Maximum element value.
Max,
/// Minimum element value.
Min,
/// L2 norm (Frobenius norm for matrices).
Norm,
}
impl Reduce {
fn apply(&self, var: &Variable) -> Result<f64> {
let t = var.data();
if t.numel() == 0 {
return Err(TensorError::new("cannot reduce empty tensor"));
}
let scalar = match self {
Reduce::Mean => t.mean()?,
Reduce::Sum => t.sum()?,
Reduce::Max => t.max()?,
Reduce::Min => t.min()?,
Reduce::Norm => t.norm()?,
};
scalar.item()
}
}
impl Graph {
/// Get the output of a tagged node from the last forward pass.
pub fn tagged(&self, tag: &str) -> Option<Variable> {
self.tagged_outputs.borrow().get(tag).cloned()
}
/// Get all tag names defined in this graph.
pub fn tag_names(&self) -> Vec<String> {
self.tag_names.keys().cloned().collect()
}
/// Snapshot current scalar values of tagged nodes into the batch buffer.
/// Returns an error if any tag has a non-scalar output — use collect_with()
/// with an explicit reduction for non-scalar tags.
pub fn collect(&self, tags: &[&str]) -> Result<()> {
let tagged = self.tagged_outputs.borrow();
let mut buffer = self.batch_buffer.borrow_mut();
let mut order = self.metric_order.borrow_mut();
for &tag in tags {
if let Some(var) = tagged.get(tag) {
match var.item() {
Ok(val) => {
if !buffer.contains_key(tag) && !order.iter().any(|n| n == tag) {
order.push(tag.to_string());
}
buffer.entry(tag.to_string()).or_default().push(val);
}
Err(_) => {
return Err(TensorError::new(&format!(
"tag {:?} has shape {:?} (not scalar); use collect_with() to specify a reduction",
tag, var.shape()
)));
}
}
}
}
Ok(())
}
/// Snapshot tagged node values into the batch buffer using a reduction.
/// Tag group names are automatically expanded to their members.
/// Each tag's output is reduced to a scalar and recorded individually.
pub fn collect_with(&self, tags: &[&str], reduce: Reduce) -> Result<()> {
let expanded = self.expand_groups(tags);
let tagged = self.tagged_outputs.borrow();
let mut buffer = self.batch_buffer.borrow_mut();
let mut order = self.metric_order.borrow_mut();
for tag in &expanded {
if let Some(var) = tagged.get(tag) {
// Scalar tags work directly, non-scalar get reduced
let val = match var.item() {
Ok(v) => v,
Err(_) => reduce.apply(var)?,
};
if !buffer.contains_key(tag.as_str()) && !order.iter().any(|n| n == tag) {
order.push(tag.clone());
}
buffer.entry(tag.clone()).or_default().push(val);
}
}
Ok(())
}
/// Inject external scalar values into the batch buffer.
///
/// Recorded values accumulate per step and are averaged on
/// [`flush()`](Self::flush). Use [`trend()`](Self::trend) to read epoch
/// history for training decisions (early stopping, LR scheduling).
///
/// For human-facing output (terminal, live dashboard), use
/// [`Monitor::log()`](crate::monitor::Monitor::log) instead.
pub fn record(&self, tag: &str, values: &[f64]) {
let mut buffer = self.batch_buffer.borrow_mut();
if !buffer.contains_key(tag) {
let mut order = self.metric_order.borrow_mut();
if !order.iter().any(|n| n == tag) {
order.push(tag.to_string());
}
}
buffer.entry(tag.to_string()).or_default().extend_from_slice(values);
}
/// Record a single scalar value. Convenience wrapper around [`record`](Self::record).
pub fn record_scalar(&self, tag: &str, value: f64) {
self.record(tag, &[value]);
}
/// Return the latest epoch value for every tag in the epoch history.
///
/// **Tree-aware**: automatically collects from labeled child subgraphs
/// with dotted prefixes (e.g. a child labeled `"subscan"` with tag `"ce"`
/// appears as `"subscan.ce"`). Parent metrics come first, then children
/// in registration order.
///
/// Useful for bridging graph observation into
/// [`Monitor::log()`](crate::monitor::Monitor::log). Returns an empty
/// vec if no epochs have been flushed yet.
///
/// Use [`latest_metrics_local()`](Self::latest_metrics_local) if you
/// only want this graph's own metrics.
pub fn latest_metrics(&self) -> Vec<(String, f64)> {
let mut metrics = self.latest_metrics_local();
// Collect from labeled children with dotted prefixes
for (label, &ni) in &self.children {
if let Some(ref module) = self.nodes[ni].module
&& let Some(child) = module.as_graph()
{
for (tag, val) in child.latest_metrics() {
metrics.push((format!("{}.{}", label, tag), val));
}
}
}
metrics
}
/// Return latest epoch values for this graph only, without child metrics.
///
/// Use this when you need only the local metrics (e.g. when children
/// report on a different cadence). See [`latest_metrics()`](Self::latest_metrics)
/// for the tree-recursive version.
pub fn latest_metrics_local(&self) -> Vec<(String, f64)> {
let history = self.epoch_history.borrow();
let order = self.metric_order.borrow();
order
.iter()
.filter_map(|tag| {
history.get(tag).and_then(|vals| vals.last().map(|&v| (tag.clone(), v)))
})
.collect()
}
/// Read raw batch buffer for a tag (all values since last flush).
pub fn collected(&self, tag: &str) -> Vec<f64> {
self.batch_buffer.borrow().get(tag).cloned().unwrap_or_default()
}
/// Compute batch means, append to epoch history, clear batch buffer.
/// Call once per epoch. If tags is empty, flushes all buffered tags.
///
/// **Tree-aware**: automatically recurses into labeled child subgraphs,
/// so a single `parent.flush(&[])` flushes the entire tree. Child buffers
/// that are already empty (e.g. flushed separately) are skipped safely.
///
/// If you need **different flush cadences** per subgraph (e.g. flushing a
/// child every 10 parent epochs), use [`flush_local()`](Self::flush_local)
/// on both the parent and the child to manage them independently:
///
/// ```ignore
/// // Every epoch: flush parent only
/// parent.flush_local(&[]);
/// // Every 10 epochs: flush the child
/// if epoch % 10 == 0 {
/// parent.child_graph("slow_child").unwrap().flush_local(&[]);
/// }
/// ```
pub fn flush(&self, tags: &[&str]) {
self.flush_local(tags);
// Recurse into labeled children
for &ni in self.children.values() {
if let Some(ref module) = self.nodes[ni].module
&& let Some(child) = module.as_graph()
{
child.flush(&[]);
}
}
}
/// Flush only this graph's own batch buffer, without recursing into children.
///
/// Use this when you need independent flush cadences per subgraph.
/// See [`flush()`](Self::flush) for the tree-recursive version.
pub fn flush_local(&self, tags: &[&str]) {
let mut buffer = self.batch_buffer.borrow_mut();
let mut history = self.epoch_history.borrow_mut();
let keys: Vec<String> = if tags.is_empty() {
buffer.keys().cloned().collect()
} else {
tags.iter().map(|t| t.to_string()).collect()
};
let mut flushed_any = false;
for key in &keys {
if let Some(values) = buffer.remove(key)
&& !values.is_empty()
{
let mean = values.iter().sum::<f64>() / values.len() as f64;
history.entry(key.clone()).or_default().push(mean);
flushed_any = true;
}
}
if flushed_any {
let count = self.flush_count.get();
self.flush_count.set(count + 1);
self.flush_times.borrow_mut().push(
super::instant_secs() - self.training_start.get(),
);
}
}
/// Number of flush calls that produced data.
pub fn flush_count(&self) -> usize {
self.flush_count.get()
}
/// Get epoch-level trend for a tag.
pub fn trend(&self, tag: &str) -> Trend {
let history = self.epoch_history.borrow();
Trend::new(history.get(tag).cloned().unwrap_or_default())
}
/// Get trends for multiple tags. Tag group names are automatically
/// expanded to their member tags.
pub fn trends(&self, tags: &[&str]) -> TrendGroup {
let expanded = self.expand_groups(tags);
let history = self.epoch_history.borrow();
let trends = expanded
.iter()
.map(|tag| Trend::new(history.get(tag).cloned().unwrap_or_default()))
.collect();
TrendGroup(trends)
}
/// Clear epoch history. If tags is empty, clears all.
/// Tag group names are automatically expanded.
pub fn reset_trend(&self, tags: &[&str]) {
let mut history = self.epoch_history.borrow_mut();
if tags.is_empty() {
history.clear();
} else {
let expanded = self.expand_groups(tags);
for tag in &expanded {
history.remove(tag);
}
}
}
/// Get per-iteration trace outputs from loop nodes.
/// Returns the trace buffer for the loop node associated with the given tag.
/// The tag should be set on a node after the loop (the loop output flows to it).
/// Returns None if no loop node with a trace buffer is found.
pub fn traces(&self, tag: &str) -> Option<Vec<Variable>> {
// Look for loop nodes by checking trace_buf
// If a tag is given, find the node it references and walk back to find the loop
if let Some(&(ni, _)) = self.tag_names.get(tag) {
// Check if this node has a trace_buf
if let Some(ref buf) = self.nodes[ni].trace_buf {
let traces = buf.borrow().clone();
if !traces.is_empty() {
return Some(traces);
}
}
}
// Search all nodes for a matching tag in the node id
for node in &self.nodes {
if let Some(ref buf) = node.trace_buf {
let traces = buf.borrow().clone();
if !traces.is_empty() && node.id.contains("loop") {
// If no tag match, return first loop with traces
return Some(traces);
}
}
}
None
}
/// Get trace buffer directly from a loop node by node ID.
pub fn traces_by_node(&self, node_id: &str) -> Option<Vec<Variable>> {
if let Some(&ni) = self.node_index.get(node_id)
&& let Some(ref buf) = self.nodes[ni].trace_buf
{
let traces = buf.borrow().clone();
if !traces.is_empty() {
return Some(traces);
}
}
None
}
/// Replace trace buffer contents for the given tag.
///
/// Used by El Che gathering to set catted traces from all devices/batches.
pub(crate) fn set_traces(&self, tag: &str, traces: Vec<Variable>) {
if let Some(&(ni, _)) = self.tag_names.get(tag) {
if let Some(ref buf) = self.nodes[ni].trace_buf {
*buf.borrow_mut() = traces;
}
}
}
/// Get the last trace output from the most recent loop iteration.
///
/// Convenience wrapper around [`traces()`](Self::traces) that returns only
/// the final iteration's trace. Useful for chaining loops where the last
/// output of one (e.g. scan) feeds into the next (e.g. read).
///
/// Returns `None` if the tag has no associated loop or the body produced
/// no traces.
pub fn last_trace(&self, tag: &str) -> Option<Variable> {
self.traces(tag).and_then(|v| v.into_iter().last())
}
/// Estimated time remaining based on average flush duration.
///
/// Returns seconds remaining. Returns 0.0 if no flushes have occurred yet.
pub fn eta(&self, total_epochs: usize) -> f64 {
let count = self.flush_count.get();
if count == 0 {
return 0.0;
}
let times = self.flush_times.borrow();
let elapsed = times[count - 1]; // already relative to training_start
let per_flush = elapsed / count as f64;
let remaining = total_epochs.saturating_sub(count);
per_flush * remaining as f64
}
/// Expand tag group names into their member tags.
/// Non-group tags pass through unchanged.
pub(crate) fn expand_groups(&self, tags: &[&str]) -> Vec<String> {
let mut expanded = Vec::new();
for &tag in tags {
if let Some(members) = self.tag_groups.get(tag) {
expanded.extend(members.iter().cloned());
} else {
expanded.push(tag.to_string());
}
}
expanded
}
}