Skip to main content

flodl/graph/
profile.rs

1use std::fmt;
2use std::time::Duration;
3
4use super::trend::{Trend, TrendGroup};
5use super::Graph;
6
7/// Per-node execution time from a single Forward pass.
8#[derive(Clone, Debug)]
9pub struct NodeTiming {
10    pub id: String,
11    pub tag: String,
12    pub duration: Duration,
13    pub level: usize,
14}
15
16/// Per-level execution time. Multi-node levels could theoretically
17/// benefit from parallelism — `parallelism()` measures efficiency.
18#[derive(Clone, Debug)]
19pub struct LevelTiming {
20    pub index: usize,
21    pub wall_clock: Duration,
22    pub sum_nodes: Duration,
23    pub num_nodes: usize,
24}
25
26impl LevelTiming {
27    /// Ratio of sequential node time to wall-clock time.
28    /// Values above 1.0 indicate effective parallelism.
29    /// Returns 1.0 for single-node levels.
30    pub fn parallelism(&self) -> f64 {
31        if self.wall_clock.is_zero() || self.num_nodes <= 1 {
32            return 1.0;
33        }
34        self.sum_nodes.as_secs_f64() / self.wall_clock.as_secs_f64()
35    }
36}
37
38/// Timing data from a single Forward pass.
39#[derive(Clone, Debug)]
40pub struct Profile {
41    pub total: Duration,
42    pub levels: Vec<LevelTiming>,
43    pub nodes: Vec<NodeTiming>,
44}
45
46impl Profile {
47    /// Duration of a tagged node, or zero if not found.
48    pub fn timing(&self, tag: &str) -> Duration {
49        for n in &self.nodes {
50            if n.tag == tag {
51                return n.duration;
52            }
53        }
54        Duration::ZERO
55    }
56}
57
58impl fmt::Display for Profile {
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        writeln!(
61            f,
62            "Forward: {:?} ({} levels, {} nodes)",
63            self.total,
64            self.levels.len(),
65            self.nodes.len()
66        )?;
67
68        let mut node_idx = 0;
69        for level in &self.levels {
70            write!(f, "\n  Level {}  {:?}", level.index, level.wall_clock)?;
71            if level.num_nodes > 1 {
72                write!(
73                    f,
74                    "  {} nodes  x{:.1}",
75                    level.num_nodes,
76                    level.parallelism()
77                )?;
78            }
79            writeln!(f)?;
80
81            while node_idx < self.nodes.len()
82                && self.nodes[node_idx].level == level.index
83            {
84                let n = &self.nodes[node_idx];
85                let mut label = n.id.clone();
86                if !n.tag.is_empty() {
87                    label += &format!(" {:?}", n.tag);
88                }
89                writeln!(f, "    {:<40} {:?}", label, n.duration)?;
90                node_idx += 1;
91            }
92        }
93
94        Ok(())
95    }
96}
97
98// --- Graph profiling methods ---
99
100impl Graph {
101    /// Turn on per-node and per-level timing for subsequent forward calls.
102    pub fn enable_profiling(&self) {
103        self.profiling.set(true);
104    }
105
106    /// Turn off timing. Subsequent forward calls have zero profiling overhead.
107    pub fn disable_profiling(&self) {
108        self.profiling.set(false);
109        *self.last_profile.borrow_mut() = None;
110    }
111
112    /// Whether profiling is currently enabled.
113    pub fn profiling(&self) -> bool {
114        self.profiling.get()
115    }
116
117    /// Timing data from the most recent forward call, or None.
118    pub fn profile(&self) -> Option<Profile> {
119        self.last_profile.borrow().clone()
120    }
121
122    /// Duration of a tagged node from the most recent forward call.
123    pub fn timing(&self, tag: &str) -> Duration {
124        self.last_profile
125            .borrow()
126            .as_ref()
127            .map(|p| p.timing(tag))
128            .unwrap_or(Duration::ZERO)
129    }
130
131    /// Snapshot tagged node durations into the timing batch buffer.
132    /// If tags is empty, all tagged nodes with timing data are collected.
133    pub fn collect_timings(&self, tags: &[&str]) {
134        let profile = self.last_profile.borrow();
135        let profile = match profile.as_ref() {
136            Some(p) => p,
137            None => return,
138        };
139        let mut buffer = self.timing_buffer.borrow_mut();
140
141        if tags.is_empty() {
142            for n in &profile.nodes {
143                if !n.tag.is_empty() {
144                    buffer
145                        .entry(n.tag.clone())
146                        .or_default()
147                        .push(n.duration.as_secs_f64());
148                }
149            }
150        } else {
151            for &tag in tags {
152                let d = profile.timing(tag);
153                if !d.is_zero() {
154                    buffer
155                        .entry(tag.to_string())
156                        .or_default()
157                        .push(d.as_secs_f64());
158                }
159            }
160        }
161    }
162
163    /// Compute batch mean, append to timing epoch history, clear buffer.
164    /// If tags is empty, flushes all buffered tags.
165    pub fn flush_timings(&self, tags: &[&str]) {
166        let mut buffer = self.timing_buffer.borrow_mut();
167        let mut history = self.timing_history.borrow_mut();
168
169        let keys: Vec<String> = if tags.is_empty() {
170            buffer.keys().cloned().collect()
171        } else {
172            tags.iter().map(|t| t.to_string()).collect()
173        };
174
175        for key in &keys {
176            if let Some(values) = buffer.remove(key)
177                && !values.is_empty()
178            {
179                let mean = values.iter().sum::<f64>() / values.len() as f64;
180                history.entry(key.clone()).or_default().push(mean);
181            }
182        }
183    }
184
185    /// Epoch-level trend over the timing history of a tagged node.
186    /// Values are mean execution times in seconds.
187    pub fn timing_trend(&self, tag: &str) -> Trend {
188        let history = self.timing_history.borrow();
189        Trend::new(history.get(tag).cloned().unwrap_or_default())
190    }
191
192    /// TrendGroup for timing trends of the given tags (expands groups).
193    pub fn timing_trends(&self, tags: &[&str]) -> TrendGroup {
194        let expanded = self.expand_groups(tags);
195        let history = self.timing_history.borrow();
196        let trends = expanded
197            .iter()
198            .map(|tag| Trend::new(history.get(tag).cloned().unwrap_or_default()))
199            .collect();
200        TrendGroup(trends)
201    }
202
203    /// Clear timing epoch history. If tags is empty, clears all.
204    pub fn reset_timing_trend(&self, tags: &[&str]) {
205        let mut history = self.timing_history.borrow_mut();
206        if tags.is_empty() {
207            history.clear();
208        } else {
209            for tag in tags {
210                history.remove(*tag);
211            }
212        }
213    }
214}