Skip to main content

datacortex_core/mixer/
logistic.rs

1//! Logistic transforms — squash and stretch for probability <-> log-odds.
2//!
3//! stretch(p) maps 12-bit probability [1, 4095] to log-odds (scaled integer).
4//! squash(d)  maps log-odds back to 12-bit probability [1, 4095].
5//!
6//! Both use lookup tables computed at compile time.
7//!
8//! Formula pair:
9//!   squash(d) = 2048 + d * 2047 / (K + |d|)
10//!   stretch(p) = c * K / (2047 - |c|)  where c = p - 2048
11//!
12//! These are analytical inverses of each other.
13//! K=64 gives a steep sigmoid covering nearly all of [1, 4095].
14
15/// Steepness parameter. K=128 gives a moderately steep sigmoid.
16/// Higher K = wider range = better resolution near 0.5 (where most predictions fall).
17/// K=128 is a good balance: covers [~50, ~4046] while giving fine resolution to the mixer.
18const K: i32 = 128;
19
20/// Squash table. 16384 entries covering d in [-8192, 8191].
21/// Formula: p = 2048 + d * 2047 / (K + |d|)
22const SQUASH_SIZE: usize = 16384;
23const SQUASH_OFFSET: i32 = 8192;
24
25static SQUASH_TABLE: [u16; SQUASH_SIZE] = {
26    let mut table = [0u16; SQUASH_SIZE];
27    let mut i = 0usize;
28    while i < SQUASH_SIZE {
29        let d = i as i32 - SQUASH_OFFSET;
30        let abs_d = if d < 0 { -d } else { d };
31        let p = 2048 + (d * 2047) / (K + abs_d);
32        table[i] = if p < 1 {
33            1
34        } else if p > 4095 {
35            4095
36        } else {
37            p as u16
38        };
39        i += 1;
40    }
41    table
42};
43
44/// Stretch table: 4097 entries for p in [0, 4096].
45/// Formula: d = c * K / (2047 - |c|) where c = p - 2048.
46/// This is the analytical inverse of squash.
47static STRETCH_TABLE: [i16; 4097] = {
48    let mut table = [0i16; 4097];
49    let mut p = 0usize;
50    while p <= 4096 {
51        let c = p as i32 - 2048;
52        let abs_c = if c < 0 { -c } else { c };
53
54        let d = if abs_c >= 2047 {
55            if c >= 0 { 8191i32 } else { -8191i32 }
56        } else {
57            (c * K) / (2047 - abs_c)
58        };
59
60        table[p] = if d > 8191 {
61            8191
62        } else if d < -8191 {
63            -8191
64        } else {
65            d as i16
66        };
67        p += 1;
68    }
69    table
70};
71
72/// Convert 12-bit probability to log-odds.
73/// Input: p in [1, 4095].
74/// Output: log-odds as scaled integer.
75#[inline(always)]
76pub fn stretch(p: u32) -> i32 {
77    STRETCH_TABLE[p.min(4096) as usize] as i32
78}
79
80/// Convert log-odds to 12-bit probability.
81/// Input: d as scaled integer (any range, clamped internally).
82/// Output: probability in [1, 4095].
83#[inline(always)]
84pub fn squash(d: i32) -> u32 {
85    let idx = (d + SQUASH_OFFSET).clamp(0, (SQUASH_SIZE - 1) as i32) as usize;
86    SQUASH_TABLE[idx] as u32
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92
93    #[test]
94    fn squash_at_zero_is_half() {
95        let p = squash(0);
96        assert_eq!(p, 2048, "squash(0) should be exactly 2048, got {p}");
97    }
98
99    #[test]
100    fn stretch_at_half_is_zero() {
101        let d = stretch(2048);
102        assert_eq!(d, 0, "stretch(2048) should be 0, got {d}");
103    }
104
105    #[test]
106    fn squash_output_in_range() {
107        for d in -10000..=10000 {
108            let p = squash(d);
109            assert!(
110                (1..=4095).contains(&p),
111                "squash({d}) = {p}, out of [1, 4095]"
112            );
113        }
114    }
115
116    #[test]
117    fn stretch_output_bounded() {
118        for p in 1..=4095u32 {
119            let d = stretch(p);
120            assert!(
121                (-8191..=8191).contains(&d),
122                "stretch({p}) = {d}, out of bounds"
123            );
124        }
125    }
126
127    #[test]
128    fn squash_is_monotonic() {
129        let mut prev = squash(-10000);
130        for d in -9999..=10000 {
131            let p = squash(d);
132            assert!(p >= prev, "squash not monotonic at d={d}: {prev} > {p}");
133            prev = p;
134        }
135    }
136
137    #[test]
138    fn stretch_is_monotonic() {
139        let mut prev = stretch(1);
140        for p in 2..=4095u32 {
141            let d = stretch(p);
142            assert!(d >= prev, "stretch not monotonic at p={p}: {prev} > {d}");
143            prev = d;
144        }
145    }
146
147    #[test]
148    fn roundtrip_squash_stretch() {
149        // squash(stretch(p)) should be approximately p.
150        // The steep sigmoid (K=64) creates large quantization steps in the
151        // stretch table near the extremes, causing rounding errors. This is
152        // acceptable — these transforms are used in logistic mixing (Phase 3)
153        // where 1-2% error in probability space is fine.
154        let mut max_diff = 0u32;
155        for p in 100..=3996u32 {
156            let d = stretch(p);
157            let p2 = squash(d);
158            let diff = (p2 as i32 - p as i32).unsigned_abs();
159            if diff > max_diff {
160                max_diff = diff;
161            }
162        }
163        // Error comes from integer division in both tables.
164        assert!(
165            max_diff <= 35,
166            "max roundtrip error {max_diff} in range [100, 3996]"
167        );
168    }
169
170    #[test]
171    fn roundtrip_stretch_squash() {
172        // stretch(squash(d)) should be approximately d.
173        // Near extremes the sigmoid flattens, amplifying quantization error.
174        for d in -1500..=1500 {
175            let p = squash(d);
176            let d2 = stretch(p);
177            let diff = (d2 - d).unsigned_abs();
178            assert!(
179                diff <= 30,
180                "roundtrip error: d={d}, squash={p}, stretch(squash)={d2}, diff={diff}"
181            );
182        }
183    }
184
185    #[test]
186    fn symmetry() {
187        // stretch(p) should equal -stretch(4096 - p)
188        for p in 1..=4095u32 {
189            let d1 = stretch(p);
190            let d2 = stretch(4096 - p);
191            assert_eq!(
192                d1,
193                -d2,
194                "asymmetry at p={p}: stretch({p})={d1}, stretch({})={d2}",
195                4096 - p,
196            );
197        }
198    }
199
200    #[test]
201    fn squash_extremes() {
202        assert!(squash(-10000) <= 100, "squash(-10000) = {}", squash(-10000));
203        assert!(squash(10000) >= 3996, "squash(10000) = {}", squash(10000));
204    }
205
206    #[test]
207    fn stretch_extremes() {
208        assert!(stretch(1) < -60, "stretch(1) = {}", stretch(1));
209        assert!(stretch(4095) > 60, "stretch(4095) = {}", stretch(4095));
210    }
211}