Skip to main content

oxicuda_seq/mrf/
mrf.rs

1//! Pairwise MRF on a generic graph plus an Ising-model specialisation.
2
3use crate::error::{SeqError, SeqResult};
4
5/// Pairwise Markov Random Field on an undirected graph.
6///
7/// * `unary[i*n_labels + l]` — local cost on node `i` for label `l`
8/// * `pairwise[(edge_idx)*n_labels*n_labels + l_i*n_labels + l_j]` — pairwise
9///   compatibility on `edge[edge_idx] = (u, v)`.
10#[derive(Debug, Clone)]
11pub struct Mrf {
12    pub n_nodes: usize,
13    pub n_labels: usize,
14    pub edges: Vec<(usize, usize)>,
15    pub unary: Vec<f64>,
16    pub pairwise: Vec<f64>,
17}
18
19impl Mrf {
20    /// Build a pairwise MRF, validating shapes.
21    pub fn new(
22        n_nodes: usize,
23        n_labels: usize,
24        edges: Vec<(usize, usize)>,
25        unary: Vec<f64>,
26        pairwise: Vec<f64>,
27    ) -> SeqResult<Self> {
28        if n_nodes == 0 || n_labels == 0 {
29            return Err(SeqError::InvalidConfiguration(
30                "n_nodes and n_labels must be > 0".to_string(),
31            ));
32        }
33        if unary.len() != n_nodes * n_labels {
34            return Err(SeqError::ShapeMismatch {
35                expected: n_nodes * n_labels,
36                got: unary.len(),
37            });
38        }
39        if pairwise.len() != edges.len() * n_labels * n_labels {
40            return Err(SeqError::ShapeMismatch {
41                expected: edges.len() * n_labels * n_labels,
42                got: pairwise.len(),
43            });
44        }
45        for &(u, v) in &edges {
46            if u >= n_nodes || v >= n_nodes || u == v {
47                return Err(SeqError::GraphInvariantViolated(format!(
48                    "edge ({u}, {v}) invalid for n_nodes={n_nodes}"
49                )));
50            }
51        }
52        Ok(Self {
53            n_nodes,
54            n_labels,
55            edges,
56            unary,
57            pairwise,
58        })
59    }
60
61    /// Energy of an assignment.
62    pub fn energy(&self, labels: &[usize]) -> SeqResult<f64> {
63        if labels.len() != self.n_nodes {
64            return Err(SeqError::ShapeMismatch {
65                expected: self.n_nodes,
66                got: labels.len(),
67            });
68        }
69        let mut e = 0.0;
70        for (i, &l) in labels.iter().enumerate() {
71            if l >= self.n_labels {
72                return Err(SeqError::IndexOutOfBounds {
73                    index: l,
74                    len: self.n_labels,
75                });
76            }
77            e += self.unary[i * self.n_labels + l];
78        }
79        let l2 = self.n_labels * self.n_labels;
80        for (e_idx, &(u, v)) in self.edges.iter().enumerate() {
81            let lu = labels[u];
82            let lv = labels[v];
83            e += self.pairwise[e_idx * l2 + lu * self.n_labels + lv];
84        }
85        Ok(e)
86    }
87}
88
89/// Ising model on a regular 2-D grid (4-connected).
90///
91/// Energy: `E(s) = − Σ_i h s_i − Σ_(i,j) J s_i s_j`, with `s_i ∈ {−1, +1}`.
92#[derive(Debug, Clone)]
93pub struct IsingModel {
94    pub n_rows: usize,
95    pub n_cols: usize,
96    pub field: f64,
97    pub coupling: f64,
98    pub beta: f64,
99}
100
101impl IsingModel {
102    /// Construct an Ising model.
103    pub fn new(
104        n_rows: usize,
105        n_cols: usize,
106        field: f64,
107        coupling: f64,
108        beta: f64,
109    ) -> SeqResult<Self> {
110        if n_rows == 0 || n_cols == 0 {
111            return Err(SeqError::InvalidConfiguration(
112                "grid dims must be > 0".to_string(),
113            ));
114        }
115        if beta <= 0.0 || !beta.is_finite() {
116            return Err(SeqError::InvalidParameter {
117                name: "beta".to_string(),
118                value: beta,
119            });
120        }
121        Ok(Self {
122            n_rows,
123            n_cols,
124            field,
125            coupling,
126            beta,
127        })
128    }
129
130    /// Total energy of an Ising spin configuration (`±1` values).
131    pub fn energy(&self, spins: &[i32]) -> SeqResult<f64> {
132        if spins.len() != self.n_rows * self.n_cols {
133            return Err(SeqError::ShapeMismatch {
134                expected: self.n_rows * self.n_cols,
135                got: spins.len(),
136            });
137        }
138        let mut e = 0.0;
139        for r in 0..self.n_rows {
140            for c in 0..self.n_cols {
141                let s = spins[r * self.n_cols + c] as f64;
142                e -= self.field * s;
143                if r + 1 < self.n_rows {
144                    let s2 = spins[(r + 1) * self.n_cols + c] as f64;
145                    e -= self.coupling * s * s2;
146                }
147                if c + 1 < self.n_cols {
148                    let s2 = spins[r * self.n_cols + (c + 1)] as f64;
149                    e -= self.coupling * s * s2;
150                }
151            }
152        }
153        Ok(e)
154    }
155
156    /// Mean magnetisation `(1/N) Σ s_i`.
157    pub fn magnetisation(&self, spins: &[i32]) -> SeqResult<f64> {
158        if spins.len() != self.n_rows * self.n_cols {
159            return Err(SeqError::ShapeMismatch {
160                expected: self.n_rows * self.n_cols,
161                got: spins.len(),
162            });
163        }
164        let s: i64 = spins.iter().map(|&x| x as i64).sum();
165        Ok(s as f64 / spins.len() as f64)
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    #[test]
174    fn mrf_construct_and_energy() {
175        let m = Mrf::new(
176            3,
177            2,
178            vec![(0, 1), (1, 2)],
179            vec![0.0, 1.0, 1.0, 0.0, 0.0, 1.0],
180            vec![0.0, 0.5, 0.5, 0.0, 0.0, 0.5, 0.5, 0.0],
181        )
182        .expect("ok");
183        let e = m.energy(&[0, 0, 0]).expect("ok");
184        assert!(e.is_finite());
185    }
186
187    #[test]
188    fn ising_all_up_magnetisation_one() {
189        let m = IsingModel::new(3, 3, 0.0, 1.0, 1.0).expect("ok");
190        let spins = vec![1i32; 9];
191        assert!((m.magnetisation(&spins).expect("ok") - 1.0).abs() < 1e-12);
192    }
193}