Skip to main content

cjc_runtime/
profile.rs

1//! Deterministic profile counters for Tier 2 of the Chess RL v2.3 upgrade.
2//!
3//! This module provides a minimal, write-only profiling sink that the
4//! CJC-Lang program can use to time named zones inside a hot loop. It is
5//! the smallest possible surface that makes the v2.2 bottleneck measurable
6//! without perturbing program state.
7//!
8//! # Builtins
9//!
10//! - `profile_zone_start(name: String) -> i64`
11//! - `profile_zone_stop(handle: i64) -> f64`
12//! - `profile_dump(path: String) -> i64`
13//!
14//! All three dispatch arms live in [`crate::builtins`]; this module owns
15//! only the thread-local state and the pure helper functions that operate
16//! on it.
17//!
18//! # Determinism story
19//!
20//! The counter state lives in a thread-local `RefCell<ProfileState>`. The
21//! program can **observe** the identity of a zone handle (to pair start/
22//! stop calls), but the *integer value* of the handle must not feed into
23//! program logic, RNG draws, tensor math, or control flow. The Chess RL
24//! v2.3 parity test asserts that an instrumented rollout produces a
25//! weight hash identical to an uninstrumented rollout.
26//!
27//! No floating-point math is done on the counters until `profile_dump`
28//! renders the CSV, by which point the nanosecond counters are committed.
29//! All iteration over zones uses [`BTreeMap`] ordering so the CSV layout
30//! is reproducible across runs.
31//!
32//! No RNG is touched. No tensor math is touched. No cross-thread state
33//! is accessed. No external crates are pulled in — only
34//! `std::time::Instant`, `std::collections::BTreeMap`, and
35//! `std::cell::RefCell`.
36
37use std::cell::RefCell;
38use std::collections::BTreeMap;
39use std::time::Instant;
40
41/// Per-zone aggregated statistics.
42#[derive(Clone, Debug, Default)]
43pub struct ZoneStats {
44    /// Number of times the zone has been stopped.
45    pub count: u64,
46    /// Total nanoseconds spent in the zone.
47    pub total_ns: u128,
48    /// Minimum single-call nanoseconds.
49    pub min_ns: u128,
50    /// Maximum single-call nanoseconds.
51    pub max_ns: u128,
52    /// Sum of `ns^2` across calls, for stddev without FMA.
53    pub sum_sq_ns: u128,
54}
55
56impl ZoneStats {
57    fn update(&mut self, ns: u128) {
58        if self.count == 0 {
59            self.min_ns = ns;
60            self.max_ns = ns;
61        } else {
62            if ns < self.min_ns {
63                self.min_ns = ns;
64            }
65            if ns > self.max_ns {
66                self.max_ns = ns;
67            }
68        }
69        self.count += 1;
70        self.total_ns = self.total_ns.saturating_add(ns);
71        // Squaring a nanosecond count can overflow u64, so we stay in u128.
72        let sq = (ns as u128).saturating_mul(ns as u128);
73        self.sum_sq_ns = self.sum_sq_ns.saturating_add(sq);
74    }
75
76    /// Integer mean nanoseconds (`total_ns / count`). Returns 0 when count is 0.
77    pub fn mean_ns(&self) -> u128 {
78        if self.count == 0 {
79            0
80        } else {
81            self.total_ns / (self.count as u128)
82        }
83    }
84
85    /// Standard deviation in nanoseconds, computed as
86    /// `sqrt(max(0, sum_sq/count - mean^2))`. Uses f64 at the last step only
87    /// — no FMA, no Kahan, because the result is for reporting only.
88    pub fn stddev_ns(&self) -> f64 {
89        if self.count == 0 {
90            return 0.0;
91        }
92        let mean = self.mean_ns() as f64;
93        let mean_sq = mean * mean;
94        let var_raw = (self.sum_sq_ns / (self.count as u128)) as f64 - mean_sq;
95        if var_raw <= 0.0 {
96            0.0
97        } else {
98            var_raw.sqrt()
99        }
100    }
101}
102
103/// Internal profiler state — one instance per thread.
104pub struct ProfileState {
105    /// Stopped zones, keyed by name for deterministic iteration.
106    pub zones: BTreeMap<String, ZoneStats>,
107    /// Currently-open zones, keyed by handle.
108    pub active: BTreeMap<i64, (String, Instant)>,
109    /// Monotonically increasing handle counter.
110    pub next_handle: i64,
111}
112
113impl ProfileState {
114    /// Fresh empty state.
115    pub fn new() -> Self {
116        Self {
117            zones: BTreeMap::new(),
118            active: BTreeMap::new(),
119            next_handle: 0,
120        }
121    }
122
123    /// Clear all state, as if the profiler had just been created.
124    pub fn reset(&mut self) {
125        self.zones.clear();
126        self.active.clear();
127        self.next_handle = 0;
128    }
129}
130
131impl Default for ProfileState {
132    fn default() -> Self {
133        Self::new()
134    }
135}
136
137thread_local! {
138    /// Thread-local profiler state. Each thread keeps its own counters;
139    /// no cross-thread coordination is done or needed.
140    pub(crate) static PROFILE: RefCell<ProfileState> = RefCell::new(ProfileState::new());
141}
142
143/// Start a zone and return an opaque handle. The handle is monotonically
144/// increasing within the current thread.
145pub fn zone_start(name: &str) -> i64 {
146    PROFILE.with(|cell| {
147        let mut state = cell.borrow_mut();
148        let handle = state.next_handle;
149        state.next_handle = state.next_handle.wrapping_add(1);
150        state
151            .active
152            .insert(handle, (name.to_string(), Instant::now()));
153        handle
154    })
155}
156
157/// Stop a zone previously started by [`zone_start`]. Updates the aggregated
158/// [`ZoneStats`] and returns the elapsed seconds as f64. Returns `-1.0` for
159/// an unknown handle (never panics).
160pub fn zone_stop(handle: i64) -> f64 {
161    PROFILE.with(|cell| {
162        let mut state = cell.borrow_mut();
163        let Some((name, start)) = state.active.remove(&handle) else {
164            return -1.0;
165        };
166        let elapsed = start.elapsed();
167        let ns = elapsed.as_nanos();
168        let entry = state.zones.entry(name).or_default();
169        entry.update(ns);
170        // Report elapsed seconds for convenience; the v2.3 parity test
171        // asserts that ignoring this value yields a bit-identical weight
172        // hash, which is the determinism contract.
173        ns as f64 / 1.0e9
174    })
175}
176
177/// Serialize the aggregated zone statistics to CSV, sorted by `total_ns`
178/// descending so the hot zones appear at the top. Returns the number of
179/// data rows written (not including the header). Resets the profiler
180/// state after writing.
181///
182/// CSV schema:
183/// ```text
184/// zone_name,count,total_ns,min_ns,max_ns,mean_ns,stddev_ns
185/// ```
186pub fn dump_to_path(path: &str) -> Result<i64, String> {
187    let csv = PROFILE.with(|cell| {
188        let mut state = cell.borrow_mut();
189
190        // Collect rows from the BTreeMap (ordered by name).
191        let mut rows: Vec<(String, ZoneStats)> = state
192            .zones
193            .iter()
194            .map(|(k, v)| (k.clone(), v.clone()))
195            .collect();
196
197        // Sort by total_ns descending. Ties fall back to the BTreeMap name
198        // ordering (which is what `stable_sort_by` inherits from `iter`).
199        rows.sort_by(|a, b| b.1.total_ns.cmp(&a.1.total_ns));
200
201        let mut out = String::new();
202        out.push_str("zone_name,count,total_ns,min_ns,max_ns,mean_ns,stddev_ns\n");
203        for (name, stats) in &rows {
204            // Round stddev to the nearest integer ns to keep CSV integer-clean.
205            let stddev_ns = stats.stddev_ns().round() as u128;
206            out.push_str(&format!(
207                "{},{},{},{},{},{},{}\n",
208                name,
209                stats.count,
210                stats.total_ns,
211                stats.min_ns,
212                stats.max_ns,
213                stats.mean_ns(),
214                stddev_ns,
215            ));
216        }
217
218        // Clear the state so subsequent runs start fresh.
219        let row_count = rows.len() as i64;
220        state.reset();
221        (out, row_count)
222    });
223
224    std::fs::write(path, &csv.0).map_err(|e| format!("profile_dump error: {e}"))?;
225    Ok(csv.1)
226}
227
228/// Test-only helper: snapshot the current zone stats without clearing
229/// state. Used by unit tests that need to inspect counters mid-run.
230#[doc(hidden)]
231pub fn snapshot_zones() -> BTreeMap<String, ZoneStats> {
232    PROFILE.with(|cell| cell.borrow().zones.clone())
233}
234
235/// Test-only helper: clear all state without writing a file.
236#[doc(hidden)]
237pub fn reset_for_test() {
238    PROFILE.with(|cell| cell.borrow_mut().reset());
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn start_stop_round_trip() {
247        reset_for_test();
248        let h = zone_start("zone_a");
249        assert_eq!(h, 0);
250        let elapsed = zone_stop(h);
251        assert!(elapsed >= 0.0);
252        let snap = snapshot_zones();
253        assert_eq!(snap.len(), 1);
254        assert_eq!(snap["zone_a"].count, 1);
255    }
256
257    #[test]
258    fn handle_is_monotonic() {
259        reset_for_test();
260        let a = zone_start("a");
261        let b = zone_start("b");
262        let c = zone_start("c");
263        assert_eq!(a, 0);
264        assert_eq!(b, 1);
265        assert_eq!(c, 2);
266        zone_stop(a);
267        zone_stop(b);
268        zone_stop(c);
269    }
270
271    #[test]
272    fn nested_zones_accumulate_independently() {
273        reset_for_test();
274        let outer = zone_start("outer");
275        let inner = zone_start("inner");
276        zone_stop(inner);
277        zone_stop(outer);
278        let snap = snapshot_zones();
279        assert_eq!(snap.len(), 2);
280        assert_eq!(snap["outer"].count, 1);
281        assert_eq!(snap["inner"].count, 1);
282    }
283
284    #[test]
285    fn repeated_zone_accumulates_count() {
286        reset_for_test();
287        for _ in 0..10 {
288            let h = zone_start("hot");
289            zone_stop(h);
290        }
291        let snap = snapshot_zones();
292        assert_eq!(snap["hot"].count, 10);
293    }
294
295    #[test]
296    fn unknown_handle_returns_negative_one() {
297        reset_for_test();
298        let e = zone_stop(9999);
299        assert!(e < 0.0);
300    }
301
302    #[test]
303    fn dump_resets_state() {
304        reset_for_test();
305        let h = zone_start("zone_x");
306        zone_stop(h);
307        let tmp = std::env::temp_dir().join("cjc_profile_dump_resets_state.csv");
308        let rows = dump_to_path(tmp.to_str().unwrap()).unwrap();
309        assert_eq!(rows, 1);
310        assert!(snapshot_zones().is_empty());
311        let content = std::fs::read_to_string(&tmp).unwrap();
312        assert!(content.starts_with(
313            "zone_name,count,total_ns,min_ns,max_ns,mean_ns,stddev_ns\n"
314        ));
315        let _ = std::fs::remove_file(&tmp);
316    }
317
318    #[test]
319    fn dump_csv_format_integer_columns() {
320        reset_for_test();
321        for _ in 0..3 {
322            let h = zone_start("z");
323            zone_stop(h);
324        }
325        let tmp = std::env::temp_dir().join("cjc_profile_dump_csv_format.csv");
326        let rows = dump_to_path(tmp.to_str().unwrap()).unwrap();
327        assert_eq!(rows, 1);
328        let content = std::fs::read_to_string(&tmp).unwrap();
329        let lines: Vec<&str> = content.lines().collect();
330        assert_eq!(lines.len(), 2);
331        let fields: Vec<&str> = lines[1].split(',').collect();
332        assert_eq!(fields.len(), 7);
333        assert_eq!(fields[0], "z");
334        // All numeric columns should parse as integers.
335        for f in &fields[1..] {
336            assert!(
337                f.parse::<u128>().is_ok(),
338                "column {f} is not an integer in v2.3 CSV"
339            );
340        }
341        let _ = std::fs::remove_file(&tmp);
342    }
343
344    #[test]
345    fn dump_sort_order_hot_first() {
346        reset_for_test();
347        // "cold" has 1 call; "hot" has a longer explicit stats injection.
348        let h_cold = zone_start("cold");
349        zone_stop(h_cold);
350        // Inject a synthetic hot zone via direct state manipulation (keeps
351        // the test under 1 ms).
352        PROFILE.with(|cell| {
353            let mut state = cell.borrow_mut();
354            let entry = state.zones.entry("hot".to_string()).or_default();
355            entry.update(10_000_000_000); // 10s
356        });
357        let tmp = std::env::temp_dir().join("cjc_profile_dump_sort_order.csv");
358        dump_to_path(tmp.to_str().unwrap()).unwrap();
359        let content = std::fs::read_to_string(&tmp).unwrap();
360        let lines: Vec<&str> = content.lines().collect();
361        assert_eq!(lines.len(), 3);
362        assert!(lines[1].starts_with("hot,"), "hot zone should be first");
363        assert!(lines[2].starts_with("cold,"), "cold zone should be second");
364        let _ = std::fs::remove_file(&tmp);
365    }
366
367    #[test]
368    fn empty_dump_writes_header_only() {
369        reset_for_test();
370        let tmp = std::env::temp_dir().join("cjc_profile_empty_dump.csv");
371        let rows = dump_to_path(tmp.to_str().unwrap()).unwrap();
372        assert_eq!(rows, 0);
373        let content = std::fs::read_to_string(&tmp).unwrap();
374        assert_eq!(
375            content,
376            "zone_name,count,total_ns,min_ns,max_ns,mean_ns,stddev_ns\n"
377        );
378        let _ = std::fs::remove_file(&tmp);
379    }
380
381    #[test]
382    fn zone_stats_update_tracks_min_max() {
383        let mut s = ZoneStats::default();
384        s.update(100);
385        s.update(50);
386        s.update(200);
387        assert_eq!(s.count, 3);
388        assert_eq!(s.min_ns, 50);
389        assert_eq!(s.max_ns, 200);
390        assert_eq!(s.total_ns, 350);
391    }
392
393    #[test]
394    fn mean_and_stddev_sane() {
395        let mut s = ZoneStats::default();
396        for ns in [100u128, 200, 300] {
397            s.update(ns);
398        }
399        assert_eq!(s.mean_ns(), 200);
400        // sum_sq = 10000 + 40000 + 90000 = 140000; /3 ≈ 46666; - 40000 = 6666
401        // sqrt(6666) ≈ 81.6
402        let sd = s.stddev_ns();
403        assert!(sd > 70.0 && sd < 100.0, "unexpected stddev: {sd}");
404    }
405}