differential_equations/solout/
crossing.rs

1//! Crossing detection solout for detecting when state components cross threshold values.
2//!
3//! This module provides functionality for detecting and recording when a specific state
4//! component crosses a defined threshold value during integration.
5
6use super::*;
7
8/// A solout that detects when a component crosses a specified threshold value.
9///
10/// # Overview
11///
12/// `CrossingSolout` monitors a specific component of the state vector and detects when
13/// it crosses a defined threshold value. This is useful for identifying important events
14/// in the system's behavior, such as:
15///
16/// - Zero-crossings (by setting threshold to 0)
17/// - Detecting when a variable exceeds or falls below a critical value
18/// - Generating data for poincare sections or other analyses
19///
20/// The solout records the times and states when crossings occur, making them available
21/// in the solver output.
22///
23/// # Example
24///
25/// ```
26/// use differential_equations::prelude::*;
27/// use differential_equations::solout::CrossingSolout;
28/// use nalgebra::{Vector2, vector};
29///
30/// // Simple harmonic oscillator - position will cross zero periodically
31/// struct HarmonicOscillator;
32///
33/// impl ODE<f64, Vector2<f64>> for HarmonicOscillator {
34///     fn diff(&self, _t: f64, y: &Vector2<f64>, dydt: &mut Vector2<f64>) {
35///         // y[0] = position, y[1] = velocity
36///         dydt[0] = y[1];
37///         dydt[1] = -y[0];
38///     }
39/// }
40///
41/// // Create the system and solver
42/// let system = HarmonicOscillator;
43/// let t0 = 0.0;
44/// let tf = 10.0;
45/// let y0 = vector![1.0, 0.0]; // Start with positive position, zero velocity
46/// let mut solver = ExplicitRungeKutta::dop853().rtol(1e-8).atol(1e-8);
47///
48/// // Detect zero-crossings of the position component (index 0)
49/// let mut crossing_detector = CrossingSolout::new(0, 0.0);
50///
51/// // Solve and get only the crossing points
52/// let problem = ODEProblem::new(system, t0, tf, y0);
53/// let solution = problem.solout(&mut crossing_detector).solve(&mut solver).unwrap();
54///
55/// // solution now contains only the points where position crosses zero
56/// println!("Zero crossings occurred at times: {:?}", solution.t);
57/// ```
58///
59/// # Directional Crossing Detection
60///
61/// You can filter the crossings by direction:
62///
63/// ```
64/// use differential_equations::solout::{CrossingSolout, CrossingDirection};
65///
66/// // Only detect positive crossings (from below to above threshold)
67/// let positive_crossings = CrossingSolout::new(0, 5.0).with_direction(CrossingDirection::Positive);
68///
69/// // Only detect negative crossings (from above to below threshold)
70/// let negative_crossings = CrossingSolout::new(0, 5.0).with_direction(CrossingDirection::Negative);
71/// ```
72pub struct CrossingSolout<T: Real> {
73    /// Index of the component to monitor
74    component_idx: usize,
75    /// Threshold value to detect crossings against
76    threshold: T,
77    /// Last observed value minus threshold (for detecting sign changes)
78    last_offset_value: Option<T>,
79    /// Direction of crossing to detect
80    direction: CrossingDirection,
81}
82
83impl<T: Real> CrossingSolout<T> {
84    /// Creates a new CrossingSolout to detect when the specified component crosses the threshold.
85    ///
86    /// By default, crossings in both directions are detected.
87    ///
88    /// # Arguments
89    /// * `component_idx` - Index of the component in the state vector to monitor
90    /// * `threshold` - The threshold value to detect crossings against
91    ///
92    /// # Example
93    ///
94    /// ```
95    /// use differential_equations::solout::CrossingSolout;
96    ///
97    /// // Detect when the first component (index 0) crosses the value 5.0
98    /// let detector = CrossingSolout::new(0, 5.0);
99    /// ```
100    pub fn new(component_idx: usize, threshold: T) -> Self {
101        CrossingSolout {
102            component_idx,
103            threshold,
104            last_offset_value: None,
105            direction: CrossingDirection::Both,
106        }
107    }
108
109    /// Set the direction of crossings to detect.
110    ///
111    /// # Arguments
112    /// * `direction` - The crossing direction to detect (Both, Positive, or Negative)
113    ///
114    /// # Returns
115    /// * `Self` - The modified CrossingSolout (builder pattern)
116    ///
117    /// # Example
118    ///
119    /// ```
120    /// use differential_equations::solout::{CrossingSolout, CrossingDirection};
121    ///
122    /// // Detect when the position (index 0) crosses zero in any direction
123    /// let any_crossing = CrossingSolout::new(0, 0.0).with_direction(CrossingDirection::Both);
124    ///
125    /// // Detect when the position (index 0) goes from negative to positive
126    /// let zero_up_detector = CrossingSolout::new(0, 0.0).with_direction(CrossingDirection::Positive);
127    ///
128    /// // Detect when the velocity (index 1) changes from positive to negative
129    /// let velocity_sign_change = CrossingSolout::new(1, 0.0).with_direction(CrossingDirection::Negative);
130    /// ```
131    pub fn with_direction(mut self, direction: CrossingDirection) -> Self {
132        self.direction = direction;
133        self
134    }
135
136    /// Set to detect only positive crossings (from below to above threshold).
137    ///
138    /// A positive crossing occurs when the monitored component transitions from
139    /// a value less than the threshold to a value greater than or equal to the threshold.
140    ///
141    /// # Returns
142    /// * `Self` - The modified CrossingSolout (builder pattern)
143    ///
144    /// # Example
145    ///
146    /// ```
147    /// use differential_equations::solout::CrossingSolout;
148    ///
149    /// // Detect when the position (index 0) goes from negative to positive
150    /// let zero_up_detector = CrossingSolout::new(0, 0.0).positive_only();
151    /// ```
152    pub fn positive_only(mut self) -> Self {
153        self.direction = CrossingDirection::Positive;
154        self
155    }
156
157    /// Set to detect only negative crossings (from above to below threshold).
158    ///
159    /// A negative crossing occurs when the monitored component transitions from
160    /// a value greater than the threshold to a value less than or equal to the threshold.
161    ///
162    /// # Returns
163    /// * `Self` - The modified CrossingSolout (builder pattern)
164    ///
165    /// # Example
166    ///
167    /// ```
168    /// use differential_equations::solout::CrossingSolout;
169    ///
170    /// // Detect when the velocity (index 1) changes from positive to negative
171    /// let velocity_sign_change = CrossingSolout::new(1, 0.0).negative_only();
172    /// ```
173    pub fn negative_only(mut self) -> Self {
174        self.direction = CrossingDirection::Negative;
175        self
176    }
177}
178
179impl<T, Y> Solout<T, Y> for CrossingSolout<T>
180where
181    T: Real,
182    Y: State<T>,
183{
184    fn solout<I>(
185        &mut self,
186        t_curr: T,
187        t_prev: T,
188        y_curr: &Y,
189        _y_prev: &Y,
190        interpolator: &mut I,
191        solution: &mut Solution<T, Y>,
192    ) -> ControlFlag<T, Y>
193    where
194        I: Interpolation<T, Y>,
195    {
196        // Calculate the offset from threshold (to detect zero-crossing)
197        let current_value = y_curr.get(self.component_idx);
198        let offset_value = current_value - self.threshold;
199
200        // If we have a previous value, check for crossing
201        if let Some(last_offset) = self.last_offset_value {
202            let zero = T::zero();
203            let is_crossing = last_offset.signum() != offset_value.signum();
204
205            if is_crossing {
206                // Check crossing direction if specified
207                let record_crossing = match self.direction {
208                    CrossingDirection::Positive => last_offset < zero && offset_value >= zero,
209                    CrossingDirection::Negative => last_offset > zero && offset_value <= zero,
210                    CrossingDirection::Both => true, // any crossing
211                };
212
213                if record_crossing {
214                    // Find crossing time using Newton's method
215                    if let Some(t_cross) = self.find_crossing_newton(
216                        interpolator,
217                        t_prev,
218                        t_curr,
219                        last_offset,
220                        offset_value,
221                    ) {
222                        // Use interpolator's interpolation for the full state vector at crossing time
223                        let y_cross = interpolator.interpolate(t_cross).unwrap();
224
225                        // push the crossing time and value
226                        solution.push(t_cross, y_cross);
227                    } else {
228                        // Fallback to linear interpolation if Newton's method fails
229                        let frac = -last_offset / (offset_value - last_offset);
230                        let t_cross = t_prev + frac * (t_curr - t_prev);
231                        let y_cross = interpolator.interpolate(t_cross).unwrap();
232
233                        // push the estimated crossing time and value
234                        solution.push(t_cross, y_cross);
235                    }
236                }
237            }
238        }
239
240        // Update last value for next comparison
241        self.last_offset_value = Some(offset_value);
242
243        // Continue the integration
244        ControlFlag::Continue
245    }
246}
247
248// Add the Newton's method implementation
249impl<T: Real> CrossingSolout<T> {
250    /// Find the crossing time using Newton's method with interpolator interpolation
251    fn find_crossing_newton<I, Y>(
252        &self,
253        interpolator: &mut I,
254        t_lower: T,
255        t_upper: T,
256        offset_lower: T,
257        offset_upper: T,
258    ) -> Option<T>
259    where
260        I: Interpolation<T, Y>,
261        Y: State<T>,
262    {
263        // Start with linear interpolation as initial guess
264        let mut t = t_lower - offset_lower * (t_upper - t_lower) / (offset_upper - offset_lower);
265
266        // Newton's method parameters
267        let max_iterations = 10;
268        let tolerance = T::default_epsilon() * T::from_f64(100.0).unwrap(); // Higher tolerance for numerical stability
269        let mut offset;
270
271        // Newton's method iterations
272        for _ in 0..max_iterations {
273            // Get interpolated state at current time guess
274            let y_t = interpolator.interpolate(t).unwrap();
275
276            // Calculate offset from threshold at this time point
277            offset = y_t.get(self.component_idx) - self.threshold;
278
279            // Check if we're close enough to the crossing
280            if offset.abs() < tolerance {
281                return Some(t);
282            }
283
284            // Calculate numerical derivative of the offset function
285            let delta_t = (t_upper - t_lower) * T::from_f64(1e-6).unwrap();
286            let t_plus = t + delta_t;
287            let y_plus = interpolator.interpolate(t_plus).unwrap();
288            let offset_plus = y_plus.get(self.component_idx) - self.threshold;
289
290            let derivative = (offset_plus - offset) / delta_t;
291
292            // Avoid division by zero or very small derivatives
293            if derivative.abs() < T::default_epsilon() * T::from_f64(10.0).unwrap() {
294                break;
295            }
296
297            // Newton step
298            let t_new = t - offset / derivative;
299
300            // Ensure we stay within bounds
301            if t_new < t_lower || t_new > t_upper {
302                // Bisection fallback
303                t = (t_lower + t_upper) / T::from_f64(2.0).unwrap();
304            } else {
305                // Check if we're making progress
306                let change = (t_new - t).abs();
307                if change < tolerance * T::from_f64(0.1).unwrap() {
308                    // We're barely moving, consider it converged
309                    t = t_new;
310                    break;
311                }
312                t = t_new;
313            }
314        }
315
316        // Final check: Get interpolated value and see if we're close enough
317        let y_t = interpolator.interpolate(t).unwrap();
318        offset = y_t.get(self.component_idx) - self.threshold;
319
320        if offset.abs() < tolerance * T::from_f64(10.0).unwrap() {
321            Some(t)
322        } else {
323            None // Failed to converge
324        }
325    }
326}