Skip to main content

uff_relax/
optimizer.rs

1use crate::forcefield::System;
2use glam::DVec3;
3
4/// Optimizer for molecular structures using the FIRE (Fast Iterative Relaxation Engine) algorithm.
5pub struct UffOptimizer {
6    /// Maximum number of iterations to perform.
7    pub max_iterations: usize,
8    /// Threshold for the maximum force on any atom (kcal/mol/Å).
9    pub force_threshold: f64,
10    /// Whether to print optimization progress to stdout.
11    pub verbose: bool,
12    /// Number of threads to use. 0 means automatic based on system size.
13    pub num_threads: usize,
14    /// Cutoff distance for non-bonded interactions (Å).
15    pub cutoff: f64,
16    /// Number of steps to average for convergence criteria.
17    pub history_size: usize,
18}
19
20impl UffOptimizer {
21    /// Creates a new optimizer with default settings.
22    ///
23    /// # Arguments
24    /// * `max_iterations` - Maximum number of steps.
25    /// * `force_threshold` - Convergence threshold for forces.
26    pub fn new(max_iterations: usize, force_threshold: f64) -> Self {
27        Self {
28            max_iterations,
29            force_threshold,
30            verbose: false,
31            num_threads: 0,
32            cutoff: 6.0,
33            history_size: 10,
34        }
35    }
36
37    pub fn with_num_threads(mut self, num_threads: usize) -> Self {
38        self.num_threads = num_threads;
39        self
40    }
41
42    pub fn with_cutoff(mut self, cutoff: f64) -> Self {
43        self.cutoff = cutoff;
44        self
45    }
46
47    pub fn with_history_size(mut self, size: usize) -> Self {
48        self.history_size = size;
49        self
50    }
51
52    pub fn with_verbose(mut self, verbose: bool) -> Self {
53        self.verbose = verbose;
54        self
55    }
56
57    /// Optimized structural geometry using the FIRE algorithm.
58    pub fn optimize(&self, system: &mut System) {
59        let n = system.atoms.len();
60        let mut velocities = vec![DVec3::ZERO; n];
61        
62        let mut dt = 0.02;
63        let dt_max = 0.2;
64        let mut n_pos = 0;
65        let mut alpha = 0.1;
66        let alpha_start = 0.1;
67
68        // Convergence history
69        let mut fmax_history = std::collections::VecDeque::with_capacity(self.history_size);
70        let mut frms_history = std::collections::VecDeque::with_capacity(self.history_size);
71        let mut ediff_history = std::collections::VecDeque::with_capacity(self.history_size);
72        let mut last_energy: Option<f64> = None;
73        
74        let start_time = std::time::Instant::now();
75
76        if self.verbose {
77            let version_str = format!(" uff-relax v{} ", env!("CARGO_PKG_VERSION"));
78            println!("\n{:=^80}", version_str);
79            println!("{:<10} {:<10} | {:<10} {:<10}", "Atoms:", n, "Bonds:", system.bonds.len());
80            println!("{:<10} {:<10.1} | {:<10} {:<10.4} kcal/mol/Å", "Cutoff:", self.cutoff, "Threshold:", self.force_threshold);
81            println!("{:<10} {:<10} | {:<10} {:<10}", "Max Iter:", self.max_iterations, "Threads:", if self.num_threads == 0 { "Auto".to_string() } else { self.num_threads.to_string() });
82            println!("{:-<80}", "");
83            println!("{:<6} | {:<14} | {:<14} | {:<16} | {:<10}", "", "Fmax", "FRMS", "Total E", "");
84            println!("{:<6} | {:<14} | {:<14} | {:<16} | {:<10}", "Iter", "(kcal/mol/Å)", "(kcal/mol/Å)", "(kcal/mol)", "Status");
85            println!("{:-<80}", "");
86        }
87
88        let mut final_iter = 0;
89        let mut final_status = "Max-Iter";
90
91        for iter in 0..self.max_iterations {
92            final_iter = iter;
93            let energy = system.compute_forces_with_threads(self.num_threads, self.cutoff);
94            
95            // Calculate Fmax and FRMS
96            let mut max_f_sq: f64 = 0.0;
97            let mut sum_f_sq: f64 = 0.0;
98            for atom in &system.atoms {
99                let f_sq = atom.force.length_squared();
100                max_f_sq = f64::max(max_f_sq, f_sq);
101                sum_f_sq += f_sq;
102            }
103            let f_max = max_f_sq.sqrt();
104            let f_rms = (sum_f_sq / (3.0 * n as f64)).sqrt();
105
106            // Update history
107            if fmax_history.len() >= self.history_size { fmax_history.pop_front(); }
108            fmax_history.push_back(f_max);
109            
110            if frms_history.len() >= self.history_size { frms_history.pop_front(); }
111            frms_history.push_back(f_rms);
112
113            if let Some(prev_e) = last_energy {
114                if ediff_history.len() >= self.history_size { ediff_history.pop_front(); }
115                ediff_history.push_back((energy.total - prev_e).abs() / n as f64);
116            }
117            last_energy = Some(energy.total);
118
119            // Convergence Check
120            let mut converged = false;
121            let mut status = "";
122            if fmax_history.len() >= self.history_size {
123                let avg_fmax: f64 = fmax_history.iter().sum::<f64>() / self.history_size as f64;
124                let avg_frms: f64 = frms_history.iter().sum::<f64>() / self.history_size as f64;
125                let avg_ediff: f64 = if ediff_history.is_empty() { 1.0 } else { ediff_history.iter().sum::<f64>() / ediff_history.len() as f64 };
126
127                if avg_fmax < self.force_threshold {
128                    converged = true;
129                    status = "Fmax-Conv";
130                } else if avg_fmax < self.force_threshold * 2.0 && avg_frms < self.force_threshold * 0.5 {
131                    converged = true;
132                    status = "FRMS-Conv";
133                } else if !ediff_history.is_empty() && avg_ediff < 1e-7 {
134                    converged = true;
135                    status = "E-Stalled";
136                }
137            }
138            
139            if self.verbose && (iter % 10 == 0 || converged) {
140                println!("{:>6} | {:>14.4} | {:>14.4} | {:>16.4} | {:<10}", iter, f_max, f_rms, energy.total, status);
141            }
142
143            if converged {
144                final_status = status;
145                break;
146            }
147
148            // FIRE logic
149            let mut p = 0.0;
150            for i in 0..n {
151                p += velocities[i].dot(system.atoms[i].force);
152            }
153
154            for i in 0..n {
155                let f_norm = system.atoms[i].force.length();
156                if f_norm > 1e-9 {
157                    velocities[i] = (1.0 - alpha) * velocities[i] + alpha * (system.atoms[i].force / f_norm) * velocities[i].length();
158                }
159            }
160
161            if p > 0.0 {
162                n_pos += 1;
163                if n_pos > 5 {
164                    dt = f64::min(dt * 1.1, dt_max);
165                    alpha *= 0.99;
166                }
167            } else {
168                n_pos = 0;
169                dt *= 0.5;
170                alpha = alpha_start;
171                for v in &mut velocities {
172                    *v = DVec3::ZERO;
173                }
174            }
175
176            // Verlet integration (Simplified)
177            for i in 0..n {
178                velocities[i] += system.atoms[i].force * dt;
179                system.atoms[i].position += velocities[i] * dt;
180            }
181        }
182
183        if self.verbose {
184            let duration = start_time.elapsed();
185            let final_energy = system.compute_forces_with_threads(self.num_threads, self.cutoff);
186            println!("{:-<80}", "");
187            println!("=== Optimization Finished ===");
188            println!("Reason: {:<20}", final_status);
189            println!("Total Time: {:<10.3?} (Avg: {:.3?} / step)", duration, duration / (final_iter + 1) as u32);
190            println!("Final Energy: {:<15.4} kcal/mol", final_energy.total);
191            println!("Final Fmax:   {:<15.4} kcal/mol/Å", fmax_history.back().unwrap_or(&0.0));
192            println!("Final FRMS:   {:<15.4} kcal/mol/Å", frms_history.back().unwrap_or(&0.0));
193            println!("{:>80}", "(c) 2026 Forblaze Project");
194            println!("{:-<80}\n", "");
195        }
196    }
197}