Skip to main content

uff_relax/
optimizer.rs

1use tracing;
2use crate::forcefield::System;
3use crate::forcefield::interactions::get_bond_equilibrium_dist;
4use glam::DVec3;
5use web_time::Instant;
6use std::collections::VecDeque;
7
8/// Callback for optimization steps: (iteration, f_max, energy)
9pub type StepHook = dyn Fn(usize, f64, f64) + Send + Sync;
10
11/// Optimizer for molecular structures using the FIRE (Fast Iterative Relaxation Engine) algorithm.
12pub struct UffOptimizer {
13    /// Maximum number of iterations to perform.
14    pub max_iterations: usize,
15    /// Threshold for the maximum force on any atom (kcal/mol/Å).
16    pub force_threshold: f64,
17    /// Whether to print optimization progress to tracing logs.
18    pub verbose: bool,
19    /// Number of threads to use. 0 means automatic based on system size.
20    pub num_threads: usize,
21    /// Cutoff distance for non-bonded interactions (Å).
22    pub cutoff: f64,
23    /// Number of steps to average for convergence criteria.
24    pub history_size: usize,
25    /// Maximum distance an atom can move in a single step (Å).
26    pub max_displacement: f64,
27    /// Optional hook called after each iteration.
28    pub step_hook: Option<std::sync::Arc<StepHook>>,
29    /// Optional flag to cancel optimization from another thread/context.
30    pub cancel_flag: Option<std::sync::Arc<std::sync::atomic::AtomicBool>>,
31}
32
33impl UffOptimizer {
34    /// Creates a new optimizer with default settings.
35    pub fn new(max_iterations: usize, force_threshold: f64) -> Self {
36        Self {
37            max_iterations,
38            force_threshold,
39            verbose: false,
40            num_threads: 0,
41            cutoff: 6.0,
42            history_size: 10,
43            max_displacement: 0.2,
44            step_hook: None,
45            cancel_flag: None,
46        }
47    }
48
49    pub fn with_max_displacement(mut self, max: f64) -> Self {
50        self.max_displacement = max;
51        self
52    }
53
54    pub fn with_num_threads(mut self, num_threads: usize) -> Self {
55        self.num_threads = num_threads;
56        self
57    }
58
59    pub fn with_cutoff(mut self, cutoff: f64) -> Self {
60        self.cutoff = cutoff;
61        self
62    }
63
64    pub fn with_history_size(mut self, size: usize) -> Self {
65        self.history_size = size;
66        self
67    }
68
69    pub fn with_verbose(mut self, verbose: bool) -> Self {
70        self.verbose = verbose;
71        self
72    }
73
74    pub fn with_step_hook<F>(mut self, hook: F) -> Self 
75    where F: Fn(usize, f64, f64) + Send + Sync + 'static {
76        self.step_hook = Some(std::sync::Arc::new(hook));
77        self
78    }
79
80    pub fn with_cancel_flag(mut self, flag: std::sync::Arc<std::sync::atomic::AtomicBool>) -> Self {
81        self.cancel_flag = Some(flag);
82        self
83    }
84
85    /// Optimized structural geometry using the FIRE algorithm.
86    pub fn optimize(&self, system: &mut System) {
87        let n = system.atoms.len();
88        if n == 0 { return; }
89        
90        // Initial wrap only if periodic boundary conditions exist
91        if !matches!(system.cell.cell_type, crate::cell::CellType::None) {
92            for atom in &mut system.atoms {
93                atom.position = system.cell.wrap_vector(atom.position);
94            }
95        }
96
97        let mut velocities = vec![DVec3::ZERO; n];
98        let mut dt = 0.01;
99        let dt_max = 0.05;
100        let mut n_pos = 0;
101        let mut alpha = 0.15;
102        let alpha_start = 0.15;
103
104        let mut fmax_history = VecDeque::with_capacity(self.history_size);
105        let mut frms_history = VecDeque::with_capacity(self.history_size);
106        let mut ediff_history = VecDeque::with_capacity(self.history_size);
107        let mut last_energy: Option<f64> = None;
108        
109        let start_time = Instant::now();
110
111        if self.verbose {
112            self.print_header(system);
113        }
114
115        let mut final_iter = 0;
116        let mut final_status = "Max-Iter";
117
118        for iter in 0..self.max_iterations {
119            final_iter = iter;
120
121            // Check for cancellation
122            if let Some(ref cancel) = self.cancel_flag {
123                if cancel.load(std::sync::atomic::Ordering::SeqCst) {
124                    final_status = "Cancelled";
125                    break;
126                }
127            }
128
129            #[cfg(target_arch = "wasm32")]
130            let energy = system.compute_forces_with_threads(1, self.cutoff);
131            #[cfg(not(target_arch = "wasm32"))]
132            let energy = system.compute_forces_with_threads(self.num_threads, self.cutoff);
133            
134            let (f_max, f_rms) = self.calculate_force_metrics(system);
135
136            // Update history
137            if fmax_history.len() >= self.history_size { fmax_history.pop_front(); }
138            fmax_history.push_back(f_max);
139            if frms_history.len() >= self.history_size { frms_history.pop_front(); }
140            frms_history.push_back(f_rms);
141            if let Some(prev_e) = last_energy {
142                if ediff_history.len() >= self.history_size { ediff_history.pop_front(); }
143                ediff_history.push_back((energy.total - prev_e).abs() / n as f64);
144            }
145            last_energy = Some(energy.total);
146
147            // Convergence Check
148            let (converged, status) = self.check_convergence(f_max, f_rms, &fmax_history, &frms_history, &ediff_history);
149            
150            if self.verbose && (iter % 10 == 0 || converged) {
151                if energy.total.abs() >= 1e10 {
152                    tracing::info!("{:>6} | {:>14.4} | {:>14.4} | {:>16.4e} | {:<10}", iter, f_max, f_rms, energy.total, status);
153                } else {
154                    tracing::info!("{:>6} | {:>14.4} | {:>14.4} | {:>16.4} | {:<10}", iter, f_max, f_rms, energy.total, status);
155                }
156            }
157
158            if let Some(ref hook) = self.step_hook {
159                hook(iter, f_max, energy.total);
160            }
161
162            if converged {
163                final_status = status;
164                break;
165            }
166
167            self.fire_update(system, &mut velocities, &mut dt, dt_max, &mut n_pos, &mut alpha, alpha_start);
168        }
169
170        if self.verbose {
171            self.print_footer(system, final_status, start_time, final_iter, &fmax_history, &frms_history);
172        }
173    }
174
175    /// Asynchronous version of the optimizer for non-blocking environments (Wasm/UIs).
176    pub async fn optimize_async(&self, system: &mut System) {
177        let n = system.atoms.len();
178        if n == 0 { return; }
179        
180        if !matches!(system.cell.cell_type, crate::cell::CellType::None) {
181            for atom in &mut system.atoms {
182                atom.position = system.cell.wrap_vector(atom.position);
183            }
184        }
185
186        let mut velocities = vec![DVec3::ZERO; n];
187        let mut dt = 0.01;
188        let dt_max = 0.05;
189        let mut n_pos = 0;
190        let mut alpha = 0.15;
191        let alpha_start = 0.15;
192
193        let mut fmax_history = VecDeque::with_capacity(self.history_size);
194        let mut frms_history = VecDeque::with_capacity(self.history_size);
195        let mut ediff_history = VecDeque::with_capacity(self.history_size);
196        let mut last_energy: Option<f64> = None;
197        
198        let start_time = Instant::now();
199
200        if self.verbose {
201            self.print_header(system);
202        }
203
204        let mut final_iter = 0;
205        let mut final_status = "Max-Iter";
206
207        for iter in 0..self.max_iterations {
208            final_iter = iter;
209
210            // Check for cancellation
211            if let Some(ref cancel) = self.cancel_flag {
212                if cancel.load(std::sync::atomic::Ordering::SeqCst) {
213                    final_status = "Cancelled";
214                    break;
215                }
216            }
217            
218            // In Wasm, num_threads should be 1 as Rayon is not supported easily
219            #[cfg(target_arch = "wasm32")]
220            let energy = system.compute_forces_with_threads(1, self.cutoff);
221            #[cfg(not(target_arch = "wasm32"))]
222            let energy = system.compute_forces_with_threads(self.num_threads, self.cutoff);
223            
224            let (f_max, f_rms) = self.calculate_force_metrics(system);
225
226            if fmax_history.len() >= self.history_size { fmax_history.pop_front(); }
227            fmax_history.push_back(f_max);
228            if frms_history.len() >= self.history_size { frms_history.pop_front(); }
229            frms_history.push_back(f_rms);
230            if let Some(prev_e) = last_energy {
231                if ediff_history.len() >= self.history_size { ediff_history.pop_front(); }
232                ediff_history.push_back((energy.total - prev_e).abs() / n as f64);
233            }
234            last_energy = Some(energy.total);
235
236            let (converged, status) = self.check_convergence(f_max, f_rms, &fmax_history, &frms_history, &ediff_history);
237            
238            if self.verbose && (iter % 10 == 0 || converged) {
239                if energy.total.abs() >= 1e10 {
240                    tracing::info!("{:>6} | {:>14.4} | {:>14.4} | {:>16.4e} | {:<10}", iter, f_max, f_rms, energy.total, status);
241                } else {
242                    tracing::info!("{:>6} | {:>14.4} | {:>14.4} | {:>16.4} | {:<10}", iter, f_max, f_rms, energy.total, status);
243                }
244            }
245
246            if let Some(ref hook) = self.step_hook {
247                hook(iter, f_max, energy.total);
248            }
249
250            if converged {
251                final_status = status;
252                break;
253            }
254
255            self.fire_update(system, &mut velocities, &mut dt, dt_max, &mut n_pos, &mut alpha, alpha_start);
256
257            // Yield control back to the environment periodically
258            if iter % 5 == 0 {
259                self.yield_now().await;
260            }
261        }
262
263        if self.verbose {
264            self.print_footer(system, final_status, start_time, final_iter, &fmax_history, &frms_history);
265        }
266    }
267
268    async fn yield_now(&self) {
269        #[cfg(feature = "wasm")]
270        {
271            let promise = js_sys::Promise::new(&mut |resolve, _| {
272                if let Some(window) = web_sys::window() {
273                    window.set_timeout_with_callback_and_timeout_and_arguments_0(&resolve, 0).unwrap();
274                }
275            });
276            let _ = wasm_bindgen_futures::JsFuture::from(promise).await;
277        }
278    }
279
280    fn calculate_force_metrics(&self, system: &System) -> (f64, f64) {
281        let n = system.atoms.len();
282        let mut max_f_sq: f64 = 0.0;
283        let mut sum_f_sq: f64 = 0.0;
284        for atom in &system.atoms {
285            let f_sq = atom.force.length_squared();
286            max_f_sq = f64::max(max_f_sq, f_sq);
287            sum_f_sq += f_sq;
288        }
289        (max_f_sq.sqrt(), (sum_f_sq / (3.0 * n as f64)).sqrt())
290    }
291
292    fn check_convergence(&self, _f_max: f64, _f_rms: f64, fmax_hist: &VecDeque<f64>, frms_hist: &VecDeque<f64>, ediff_hist: &VecDeque<f64>) -> (bool, &'static str) {
293        if fmax_hist.len() < self.history_size {
294            return (false, "");
295        }
296        let avg_fmax: f64 = fmax_hist.iter().sum::<f64>() / self.history_size as f64;
297        let avg_frms: f64 = frms_hist.iter().sum::<f64>() / self.history_size as f64;
298        let avg_ediff: f64 = if ediff_hist.is_empty() { 1.0 } else { ediff_hist.iter().sum::<f64>() / ediff_hist.len() as f64 };
299
300        if avg_fmax < self.force_threshold {
301            (true, "Fmax-Conv")
302        } else if avg_fmax < self.force_threshold * 2.0 && avg_frms < self.force_threshold * 0.5 {
303            (true, "FRMS-Conv")
304        } else if !ediff_hist.is_empty() && avg_ediff < 1e-7 {
305            (true, "E-Stalled")
306        } else {
307            (false, "")
308        }
309    }
310
311    fn fire_update(&self, system: &mut System, velocities: &mut [DVec3], dt: &mut f64, dt_max: f64, n_pos: &mut usize, alpha: &mut f64, alpha_start: f64) {
312        let n = system.atoms.len();
313        let mut p = 0.0;
314        for i in 0..n {
315            p += velocities[i].dot(system.atoms[i].force);
316        }
317
318        for i in 0..n {
319            let f_norm = system.atoms[i].force.length();
320            let v_norm = velocities[i].length();
321            if f_norm > 1e-9 {
322                velocities[i] = (1.0 - *alpha) * velocities[i] + *alpha * (system.atoms[i].force / f_norm) * v_norm;
323            }
324        }
325
326        if p > 0.0 {
327            *n_pos += 1;
328            if *n_pos > 5 {
329                *dt = f64::min(*dt * 1.05, dt_max);
330                *alpha *= 0.99;
331            }
332        } else {
333            *n_pos = 0;
334            *dt *= 0.5;
335            *alpha = alpha_start;
336            for v in velocities.iter_mut() {
337                *v = DVec3::ZERO;
338            }
339        }
340
341        for i in 0..n {
342            velocities[i] += system.atoms[i].force * (*dt);
343            let mut move_vec = velocities[i] * (*dt);
344            let move_len = move_vec.length();
345            if move_len > self.max_displacement {
346                move_vec *= self.max_displacement / move_len;
347                velocities[i] = move_vec / (*dt);
348            }
349            let new_pos = system.atoms[i].position + move_vec;
350            system.atoms[i].position = system.cell.wrap_vector(new_pos);
351        }
352    }
353
354    fn print_header(&self, system: &System) {
355        let n_atoms = system.atoms.len();
356        let n_bonds = system.bonds.len();
357        let has_charges = system.atoms.iter().any(|a| a.charge.abs() > 1e-12);
358        
359        // Determine actual threads used
360        #[cfg(target_arch = "wasm32")]
361        let actual_threads = 1;
362        
363        #[cfg(not(target_arch = "wasm32"))]
364        let actual_threads = if self.num_threads == 1 {
365            1
366        } else if self.num_threads > 1 {
367            self.num_threads
368        } else if n_atoms >= 1000 { // PARALLEL_THRESHOLD
369            std::env::var("RAYON_NUM_THREADS")
370                .ok()
371                .and_then(|s| s.parse().ok())
372                .unwrap_or(4)
373        } else {
374            1
375        };
376
377        let version_str = format!(" uff-relax v{} ", env!("CARGO_PKG_VERSION"));
378        tracing::info!("\n{:=^80}", version_str);
379        tracing::info!("{:<10} {:<10} | {:<10} {:<10}", "Atoms:", n_atoms, "Bonds:", n_bonds);
380        tracing::info!("{:<10} {:<10.1} | {:<10} {:<10.4} kcal/mol/Å", "Cutoff:", self.cutoff, "Threshold:", self.force_threshold);
381        tracing::info!("{:<10} {:<10} | {:<10} {:<10}", 
382            "Threads:", actual_threads, 
383            "Charges:", if has_charges { "Active (Wolf)" } else { "Inactive" }
384        );
385        tracing::info!("{:<10} {:<10} | {:<10} {:<10}", "Max Iter:", self.max_iterations, "", "");
386        tracing::info!("{:-<80}", "");
387        tracing::info!("{:<6} | {:<14} | {:<14} | {:<16} | {:<10}", "", "Fmax", "FRMS", "Total E", "");
388        tracing::info!("{:<6} | {:<14} | {:<14} | {:<16} | {:<10}", "Iter", "(kcal/mol/Å)", "(kcal/mol/Å)", "(kcal/mol)", "Status");
389        tracing::info!("{:-<80}", "");
390    }
391
392    fn print_footer(&self, system: &mut System, final_status: &str, start_time: Instant, final_iter: usize, fmax_hist: &VecDeque<f64>, frms_hist: &VecDeque<f64>) {
393        let n = system.atoms.len();
394        let duration = start_time.elapsed();
395        let final_energy = system.compute_forces_with_threads(self.num_threads, self.cutoff);
396
397        let mut min_dist = f64::MAX;
398        let mut min_pair = (0, 0);
399        for i in 0..n {
400            for j in i + 1..n {
401                let d = system.cell.distance_vector(system.atoms[i].position, system.atoms[j].position).length();
402                if d < min_dist { 
403                    min_dist = d;
404                    min_pair = (i, j);
405                }
406            }
407        }
408
409        // Abnormal bond detection (Mechanical entanglement check)
410        let mut abnormal_bonds = Vec::new();
411        for bond in &system.bonds {
412            let (i, j) = bond.atom_indices;
413            let current_dist = system.cell.distance_vector(system.atoms[i].position, system.atoms[j].position).length();
414            if let Some(r0) = get_bond_equilibrium_dist(&system.atoms[i].uff_type, &system.atoms[j].uff_type, bond.order) {
415                if current_dist > r0 * 1.3 {
416                    abnormal_bonds.push((i, j, current_dist, r0));
417                }
418            }
419        }
420
421        tracing::info!("{:-<80}", "");
422        tracing::info!("=== Optimization Finished ===");
423        tracing::info!("Reason: {:<20}", final_status);
424        tracing::info!("Total Time: {:<10.3?} (Avg: {:.3?} / step)", duration, duration / (final_iter + 1) as u32);
425        if final_energy.total.abs() >= 1e10 {
426            tracing::info!("Final Energy: {:<15.4e} kcal/mol", final_energy.total);
427        } else {
428            tracing::info!("Final Energy: {:<15.4} kcal/mol", final_energy.total);
429        }
430        tracing::info!("Final Fmax:   {:<15.4} kcal/mol/Å", fmax_hist.back().unwrap_or(&0.0));
431        tracing::info!("Final FRMS:   {:<15.4} kcal/mol/Å", frms_hist.back().unwrap_or(&0.0));
432        tracing::info!("Min Distance: {:<15.4} Å (Atoms {} and {})", min_dist, min_pair.0 + 1, min_pair.1 + 1);
433
434        if !abnormal_bonds.is_empty() {
435            tracing::warn!("!!! ABNORMAL BONDS DETECTED ({} total) !!!", abnormal_bonds.len());
436            for (i, j, dist, r0) in abnormal_bonds.iter().take(3) {
437                tracing::warn!("  Bond {}-{} : Length {:.4} Å (Equiv: {:.4} Å, Dev: {:.1}%)", 
438                    i+1, j+1, dist, r0, (dist/r0 - 1.0)*100.0);
439            }
440            if abnormal_bonds.len() > 3 {
441                tracing::warn!("  ... and {} more abnormal bonds.", abnormal_bonds.len() - 3);
442            }
443        }
444
445        tracing::info!("{:>80}", "(c) 2026 Forblaze Project");
446        tracing::info!("{:-<80}\n", "");
447    }
448}