Skip to main content

pylon_workers/
durable_object.rs

1//! Durable Object adapter for shards on Cloudflare Workers.
2//!
3//! The idea: one DO class per shard type. Each match/zone/document gets its
4//! own DO instance (addressed by name). The DO owns the shard's state, the
5//! tick loop (via `setAlarm`), and all WebSocket connections.
6//!
7//! Why DOs specifically: they give you exactly the isolation + persistence
8//! semantics that Workers lack otherwise — single-threaded execution per
9//! instance, on-instance storage, WebSocket hibernation (clients stay
10//! connected even while the DO is idle and not billed).
11//!
12//! # Status
13//!
14//! This is a **scaffold**, not a working DO. The real `worker` crate has
15//! a different macro story and tight coupling to wasm-bindgen. A full impl
16//! requires:
17//!
18//! - A concrete `#[durable_object]`-annotated struct (in the user's Workers
19//!   bundle, not here).
20//! - A bridge from incoming WS messages → shard input queue.
21//! - A bridge from shard snapshots → DO WebSocket send.
22//! - `state.setAlarm()` for scheduled ticks in event-driven mode.
23//! - Storage.get()/put() for persistence across hibernation.
24//!
25//! The abstractions below give users a head-start by providing the right
26//! trait shapes for their DO implementation to hook into.
27
28use std::sync::Arc;
29
30use pylon_realtime::{DynShard, ShardAuth, SnapshotSink, SubscriberId};
31
32// ---------------------------------------------------------------------------
33// WorkerDoSink — bridges a shard's broadcast to DO WebSocket sends
34// ---------------------------------------------------------------------------
35
36/// Build a [`SnapshotSink`] that forwards snapshots to a DO's WebSocket send.
37///
38/// The caller provides the send function (which is `worker::WebSocket::send`
39/// or equivalent). This lets the shard's broadcast loop push snapshots to
40/// connected clients over the DO's own WebSocket.
41pub fn do_websocket_sink(send: impl Fn(&[u8]) + Send + Sync + 'static) -> SnapshotSink {
42    Box::new(move |tick: u64, bytes: &[u8]| {
43        let mut payload = Vec::with_capacity(8 + bytes.len());
44        payload.extend_from_slice(&tick.to_be_bytes());
45        payload.extend_from_slice(bytes);
46        send(&payload);
47    })
48}
49
50// ---------------------------------------------------------------------------
51// Persistence hooks for DO storage
52// ---------------------------------------------------------------------------
53
54/// Abstraction over DO storage.put()/get() so the scaffold can express the
55/// persistence pattern without depending on the `worker` crate.
56pub trait DoStorage: Send + Sync {
57    fn get_bytes(&self, key: &str) -> Option<Vec<u8>>;
58    fn put_bytes(&self, key: &str, value: &[u8]);
59    fn delete(&self, key: &str);
60}
61
62/// Save a shard's serialized state to DO storage.
63///
64/// Called from the shard's `on_tick` hook via `persist_every_ticks`.
65pub fn persist_to_do_storage<T: serde::Serialize>(
66    storage: &dyn DoStorage,
67    shard_id: &str,
68    state: &T,
69    tick: u64,
70) {
71    if let Ok(bytes) = serde_json::to_vec(state) {
72        storage.put_bytes(&format!("shard:{shard_id}:state"), &bytes);
73        storage.put_bytes(&format!("shard:{shard_id}:tick"), &tick.to_be_bytes());
74    }
75}
76
77/// Restore a shard's serialized state from DO storage.
78pub fn restore_from_do_storage<T: serde::de::DeserializeOwned>(
79    storage: &dyn DoStorage,
80    shard_id: &str,
81) -> Option<T> {
82    let key = format!("shard:{shard_id}:state");
83    storage
84        .get_bytes(&key)
85        .and_then(|bytes| serde_json::from_slice(&bytes).ok())
86}
87
88// ---------------------------------------------------------------------------
89// Subscriber entry — helper for DO fetch handlers
90// ---------------------------------------------------------------------------
91
92/// Registers a WebSocket-connected player with a shard running inside a DO.
93///
94/// The caller's DO fetch handler:
95/// 1. Accepts the WebSocket upgrade.
96/// 2. Gets a reference to the shard (held in DO-scoped state).
97/// 3. Calls this function with the WS send closure and the resolved auth.
98///
99/// Returns a close handler — call it in the DO's webSocketClose event.
100pub fn register_do_subscriber(
101    shard: Arc<dyn DynShard>,
102    subscriber_id: SubscriberId,
103    ws_send: impl Fn(&[u8]) + Send + Sync + 'static,
104    auth: ShardAuth,
105) -> Result<DoSubscriberHandle, String> {
106    let sink = do_websocket_sink(ws_send);
107    shard
108        .add_subscriber(subscriber_id.clone(), sink, &auth)
109        .map_err(|e| e.to_string())?;
110    Ok(DoSubscriberHandle {
111        shard,
112        subscriber_id,
113    })
114}
115
116pub struct DoSubscriberHandle {
117    shard: Arc<dyn DynShard>,
118    subscriber_id: SubscriberId,
119}
120
121impl DoSubscriberHandle {
122    /// Call on DO webSocketClose / webSocketError events.
123    pub fn close(self) {
124        self.shard.remove_subscriber(&self.subscriber_id);
125    }
126}
127
128// ---------------------------------------------------------------------------
129// Template: the JavaScript side of a Durable Object
130// ---------------------------------------------------------------------------
131
132/// The boilerplate a user adds to their Workers bundle's JS entry file.
133///
134/// This can't be generated from Rust alone (the DO class must be exported
135/// from JS so the Workers runtime can instantiate it), so we ship this as
136/// a string constant that the `pylon deploy --target workers` command
137/// can drop into the generated bundle.
138pub const DURABLE_OBJECT_TEMPLATE_JS: &str = r#"
139// Auto-generated. One class per shard type.
140// Wires up fetch handling, WebSocket accept/hibernation, and alarm-based ticks.
141export class ShardDO {
142  constructor(state, env) {
143    this.state = state;
144    this.env = env;
145    this.sockets = new Map(); // sid -> WebSocket
146    this.tickRateHz = env.TICK_RATE_HZ || 20;
147  }
148
149  async fetch(req) {
150    const url = new URL(req.url);
151    const sid = url.searchParams.get('sid') || 'anon';
152
153    if (req.headers.get('Upgrade') === 'websocket') {
154      const pair = new WebSocketPair();
155      const [client, server] = Object.values(pair);
156      this.state.acceptWebSocket(server); // hibernation-compatible
157      this.sockets.set(sid, server);
158      if (!(await this.state.storage.get('alarm_set'))) {
159        await this.state.storage.setAlarm(Date.now() + (1000 / this.tickRateHz));
160        await this.state.storage.put('alarm_set', true);
161      }
162      return new Response(null, { status: 101, webSocket: client });
163    }
164    return new Response('not found', { status: 404 });
165  }
166
167  async webSocketMessage(ws, message) {
168    // Forward input JSON into the shard's input queue (via bound Wasm fn).
169    this.env.SHARD_IMPORT.pushInput(this.state.id.toString(), message);
170  }
171
172  async webSocketClose(ws) {
173    for (const [sid, s] of this.sockets) {
174      if (s === ws) this.sockets.delete(sid);
175    }
176  }
177
178  async alarm() {
179    // Run one tick and broadcast to all connected sockets.
180    const snapshot = this.env.SHARD_IMPORT.runTick(this.state.id.toString());
181    for (const ws of this.sockets.values()) {
182      try { ws.send(snapshot); } catch {}
183    }
184    // Reschedule.
185    await this.state.storage.setAlarm(Date.now() + (1000 / this.tickRateHz));
186  }
187}
188"#;
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use std::sync::Mutex;
194
195    struct InMemoryStorage {
196        map: Mutex<std::collections::HashMap<String, Vec<u8>>>,
197    }
198
199    impl DoStorage for InMemoryStorage {
200        fn get_bytes(&self, key: &str) -> Option<Vec<u8>> {
201            self.map.lock().unwrap().get(key).cloned()
202        }
203        fn put_bytes(&self, key: &str, value: &[u8]) {
204            self.map
205                .lock()
206                .unwrap()
207                .insert(key.to_string(), value.to_vec());
208        }
209        fn delete(&self, key: &str) {
210            self.map.lock().unwrap().remove(key);
211        }
212    }
213
214    #[test]
215    fn persist_and_restore_roundtrip() {
216        let storage = InMemoryStorage {
217            map: Mutex::new(std::collections::HashMap::new()),
218        };
219
220        #[derive(serde::Serialize, serde::Deserialize, PartialEq, Debug)]
221        struct State {
222            score: u64,
223            players: Vec<String>,
224        }
225
226        let original = State {
227            score: 42,
228            players: vec!["alice".into(), "bob".into()],
229        };
230        persist_to_do_storage(&storage, "match1", &original, 100);
231
232        let restored: State = restore_from_do_storage(&storage, "match1").unwrap();
233        assert_eq!(restored, original);
234    }
235
236    #[test]
237    fn do_websocket_sink_prepends_tick() {
238        let captured = std::sync::Arc::new(Mutex::new(Vec::<Vec<u8>>::new()));
239        let captured_clone = std::sync::Arc::clone(&captured);
240        let sink = do_websocket_sink(move |bytes| {
241            captured_clone.lock().unwrap().push(bytes.to_vec());
242        });
243
244        sink(42u64, b"hello");
245        let all = captured.lock().unwrap();
246        assert_eq!(all.len(), 1);
247        // First 8 bytes: big-endian u64 tick number.
248        assert_eq!(&all[0][..8], &42u64.to_be_bytes());
249        assert_eq!(&all[0][8..], b"hello");
250    }
251
252    #[test]
253    fn template_js_nonempty() {
254        assert!(DURABLE_OBJECT_TEMPLATE_JS.contains("export class ShardDO"));
255    }
256}