1use super::brownian::Rng;
8use glam::Vec2;
9
10pub struct MarkovChain {
16 pub states: usize,
18 pub transition: Vec<Vec<f64>>,
20}
21
22impl MarkovChain {
23 pub fn new(transition: Vec<Vec<f64>>) -> Self {
25 let states = transition.len();
26 Self { states, transition }
27 }
28
29 pub fn random(states: usize, rng: &mut Rng) -> Self {
31 let mut transition = vec![vec![0.0; states]; states];
32 for row in transition.iter_mut() {
33 let raw: Vec<f64> = (0..states).map(|_| rng.uniform().max(0.01)).collect();
34 let sum: f64 = raw.iter().sum();
35 for (j, val) in raw.iter().enumerate() {
36 row[j] = val / sum;
37 }
38 }
39 Self { states, transition }
40 }
41
42 pub fn is_stochastic(&self) -> bool {
44 for row in &self.transition {
45 if row.len() != self.states {
46 return false;
47 }
48 let sum: f64 = row.iter().sum();
49 if (sum - 1.0).abs() > 1e-6 {
50 return false;
51 }
52 if row.iter().any(|&p| p < -1e-10) {
53 return false;
54 }
55 }
56 true
57 }
58
59 pub fn step(&self, rng: &mut Rng, current_state: usize) -> usize {
61 let u = rng.uniform();
62 let row = &self.transition[current_state];
63 let mut cumsum = 0.0;
64 for (j, &p) in row.iter().enumerate() {
65 cumsum += p;
66 if u < cumsum {
67 return j;
68 }
69 }
70 self.states - 1
71 }
72
73 pub fn simulate(&self, rng: &mut Rng, initial: usize, steps: usize) -> Vec<usize> {
75 let mut path = Vec::with_capacity(steps + 1);
76 path.push(initial);
77 let mut current = initial;
78 for _ in 0..steps {
79 current = self.step(rng, current);
80 path.push(current);
81 }
82 path
83 }
84
85 pub fn stationary_distribution(&self) -> Vec<f64> {
88 let n = self.states;
89 let mut pi = vec![1.0 / n as f64; n];
90 let max_iter = 10_000;
91 let tol = 1e-12;
92
93 for _ in 0..max_iter {
94 let mut next = vec![0.0; n];
95 for j in 0..n {
96 for i in 0..n {
97 next[j] += pi[i] * self.transition[i][j];
98 }
99 }
100 let sum: f64 = next.iter().sum();
102 if sum > 0.0 {
103 for v in next.iter_mut() {
104 *v /= sum;
105 }
106 }
107
108 let diff: f64 = pi.iter().zip(next.iter()).map(|(a, b)| (a - b).abs()).sum();
110 pi = next;
111 if diff < tol {
112 break;
113 }
114 }
115 pi
116 }
117
118 pub fn is_irreducible(&self) -> bool {
120 let n = self.states;
121 for start in 0..n {
123 let mut visited = vec![false; n];
124 let mut queue = std::collections::VecDeque::new();
125 queue.push_back(start);
126 visited[start] = true;
127 while let Some(s) = queue.pop_front() {
128 for (j, &p) in self.transition[s].iter().enumerate() {
129 if p > 0.0 && !visited[j] {
130 visited[j] = true;
131 queue.push_back(j);
132 }
133 }
134 }
135 if visited.iter().any(|&v| !v) {
136 return false;
137 }
138 }
139 true
140 }
141
142 pub fn is_ergodic(&self) -> bool {
146 if !self.is_irreducible() {
147 return false;
148 }
149 if self.transition.iter().enumerate().any(|(i, row)| row[i] > 0.0) {
151 return true;
152 }
153 let p2 = mat_mul(&self.transition, &self.transition);
155 let p3 = mat_mul(&p2, &self.transition);
156 let combined = mat_add(&p2, &p3);
157 combined.iter().all(|row| row.iter().all(|&v| v > 1e-15))
158 }
159
160 pub fn absorbing_states(&self) -> Vec<usize> {
162 (0..self.states)
163 .filter(|&i| (self.transition[i][i] - 1.0).abs() < 1e-10)
164 .collect()
165 }
166
167 pub fn mean_first_passage(&self, from: usize, to: usize) -> f64 {
171 if from == to {
172 return 0.0;
173 }
174 let n = self.states;
175 let mut m = vec![0.0; n];
176 let max_iter = 50_000;
177 let tol = 1e-10;
178
179 for _ in 0..max_iter {
180 let mut new_m = vec![0.0; n];
181 let mut max_diff = 0.0_f64;
182 for i in 0..n {
183 if i == to {
184 new_m[i] = 0.0;
185 continue;
186 }
187 let mut val = 1.0;
188 for j in 0..n {
189 if j != to {
190 val += self.transition[i][j] * m[j];
191 }
192 }
193 new_m[i] = val;
194 max_diff = max_diff.max((new_m[i] - m[i]).abs());
195 }
196 m = new_m;
197 if max_diff < tol {
198 break;
199 }
200 }
201 m[from]
202 }
203
204 pub fn power(&self, n: usize) -> Vec<Vec<f64>> {
206 let mut result = identity(self.states);
207 let mut base = self.transition.clone();
208 let mut exp = n;
209 while exp > 0 {
210 if exp % 2 == 1 {
211 result = mat_mul(&result, &base);
212 }
213 base = mat_mul(&base, &base);
214 exp /= 2;
215 }
216 result
217 }
218
219 pub fn empirical_stationary(&self, rng: &mut Rng, steps: usize) -> Vec<f64> {
221 let path = self.simulate(rng, 0, steps);
222 let mut counts = vec![0usize; self.states];
223 for &s in &path {
224 counts[s] += 1;
225 }
226 let total = path.len() as f64;
227 counts.iter().map(|&c| c as f64 / total).collect()
228 }
229}
230
231fn identity(n: usize) -> Vec<Vec<f64>> {
236 let mut m = vec![vec![0.0; n]; n];
237 for i in 0..n {
238 m[i][i] = 1.0;
239 }
240 m
241}
242
243fn mat_mul(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
244 let n = a.len();
245 let p = b[0].len();
246 let k = b.len();
247 let mut c = vec![vec![0.0; p]; n];
248 for i in 0..n {
249 for j in 0..p {
250 for l in 0..k {
251 c[i][j] += a[i][l] * b[l][j];
252 }
253 }
254 }
255 c
256}
257
258fn mat_add(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
259 a.iter()
260 .zip(b.iter())
261 .map(|(ra, rb)| ra.iter().zip(rb.iter()).map(|(x, y)| x + y).collect())
262 .collect()
263}
264
265pub struct ContinuousTimeMarkov {
272 pub states: usize,
273 pub generator: Vec<Vec<f64>>,
274}
275
276impl ContinuousTimeMarkov {
277 pub fn new(generator: Vec<Vec<f64>>) -> Self {
278 let states = generator.len();
279 Self { states, generator }
280 }
281
282 pub fn holding_time(&self, state: usize, rng: &mut Rng) -> f64 {
284 let rate = -self.generator[state][state];
285 if rate <= 0.0 {
286 return f64::INFINITY; }
288 let u = rng.uniform().max(1e-15);
289 -u.ln() / rate
290 }
291
292 pub fn jump_prob(&self, from: usize, to: usize) -> f64 {
294 if from == to {
295 return 0.0;
296 }
297 let rate = -self.generator[from][from];
298 if rate <= 0.0 {
299 return 0.0;
300 }
301 self.generator[from][to] / rate
302 }
303
304 pub fn simulate(&self, rng: &mut Rng, initial: usize, duration: f64) -> Vec<(f64, usize)> {
306 let mut path = Vec::new();
307 let mut t = 0.0;
308 let mut state = initial;
309 path.push((t, state));
310
311 loop {
312 let hold = self.holding_time(state, rng);
313 t += hold;
314 if t > duration {
315 break;
316 }
317 let u = rng.uniform();
319 let mut cumsum = 0.0;
320 let mut next_state = state;
321 for j in 0..self.states {
322 if j == state {
323 continue;
324 }
325 cumsum += self.jump_prob(state, j);
326 if u < cumsum {
327 next_state = j;
328 break;
329 }
330 }
331 state = next_state;
332 path.push((t, state));
333 }
334 path
335 }
336
337 pub fn embedded_chain(&self) -> MarkovChain {
339 let n = self.states;
340 let mut p = vec![vec![0.0; n]; n];
341 for i in 0..n {
342 let rate = -self.generator[i][i];
343 if rate <= 0.0 {
344 p[i][i] = 1.0; } else {
346 for j in 0..n {
347 if j != i {
348 p[i][j] = self.generator[i][j] / rate;
349 }
350 }
351 }
352 }
353 MarkovChain::new(p)
354 }
355
356 pub fn stationary_distribution(&self) -> Vec<f64> {
359 let embedded = self.embedded_chain();
360 let pi_embedded = embedded.stationary_distribution();
361 let n = self.states;
362
363 let mut weighted = vec![0.0; n];
365 for i in 0..n {
366 let rate = -self.generator[i][i];
367 if rate > 0.0 {
368 weighted[i] = pi_embedded[i] / rate;
369 }
370 }
371 let sum: f64 = weighted.iter().sum();
372 if sum > 0.0 {
373 for w in weighted.iter_mut() {
374 *w /= sum;
375 }
376 }
377 weighted
378 }
379}
380
381pub struct MarkovChainRenderer {
387 pub node_character: char,
388 pub edge_character: char,
389 pub node_color: [f32; 4],
390 pub edge_color: [f32; 4],
391 pub radius: f32,
392}
393
394impl MarkovChainRenderer {
395 pub fn new() -> Self {
396 Self {
397 node_character: '●',
398 edge_character: '→',
399 node_color: [1.0, 0.8, 0.2, 1.0],
400 edge_color: [0.5, 0.5, 0.8, 0.6],
401 radius: 5.0,
402 }
403 }
404
405 pub fn render(&self, chain: &MarkovChain) -> Vec<(Vec2, char, [f32; 4])> {
407 let n = chain.states;
408 let mut glyphs = Vec::new();
409
410 let positions: Vec<Vec2> = (0..n)
412 .map(|i| {
413 let angle = 2.0 * std::f32::consts::PI * i as f32 / n as f32;
414 Vec2::new(self.radius * angle.cos(), self.radius * angle.sin())
415 })
416 .collect();
417
418 for &pos in &positions {
420 glyphs.push((pos, self.node_character, self.node_color));
421 }
422
423 let threshold = 0.05;
425 for i in 0..n {
426 for j in 0..n {
427 let p = chain.transition[i][j];
428 if p > threshold && i != j {
429 let from = positions[i];
430 let to = positions[j];
431 let edge_steps = 5;
432 let alpha = (p as f32).min(1.0) * self.edge_color[3];
433 let color = [
434 self.edge_color[0],
435 self.edge_color[1],
436 self.edge_color[2],
437 alpha,
438 ];
439 for k in 1..edge_steps {
440 let t = k as f32 / edge_steps as f32;
441 let pos = from.lerp(to, t);
442 glyphs.push((pos, self.edge_character, color));
443 }
444 }
445 }
446 }
447
448 glyphs
449 }
450
451 pub fn render_trajectory(
453 &self,
454 chain: &MarkovChain,
455 trajectory: &[usize],
456 ) -> Vec<(Vec2, char, [f32; 4])> {
457 let mut glyphs = Vec::new();
458 let n = chain.states;
459 let positions: Vec<Vec2> = (0..n)
460 .map(|i| {
461 let angle = 2.0 * std::f32::consts::PI * i as f32 / n as f32;
462 Vec2::new(self.radius * angle.cos(), self.radius * angle.sin())
463 })
464 .collect();
465
466 for (step, &state) in trajectory.iter().enumerate() {
467 let alpha = (step as f32 / trajectory.len() as f32).max(0.1);
468 let color = [1.0, 0.3, 0.3, alpha];
469 let offset = Vec2::new(step as f32 * 0.02, 0.0);
470 glyphs.push((positions[state] + offset, '◆', color));
471 }
472 glyphs
473 }
474}
475
476impl Default for MarkovChainRenderer {
477 fn default() -> Self {
478 Self::new()
479 }
480}
481
482#[cfg(test)]
487mod tests {
488 use super::*;
489
490 fn simple_chain() -> MarkovChain {
491 MarkovChain::new(vec![vec![0.7, 0.3], vec![0.4, 0.6]])
493 }
494
495 #[test]
496 fn test_is_stochastic() {
497 let mc = simple_chain();
498 assert!(mc.is_stochastic());
499 }
500
501 #[test]
502 fn test_simulate_length() {
503 let mc = simple_chain();
504 let mut rng = Rng::new(42);
505 let path = mc.simulate(&mut rng, 0, 100);
506 assert_eq!(path.len(), 101);
507 }
508
509 #[test]
510 fn test_stationary_distribution_sums_to_one() {
511 let mc = simple_chain();
512 let pi = mc.stationary_distribution();
513 let sum: f64 = pi.iter().sum();
514 assert!(
515 (sum - 1.0).abs() < 1e-6,
516 "stationary distribution should sum to 1, got {}",
517 sum
518 );
519 }
520
521 #[test]
522 fn test_stationary_distribution_values() {
523 let mc = simple_chain();
526 let pi = mc.stationary_distribution();
527 assert!(
528 (pi[0] - 4.0 / 7.0).abs() < 1e-4,
529 "pi[0] should be ~4/7, got {}",
530 pi[0]
531 );
532 assert!(
533 (pi[1] - 3.0 / 7.0).abs() < 1e-4,
534 "pi[1] should be ~3/7, got {}",
535 pi[1]
536 );
537 }
538
539 #[test]
540 fn test_irreducible() {
541 let mc = simple_chain();
542 assert!(mc.is_irreducible());
543
544 let reducible = MarkovChain::new(vec![vec![0.5, 0.5], vec![0.0, 1.0]]);
546 assert!(!reducible.is_irreducible());
547 }
548
549 #[test]
550 fn test_ergodic() {
551 let mc = simple_chain();
552 assert!(mc.is_ergodic());
553 }
554
555 #[test]
556 fn test_absorbing_states() {
557 let mc = MarkovChain::new(vec![
558 vec![0.5, 0.5, 0.0],
559 vec![0.0, 1.0, 0.0],
560 vec![0.3, 0.0, 0.7],
561 ]);
562 let abs = mc.absorbing_states();
563 assert_eq!(abs, vec![1]);
564 }
565
566 #[test]
567 fn test_mean_first_passage() {
568 let mc = simple_chain();
569 let mfp = mc.mean_first_passage(0, 1);
570 assert!(
572 (mfp - 1.0 / 0.3).abs() < 0.1,
573 "mean first passage should be ~3.33, got {}",
574 mfp
575 );
576 }
577
578 #[test]
579 fn test_power_matrix() {
580 let mc = simple_chain();
581 let p1 = mc.power(1);
582 assert!((p1[0][0] - 0.7).abs() < 1e-10);
583
584 let p2 = mc.power(2);
585 assert!((p2[0][0] - 0.61).abs() < 1e-10);
587 }
588
589 #[test]
590 fn test_ctmc_simulation() {
591 let gen = vec![vec![-2.0, 2.0], vec![3.0, -3.0]];
592 let ctmc = ContinuousTimeMarkov::new(gen);
593 let mut rng = Rng::new(42);
594 let path = ctmc.simulate(&mut rng, 0, 10.0);
595 assert!(!path.is_empty());
596 assert_eq!(path[0], (0.0, 0));
597 }
598
599 #[test]
600 fn test_ctmc_stationary() {
601 let gen = vec![vec![-2.0, 2.0], vec![3.0, -3.0]];
604 let ctmc = ContinuousTimeMarkov::new(gen);
605 let pi = ctmc.stationary_distribution();
606 assert!(
607 (pi[0] - 0.6).abs() < 0.05,
608 "CTMC pi[0] should be ~0.6, got {}",
609 pi[0]
610 );
611 assert!(
612 (pi[1] - 0.4).abs() < 0.05,
613 "CTMC pi[1] should be ~0.4, got {}",
614 pi[1]
615 );
616 }
617
618 #[test]
619 fn test_random_chain_is_stochastic() {
620 let mut rng = Rng::new(42);
621 let mc = MarkovChain::random(5, &mut rng);
622 assert!(mc.is_stochastic());
623 }
624
625 #[test]
626 fn test_renderer() {
627 let mc = simple_chain();
628 let renderer = MarkovChainRenderer::new();
629 let glyphs = renderer.render(&mc);
630 assert!(!glyphs.is_empty());
631 }
632}