Skip to main content

ant_node/upgrade/
rollout.rs

1//! Staged rollout for network-wide upgrades.
2//!
3//! This module provides deterministic delay calculation to prevent mass node
4//! restarts during upgrades. Each node calculates a unique delay based on its
5//! node ID, distributing upgrades evenly across the configured time window.
6//!
7//! ## Why Staged Rollout?
8//!
9//! When a new version is released, if all nodes upgrade simultaneously:
10//! - Network partitioning may occur
11//! - Data availability could be temporarily reduced
12//! - The network may become unstable
13//!
14//! By spreading upgrades over a 24-hour window (default), we ensure:
15//! - Continuous network availability
16//! - Gradual transition to the new version
17//! - Ability to detect issues before all nodes upgrade
18//!
19//! ## Deterministic Delays
20//!
21//! The delay is calculated deterministically from the node ID hash, so:
22//! - Each node gets a consistent delay (no drift on restarts)
23//! - Nodes are evenly distributed across the rollout window
24//! - The same node always upgrades at the same point in the window
25
26use std::time::Duration;
27use tracing::debug;
28
29/// Staged rollout configuration and delay calculation.
30#[derive(Debug, Clone)]
31pub struct StagedRollout {
32    /// Maximum delay in hours (nodes will be distributed from 0 to this value).
33    max_delay_hours: u64,
34    /// Hash of the node ID for deterministic delay calculation.
35    node_id_hash: [u8; 32],
36}
37
38impl StagedRollout {
39    /// Create a new staged rollout calculator.
40    ///
41    /// # Arguments
42    ///
43    /// * `node_id` - The node's unique identifier (typically a public key)
44    /// * `max_delay_hours` - Maximum rollout window (default: 24 hours)
45    #[must_use]
46    pub fn new(node_id: &[u8], max_delay_hours: u64) -> Self {
47        let node_id_hash = *blake3::hash(node_id).as_bytes();
48
49        Self {
50            max_delay_hours,
51            node_id_hash,
52        }
53    }
54
55    /// Calculate the delay before this node should apply an upgrade.
56    ///
57    /// The delay is deterministically derived from the node ID, ensuring:
58    /// - Each node gets a consistent delay on every check
59    /// - Nodes are evenly distributed across the rollout window
60    /// - The delay is reproducible (same node ID = same delay)
61    #[must_use]
62    pub fn calculate_delay(&self) -> Duration {
63        if self.max_delay_hours == 0 {
64            return Duration::ZERO;
65        }
66
67        // Use first 8 bytes of hash as a u64 for delay calculation
68        let hash_value = u64::from_le_bytes([
69            self.node_id_hash[0],
70            self.node_id_hash[1],
71            self.node_id_hash[2],
72            self.node_id_hash[3],
73            self.node_id_hash[4],
74            self.node_id_hash[5],
75            self.node_id_hash[6],
76            self.node_id_hash[7],
77        ]);
78
79        // Calculate delay as a fraction of the max window
80        // hash_value / u64::MAX gives a value between 0 and 1
81        let max_delay_secs = self.max_delay_hours * 3600;
82
83        // Avoid division by zero and calculate proportional delay
84        #[allow(clippy::cast_precision_loss)]
85        let delay_fraction = (hash_value as f64) / (u64::MAX as f64);
86
87        #[allow(
88            clippy::cast_possible_truncation,
89            clippy::cast_sign_loss,
90            clippy::cast_precision_loss
91        )]
92        let delay_secs = (delay_fraction * max_delay_secs as f64) as u64;
93
94        let delay = Duration::from_secs(delay_secs);
95
96        debug!(
97            "Calculated staged rollout delay: {}h {}m {}s",
98            delay.as_secs() / 3600,
99            (delay.as_secs() % 3600) / 60,
100            delay.as_secs() % 60
101        );
102
103        delay
104    }
105
106    /// Get the maximum rollout window in hours.
107    #[must_use]
108    pub fn max_delay_hours(&self) -> u64 {
109        self.max_delay_hours
110    }
111
112    /// Check if staged rollout is enabled (`max_delay_hours` > 0).
113    #[must_use]
114    pub fn is_enabled(&self) -> bool {
115        self.max_delay_hours > 0
116    }
117
118    /// Calculate the delay for a specific version upgrade.
119    ///
120    /// This includes the version in the hash to ensure different versions
121    /// get different delays for the same node (useful for critical updates
122    /// that should be spread differently).
123    #[must_use]
124    pub fn calculate_delay_for_version(&self, version: &semver::Version) -> Duration {
125        if self.max_delay_hours == 0 {
126            return Duration::ZERO;
127        }
128
129        // Include version in the hash for version-specific delays
130        let mut hasher = blake3::Hasher::new();
131        hasher.update(&self.node_id_hash);
132        hasher.update(version.to_string().as_bytes());
133        let hash_result = hasher.finalize();
134
135        let hash_value = u64::from_le_bytes([
136            hash_result.as_bytes()[0],
137            hash_result.as_bytes()[1],
138            hash_result.as_bytes()[2],
139            hash_result.as_bytes()[3],
140            hash_result.as_bytes()[4],
141            hash_result.as_bytes()[5],
142            hash_result.as_bytes()[6],
143            hash_result.as_bytes()[7],
144        ]);
145
146        let max_delay_secs = self.max_delay_hours * 3600;
147
148        #[allow(clippy::cast_precision_loss)]
149        let delay_fraction = (hash_value as f64) / (u64::MAX as f64);
150
151        #[allow(
152            clippy::cast_possible_truncation,
153            clippy::cast_sign_loss,
154            clippy::cast_precision_loss
155        )]
156        let delay_secs = (delay_fraction * max_delay_secs as f64) as u64;
157
158        Duration::from_secs(delay_secs)
159    }
160}
161
162#[cfg(test)]
163#[allow(clippy::unwrap_used, clippy::expect_used)]
164mod tests {
165    use super::*;
166
167    /// Test 1: Zero delay when disabled
168    #[test]
169    fn test_zero_delay_when_disabled() {
170        let rollout = StagedRollout::new(b"node-1", 0);
171        assert_eq!(rollout.calculate_delay(), Duration::ZERO);
172        assert!(!rollout.is_enabled());
173    }
174
175    /// Test 2: Delay within expected range
176    #[test]
177    fn test_delay_within_range() {
178        let rollout = StagedRollout::new(b"node-1", 24);
179        let delay = rollout.calculate_delay();
180
181        // Should be between 0 and 24 hours
182        assert!(delay <= Duration::from_secs(24 * 3600));
183        assert!(rollout.is_enabled());
184    }
185
186    /// Test 3: Deterministic delays (same node ID = same delay)
187    #[test]
188    fn test_deterministic_delay() {
189        let rollout1 = StagedRollout::new(b"node-1", 24);
190        let rollout2 = StagedRollout::new(b"node-1", 24);
191
192        assert_eq!(rollout1.calculate_delay(), rollout2.calculate_delay());
193    }
194
195    /// Test 4: Different nodes get different delays
196    #[test]
197    fn test_different_nodes_different_delays() {
198        let rollout1 = StagedRollout::new(b"node-1", 24);
199        let rollout2 = StagedRollout::new(b"node-2", 24);
200
201        // Different node IDs should (very likely) produce different delays
202        // There's a tiny chance they could be equal, but statistically negligible
203        assert_ne!(rollout1.calculate_delay(), rollout2.calculate_delay());
204    }
205
206    /// Test 5: Delay scales with max hours
207    #[test]
208    fn test_delay_scales_with_max_hours() {
209        let node_id = b"consistent-node";
210        let rollout_12h = StagedRollout::new(node_id, 12);
211        let rollout_24h = StagedRollout::new(node_id, 24);
212
213        // The 24h rollout should have roughly double the delay of 12h
214        // (within some tolerance since we're dealing with fractions)
215        let delay_12h = rollout_12h.calculate_delay().as_secs();
216        let delay_24h = rollout_24h.calculate_delay().as_secs();
217
218        // Check ratio is approximately 2:1 (with 10% tolerance)
219        if delay_12h > 0 {
220            #[allow(clippy::cast_precision_loss)]
221            let ratio = delay_24h as f64 / delay_12h as f64;
222            assert!(
223                (ratio - 2.0).abs() < 0.1,
224                "Ratio should be ~2.0, got {ratio}"
225            );
226        }
227    }
228
229    /// Test 6: Version-specific delays differ
230    #[test]
231    fn test_version_specific_delays() {
232        let rollout = StagedRollout::new(b"node-1", 24);
233        let v1 = semver::Version::new(1, 0, 0);
234        let v2 = semver::Version::new(2, 0, 0);
235
236        let delay_v1 = rollout.calculate_delay_for_version(&v1);
237        let delay_v2 = rollout.calculate_delay_for_version(&v2);
238
239        // Different versions should produce different delays
240        assert_ne!(delay_v1, delay_v2);
241    }
242
243    /// Test 7: Max delay hours getter
244    #[test]
245    fn test_max_delay_hours_getter() {
246        let rollout = StagedRollout::new(b"node", 48);
247        assert_eq!(rollout.max_delay_hours(), 48);
248    }
249
250    /// Test 8: Large node ID handled correctly
251    #[test]
252    fn test_large_node_id() {
253        let large_id = vec![0xABu8; 1000];
254        let rollout = StagedRollout::new(&large_id, 24);
255        let delay = rollout.calculate_delay();
256
257        assert!(delay <= Duration::from_secs(24 * 3600));
258    }
259
260    /// Test 9: Empty node ID handled
261    #[test]
262    fn test_empty_node_id() {
263        let rollout = StagedRollout::new(&[], 24);
264        let delay = rollout.calculate_delay();
265
266        // Should still produce a valid delay
267        assert!(delay <= Duration::from_secs(24 * 3600));
268    }
269
270    /// Test 10: Distribution test - ensure delays are spread across window
271    #[test]
272    fn test_delay_distribution() {
273        let max_hours = 24u64;
274        let max_secs = max_hours * 3600;
275        let mut delays = Vec::new();
276
277        // Generate 100 different node delays
278        for i in 0..100 {
279            let node_id = format!("node-{i}");
280            let rollout = StagedRollout::new(node_id.as_bytes(), max_hours);
281            delays.push(rollout.calculate_delay().as_secs());
282        }
283
284        // Calculate basic statistics
285        let min = *delays.iter().min().unwrap();
286        let max = *delays.iter().max().unwrap();
287
288        // Delays should be distributed across the window
289        // At least some should be in the first quarter and some in the last quarter
290        assert!(min < max_secs / 4, "Should have some early delays");
291        assert!(max > 3 * max_secs / 4, "Should have some late delays");
292    }
293}