pylon_workers/
durable_object.rs1use std::sync::Arc;
29
30use pylon_realtime::{DynShard, ShardAuth, SnapshotSink, SubscriberId};
31
32pub fn do_websocket_sink(send: impl Fn(&[u8]) + Send + Sync + 'static) -> SnapshotSink {
42 Box::new(move |tick: u64, bytes: &[u8]| {
43 let mut payload = Vec::with_capacity(8 + bytes.len());
44 payload.extend_from_slice(&tick.to_be_bytes());
45 payload.extend_from_slice(bytes);
46 send(&payload);
47 })
48}
49
50pub trait DoStorage: Send + Sync {
57 fn get_bytes(&self, key: &str) -> Option<Vec<u8>>;
58 fn put_bytes(&self, key: &str, value: &[u8]);
59 fn delete(&self, key: &str);
60}
61
62pub fn persist_to_do_storage<T: serde::Serialize>(
66 storage: &dyn DoStorage,
67 shard_id: &str,
68 state: &T,
69 tick: u64,
70) {
71 if let Ok(bytes) = serde_json::to_vec(state) {
72 storage.put_bytes(&format!("shard:{shard_id}:state"), &bytes);
73 storage.put_bytes(&format!("shard:{shard_id}:tick"), &tick.to_be_bytes());
74 }
75}
76
77pub fn restore_from_do_storage<T: serde::de::DeserializeOwned>(
79 storage: &dyn DoStorage,
80 shard_id: &str,
81) -> Option<T> {
82 let key = format!("shard:{shard_id}:state");
83 storage
84 .get_bytes(&key)
85 .and_then(|bytes| serde_json::from_slice(&bytes).ok())
86}
87
88pub fn register_do_subscriber(
101 shard: Arc<dyn DynShard>,
102 subscriber_id: SubscriberId,
103 ws_send: impl Fn(&[u8]) + Send + Sync + 'static,
104 auth: ShardAuth,
105) -> Result<DoSubscriberHandle, String> {
106 let sink = do_websocket_sink(ws_send);
107 shard
108 .add_subscriber(subscriber_id.clone(), sink, &auth)
109 .map_err(|e| e.to_string())?;
110 Ok(DoSubscriberHandle {
111 shard,
112 subscriber_id,
113 })
114}
115
116pub struct DoSubscriberHandle {
117 shard: Arc<dyn DynShard>,
118 subscriber_id: SubscriberId,
119}
120
121impl DoSubscriberHandle {
122 pub fn close(self) {
124 self.shard.remove_subscriber(&self.subscriber_id);
125 }
126}
127
128pub const DURABLE_OBJECT_TEMPLATE_JS: &str = r#"
139// Auto-generated. One class per shard type.
140// Wires up fetch handling, WebSocket accept/hibernation, and alarm-based ticks.
141export class ShardDO {
142 constructor(state, env) {
143 this.state = state;
144 this.env = env;
145 this.sockets = new Map(); // sid -> WebSocket
146 this.tickRateHz = env.TICK_RATE_HZ || 20;
147 }
148
149 async fetch(req) {
150 const url = new URL(req.url);
151 const sid = url.searchParams.get('sid') || 'anon';
152
153 if (req.headers.get('Upgrade') === 'websocket') {
154 const pair = new WebSocketPair();
155 const [client, server] = Object.values(pair);
156 this.state.acceptWebSocket(server); // hibernation-compatible
157 this.sockets.set(sid, server);
158 if (!(await this.state.storage.get('alarm_set'))) {
159 await this.state.storage.setAlarm(Date.now() + (1000 / this.tickRateHz));
160 await this.state.storage.put('alarm_set', true);
161 }
162 return new Response(null, { status: 101, webSocket: client });
163 }
164 return new Response('not found', { status: 404 });
165 }
166
167 async webSocketMessage(ws, message) {
168 // Forward input JSON into the shard's input queue (via bound Wasm fn).
169 this.env.SHARD_IMPORT.pushInput(this.state.id.toString(), message);
170 }
171
172 async webSocketClose(ws) {
173 for (const [sid, s] of this.sockets) {
174 if (s === ws) this.sockets.delete(sid);
175 }
176 }
177
178 async alarm() {
179 // Run one tick and broadcast to all connected sockets.
180 const snapshot = this.env.SHARD_IMPORT.runTick(this.state.id.toString());
181 for (const ws of this.sockets.values()) {
182 try { ws.send(snapshot); } catch {}
183 }
184 // Reschedule.
185 await this.state.storage.setAlarm(Date.now() + (1000 / this.tickRateHz));
186 }
187}
188"#;
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use std::sync::Mutex;
194
195 struct InMemoryStorage {
196 map: Mutex<std::collections::HashMap<String, Vec<u8>>>,
197 }
198
199 impl DoStorage for InMemoryStorage {
200 fn get_bytes(&self, key: &str) -> Option<Vec<u8>> {
201 self.map.lock().unwrap().get(key).cloned()
202 }
203 fn put_bytes(&self, key: &str, value: &[u8]) {
204 self.map
205 .lock()
206 .unwrap()
207 .insert(key.to_string(), value.to_vec());
208 }
209 fn delete(&self, key: &str) {
210 self.map.lock().unwrap().remove(key);
211 }
212 }
213
214 #[test]
215 fn persist_and_restore_roundtrip() {
216 let storage = InMemoryStorage {
217 map: Mutex::new(std::collections::HashMap::new()),
218 };
219
220 #[derive(serde::Serialize, serde::Deserialize, PartialEq, Debug)]
221 struct State {
222 score: u64,
223 players: Vec<String>,
224 }
225
226 let original = State {
227 score: 42,
228 players: vec!["alice".into(), "bob".into()],
229 };
230 persist_to_do_storage(&storage, "match1", &original, 100);
231
232 let restored: State = restore_from_do_storage(&storage, "match1").unwrap();
233 assert_eq!(restored, original);
234 }
235
236 #[test]
237 fn do_websocket_sink_prepends_tick() {
238 let captured = std::sync::Arc::new(Mutex::new(Vec::<Vec<u8>>::new()));
239 let captured_clone = std::sync::Arc::clone(&captured);
240 let sink = do_websocket_sink(move |bytes| {
241 captured_clone.lock().unwrap().push(bytes.to_vec());
242 });
243
244 sink(42u64, b"hello");
245 let all = captured.lock().unwrap();
246 assert_eq!(all.len(), 1);
247 assert_eq!(&all[0][..8], &42u64.to_be_bytes());
249 assert_eq!(&all[0][8..], b"hello");
250 }
251
252 #[test]
253 fn template_js_nonempty() {
254 assert!(DURABLE_OBJECT_TEMPLATE_JS.contains("export class ShardDO"));
255 }
256}