Skip to main content

whisker_dev_server/
server.rs

1//! WebSocket dev server.
2//!
3//! `whisker run` opens a TCP listener, exposes a single
4//! `GET /whisker-dev` route that upgrades to WebSocket, and pushes
5//! patch messages to every connected client. The on-device
6//! `whisker-dev-runtime` is the canonical client.
7//!
8//! ## Wire format
9//!
10//! Two frame types:
11//!
12//! **Patches** — *binary* frames laid out as:
13//!
14//! ```text
15//! [8 bytes: u64 BE — JSON header length]
16//! [N bytes:        JSON header { "kind": "patch", "table": {...} } ]
17//! [rest:           raw patch dylib bytes (no encoding) ]
18//! ```
19//!
20//! No base64. The dylib lands on the device with the original byte
21//! count, ~30 % smaller on the wire than the previous JSON-with-
22//! base64-string protocol.
23//!
24//! **Hello** — *text* frame, `{"kind":"hello","aslr_reference":<u64>}`.
25//! The device sends this on connect; the server stores the value
26//! and the patcher uses it to compute the ASLR slide.
27//!
28//! ## Architecture
29//!
30//! A single `tokio::sync::broadcast` channel: every connected socket
31//! has its own subscriber receiver, so one `PatchSender::send` reaches
32//! all clients. New connections see only payloads sent *after* they
33//! subscribe — the receiver is at the tail end of the broadcast
34//! buffer, not replayed.
35
36use anyhow::Result;
37use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
38use axum::extract::State;
39use axum::response::Response;
40use axum::routing::get;
41use axum::Router;
42use std::net::SocketAddr;
43use std::sync::{Arc, Mutex};
44use tokio::sync::broadcast;
45
46use crate::Event;
47
48/// Cheap-to-clone broadcast payload. The dylib bytes are held by
49/// `Arc` so cloning into each subscribed client's receive queue is
50/// just a refcount bump.
51#[derive(Debug, Clone)]
52pub struct Patch {
53    /// The address-map metadata. Serialized as JSON in the binary
54    /// frame's prefix.
55    pub table: subsecond_types::JumpTable,
56    /// Raw patch dylib bytes. Streamed verbatim after the JSON
57    /// prefix; the device writes them to disk and `dlopen`s the
58    /// resulting file.
59    pub dylib_bytes: Arc<Vec<u8>>,
60}
61
62/// JSON header that prefixes the binary patch frame. Mirrors the
63/// shape `whisker-dev-runtime::hot_reload::Header` deserialises.
64///
65/// `table.map` is serialised as a JSON array of `[old, new]` pairs
66/// rather than a JSON object. JSON objects can only have string
67/// keys, so the default `HashMap<u64, u64>` derive would produce
68/// `{ "1234": 5678 }` — and the matching deserialize side, given a
69/// custom hasher like `subsecond_types::BuildAddressHasher`, fails
70/// to convert the string back to `u64`. The pair-array form
71/// sidesteps both.
72#[derive(Debug, Clone, serde::Serialize)]
73#[serde(tag = "kind", rename_all = "snake_case")]
74enum PatchHeader<'a> {
75    Patch {
76        #[serde(serialize_with = "wire_jump_table::serialize")]
77        table: &'a subsecond_types::JumpTable,
78    },
79}
80
81/// Shared serde adapter used by `whisker-dev-runtime::hot_reload` too —
82/// both sides must agree on the JSON shape. Kept inline (not a
83/// shared crate) because the type is tiny and the duplication
84/// burden is one ~30-line module.
85pub mod wire_jump_table {
86    use serde::ser::SerializeStruct;
87    use serde::Serializer;
88    use subsecond_types::JumpTable;
89
90    pub fn serialize<S: Serializer>(t: &JumpTable, s: S) -> Result<S::Ok, S::Error> {
91        let pairs: Vec<(u64, u64)> = t.map.iter().map(|(k, v)| (*k, *v)).collect();
92        let mut st = s.serialize_struct("JumpTable", 5)?;
93        st.serialize_field("lib", &t.lib)?;
94        st.serialize_field("map", &pairs)?;
95        st.serialize_field("aslr_reference", &t.aslr_reference)?;
96        st.serialize_field("new_base_address", &t.new_base_address)?;
97        st.serialize_field("ifunc_count", &t.ifunc_count)?;
98        st.end()
99    }
100}
101
102/// Cheap-to-clone handle for sending patches from the rest of the
103/// dev server (file watcher / builder / etc.) to every connected
104/// client.
105#[derive(Clone)]
106pub struct PatchSender {
107    tx: broadcast::Sender<Patch>,
108    /// Latest `aslr_reference` reported by a connected client via the
109    /// `hello` handshake. Single-slot, last-write-wins: we don't yet
110    /// support targeted-per-client patches, so all connected clients
111    /// must share an ASLR base. For typical single-emulator dev
112    /// sessions that's fine; for multi-device this becomes the
113    /// natural boundary where patches start being per-client.
114    aslr_reference: Arc<Mutex<Option<u64>>>,
115}
116
117impl PatchSender {
118    /// Broadcast `patch` to every currently-connected client.
119    /// Returns the number of clients the message was queued for —
120    /// `0` is fine (no client connected yet) and not an error.
121    pub fn send(&self, patch: Patch) -> usize {
122        self.tx.send(patch).unwrap_or(0)
123    }
124
125    /// Number of clients currently subscribed.
126    pub fn client_count(&self) -> usize {
127        self.tx.receiver_count()
128    }
129
130    /// The runtime address of `main` (= `subsecond::aslr_reference()`)
131    /// most recently reported by a connected client. `None` when no
132    /// client has connected or sent its `hello` yet — the patcher
133    /// should withhold Tier 1 patches in that case (fall back to
134    /// Tier 2 cold rebuild).
135    pub fn latest_aslr_reference(&self) -> Option<u64> {
136        self.aslr_reference.lock().ok().and_then(|g| *g)
137    }
138}
139
140#[derive(Clone)]
141struct AppState {
142    tx: broadcast::Sender<Patch>,
143    on_event: Option<Arc<dyn Fn(Event) + Send + Sync>>,
144    aslr_reference: Arc<Mutex<Option<u64>>>,
145}
146
147/// Bind on `addr`, spawn the axum server on the current tokio
148/// runtime, and return:
149///   - a [`PatchSender`] for the rest of the dev loop to push patches
150///   - the actual bound address (useful when caller asked for port 0)
151///   - the spawned server task's `JoinHandle`
152///
153/// `on_event` is an optional observer hook — `whisker-cli` uses it to
154/// render terminal UI on connect/disconnect events.
155pub async fn serve(
156    addr: SocketAddr,
157    on_event: Option<Arc<dyn Fn(Event) + Send + Sync>>,
158) -> Result<(PatchSender, SocketAddr, tokio::task::JoinHandle<()>)> {
159    let (tx, _rx) = broadcast::channel::<Patch>(16);
160    let aslr_reference: Arc<Mutex<Option<u64>>> = Arc::new(Mutex::new(None));
161    let state = AppState {
162        tx: tx.clone(),
163        on_event,
164        aslr_reference: Arc::clone(&aslr_reference),
165    };
166
167    let app = Router::new()
168        .route("/whisker-dev", get(ws_handler))
169        .with_state(state);
170
171    let listener = tokio::net::TcpListener::bind(addr).await?;
172    let bound = listener.local_addr()?;
173
174    let handle = tokio::spawn(async move {
175        if let Err(e) = axum::serve(listener, app).await {
176            whisker_build::ui::error(format!("axum serve error: {e}"));
177        }
178    });
179
180    Ok((PatchSender { tx, aslr_reference }, bound, handle))
181}
182
183async fn ws_handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> Response {
184    ws.on_upgrade(move |socket| handle_socket(socket, state))
185}
186
187async fn handle_socket(socket: WebSocket, state: AppState) {
188    use futures_util::{SinkExt, StreamExt};
189
190    let (mut tx_ws, mut rx_ws) = socket.split();
191    let mut bcast_rx = state.tx.subscribe();
192    whisker_build::ui::set_status(format!("{} client(s) connected", state.tx.receiver_count(),));
193    // `aslr_reference` is internal handshake plumbing; emit at debug
194    // grade so the steady-state UI stays clean.
195    if let Some(cb) = &state.on_event {
196        cb(Event::ClientConnected);
197    }
198
199    loop {
200        tokio::select! {
201            // server → client: forward broadcast patches as binary frames.
202            recv = bcast_rx.recv() => {
203                let patch = match recv {
204                    Ok(p) => p,
205                    Err(broadcast::error::RecvError::Lagged(_)) => continue,
206                    Err(broadcast::error::RecvError::Closed) => break,
207                };
208                let frame = match encode_patch_frame(&patch) {
209                    Ok(b) => b,
210                    Err(e) => {
211                        whisker_build::ui::warn(format!("encode patch frame: {e}"));
212                        continue;
213                    }
214                };
215                if tx_ws.send(Message::Binary(frame.into())).await.is_err() {
216                    break;
217                }
218            }
219            // client → server: drain incoming so Pings/Pongs are honoured;
220            // close on Close frame or transport error. Text frames are
221            // parsed for `hello` envelopes carrying the client's
222            // `aslr_reference`.
223            msg = rx_ws.next() => {
224                match msg {
225                    Some(Ok(Message::Close(_))) | None => break,
226                    Some(Err(_)) => break,
227                    Some(Ok(Message::Text(t))) => {
228                        if let Some(aslr) = parse_client_aslr_reference(&t) {
229                            whisker_build::ui::debug(format!(
230                                "client hello · aslr_reference={aslr:#x}"
231                            ));
232                            if let Ok(mut g) = state.aslr_reference.lock() {
233                                *g = Some(aslr);
234                            }
235                        } else if let Some(log) = parse_client_log(&t) {
236                            if let Some(cb) = &state.on_event {
237                                cb(Event::DeviceLog {
238                                    stream: log.stream,
239                                    line: log.line,
240                                    ts_micros: log.ts_micros,
241                                });
242                            }
243                        }
244                    }
245                    _ => {}
246                }
247            }
248        }
249    }
250
251    if let Some(cb) = &state.on_event {
252        cb(Event::ClientDisconnected);
253    }
254}
255
256/// Build the on-the-wire binary frame:
257///   `[u64 BE json_len][json header][raw dylib bytes]`
258fn encode_patch_frame(patch: &Patch) -> Result<Vec<u8>> {
259    let header = PatchHeader::Patch {
260        table: &patch.table,
261    };
262    let json = serde_json::to_vec(&header)?;
263    let json_len = json.len() as u64;
264    let dylib = patch.dylib_bytes.as_slice();
265    let mut frame = Vec::with_capacity(8 + json.len() + dylib.len());
266    frame.extend_from_slice(&json_len.to_be_bytes());
267    frame.extend_from_slice(&json);
268    frame.extend_from_slice(dylib);
269    Ok(frame)
270}
271
272/// Pull the `aslr_reference` field out of a client hello envelope.
273/// Returns `None` for non-hello text frames (or malformed payloads)
274/// — the only thing we actively listen for client→server today is the
275/// initial handshake.
276fn parse_client_aslr_reference(text: &str) -> Option<u64> {
277    #[derive(serde::Deserialize)]
278    struct Hello {
279        kind: String,
280        aslr_reference: u64,
281    }
282    let h: Hello = serde_json::from_str(text).ok()?;
283    if h.kind == "hello" {
284        Some(h.aslr_reference)
285    } else {
286        None
287    }
288}
289
290/// Decoded payload of a `{"kind":"log",…}` text frame emitted by the
291/// device-side `log_capture` module.
292struct ClientLog {
293    stream: String,
294    line: String,
295    ts_micros: u128,
296}
297
298/// Parse a client log envelope. Returns `None` for any other text
299/// frame so the caller can fall through to other handlers (the hello
300/// envelope is the only other text frame today).
301///
302/// `ts_micros` arrives as a string on the wire because `u128` doesn't
303/// round-trip through JSON's number type cleanly (>2^53 is lossy in
304/// most decoders). The device serializes via `to_string`; we decode
305/// with `parse`, defaulting to `0` on parse failure rather than
306/// rejecting the whole frame — the line itself is more valuable than
307/// a precise timestamp.
308fn parse_client_log(text: &str) -> Option<ClientLog> {
309    #[derive(serde::Deserialize)]
310    struct Log {
311        kind: String,
312        stream: String,
313        line: String,
314        #[serde(default)]
315        ts_micros: Option<String>,
316    }
317    let h: Log = serde_json::from_str(text).ok()?;
318    if h.kind != "log" {
319        return None;
320    }
321    let ts_micros = h
322        .ts_micros
323        .as_deref()
324        .and_then(|s| s.parse::<u128>().ok())
325        .unwrap_or(0);
326    Some(ClientLog {
327        stream: h.stream,
328        line: h.line,
329        ts_micros,
330    })
331}
332
333// ============================================================================
334// Tests
335// ============================================================================
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340    use futures_util::{SinkExt, StreamExt};
341    use std::sync::atomic::{AtomicUsize, Ordering};
342
343    fn make_dummy_jump_table() -> subsecond_types::JumpTable {
344        // Construct via JSON to avoid pinning ourselves to private
345        // field shapes. All fields are public + plain types so this
346        // is stable.
347        let json = r#"{
348            "lib": "/tmp/dummy.dylib",
349            "map": {},
350            "aslr_reference": 4294967296,
351            "new_base_address": 8589934592,
352            "ifunc_count": 0
353        }"#;
354        serde_json::from_str(json).expect("dummy JumpTable")
355    }
356
357    /// Spawn the server on an ephemeral port and return its address +
358    /// sender so tests don't have to worry about port collisions.
359    async fn spawn_test_server(
360        on_event: Option<Arc<dyn Fn(Event) + Send + Sync>>,
361    ) -> (PatchSender, SocketAddr) {
362        let any: SocketAddr = "127.0.0.1:0".parse().unwrap();
363        let (sender, addr, _handle) = serve(any, on_event).await.expect("serve");
364        (sender, addr)
365    }
366
367    async fn connect(
368        addr: SocketAddr,
369    ) -> tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>
370    {
371        let url = format!("ws://{addr}/whisker-dev");
372        let (ws, _) = tokio_tungstenite::connect_async(&url)
373            .await
374            .expect("connect");
375        ws
376    }
377
378    /// Decode a wire frame back into (header JSON value, dylib bytes)
379    /// so tests can assert against both halves.
380    fn decode_patch_frame(bytes: &[u8]) -> (serde_json::Value, Vec<u8>) {
381        assert!(bytes.len() >= 8, "frame too short");
382        let json_len = u64::from_be_bytes(bytes[..8].try_into().unwrap()) as usize;
383        assert!(bytes.len() >= 8 + json_len, "frame truncated");
384        let header: serde_json::Value =
385            serde_json::from_slice(&bytes[8..8 + json_len]).expect("parse header");
386        let dylib = bytes[8 + json_len..].to_vec();
387        (header, dylib)
388    }
389
390    #[tokio::test]
391    async fn client_can_connect_and_receive_a_broadcast_patch() {
392        let (sender, addr) = spawn_test_server(None).await;
393        let mut client = connect(addr).await;
394
395        // Wait for the server to register the subscription before we
396        // send. Polling client_count is the cheap, deterministic way.
397        for _ in 0..100 {
398            if sender.client_count() > 0 {
399                break;
400            }
401            tokio::time::sleep(std::time::Duration::from_millis(5)).await;
402        }
403        assert_eq!(sender.client_count(), 1);
404
405        let table = make_dummy_jump_table();
406        let n = sender.send(Patch {
407            table: table.clone(),
408            dylib_bytes: Arc::new(b"FAKE_DYLIB_BYTES".to_vec()),
409        });
410        assert_eq!(n, 1);
411
412        let msg = tokio::time::timeout(std::time::Duration::from_secs(2), client.next())
413            .await
414            .expect("recv timed out")
415            .expect("stream ended")
416            .expect("ws error");
417        let bytes = match msg {
418            tokio_tungstenite::tungstenite::Message::Binary(b) => b,
419            other => panic!("expected binary, got {other:?}"),
420        };
421        let (header, dylib) = decode_patch_frame(&bytes);
422        assert_eq!(header["kind"], "patch");
423        assert_eq!(header["table"]["lib"], "/tmp/dummy.dylib");
424        assert_eq!(header["table"]["aslr_reference"], 4294967296_u64);
425        assert_eq!(dylib, b"FAKE_DYLIB_BYTES");
426    }
427
428    #[tokio::test]
429    async fn send_with_no_clients_returns_zero_and_does_not_error() {
430        let (sender, _addr) = spawn_test_server(None).await;
431        assert_eq!(sender.client_count(), 0);
432        let n = sender.send(Patch {
433            table: make_dummy_jump_table(),
434            dylib_bytes: Arc::new(Vec::new()),
435        });
436        assert_eq!(n, 0);
437    }
438
439    #[tokio::test]
440    async fn multiple_clients_each_receive_the_same_patch() {
441        let (sender, addr) = spawn_test_server(None).await;
442        let mut a = connect(addr).await;
443        let mut b = connect(addr).await;
444
445        for _ in 0..100 {
446            if sender.client_count() == 2 {
447                break;
448            }
449            tokio::time::sleep(std::time::Duration::from_millis(5)).await;
450        }
451        assert_eq!(sender.client_count(), 2);
452
453        let n = sender.send(Patch {
454            table: make_dummy_jump_table(),
455            dylib_bytes: Arc::new(b"SHARED".to_vec()),
456        });
457        assert_eq!(n, 2);
458
459        for client in [&mut a, &mut b] {
460            let msg = tokio::time::timeout(std::time::Duration::from_secs(2), client.next())
461                .await
462                .expect("timeout")
463                .expect("stream end")
464                .expect("ws err");
465            assert!(matches!(
466                msg,
467                tokio_tungstenite::tungstenite::Message::Binary(_)
468            ));
469        }
470    }
471
472    #[tokio::test]
473    async fn on_event_callback_fires_for_connect_and_disconnect() {
474        let connect_count = Arc::new(AtomicUsize::new(0));
475        let disconnect_count = Arc::new(AtomicUsize::new(0));
476
477        let cc = connect_count.clone();
478        let dc = disconnect_count.clone();
479        let on_event: Arc<dyn Fn(Event) + Send + Sync> = Arc::new(move |e| match e {
480            Event::ClientConnected => {
481                cc.fetch_add(1, Ordering::SeqCst);
482            }
483            Event::ClientDisconnected => {
484                dc.fetch_add(1, Ordering::SeqCst);
485            }
486            _ => {}
487        });
488
489        let (sender, addr) = spawn_test_server(Some(on_event)).await;
490
491        let mut client = connect(addr).await;
492        // Wait for connect callback.
493        for _ in 0..100 {
494            if connect_count.load(Ordering::SeqCst) == 1 {
495                break;
496            }
497            tokio::time::sleep(std::time::Duration::from_millis(5)).await;
498        }
499        assert_eq!(connect_count.load(Ordering::SeqCst), 1);
500
501        // Close the client.
502        client
503            .send(tokio_tungstenite::tungstenite::Message::Close(None))
504            .await
505            .expect("send close");
506        drop(client);
507
508        // Wait for disconnect callback.
509        for _ in 0..200 {
510            if disconnect_count.load(Ordering::SeqCst) == 1 {
511                break;
512            }
513            tokio::time::sleep(std::time::Duration::from_millis(5)).await;
514        }
515        assert_eq!(disconnect_count.load(Ordering::SeqCst), 1);
516
517        // sender stays usable across the whole flow.
518        assert_eq!(sender.client_count(), 0);
519    }
520
521    #[test]
522    fn parse_client_log_decodes_a_well_formed_frame() {
523        let log = parse_client_log(
524            r#"{"kind":"log","stream":"stdout","line":"hello world","ts_micros":"12345"}"#,
525        )
526        .expect("valid log envelope");
527        assert_eq!(log.stream, "stdout");
528        assert_eq!(log.line, "hello world");
529        assert_eq!(log.ts_micros, 12345);
530    }
531
532    #[test]
533    fn parse_client_log_falls_back_to_zero_ts_when_missing() {
534        let log =
535            parse_client_log(r#"{"kind":"log","stream":"stderr","line":"oops"}"#).expect("valid");
536        assert_eq!(log.stream, "stderr");
537        assert_eq!(log.line, "oops");
538        assert_eq!(log.ts_micros, 0);
539    }
540
541    #[test]
542    fn parse_client_log_rejects_other_kinds() {
543        assert!(parse_client_log(r#"{"kind":"hello","aslr_reference":42}"#,).is_none());
544    }
545
546    #[tokio::test]
547    async fn on_event_callback_fires_with_device_log_lines() {
548        use std::sync::Mutex;
549        let captured: Arc<Mutex<Vec<(String, String, u128)>>> = Arc::new(Mutex::new(Vec::new()));
550        let captured_clone = Arc::clone(&captured);
551        let on_event: Arc<dyn Fn(Event) + Send + Sync> = Arc::new(move |e| {
552            if let Event::DeviceLog {
553                stream,
554                line,
555                ts_micros,
556            } = e
557            {
558                captured_clone
559                    .lock()
560                    .unwrap()
561                    .push((stream, line, ts_micros));
562            }
563        });
564
565        let (sender, addr) = spawn_test_server(Some(on_event)).await;
566        let mut client = connect(addr).await;
567        for _ in 0..100 {
568            if sender.client_count() > 0 {
569                break;
570            }
571            tokio::time::sleep(std::time::Duration::from_millis(5)).await;
572        }
573        assert_eq!(sender.client_count(), 1);
574
575        client
576            .send(tokio_tungstenite::tungstenite::Message::Text(
577                r#"{"kind":"log","stream":"stdout","line":"hi from device","ts_micros":"42"}"#
578                    .into(),
579            ))
580            .await
581            .expect("send log frame");
582
583        // Wait for the server to dispatch the callback.
584        for _ in 0..100 {
585            if !captured.lock().unwrap().is_empty() {
586                break;
587            }
588            tokio::time::sleep(std::time::Duration::from_millis(5)).await;
589        }
590        let g = captured.lock().unwrap();
591        assert_eq!(g.len(), 1);
592        assert_eq!(g[0].0, "stdout");
593        assert_eq!(g[0].1, "hi from device");
594        assert_eq!(g[0].2, 42);
595    }
596}