1use std::collections::VecDeque;
11use std::time::Instant;
12
13use async_trait::async_trait;
14use atomr_core::actor::{Actor, Context, Props};
15use tokio::sync::oneshot;
16
17use atomr_accel_cuda::error::GpuError;
18
19use crate::optimizer::OptimizerKind;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub struct WorkerId(pub u32);
23
24#[derive(Debug, Clone, Copy, Default)]
25pub struct ParameterServerStats {
26 pub version: u64,
27 pub gradients_applied: u64,
28 pub weights_pulled: u64,
29 pub avg_staleness: f32,
30}
31
32pub enum ParameterServerMsg {
33 PushGradient {
36 worker: WorkerId,
37 worker_version: u64,
38 gradient: Vec<f32>,
39 reply: oneshot::Sender<Result<u64, GpuError>>,
40 },
41 PullWeights {
43 worker: WorkerId,
44 reply: oneshot::Sender<(Vec<f32>, u64)>,
45 },
46 Stats {
47 reply: oneshot::Sender<ParameterServerStats>,
48 },
49}
50
51pub struct AsyncParameterServer {
52 weights: Vec<f32>,
53 version: u64,
54 optimizer: OptimizerKind,
55 max_staleness: u64,
58 gradients_applied: u64,
59 weights_pulled: u64,
60 staleness_window: VecDeque<u64>,
63 started: Instant,
64}
65
66impl AsyncParameterServer {
67 pub fn props(
68 initial_weights: Vec<f32>,
69 optimizer: OptimizerKind,
70 max_staleness: u64,
71 ) -> Props<Self> {
72 Props::create(move || AsyncParameterServer {
73 weights: initial_weights.clone(),
74 version: 0,
75 optimizer,
76 max_staleness,
77 gradients_applied: 0,
78 weights_pulled: 0,
79 staleness_window: VecDeque::with_capacity(128),
80 started: Instant::now(),
81 })
82 }
83
84 fn apply_gradient(&mut self, grad: &[f32]) {
85 let lr = self.optimizer.lr();
86 let n = self.weights.len().min(grad.len());
88 for (w, g) in self.weights.iter_mut().zip(grad.iter()).take(n) {
89 *w -= lr * g;
90 }
91 self.version += 1;
92 self.gradients_applied += 1;
93 }
94}
95
96#[async_trait]
97impl Actor for AsyncParameterServer {
98 type Msg = ParameterServerMsg;
99
100 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: ParameterServerMsg) {
101 match msg {
102 ParameterServerMsg::PushGradient {
103 worker: _,
104 worker_version,
105 gradient,
106 reply,
107 } => {
108 let staleness = self.version.saturating_sub(worker_version);
109 if staleness > self.max_staleness {
110 let _ = reply.send(Err(GpuError::Unrecoverable(format!(
111 "parameter server: staleness {staleness} > max {}",
112 self.max_staleness
113 ))));
114 return;
115 }
116 self.apply_gradient(&gradient);
117 self.staleness_window.push_back(staleness);
118 if self.staleness_window.len() > 128 {
119 self.staleness_window.pop_front();
120 }
121 let _ = reply.send(Ok(self.version));
122 }
123 ParameterServerMsg::PullWeights { worker: _, reply } => {
124 self.weights_pulled += 1;
125 let _ = reply.send((self.weights.clone(), self.version));
126 }
127 ParameterServerMsg::Stats { reply } => {
128 let avg_stale = if self.staleness_window.is_empty() {
129 0.0
130 } else {
131 let sum: u64 = self.staleness_window.iter().sum();
132 sum as f32 / self.staleness_window.len() as f32
133 };
134 let _ = reply.send(ParameterServerStats {
135 version: self.version,
136 gradients_applied: self.gradients_applied,
137 weights_pulled: self.weights_pulled,
138 avg_staleness: avg_stale,
139 });
140 }
141 }
142 let _ = self.started; }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use atomr_config::Config;
150 use atomr_core::actor::ActorSystem;
151 use std::time::Duration;
152
153 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
154 async fn push_gradient_advances_version() {
155 let sys = ActorSystem::create("ps-test", Config::empty())
156 .await
157 .unwrap();
158 let ps = sys
159 .actor_of(
160 AsyncParameterServer::props(
161 vec![10.0, 20.0],
162 OptimizerKind::Sgd {
163 lr: 0.1,
164 momentum: 0.0,
165 weight_decay: 0.0,
166 },
167 4,
168 ),
169 "ps",
170 )
171 .unwrap();
172
173 let (tx, rx) = oneshot::channel();
174 ps.tell(ParameterServerMsg::PushGradient {
175 worker: WorkerId(1),
176 worker_version: 0,
177 gradient: vec![1.0, 2.0],
178 reply: tx,
179 });
180 let v = tokio::time::timeout(Duration::from_secs(2), rx)
181 .await
182 .unwrap()
183 .unwrap()
184 .unwrap();
185 assert_eq!(v, 1);
186
187 let (tx, rx) = oneshot::channel();
188 ps.tell(ParameterServerMsg::PullWeights {
189 worker: WorkerId(1),
190 reply: tx,
191 });
192 let (w, version) = tokio::time::timeout(Duration::from_secs(2), rx)
193 .await
194 .unwrap()
195 .unwrap();
196 assert_eq!(version, 1);
197 assert!((w[0] - 9.9).abs() < 1e-5);
199 assert!((w[1] - 19.8).abs() < 1e-5);
200
201 sys.terminate().await;
202 }
203
204 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
205 async fn stale_gradient_is_rejected() {
206 let sys = ActorSystem::create("ps-stale", Config::empty())
207 .await
208 .unwrap();
209 let ps = sys
210 .actor_of(
211 AsyncParameterServer::props(
212 vec![1.0],
213 OptimizerKind::Sgd {
214 lr: 0.1,
215 momentum: 0.0,
216 weight_decay: 0.0,
217 },
218 1,
219 ),
220 "ps",
221 )
222 .unwrap();
223
224 for _ in 0..3 {
226 let (tx, rx) = oneshot::channel();
227 ps.tell(ParameterServerMsg::PushGradient {
228 worker: WorkerId(1),
229 worker_version: 0,
230 gradient: vec![0.1],
231 reply: tx,
232 });
233 let _ = tokio::time::timeout(Duration::from_secs(2), rx)
236 .await
237 .unwrap()
238 .unwrap();
239 }
240 let (tx, rx) = oneshot::channel();
243 ps.tell(ParameterServerMsg::PushGradient {
244 worker: WorkerId(1),
245 worker_version: 0,
246 gradient: vec![0.1],
247 reply: tx,
248 });
249 let r = tokio::time::timeout(Duration::from_secs(2), rx)
250 .await
251 .unwrap()
252 .unwrap();
253 assert!(matches!(r, Err(GpuError::Unrecoverable(_))));
254
255 sys.terminate().await;
256 }
257}