Skip to main content

proof_engine/stochastic/
markov.rs

1//! Markov chains: discrete-time and continuous-time.
2//!
3//! Provides transition matrix operations, stationary distribution computation
4//! via power iteration, ergodicity/irreducibility checks, mean first passage
5//! times, and a glyph-based renderer.
6
7use super::brownian::Rng;
8use glam::Vec2;
9
10// ---------------------------------------------------------------------------
11// MarkovChain (discrete-time)
12// ---------------------------------------------------------------------------
13
14/// Discrete-time Markov chain with a finite state space.
15pub struct MarkovChain {
16    /// Number of states.
17    pub states: usize,
18    /// Row-stochastic transition matrix: transition[i][j] = P(X_{n+1}=j | X_n=i).
19    pub transition: Vec<Vec<f64>>,
20}
21
22impl MarkovChain {
23    /// Create from a transition matrix. Rows must sum to 1.
24    pub fn new(transition: Vec<Vec<f64>>) -> Self {
25        let states = transition.len();
26        Self { states, transition }
27    }
28
29    /// Create a random transition matrix.
30    pub fn random(states: usize, rng: &mut Rng) -> Self {
31        let mut transition = vec![vec![0.0; states]; states];
32        for row in transition.iter_mut() {
33            let raw: Vec<f64> = (0..states).map(|_| rng.uniform().max(0.01)).collect();
34            let sum: f64 = raw.iter().sum();
35            for (j, val) in raw.iter().enumerate() {
36                row[j] = val / sum;
37            }
38        }
39        Self { states, transition }
40    }
41
42    /// Validate that this is a proper stochastic matrix.
43    pub fn is_stochastic(&self) -> bool {
44        for row in &self.transition {
45            if row.len() != self.states {
46                return false;
47            }
48            let sum: f64 = row.iter().sum();
49            if (sum - 1.0).abs() > 1e-6 {
50                return false;
51            }
52            if row.iter().any(|&p| p < -1e-10) {
53                return false;
54            }
55        }
56        true
57    }
58
59    /// Single step: sample next state from current.
60    pub fn step(&self, rng: &mut Rng, current_state: usize) -> usize {
61        let u = rng.uniform();
62        let row = &self.transition[current_state];
63        let mut cumsum = 0.0;
64        for (j, &p) in row.iter().enumerate() {
65            cumsum += p;
66            if u < cumsum {
67                return j;
68            }
69        }
70        self.states - 1
71    }
72
73    /// Simulate a trajectory of `steps` states starting from `initial`.
74    pub fn simulate(&self, rng: &mut Rng, initial: usize, steps: usize) -> Vec<usize> {
75        let mut path = Vec::with_capacity(steps + 1);
76        path.push(initial);
77        let mut current = initial;
78        for _ in 0..steps {
79            current = self.step(rng, current);
80            path.push(current);
81        }
82        path
83    }
84
85    /// Compute stationary distribution via power iteration.
86    /// Finds pi such that pi * P = pi.
87    pub fn stationary_distribution(&self) -> Vec<f64> {
88        let n = self.states;
89        let mut pi = vec![1.0 / n as f64; n];
90        let max_iter = 10_000;
91        let tol = 1e-12;
92
93        for _ in 0..max_iter {
94            let mut next = vec![0.0; n];
95            for j in 0..n {
96                for i in 0..n {
97                    next[j] += pi[i] * self.transition[i][j];
98                }
99            }
100            // Normalize
101            let sum: f64 = next.iter().sum();
102            if sum > 0.0 {
103                for v in next.iter_mut() {
104                    *v /= sum;
105                }
106            }
107
108            // Check convergence
109            let diff: f64 = pi.iter().zip(next.iter()).map(|(a, b)| (a - b).abs()).sum();
110            pi = next;
111            if diff < tol {
112                break;
113            }
114        }
115        pi
116    }
117
118    /// Check if the chain is irreducible (all states reachable from all states).
119    pub fn is_irreducible(&self) -> bool {
120        let n = self.states;
121        // Build adjacency and do BFS from each state
122        for start in 0..n {
123            let mut visited = vec![false; n];
124            let mut queue = std::collections::VecDeque::new();
125            queue.push_back(start);
126            visited[start] = true;
127            while let Some(s) = queue.pop_front() {
128                for (j, &p) in self.transition[s].iter().enumerate() {
129                    if p > 0.0 && !visited[j] {
130                        visited[j] = true;
131                        queue.push_back(j);
132                    }
133                }
134            }
135            if visited.iter().any(|&v| !v) {
136                return false;
137            }
138        }
139        true
140    }
141
142    /// Check if the chain is ergodic (irreducible and aperiodic).
143    /// Aperiodicity is checked by verifying gcd of return times = 1,
144    /// which we approximate by checking if P^n has all positive entries for some n.
145    pub fn is_ergodic(&self) -> bool {
146        if !self.is_irreducible() {
147            return false;
148        }
149        // Check aperiodicity: if any diagonal entry > 0, then aperiodic
150        if self.transition.iter().enumerate().any(|(i, row)| row[i] > 0.0) {
151            return true;
152        }
153        // Otherwise compute P^2 + P^3 and check if all entries > 0
154        let p2 = mat_mul(&self.transition, &self.transition);
155        let p3 = mat_mul(&p2, &self.transition);
156        let combined = mat_add(&p2, &p3);
157        combined.iter().all(|row| row.iter().all(|&v| v > 1e-15))
158    }
159
160    /// Find absorbing states (states i where P(i,i) = 1).
161    pub fn absorbing_states(&self) -> Vec<usize> {
162        (0..self.states)
163            .filter(|&i| (self.transition[i][i] - 1.0).abs() < 1e-10)
164            .collect()
165    }
166
167    /// Mean first passage time from state `from` to state `to`.
168    /// Computed by solving the system: m_i = 1 + sum_{j != to} P(i,j) * m_j
169    /// using iterative method.
170    pub fn mean_first_passage(&self, from: usize, to: usize) -> f64 {
171        if from == to {
172            return 0.0;
173        }
174        let n = self.states;
175        let mut m = vec![0.0; n];
176        let max_iter = 50_000;
177        let tol = 1e-10;
178
179        for _ in 0..max_iter {
180            let mut new_m = vec![0.0; n];
181            let mut max_diff = 0.0_f64;
182            for i in 0..n {
183                if i == to {
184                    new_m[i] = 0.0;
185                    continue;
186                }
187                let mut val = 1.0;
188                for j in 0..n {
189                    if j != to {
190                        val += self.transition[i][j] * m[j];
191                    }
192                }
193                new_m[i] = val;
194                max_diff = max_diff.max((new_m[i] - m[i]).abs());
195            }
196            m = new_m;
197            if max_diff < tol {
198                break;
199            }
200        }
201        m[from]
202    }
203
204    /// Compute the n-step transition matrix P^n.
205    pub fn power(&self, n: usize) -> Vec<Vec<f64>> {
206        let mut result = identity(self.states);
207        let mut base = self.transition.clone();
208        let mut exp = n;
209        while exp > 0 {
210            if exp % 2 == 1 {
211                result = mat_mul(&result, &base);
212            }
213            base = mat_mul(&base, &base);
214            exp /= 2;
215        }
216        result
217    }
218
219    /// Empirical stationary distribution from a long simulation.
220    pub fn empirical_stationary(&self, rng: &mut Rng, steps: usize) -> Vec<f64> {
221        let path = self.simulate(rng, 0, steps);
222        let mut counts = vec![0usize; self.states];
223        for &s in &path {
224            counts[s] += 1;
225        }
226        let total = path.len() as f64;
227        counts.iter().map(|&c| c as f64 / total).collect()
228    }
229}
230
231// ---------------------------------------------------------------------------
232// Matrix helpers
233// ---------------------------------------------------------------------------
234
235fn identity(n: usize) -> Vec<Vec<f64>> {
236    let mut m = vec![vec![0.0; n]; n];
237    for i in 0..n {
238        m[i][i] = 1.0;
239    }
240    m
241}
242
243fn mat_mul(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
244    let n = a.len();
245    let p = b[0].len();
246    let k = b.len();
247    let mut c = vec![vec![0.0; p]; n];
248    for i in 0..n {
249        for j in 0..p {
250            for l in 0..k {
251                c[i][j] += a[i][l] * b[l][j];
252            }
253        }
254    }
255    c
256}
257
258fn mat_add(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
259    a.iter()
260        .zip(b.iter())
261        .map(|(ra, rb)| ra.iter().zip(rb.iter()).map(|(x, y)| x + y).collect())
262        .collect()
263}
264
265// ---------------------------------------------------------------------------
266// ContinuousTimeMarkov
267// ---------------------------------------------------------------------------
268
269/// Continuous-time Markov chain defined by a generator matrix Q.
270/// Q[i][j] >= 0 for i != j, Q[i][i] = -sum_{j!=i} Q[i][j].
271pub struct ContinuousTimeMarkov {
272    pub states: usize,
273    pub generator: Vec<Vec<f64>>,
274}
275
276impl ContinuousTimeMarkov {
277    pub fn new(generator: Vec<Vec<f64>>) -> Self {
278        let states = generator.len();
279        Self { states, generator }
280    }
281
282    /// Holding time in state i: Exp(-Q[i][i]).
283    pub fn holding_time(&self, state: usize, rng: &mut Rng) -> f64 {
284        let rate = -self.generator[state][state];
285        if rate <= 0.0 {
286            return f64::INFINITY; // absorbing state
287        }
288        let u = rng.uniform().max(1e-15);
289        -u.ln() / rate
290    }
291
292    /// Jump probability from state i to state j (given that a jump occurs).
293    pub fn jump_prob(&self, from: usize, to: usize) -> f64 {
294        if from == to {
295            return 0.0;
296        }
297        let rate = -self.generator[from][from];
298        if rate <= 0.0 {
299            return 0.0;
300        }
301        self.generator[from][to] / rate
302    }
303
304    /// Simulate the CTMC: returns Vec of (time, state).
305    pub fn simulate(&self, rng: &mut Rng, initial: usize, duration: f64) -> Vec<(f64, usize)> {
306        let mut path = Vec::new();
307        let mut t = 0.0;
308        let mut state = initial;
309        path.push((t, state));
310
311        loop {
312            let hold = self.holding_time(state, rng);
313            t += hold;
314            if t > duration {
315                break;
316            }
317            // Jump
318            let u = rng.uniform();
319            let mut cumsum = 0.0;
320            let mut next_state = state;
321            for j in 0..self.states {
322                if j == state {
323                    continue;
324                }
325                cumsum += self.jump_prob(state, j);
326                if u < cumsum {
327                    next_state = j;
328                    break;
329                }
330            }
331            state = next_state;
332            path.push((t, state));
333        }
334        path
335    }
336
337    /// Compute the embedded discrete-time chain's transition matrix.
338    pub fn embedded_chain(&self) -> MarkovChain {
339        let n = self.states;
340        let mut p = vec![vec![0.0; n]; n];
341        for i in 0..n {
342            let rate = -self.generator[i][i];
343            if rate <= 0.0 {
344                p[i][i] = 1.0; // absorbing
345            } else {
346                for j in 0..n {
347                    if j != i {
348                        p[i][j] = self.generator[i][j] / rate;
349                    }
350                }
351            }
352        }
353        MarkovChain::new(p)
354    }
355
356    /// Stationary distribution via solving pi * Q = 0 using power iteration
357    /// on the embedded chain weighted by holding times.
358    pub fn stationary_distribution(&self) -> Vec<f64> {
359        let embedded = self.embedded_chain();
360        let pi_embedded = embedded.stationary_distribution();
361        let n = self.states;
362
363        // Weight by mean holding time (1 / -Q[i][i])
364        let mut weighted = vec![0.0; n];
365        for i in 0..n {
366            let rate = -self.generator[i][i];
367            if rate > 0.0 {
368                weighted[i] = pi_embedded[i] / rate;
369            }
370        }
371        let sum: f64 = weighted.iter().sum();
372        if sum > 0.0 {
373            for w in weighted.iter_mut() {
374                *w /= sum;
375            }
376        }
377        weighted
378    }
379}
380
381// ---------------------------------------------------------------------------
382// MarkovChainRenderer
383// ---------------------------------------------------------------------------
384
385/// Render Markov chain states as nodes and transitions as weighted edges.
386pub struct MarkovChainRenderer {
387    pub node_character: char,
388    pub edge_character: char,
389    pub node_color: [f32; 4],
390    pub edge_color: [f32; 4],
391    pub radius: f32,
392}
393
394impl MarkovChainRenderer {
395    pub fn new() -> Self {
396        Self {
397            node_character: '●',
398            edge_character: '→',
399            node_color: [1.0, 0.8, 0.2, 1.0],
400            edge_color: [0.5, 0.5, 0.8, 0.6],
401            radius: 5.0,
402        }
403    }
404
405    /// Arrange states in a circle and generate glyphs.
406    pub fn render(&self, chain: &MarkovChain) -> Vec<(Vec2, char, [f32; 4])> {
407        let n = chain.states;
408        let mut glyphs = Vec::new();
409
410        // Place nodes in a circle
411        let positions: Vec<Vec2> = (0..n)
412            .map(|i| {
413                let angle = 2.0 * std::f32::consts::PI * i as f32 / n as f32;
414                Vec2::new(self.radius * angle.cos(), self.radius * angle.sin())
415            })
416            .collect();
417
418        // Draw nodes
419        for &pos in &positions {
420            glyphs.push((pos, self.node_character, self.node_color));
421        }
422
423        // Draw edges (sample points along lines for transitions with p > threshold)
424        let threshold = 0.05;
425        for i in 0..n {
426            for j in 0..n {
427                let p = chain.transition[i][j];
428                if p > threshold && i != j {
429                    let from = positions[i];
430                    let to = positions[j];
431                    let edge_steps = 5;
432                    let alpha = (p as f32).min(1.0) * self.edge_color[3];
433                    let color = [
434                        self.edge_color[0],
435                        self.edge_color[1],
436                        self.edge_color[2],
437                        alpha,
438                    ];
439                    for k in 1..edge_steps {
440                        let t = k as f32 / edge_steps as f32;
441                        let pos = from.lerp(to, t);
442                        glyphs.push((pos, self.edge_character, color));
443                    }
444                }
445            }
446        }
447
448        glyphs
449    }
450
451    /// Render a trajectory as a sequence of highlighted states over time.
452    pub fn render_trajectory(
453        &self,
454        chain: &MarkovChain,
455        trajectory: &[usize],
456    ) -> Vec<(Vec2, char, [f32; 4])> {
457        let mut glyphs = Vec::new();
458        let n = chain.states;
459        let positions: Vec<Vec2> = (0..n)
460            .map(|i| {
461                let angle = 2.0 * std::f32::consts::PI * i as f32 / n as f32;
462                Vec2::new(self.radius * angle.cos(), self.radius * angle.sin())
463            })
464            .collect();
465
466        for (step, &state) in trajectory.iter().enumerate() {
467            let alpha = (step as f32 / trajectory.len() as f32).max(0.1);
468            let color = [1.0, 0.3, 0.3, alpha];
469            let offset = Vec2::new(step as f32 * 0.02, 0.0);
470            glyphs.push((positions[state] + offset, '◆', color));
471        }
472        glyphs
473    }
474}
475
476impl Default for MarkovChainRenderer {
477    fn default() -> Self {
478        Self::new()
479    }
480}
481
482// ---------------------------------------------------------------------------
483// Tests
484// ---------------------------------------------------------------------------
485
486#[cfg(test)]
487mod tests {
488    use super::*;
489
490    fn simple_chain() -> MarkovChain {
491        // Two-state chain
492        MarkovChain::new(vec![vec![0.7, 0.3], vec![0.4, 0.6]])
493    }
494
495    #[test]
496    fn test_is_stochastic() {
497        let mc = simple_chain();
498        assert!(mc.is_stochastic());
499    }
500
501    #[test]
502    fn test_simulate_length() {
503        let mc = simple_chain();
504        let mut rng = Rng::new(42);
505        let path = mc.simulate(&mut rng, 0, 100);
506        assert_eq!(path.len(), 101);
507    }
508
509    #[test]
510    fn test_stationary_distribution_sums_to_one() {
511        let mc = simple_chain();
512        let pi = mc.stationary_distribution();
513        let sum: f64 = pi.iter().sum();
514        assert!(
515            (sum - 1.0).abs() < 1e-6,
516            "stationary distribution should sum to 1, got {}",
517            sum
518        );
519    }
520
521    #[test]
522    fn test_stationary_distribution_values() {
523        // For [[0.7, 0.3], [0.4, 0.6]]:
524        // pi = [4/7, 3/7] ≈ [0.5714, 0.4286]
525        let mc = simple_chain();
526        let pi = mc.stationary_distribution();
527        assert!(
528            (pi[0] - 4.0 / 7.0).abs() < 1e-4,
529            "pi[0] should be ~4/7, got {}",
530            pi[0]
531        );
532        assert!(
533            (pi[1] - 3.0 / 7.0).abs() < 1e-4,
534            "pi[1] should be ~3/7, got {}",
535            pi[1]
536        );
537    }
538
539    #[test]
540    fn test_irreducible() {
541        let mc = simple_chain();
542        assert!(mc.is_irreducible());
543
544        // Reducible: state 1 is absorbing
545        let reducible = MarkovChain::new(vec![vec![0.5, 0.5], vec![0.0, 1.0]]);
546        assert!(!reducible.is_irreducible());
547    }
548
549    #[test]
550    fn test_ergodic() {
551        let mc = simple_chain();
552        assert!(mc.is_ergodic());
553    }
554
555    #[test]
556    fn test_absorbing_states() {
557        let mc = MarkovChain::new(vec![
558            vec![0.5, 0.5, 0.0],
559            vec![0.0, 1.0, 0.0],
560            vec![0.3, 0.0, 0.7],
561        ]);
562        let abs = mc.absorbing_states();
563        assert_eq!(abs, vec![1]);
564    }
565
566    #[test]
567    fn test_mean_first_passage() {
568        let mc = simple_chain();
569        let mfp = mc.mean_first_passage(0, 1);
570        // Analytical: m_{0->1} = 1/0.3 = 3.333...
571        assert!(
572            (mfp - 1.0 / 0.3).abs() < 0.1,
573            "mean first passage should be ~3.33, got {}",
574            mfp
575        );
576    }
577
578    #[test]
579    fn test_power_matrix() {
580        let mc = simple_chain();
581        let p1 = mc.power(1);
582        assert!((p1[0][0] - 0.7).abs() < 1e-10);
583
584        let p2 = mc.power(2);
585        // P^2[0][0] = 0.7*0.7 + 0.3*0.4 = 0.49 + 0.12 = 0.61
586        assert!((p2[0][0] - 0.61).abs() < 1e-10);
587    }
588
589    #[test]
590    fn test_ctmc_simulation() {
591        let gen = vec![vec![-2.0, 2.0], vec![3.0, -3.0]];
592        let ctmc = ContinuousTimeMarkov::new(gen);
593        let mut rng = Rng::new(42);
594        let path = ctmc.simulate(&mut rng, 0, 10.0);
595        assert!(!path.is_empty());
596        assert_eq!(path[0], (0.0, 0));
597    }
598
599    #[test]
600    fn test_ctmc_stationary() {
601        // Q = [[-2, 2], [3, -3]]
602        // pi = [3/5, 2/5]
603        let gen = vec![vec![-2.0, 2.0], vec![3.0, -3.0]];
604        let ctmc = ContinuousTimeMarkov::new(gen);
605        let pi = ctmc.stationary_distribution();
606        assert!(
607            (pi[0] - 0.6).abs() < 0.05,
608            "CTMC pi[0] should be ~0.6, got {}",
609            pi[0]
610        );
611        assert!(
612            (pi[1] - 0.4).abs() < 0.05,
613            "CTMC pi[1] should be ~0.4, got {}",
614            pi[1]
615        );
616    }
617
618    #[test]
619    fn test_random_chain_is_stochastic() {
620        let mut rng = Rng::new(42);
621        let mc = MarkovChain::random(5, &mut rng);
622        assert!(mc.is_stochastic());
623    }
624
625    #[test]
626    fn test_renderer() {
627        let mc = simple_chain();
628        let renderer = MarkovChainRenderer::new();
629        let glyphs = renderer.render(&mc);
630        assert!(!glyphs.is_empty());
631    }
632}