Skip to main content

atomr_accel_train/
parameter_server.rs

1//! `AsyncParameterServer` — central parameter store with async
2//! gradient pushes and async weight pulls.
3//!
4//! Workers push gradients (via `PushGradient`); the server applies
5//! them with the configured optimizer and a staleness window.
6//! Workers pull the latest weights (via `PullWeights`) on their
7//! own schedule. Bounded-staleness training tolerates a few steps
8//! of drift in exchange for higher worker utilization.
9
10use std::collections::VecDeque;
11use std::time::Instant;
12
13use async_trait::async_trait;
14use atomr_core::actor::{Actor, Context, Props};
15use tokio::sync::oneshot;
16
17use atomr_accel_cuda::error::GpuError;
18
19use crate::optimizer::OptimizerKind;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub struct WorkerId(pub u32);
23
24#[derive(Debug, Clone, Copy, Default)]
25pub struct ParameterServerStats {
26    pub version: u64,
27    pub gradients_applied: u64,
28    pub weights_pulled: u64,
29    pub avg_staleness: f32,
30}
31
32pub enum ParameterServerMsg {
33    /// Worker pushes a gradient vector. `worker_version` is the
34    /// version the worker had when it computed this gradient.
35    PushGradient {
36        worker: WorkerId,
37        worker_version: u64,
38        gradient: Vec<f32>,
39        reply: oneshot::Sender<Result<u64, GpuError>>,
40    },
41    /// Worker pulls the latest weights + version.
42    PullWeights {
43        worker: WorkerId,
44        reply: oneshot::Sender<(Vec<f32>, u64)>,
45    },
46    Stats {
47        reply: oneshot::Sender<ParameterServerStats>,
48    },
49}
50
51pub struct AsyncParameterServer {
52    weights: Vec<f32>,
53    version: u64,
54    optimizer: OptimizerKind,
55    /// Maximum allowed staleness — gradients computed against
56    /// versions older than `version - max_staleness` are rejected.
57    max_staleness: u64,
58    gradients_applied: u64,
59    weights_pulled: u64,
60    /// Sliding window of (version - worker_version) for the last N
61    /// applied gradients.
62    staleness_window: VecDeque<u64>,
63    started: Instant,
64}
65
66impl AsyncParameterServer {
67    pub fn props(
68        initial_weights: Vec<f32>,
69        optimizer: OptimizerKind,
70        max_staleness: u64,
71    ) -> Props<Self> {
72        Props::create(move || AsyncParameterServer {
73            weights: initial_weights.clone(),
74            version: 0,
75            optimizer,
76            max_staleness,
77            gradients_applied: 0,
78            weights_pulled: 0,
79            staleness_window: VecDeque::with_capacity(128),
80            started: Instant::now(),
81        })
82    }
83
84    fn apply_gradient(&mut self, grad: &[f32]) {
85        let lr = self.optimizer.lr();
86        // SGD-style: w <- w - lr * grad.
87        let n = self.weights.len().min(grad.len());
88        for (w, g) in self.weights.iter_mut().zip(grad.iter()).take(n) {
89            *w -= lr * g;
90        }
91        self.version += 1;
92        self.gradients_applied += 1;
93    }
94}
95
96#[async_trait]
97impl Actor for AsyncParameterServer {
98    type Msg = ParameterServerMsg;
99
100    async fn handle(&mut self, _ctx: &mut Context<Self>, msg: ParameterServerMsg) {
101        match msg {
102            ParameterServerMsg::PushGradient {
103                worker: _,
104                worker_version,
105                gradient,
106                reply,
107            } => {
108                let staleness = self.version.saturating_sub(worker_version);
109                if staleness > self.max_staleness {
110                    let _ = reply.send(Err(GpuError::Unrecoverable(format!(
111                        "parameter server: staleness {staleness} > max {}",
112                        self.max_staleness
113                    ))));
114                    return;
115                }
116                self.apply_gradient(&gradient);
117                self.staleness_window.push_back(staleness);
118                if self.staleness_window.len() > 128 {
119                    self.staleness_window.pop_front();
120                }
121                let _ = reply.send(Ok(self.version));
122            }
123            ParameterServerMsg::PullWeights { worker: _, reply } => {
124                self.weights_pulled += 1;
125                let _ = reply.send((self.weights.clone(), self.version));
126            }
127            ParameterServerMsg::Stats { reply } => {
128                let avg_stale = if self.staleness_window.is_empty() {
129                    0.0
130                } else {
131                    let sum: u64 = self.staleness_window.iter().sum();
132                    sum as f32 / self.staleness_window.len() as f32
133                };
134                let _ = reply.send(ParameterServerStats {
135                    version: self.version,
136                    gradients_applied: self.gradients_applied,
137                    weights_pulled: self.weights_pulled,
138                    avg_staleness: avg_stale,
139                });
140            }
141        }
142        let _ = self.started; // suppress unused
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use atomr_config::Config;
150    use atomr_core::actor::ActorSystem;
151    use std::time::Duration;
152
153    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
154    async fn push_gradient_advances_version() {
155        let sys = ActorSystem::create("ps-test", Config::empty())
156            .await
157            .unwrap();
158        let ps = sys
159            .actor_of(
160                AsyncParameterServer::props(
161                    vec![10.0, 20.0],
162                    OptimizerKind::Sgd {
163                        lr: 0.1,
164                        momentum: 0.0,
165                        weight_decay: 0.0,
166                    },
167                    /* max_staleness */ 4,
168                ),
169                "ps",
170            )
171            .unwrap();
172
173        let (tx, rx) = oneshot::channel();
174        ps.tell(ParameterServerMsg::PushGradient {
175            worker: WorkerId(1),
176            worker_version: 0,
177            gradient: vec![1.0, 2.0],
178            reply: tx,
179        });
180        let v = tokio::time::timeout(Duration::from_secs(2), rx)
181            .await
182            .unwrap()
183            .unwrap()
184            .unwrap();
185        assert_eq!(v, 1);
186
187        let (tx, rx) = oneshot::channel();
188        ps.tell(ParameterServerMsg::PullWeights {
189            worker: WorkerId(1),
190            reply: tx,
191        });
192        let (w, version) = tokio::time::timeout(Duration::from_secs(2), rx)
193            .await
194            .unwrap()
195            .unwrap();
196        assert_eq!(version, 1);
197        // w[0] = 10 - 0.1 * 1 = 9.9; w[1] = 20 - 0.1 * 2 = 19.8.
198        assert!((w[0] - 9.9).abs() < 1e-5);
199        assert!((w[1] - 19.8).abs() < 1e-5);
200
201        sys.terminate().await;
202    }
203
204    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
205    async fn stale_gradient_is_rejected() {
206        let sys = ActorSystem::create("ps-stale", Config::empty())
207            .await
208            .unwrap();
209        let ps = sys
210            .actor_of(
211                AsyncParameterServer::props(
212                    vec![1.0],
213                    OptimizerKind::Sgd {
214                        lr: 0.1,
215                        momentum: 0.0,
216                        weight_decay: 0.0,
217                    },
218                    /* max_staleness */ 1,
219                ),
220                "ps",
221            )
222            .unwrap();
223
224        // Advance to version 3.
225        for _ in 0..3 {
226            let (tx, rx) = oneshot::channel();
227            ps.tell(ParameterServerMsg::PushGradient {
228                worker: WorkerId(1),
229                worker_version: 0,
230                gradient: vec![0.1],
231                reply: tx,
232            });
233            // Some pushes will be rejected once staleness exceeds 1;
234            // we ignore individual results here.
235            let _ = tokio::time::timeout(Duration::from_secs(2), rx)
236                .await
237                .unwrap()
238                .unwrap();
239        }
240        // Now push with worker_version=0 against a much-newer
241        // server; should be rejected.
242        let (tx, rx) = oneshot::channel();
243        ps.tell(ParameterServerMsg::PushGradient {
244            worker: WorkerId(1),
245            worker_version: 0,
246            gradient: vec![0.1],
247            reply: tx,
248        });
249        let r = tokio::time::timeout(Duration::from_secs(2), rx)
250            .await
251            .unwrap()
252            .unwrap();
253        assert!(matches!(r, Err(GpuError::Unrecoverable(_))));
254
255        sys.terminate().await;
256    }
257}