1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
use crate::autograd::Variable;
use crate::tensor::{Result, TensorError};
use super::trend::{Trend, TrendGroup};
use super::Graph;
/// Reduction strategy for non-scalar tagged outputs in collect_with().
pub enum Reduce {
/// Arithmetic mean over all elements.
Mean,
/// Sum of all elements.
Sum,
/// Maximum element value.
Max,
/// Minimum element value.
Min,
/// L2 norm (Frobenius norm for matrices).
Norm,
}
impl Reduce {
fn apply(&self, var: &Variable) -> Result<f64> {
let t = var.data();
if t.numel() == 0 {
return Err(TensorError::new("cannot reduce empty tensor"));
}
let scalar = match self {
Reduce::Mean => t.mean()?,
Reduce::Sum => t.sum()?,
Reduce::Max => t.max()?,
Reduce::Min => t.min()?,
Reduce::Norm => t.norm()?,
};
scalar.item()
}
}
impl Graph {
/// Get the output of a tagged node from the last forward pass.
pub fn tagged(&self, tag: &str) -> Option<Variable> {
self.tagged_outputs.borrow().get(tag).cloned()
}
/// Get all tag names defined in this graph.
pub fn tag_names(&self) -> Vec<String> {
self.tag_names.keys().cloned().collect()
}
/// Snapshot current scalar values of tagged nodes into the batch buffer.
/// Returns an error if any tag has a non-scalar output — use collect_with()
/// with an explicit reduction for non-scalar tags.
pub fn collect(&self, tags: &[&str]) -> Result<()> {
let tagged = self.tagged_outputs.borrow();
let mut buffer = self.batch_buffer.borrow_mut();
let mut order = self.metric_order.borrow_mut();
for &tag in tags {
if let Some(var) = tagged.get(tag) {
match var.item() {
Ok(val) => {
if !buffer.contains_key(tag) && !order.iter().any(|n| n == tag) {
order.push(tag.to_string());
}
buffer.entry(tag.to_string()).or_default().push(val);
}
Err(_) => {
return Err(TensorError::new(&format!(
"tag {:?} has shape {:?} (not scalar); use collect_with() to specify a reduction",
tag, var.shape()
)));
}
}
}
}
Ok(())
}
/// Snapshot tagged node values into the batch buffer using a reduction.
/// Tag group names are automatically expanded to their members.
/// Each tag's output is reduced to a scalar and recorded individually.
pub fn collect_with(&self, tags: &[&str], reduce: Reduce) -> Result<()> {
let expanded = self.expand_groups(tags);
let tagged = self.tagged_outputs.borrow();
let mut buffer = self.batch_buffer.borrow_mut();
let mut order = self.metric_order.borrow_mut();
for tag in &expanded {
if let Some(var) = tagged.get(tag) {
// Scalar tags work directly, non-scalar get reduced
let val = match var.item() {
Ok(v) => v,
Err(_) => reduce.apply(var)?,
};
if !buffer.contains_key(tag.as_str()) && !order.iter().any(|n| n == tag) {
order.push(tag.clone());
}
buffer.entry(tag.clone()).or_default().push(val);
}
}
Ok(())
}
/// Inject external scalar values into the batch buffer.
///
/// Recorded values accumulate per step and are averaged on
/// [`flush()`](Self::flush). Use [`trend()`](Self::trend) to read epoch
/// history for training decisions (early stopping, LR scheduling).
///
/// For human-facing output (terminal, live dashboard), use
/// [`Monitor::log()`](crate::monitor::Monitor::log) instead.
pub fn record(&self, tag: &str, values: &[f64]) {
let mut buffer = self.batch_buffer.borrow_mut();
if !buffer.contains_key(tag) {
let mut order = self.metric_order.borrow_mut();
if !order.iter().any(|n| n == tag) {
order.push(tag.to_string());
}
}
buffer.entry(tag.to_string()).or_default().extend_from_slice(values);
}
/// Record a single scalar value. Convenience wrapper around [`record`](Self::record).
pub fn record_scalar(&self, tag: &str, value: f64) {
self.record(tag, &[value]);
}
/// Return the latest epoch value for every tag in the epoch history.
///
/// Useful for bridging graph observation into
/// [`Monitor::log()`](crate::monitor::Monitor::log). Returns an empty
/// vec if no epochs have been flushed yet.
pub fn latest_metrics(&self) -> Vec<(String, f64)> {
let history = self.epoch_history.borrow();
let order = self.metric_order.borrow();
order
.iter()
.filter_map(|tag| {
history.get(tag).and_then(|vals| vals.last().map(|&v| (tag.clone(), v)))
})
.collect()
}
/// Read raw batch buffer for a tag (all values since last flush).
pub fn collected(&self, tag: &str) -> Vec<f64> {
self.batch_buffer.borrow().get(tag).cloned().unwrap_or_default()
}
/// Compute batch means, append to epoch history, clear batch buffer.
/// Call once per epoch. If tags is empty, flushes all buffered tags.
pub fn flush(&self, tags: &[&str]) {
let mut buffer = self.batch_buffer.borrow_mut();
let mut history = self.epoch_history.borrow_mut();
let keys: Vec<String> = if tags.is_empty() {
buffer.keys().cloned().collect()
} else {
tags.iter().map(|t| t.to_string()).collect()
};
let mut flushed_any = false;
for key in &keys {
if let Some(values) = buffer.remove(key)
&& !values.is_empty()
{
let mean = values.iter().sum::<f64>() / values.len() as f64;
history.entry(key.clone()).or_default().push(mean);
flushed_any = true;
}
}
if flushed_any {
let count = self.flush_count.get();
self.flush_count.set(count + 1);
self.flush_times.borrow_mut().push(
super::instant_secs() - self.training_start.get(),
);
}
}
/// Number of flush calls that produced data.
pub fn flush_count(&self) -> usize {
self.flush_count.get()
}
/// Get epoch-level trend for a tag.
pub fn trend(&self, tag: &str) -> Trend {
let history = self.epoch_history.borrow();
Trend::new(history.get(tag).cloned().unwrap_or_default())
}
/// Get trends for multiple tags. Tag group names are automatically
/// expanded to their member tags.
pub fn trends(&self, tags: &[&str]) -> TrendGroup {
let expanded = self.expand_groups(tags);
let history = self.epoch_history.borrow();
let trends = expanded
.iter()
.map(|tag| Trend::new(history.get(tag).cloned().unwrap_or_default()))
.collect();
TrendGroup(trends)
}
/// Clear epoch history. If tags is empty, clears all.
/// Tag group names are automatically expanded.
pub fn reset_trend(&self, tags: &[&str]) {
let mut history = self.epoch_history.borrow_mut();
if tags.is_empty() {
history.clear();
} else {
let expanded = self.expand_groups(tags);
for tag in &expanded {
history.remove(tag);
}
}
}
/// Get per-iteration trace outputs from loop nodes.
/// Returns the trace buffer for the loop node associated with the given tag.
/// The tag should be set on a node after the loop (the loop output flows to it).
/// Returns None if no loop node with a trace buffer is found.
pub fn traces(&self, tag: &str) -> Option<Vec<Variable>> {
// Look for loop nodes by checking trace_buf
// If a tag is given, find the node it references and walk back to find the loop
if let Some(&(ni, _)) = self.tag_names.get(tag) {
// Check if this node has a trace_buf
if let Some(ref buf) = self.nodes[ni].trace_buf {
let traces = buf.borrow().clone();
if !traces.is_empty() {
return Some(traces);
}
}
}
// Search all nodes for a matching tag in the node id
for node in &self.nodes {
if let Some(ref buf) = node.trace_buf {
let traces = buf.borrow().clone();
if !traces.is_empty() && node.id.contains("loop") {
// If no tag match, return first loop with traces
return Some(traces);
}
}
}
None
}
/// Get trace buffer directly from a loop node by node ID.
pub fn traces_by_node(&self, node_id: &str) -> Option<Vec<Variable>> {
if let Some(&ni) = self.node_index.get(node_id)
&& let Some(ref buf) = self.nodes[ni].trace_buf
{
let traces = buf.borrow().clone();
if !traces.is_empty() {
return Some(traces);
}
}
None
}
/// Get the last trace output from the most recent loop iteration.
///
/// Convenience wrapper around [`traces()`](Self::traces) that returns only
/// the final iteration's trace. Useful for chaining loops where the last
/// output of one (e.g. scan) feeds into the next (e.g. read).
///
/// Returns `None` if the tag has no associated loop or the body produced
/// no traces.
pub fn last_trace(&self, tag: &str) -> Option<Variable> {
self.traces(tag).and_then(|v| v.into_iter().last())
}
/// Estimated time remaining based on average flush duration.
///
/// Returns seconds remaining. Returns 0.0 if no flushes have occurred yet.
pub fn eta(&self, total_epochs: usize) -> f64 {
let count = self.flush_count.get();
if count == 0 {
return 0.0;
}
let times = self.flush_times.borrow();
let elapsed = times[count - 1]; // already relative to training_start
let per_flush = elapsed / count as f64;
let remaining = total_epochs.saturating_sub(count);
per_flush * remaining as f64
}
/// Expand tag group names into their member tags.
/// Non-group tags pass through unchanged.
pub(crate) fn expand_groups(&self, tags: &[&str]) -> Vec<String> {
let mut expanded = Vec::new();
for &tag in tags {
if let Some(members) = self.tag_groups.get(tag) {
expanded.extend(members.iter().cloned());
} else {
expanded.push(tag.to_string());
}
}
expanded
}
}