Skip to main content

cjc_runtime/
idx.rs

1//! Phase 2b — typed-ID newtype(s) for ML metadata in `cjc-runtime`.
2//!
3//! Currently provides only [`ParamIdx`], used by `crate::ml::AdamState`
4//! and `crate::ml::SgdState` to type optimizer-state buffers.
5//!
6//! ## Why a separate file from `cjc-ad/src/idx.rs`?
7//!
8//! Phase 2a (PR #8) introduced `NodeIdx`, `ParamIdx`, `LayerIdx` in
9//! `cjc-ad/src/idx.rs`. Architecturally those newtypes belong in
10//! `cjc-runtime` (the foundation crate, upstream of `cjc-ad`), but
11//! moving them at this stage would conflict with Phase 2a's open PR
12//! diff. Instead, this Phase 2b PR defines `ParamIdx` *locally* in
13//! `cjc-runtime` for the optimizer-state migration; a future
14//! "Phase 2 cleanup" PR will consolidate the typed-ID newtypes in
15//! one canonical home.
16//!
17//! Until then, `cjc_runtime::idx::ParamIdx` and
18//! `cjc_ad::idx::ParamIdx` are distinct types with identical shape.
19//! They are not used in the same code today, so the duplication is
20//! harmless.
21//!
22//! ## Repr
23//!
24//! `repr(transparent)` over `u32`: ABI-identical to `u32`, zero
25//! runtime overhead, FFI-stable bit pattern.
26
27/// Index into a per-parameter optimizer-state buffer.
28///
29/// The optimizer enumerates trainable parameters in registration order.
30/// `ParamIdx(0)` always refers to the first registered parameter,
31/// `ParamIdx(1)` to the second, etc. This is *not* a graph-node index
32/// (those are `cjc_ad::NodeIdx`); the two integer spaces are unrelated
33/// even though both number from zero.
34#[repr(transparent)]
35#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
36pub struct ParamIdx(pub u32);
37
38impl ParamIdx {
39    #[inline]
40    pub fn from_usize(i: usize) -> Self {
41        debug_assert!(i <= u32::MAX as usize, "ParamIdx overflow");
42        Self(i as u32)
43    }
44
45    #[inline]
46    pub fn index(self) -> usize {
47        self.0 as usize
48    }
49}
50
51impl From<ParamIdx> for usize {
52    #[inline]
53    fn from(idx: ParamIdx) -> Self {
54        idx.0 as usize
55    }
56}
57
58impl std::fmt::Display for ParamIdx {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        write!(f, "{}", self.0)
61    }
62}
63
64#[cfg(test)]
65mod tests {
66    use super::*;
67
68    #[test]
69    fn from_usize_roundtrip() {
70        for i in [0usize, 1, 7, 64, 1024] {
71            let p = ParamIdx::from_usize(i);
72            assert_eq!(p.index(), i);
73            assert_eq!(usize::from(p), i);
74        }
75    }
76
77    #[test]
78    fn ord_matches_inner() {
79        assert!(ParamIdx(0) < ParamIdx(1));
80        assert!(ParamIdx(7) < ParamIdx(8));
81    }
82
83    #[test]
84    fn repr_transparent_means_size_eq_u32() {
85        assert_eq!(std::mem::size_of::<ParamIdx>(), std::mem::size_of::<u32>());
86    }
87
88    #[test]
89    fn display_format() {
90        assert_eq!(format!("{}", ParamIdx(42)), "42");
91    }
92}