Skip to main content

causal_triangulations/cdt/
metropolis.rs

1//! Metropolis-Hastings algorithm for Causal Dynamical Triangulations.
2//!
3//! This module implements the Monte Carlo sampling algorithm used to sample
4//! triangulation configurations according to the CDT path integral measure.
5
6use crate::cdt::action::ActionConfig;
7use crate::cdt::ergodic_moves::{ErgodicsSystem, MoveType};
8use crate::geometry::traits::TriangulationQuery;
9use num_traits::cast::NumCast;
10use std::time::Instant;
11
12// Test utilities are now handled through backend-agnostic CdtTriangulation::new
13
14/// Configuration for the Metropolis-Hastings algorithm.
15#[derive(Debug, Clone)]
16pub struct MetropolisConfig {
17    /// Temperature parameter (1/β)
18    pub temperature: f64,
19    /// Number of Monte Carlo steps to perform
20    pub steps: u32,
21    /// Number of thermalization steps before measurements
22    pub thermalization_steps: u32,
23    /// Frequency of measurements (take measurement every N steps)
24    pub measurement_frequency: u32,
25}
26
27impl Default for MetropolisConfig {
28    /// Default Metropolis configuration for 2D CDT.
29    fn default() -> Self {
30        Self {
31            temperature: 1.0,
32            steps: 1000,
33            thermalization_steps: 100,
34            measurement_frequency: 10,
35        }
36    }
37}
38
39impl MetropolisConfig {
40    /// Creates a new Metropolis configuration.
41    #[must_use]
42    pub const fn new(
43        temperature: f64,
44        steps: u32,
45        thermalization_steps: u32,
46        measurement_frequency: u32,
47    ) -> Self {
48        Self {
49            temperature,
50            steps,
51            thermalization_steps,
52            measurement_frequency,
53        }
54    }
55
56    /// Returns the inverse temperature (β = 1/T).
57    #[must_use]
58    pub fn beta(&self) -> f64 {
59        1.0 / self.temperature
60    }
61}
62
63/// Result of a Monte Carlo step.
64#[derive(Debug, Clone)]
65pub struct MonteCarloStep {
66    /// Step number
67    pub step: u32,
68    /// Move type attempted
69    pub move_type: MoveType,
70    /// Whether the move was accepted
71    pub accepted: bool,
72    /// Action before the move
73    pub action_before: f64,
74    /// Action after the move (if accepted)
75    pub action_after: Option<f64>,
76    /// Change in action (ΔS)
77    pub delta_action: Option<f64>,
78}
79
80/// Measurement data collected during simulation.
81#[derive(Debug, Clone)]
82pub struct Measurement {
83    /// Monte Carlo step when measurement was taken
84    pub step: u32,
85    /// Current action value
86    pub action: f64,
87    /// Number of vertices
88    pub vertices: u32,
89    /// Number of edges
90    pub edges: u32,
91    /// Number of triangles
92    pub triangles: u32,
93}
94
95/// Metropolis-Hastings algorithm implementation for CDT.
96///
97/// This implementation works with both the legacy Tds-based approach
98/// and the new trait-based geometry backends.
99pub struct MetropolisAlgorithm {
100    /// Algorithm configuration
101    config: MetropolisConfig,
102    /// Action calculation configuration
103    action_config: ActionConfig,
104    /// Ergodic moves system
105    ergodics: ErgodicsSystem,
106}
107
108impl MetropolisAlgorithm {
109    /// Creates a new Metropolis algorithm instance.
110    #[must_use]
111    pub fn new(config: MetropolisConfig, action_config: ActionConfig) -> Self {
112        Self {
113            config,
114            action_config,
115            ergodics: ErgodicsSystem::new(),
116        }
117    }
118
119    /// Run the Monte Carlo simulation.
120    ///
121    /// This runs the Metropolis-Hastings algorithm on the given triangulation.
122    pub fn run(
123        &mut self,
124        triangulation: crate::geometry::CdtTriangulation2D,
125    ) -> SimulationResultsBackend {
126        let start_time = Instant::now();
127        let mut steps = Vec::new();
128        let mut measurements = Vec::new();
129
130        log::info!("Starting Metropolis-Hastings simulation with new backend...");
131        log::info!("Temperature: {}", self.config.temperature);
132        log::info!("Total steps: {}", self.config.steps);
133        log::info!("Thermalization steps: {}", self.config.thermalization_steps);
134
135        // Calculate initial action
136        let geometry = triangulation.geometry();
137        let current_action = self.action_config.calculate_action(
138            u32::try_from(geometry.vertex_count()).unwrap_or_default(),
139            u32::try_from(geometry.edge_count()).unwrap_or_default(),
140            u32::try_from(geometry.face_count()).unwrap_or_default(),
141        );
142
143        for step_num in 0..self.config.steps {
144            // For now, just simulate the step without actual moves
145            // TODO: Implement ergodic moves for trait-based backends
146            let move_type = self.ergodics.select_random_move();
147
148            let mc_step = MonteCarloStep {
149                step: step_num,
150                move_type,
151                accepted: false,
152                action_before: current_action,
153                action_after: None,
154                delta_action: None,
155            };
156
157            steps.push(mc_step);
158
159            // Take measurement if needed
160            if step_num % self.config.measurement_frequency == 0 {
161                let measurement = Measurement {
162                    step: step_num,
163                    action: current_action,
164                    vertices: u32::try_from(geometry.vertex_count()).unwrap_or_default(),
165                    edges: u32::try_from(geometry.edge_count()).unwrap_or_default(),
166                    triangles: u32::try_from(geometry.face_count()).unwrap_or_default(),
167                };
168                measurements.push(measurement);
169            }
170
171            // Progress reporting
172            if step_num % 100 == 0 {
173                log::debug!(
174                    "Step {}/{}, Action: {:.3}",
175                    step_num,
176                    self.config.steps,
177                    current_action
178                );
179            }
180        }
181
182        let elapsed_time = start_time.elapsed();
183        log::info!("Simulation completed in {elapsed_time:.2?}");
184
185        SimulationResultsBackend {
186            config: self.config.clone(),
187            action_config: self.action_config.clone(),
188            steps,
189            measurements,
190            elapsed_time,
191            triangulation,
192        }
193    }
194}
195
196/// Results from a simulation using the new backend system.
197#[derive(Debug)]
198pub struct SimulationResultsBackend {
199    /// Configuration used for the simulation
200    pub config: MetropolisConfig,
201    /// Action configuration used
202    pub action_config: ActionConfig,
203    /// All Monte Carlo steps performed
204    pub steps: Vec<MonteCarloStep>,
205    /// Measurements taken during simulation
206    pub measurements: Vec<Measurement>,
207    /// Total simulation time
208    pub elapsed_time: std::time::Duration,
209    /// Final triangulation state
210    pub triangulation: crate::geometry::CdtTriangulation2D,
211}
212
213impl SimulationResultsBackend {
214    /// Calculates the acceptance rate for the simulation.
215    #[must_use]
216    pub fn acceptance_rate(&self) -> f64 {
217        if self.steps.is_empty() {
218            return 0.0;
219        }
220
221        let accepted_count = self.steps.iter().filter(|step| step.accepted).count();
222        let total_count = self.steps.len();
223
224        let accepted_f64 = NumCast::from(accepted_count).unwrap_or(0.0);
225        let total_f64 = NumCast::from(total_count).unwrap_or(1.0);
226
227        accepted_f64 / total_f64
228    }
229
230    /// Calculates the average action over all measurements.
231    #[must_use]
232    pub fn average_action(&self) -> f64 {
233        if self.measurements.is_empty() {
234            return 0.0;
235        }
236
237        let sum: f64 = self.measurements.iter().map(|m| m.action).sum();
238        let count = self.measurements.len();
239
240        let count_f64 = NumCast::from(count).unwrap_or(1.0);
241
242        sum / count_f64
243    }
244
245    /// Returns measurements after thermalization.
246    #[must_use]
247    pub fn equilibrium_measurements(&self) -> Vec<&Measurement> {
248        self.measurements
249            .iter()
250            .filter(|m| m.step >= self.config.thermalization_steps)
251            .collect()
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use crate::cdt::triangulation::CdtTriangulation;
259    use crate::geometry::traits::TriangulationQuery;
260    use approx::assert_relative_eq;
261
262    #[test]
263    fn test_metropolis_config() {
264        let config = MetropolisConfig::new(2.0, 500, 50, 5);
265        assert_relative_eq!(config.temperature, 2.0);
266        assert_relative_eq!(config.beta(), 0.5);
267        assert_eq!(config.steps, 500);
268    }
269
270    #[test]
271    fn test_backend_vertex_and_edge_counting() {
272        // Use fixed seed
273        const TRIANGULATION_SEED: u64 = 53;
274
275        let triangulation = CdtTriangulation::from_seeded_points(5, 1, 2, TRIANGULATION_SEED)
276            .expect("Failed to create triangulation with fixed seed");
277        let geometry = triangulation.geometry();
278
279        // We intentionally do NOT rely on the upstream deep validation here, since it can be flaky
280        // for some generated point sets. Backend-level validity means the triangulation is
281        // structurally usable by this crate (counts and iterators behave as expected).
282        assert!(
283            geometry.is_valid(),
284            "Triangulation should be structurally valid for backend queries"
285        );
286
287        // Ensure the backend exposes the expected simplex counts.
288        assert_eq!(
289            geometry.vertex_count(),
290            5,
291            "Vertex count should match requested seeded generation"
292        );
293        assert!(geometry.edge_count() > 0, "Should have edges");
294        assert!(geometry.face_count() > 0, "Should have faces");
295    }
296
297    #[test]
298    fn test_action_calculation() {
299        let triangulation =
300            CdtTriangulation::from_random_points(5, 1, 2).expect("Failed to create triangulation");
301
302        let config = MetropolisConfig::default();
303        let action_config = ActionConfig::default();
304        let _algorithm = MetropolisAlgorithm::new(config, action_config.clone());
305
306        let geometry = triangulation.geometry();
307        let action = action_config.calculate_action(
308            u32::try_from(geometry.vertex_count()).unwrap_or_default(),
309            u32::try_from(geometry.edge_count()).unwrap_or_default(),
310            u32::try_from(geometry.face_count()).unwrap_or_default(),
311        );
312
313        // Since we're using a random triangulation, just verify it returns a finite value
314        assert!(action.is_finite());
315    }
316
317    #[test]
318    fn test_simulation_results() {
319        let config = MetropolisConfig::default();
320        let measurements = vec![
321            Measurement {
322                step: 0,
323                action: 1.0,
324                vertices: 3,
325                edges: 3,
326                triangles: 1,
327            },
328            Measurement {
329                step: 10,
330                action: 2.0,
331                vertices: 4,
332                edges: 5,
333                triangles: 2,
334            },
335        ];
336
337        let triangulation =
338            CdtTriangulation::from_random_points(3, 1, 2).expect("Failed to create triangulation");
339
340        let results = SimulationResultsBackend {
341            config,
342            action_config: ActionConfig::default(),
343            steps: vec![],
344            measurements,
345            elapsed_time: std::time::Duration::from_millis(100),
346            triangulation,
347        };
348
349        assert_relative_eq!(results.average_action(), 1.5);
350    }
351}