1use crate::error::{SeqError, SeqResult};
4
5#[derive(Debug, Clone)]
11pub struct Mrf {
12 pub n_nodes: usize,
13 pub n_labels: usize,
14 pub edges: Vec<(usize, usize)>,
15 pub unary: Vec<f64>,
16 pub pairwise: Vec<f64>,
17}
18
19impl Mrf {
20 pub fn new(
22 n_nodes: usize,
23 n_labels: usize,
24 edges: Vec<(usize, usize)>,
25 unary: Vec<f64>,
26 pairwise: Vec<f64>,
27 ) -> SeqResult<Self> {
28 if n_nodes == 0 || n_labels == 0 {
29 return Err(SeqError::InvalidConfiguration(
30 "n_nodes and n_labels must be > 0".to_string(),
31 ));
32 }
33 if unary.len() != n_nodes * n_labels {
34 return Err(SeqError::ShapeMismatch {
35 expected: n_nodes * n_labels,
36 got: unary.len(),
37 });
38 }
39 if pairwise.len() != edges.len() * n_labels * n_labels {
40 return Err(SeqError::ShapeMismatch {
41 expected: edges.len() * n_labels * n_labels,
42 got: pairwise.len(),
43 });
44 }
45 for &(u, v) in &edges {
46 if u >= n_nodes || v >= n_nodes || u == v {
47 return Err(SeqError::GraphInvariantViolated(format!(
48 "edge ({u}, {v}) invalid for n_nodes={n_nodes}"
49 )));
50 }
51 }
52 Ok(Self {
53 n_nodes,
54 n_labels,
55 edges,
56 unary,
57 pairwise,
58 })
59 }
60
61 pub fn energy(&self, labels: &[usize]) -> SeqResult<f64> {
63 if labels.len() != self.n_nodes {
64 return Err(SeqError::ShapeMismatch {
65 expected: self.n_nodes,
66 got: labels.len(),
67 });
68 }
69 let mut e = 0.0;
70 for (i, &l) in labels.iter().enumerate() {
71 if l >= self.n_labels {
72 return Err(SeqError::IndexOutOfBounds {
73 index: l,
74 len: self.n_labels,
75 });
76 }
77 e += self.unary[i * self.n_labels + l];
78 }
79 let l2 = self.n_labels * self.n_labels;
80 for (e_idx, &(u, v)) in self.edges.iter().enumerate() {
81 let lu = labels[u];
82 let lv = labels[v];
83 e += self.pairwise[e_idx * l2 + lu * self.n_labels + lv];
84 }
85 Ok(e)
86 }
87}
88
89#[derive(Debug, Clone)]
93pub struct IsingModel {
94 pub n_rows: usize,
95 pub n_cols: usize,
96 pub field: f64,
97 pub coupling: f64,
98 pub beta: f64,
99}
100
101impl IsingModel {
102 pub fn new(
104 n_rows: usize,
105 n_cols: usize,
106 field: f64,
107 coupling: f64,
108 beta: f64,
109 ) -> SeqResult<Self> {
110 if n_rows == 0 || n_cols == 0 {
111 return Err(SeqError::InvalidConfiguration(
112 "grid dims must be > 0".to_string(),
113 ));
114 }
115 if beta <= 0.0 || !beta.is_finite() {
116 return Err(SeqError::InvalidParameter {
117 name: "beta".to_string(),
118 value: beta,
119 });
120 }
121 Ok(Self {
122 n_rows,
123 n_cols,
124 field,
125 coupling,
126 beta,
127 })
128 }
129
130 pub fn energy(&self, spins: &[i32]) -> SeqResult<f64> {
132 if spins.len() != self.n_rows * self.n_cols {
133 return Err(SeqError::ShapeMismatch {
134 expected: self.n_rows * self.n_cols,
135 got: spins.len(),
136 });
137 }
138 let mut e = 0.0;
139 for r in 0..self.n_rows {
140 for c in 0..self.n_cols {
141 let s = spins[r * self.n_cols + c] as f64;
142 e -= self.field * s;
143 if r + 1 < self.n_rows {
144 let s2 = spins[(r + 1) * self.n_cols + c] as f64;
145 e -= self.coupling * s * s2;
146 }
147 if c + 1 < self.n_cols {
148 let s2 = spins[r * self.n_cols + (c + 1)] as f64;
149 e -= self.coupling * s * s2;
150 }
151 }
152 }
153 Ok(e)
154 }
155
156 pub fn magnetisation(&self, spins: &[i32]) -> SeqResult<f64> {
158 if spins.len() != self.n_rows * self.n_cols {
159 return Err(SeqError::ShapeMismatch {
160 expected: self.n_rows * self.n_cols,
161 got: spins.len(),
162 });
163 }
164 let s: i64 = spins.iter().map(|&x| x as i64).sum();
165 Ok(s as f64 / spins.len() as f64)
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 #[test]
174 fn mrf_construct_and_energy() {
175 let m = Mrf::new(
176 3,
177 2,
178 vec![(0, 1), (1, 2)],
179 vec![0.0, 1.0, 1.0, 0.0, 0.0, 1.0],
180 vec![0.0, 0.5, 0.5, 0.0, 0.0, 0.5, 0.5, 0.0],
181 )
182 .expect("ok");
183 let e = m.energy(&[0, 0, 0]).expect("ok");
184 assert!(e.is_finite());
185 }
186
187 #[test]
188 fn ising_all_up_magnetisation_one() {
189 let m = IsingModel::new(3, 3, 0.0, 1.0, 1.0).expect("ok");
190 let spins = vec![1i32; 9];
191 assert!((m.magnetisation(&spins).expect("ok") - 1.0).abs() < 1e-12);
192 }
193}