Skip to main content

host_extensions/first_party/
crdt.rs

1//! CRDT extension — collaborative document sync over Statement Store.
2//!
3//! Registers as `window.host.ext.crdt` in the SPA WebView.
4//! Uses the core `__hostCall` bridge (shared pending-promise map).
5
6use std::sync::Arc;
7
8use serde::{Deserialize, Serialize};
9
10use crate::HostExtension;
11
12use super::crdt_runtime::{CrdtRuntime, CrdtRuntimeError, InMemoryCrdtRuntime};
13
14/// Maximum allowed length for a room ID.
15const MAX_ROOM_ID_LEN: usize = 256;
16
17const HOST_EXT_CRDT_SCRIPT: &str = r#"
18(function(){
19    if (!window.host || !window.host.ext) return;
20    window.host.ext.crdt = Object.freeze({
21        join: function(roomId, opts) {
22            return window.__hostCall('hostBridge', 'crdtJoin', {
23                roomId: roomId,
24                transport: (opts && opts.transport) || 'relay'
25            });
26        },
27        applyUpdate: function(roomId, dataBase64) {
28            return window.__hostCall('hostBridge', 'crdtApplyUpdate', {
29                roomId: roomId,
30                dataBase64: dataBase64
31            });
32        },
33        getStateVector: function(roomId) {
34            return window.__hostCall('hostBridge', 'crdtGetStateVector', {
35                roomId: roomId
36            });
37        },
38        getFullState: function(roomId) {
39            return window.__hostCall('hostBridge', 'crdtGetFullState', {
40                roomId: roomId
41            });
42        },
43        setAwareness: function(roomId, state) {
44            return window.__hostCall('hostBridge', 'crdtSetAwareness', {
45                roomId: roomId,
46                state: JSON.stringify(state)
47            });
48        },
49        destroy: function(roomId) {
50            return window.__hostCall('hostBridge', 'crdtDestroy', {
51                roomId: roomId
52            });
53        }
54    });
55})();
56"#;
57
58pub struct CrdtExtension {
59    runtime: Arc<dyn CrdtRuntime>,
60}
61
62impl Default for CrdtExtension {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68impl CrdtExtension {
69    pub fn new() -> Self {
70        Self::with_runtime(InMemoryCrdtRuntime::new())
71    }
72
73    pub fn with_runtime<R>(runtime: R) -> Self
74    where
75        R: CrdtRuntime,
76    {
77        Self::with_shared_runtime(Arc::new(runtime))
78    }
79
80    pub fn with_shared_runtime(runtime: Arc<dyn CrdtRuntime>) -> Self {
81        Self { runtime }
82    }
83}
84
85/// Validate a room ID. Returns `None` if valid, or the bridge response to
86/// return early (either `Some(error_json)` for empty, or `None` for overlength).
87fn check_room_id(room_id: &str) -> Option<Option<String>> {
88    if room_id.is_empty() {
89        return Some(error_json(&CrdtRuntimeError::RoomIdRequired));
90    }
91    if room_id.len() > MAX_ROOM_ID_LEN {
92        return Some(None);
93    }
94    None
95}
96
97impl HostExtension for CrdtExtension {
98    fn namespace(&self) -> &str {
99        "crdt"
100    }
101
102    fn channel(&self) -> &str {
103        "hostBridge"
104    }
105
106    fn inject_script(&self) -> &str {
107        HOST_EXT_CRDT_SCRIPT
108    }
109
110    fn handle_message(&self, method: &str, params: &str) -> Option<String> {
111        match method {
112            "crdtJoin" => {
113                let req: JoinPayload = serde_json::from_str(params).ok()?;
114                if let Some(early) = check_room_id(&req.room_id) {
115                    return early;
116                }
117                match self.runtime.join(&req.room_id, &req.transport) {
118                    Ok(result) => serde_json::to_string(&result).ok(),
119                    Err(e) => error_json(&e),
120                }
121            }
122            "crdtApplyUpdate" => {
123                let req: UpdatePayload = serde_json::from_str(params).ok()?;
124                if let Some(early) = check_room_id(&req.room_id) {
125                    return early;
126                }
127                match self.runtime.apply_update(&req.room_id, &req.data_base64) {
128                    Ok(v) => serde_json::to_string(&v).ok(),
129                    Err(e) => error_json(&e),
130                }
131            }
132            "crdtGetStateVector" => {
133                let req: RoomIdPayload = serde_json::from_str(params).ok()?;
134                if let Some(early) = check_room_id(&req.room_id) {
135                    return early;
136                }
137                match self.runtime.get_state_vector(&req.room_id) {
138                    Ok(v) => serde_json::to_string(&v).ok(),
139                    Err(e) => error_json(&e),
140                }
141            }
142            "crdtGetFullState" => {
143                let req: RoomIdPayload = serde_json::from_str(params).ok()?;
144                if let Some(early) = check_room_id(&req.room_id) {
145                    return early;
146                }
147                match self.runtime.get_full_state(&req.room_id) {
148                    Ok(v) => serde_json::to_string(&v).ok(),
149                    Err(e) => error_json(&e),
150                }
151            }
152            "crdtSetAwareness" => {
153                let req: AwarenessPayload = serde_json::from_str(params).ok()?;
154                if let Some(early) = check_room_id(&req.room_id) {
155                    return early;
156                }
157                match self.runtime.set_awareness(&req.room_id, &req.state) {
158                    Ok(v) => serde_json::to_string(&v).ok(),
159                    Err(e) => error_json(&e),
160                }
161            }
162            "crdtDestroy" => {
163                let req: RoomIdPayload = serde_json::from_str(params).ok()?;
164                if let Some(early) = check_room_id(&req.room_id) {
165                    return early;
166                }
167                match self.runtime.destroy(&req.room_id) {
168                    Ok(v) => serde_json::to_string(&v).ok(),
169                    Err(e) => error_json(&e),
170                }
171            }
172            _ => None,
173        }
174    }
175
176    fn drain_events(&self) -> Vec<crate::HostPushEvent> {
177        self.runtime.drain_events()
178    }
179}
180
181/// JSON-safe error response for CRDT bridge methods.
182///
183/// The host bridge always resolves the JS promise (never rejects for extension
184/// calls), so errors are returned as `{"error":"..."}` in the resolved value.
185/// The guest SDK intercepts this shape and throws a `HostError`.
186#[derive(Debug, Serialize)]
187struct CrdtErrorEnvelope {
188    error: String,
189}
190
191fn error_json(err: &CrdtRuntimeError) -> Option<String> {
192    serde_json::to_string(&CrdtErrorEnvelope {
193        error: err.to_string(),
194    })
195    .ok()
196}
197
198#[derive(Debug, Deserialize)]
199#[serde(rename_all = "camelCase")]
200struct JoinPayload {
201    room_id: String,
202    #[serde(default = "default_transport")]
203    transport: String,
204}
205
206fn default_transport() -> String {
207    "relay".to_string()
208}
209
210#[derive(Debug, Deserialize)]
211#[serde(rename_all = "camelCase")]
212struct UpdatePayload {
213    room_id: String,
214    data_base64: String,
215}
216
217#[derive(Debug, Deserialize)]
218#[serde(rename_all = "camelCase")]
219struct RoomIdPayload {
220    room_id: String,
221}
222
223#[derive(Debug, Deserialize)]
224#[serde(rename_all = "camelCase")]
225struct AwarenessPayload {
226    room_id: String,
227    state: String,
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use crate::executor_contract::CrdtJoinResult;
234
235    #[test]
236    fn crdt_extension_basics() {
237        let ext = CrdtExtension::new();
238        assert_eq!(ext.namespace(), "crdt");
239        assert_eq!(ext.channel(), "hostBridge");
240        assert!(ext.inject_script().contains("window.host.ext.crdt"));
241        assert!(ext.inject_script().contains("crdtJoin"));
242        assert!(ext.inject_script().contains("crdtApplyUpdate"));
243        assert!(ext.inject_script().contains("crdtGetStateVector"));
244        assert!(ext.inject_script().contains("crdtGetFullState"));
245        assert!(ext.inject_script().contains("crdtSetAwareness"));
246        assert!(ext.inject_script().contains("crdtDestroy"));
247    }
248
249    #[test]
250    fn crdt_join_returns_result() {
251        let ext = CrdtExtension::new();
252        let result_json = ext
253            .handle_message("crdtJoin", r#"{"roomId":"doc-1"}"#)
254            .expect("join result");
255        let result: CrdtJoinResult = serde_json::from_str(&result_json).expect("parse join result");
256        assert_eq!(result.room_id, "doc-1");
257        assert_eq!(result.transport, "relay");
258    }
259
260    #[test]
261    fn crdt_join_with_transport() {
262        let ext = CrdtExtension::new();
263        let result_json = ext
264            .handle_message("crdtJoin", r#"{"roomId":"doc-1","transport":"p2p"}"#)
265            .expect("join result");
266        let result: CrdtJoinResult = serde_json::from_str(&result_json).expect("parse join result");
267        assert_eq!(result.transport, "p2p");
268    }
269
270    #[test]
271    fn crdt_apply_update_succeeds() {
272        let ext = CrdtExtension::new();
273        ext.handle_message("crdtJoin", r#"{"roomId":"doc-1"}"#);
274        let result = ext
275            .handle_message(
276                "crdtApplyUpdate",
277                r#"{"roomId":"doc-1","dataBase64":"AQID"}"#,
278            )
279            .expect("apply result");
280        assert_eq!(result, "true");
281    }
282
283    #[test]
284    fn crdt_get_state_vector_succeeds() {
285        let ext = CrdtExtension::new();
286        ext.handle_message("crdtJoin", r#"{"roomId":"doc-1"}"#);
287        let result = ext
288            .handle_message("crdtGetStateVector", r#"{"roomId":"doc-1"}"#)
289            .expect("state vector");
290        assert_eq!(result, r#""0""#);
291    }
292
293    #[test]
294    fn crdt_get_full_state_succeeds() {
295        let ext = CrdtExtension::new();
296        ext.handle_message("crdtJoin", r#"{"roomId":"doc-1"}"#);
297        ext.handle_message(
298            "crdtApplyUpdate",
299            r#"{"roomId":"doc-1","dataBase64":"AQID"}"#,
300        );
301        let result = ext
302            .handle_message("crdtGetFullState", r#"{"roomId":"doc-1"}"#)
303            .expect("full state");
304        assert_eq!(result, r#""AQID""#);
305    }
306
307    #[test]
308    fn crdt_set_awareness_succeeds() {
309        let ext = CrdtExtension::new();
310        ext.handle_message("crdtJoin", r#"{"roomId":"doc-1"}"#);
311        let result = ext
312            .handle_message(
313                "crdtSetAwareness",
314                r#"{"roomId":"doc-1","state":"{\"cursor\":5}"}"#,
315            )
316            .expect("awareness result");
317        assert_eq!(result, "true");
318    }
319
320    #[test]
321    fn crdt_destroy_succeeds() {
322        let ext = CrdtExtension::new();
323        ext.handle_message("crdtJoin", r#"{"roomId":"doc-1"}"#);
324        let result = ext
325            .handle_message("crdtDestroy", r#"{"roomId":"doc-1"}"#)
326            .expect("destroy result");
327        assert_eq!(result, "true");
328    }
329
330    #[test]
331    fn crdt_errors_are_json_safe() {
332        let ext = CrdtExtension::new();
333        // Every method that takes roomId should return a valid JSON error for missing rooms.
334        for (method, params) in [
335            (
336                "crdtApplyUpdate",
337                r#"{"roomId":"missing","dataBase64":"AA"}"#,
338            ),
339            ("crdtGetStateVector", r#"{"roomId":"missing"}"#),
340            ("crdtGetFullState", r#"{"roomId":"missing"}"#),
341            ("crdtSetAwareness", r#"{"roomId":"missing","state":"{}"}"#),
342            ("crdtDestroy", r#"{"roomId":"missing"}"#),
343        ] {
344            let result = ext
345                .handle_message(method, params)
346                .unwrap_or_else(|| panic!("{method} should return Some"));
347            let parsed: serde_json::Value =
348                serde_json::from_str(&result).unwrap_or_else(|_| panic!("{method} not valid JSON"));
349            assert_eq!(
350                parsed["error"].as_str(),
351                Some("crdt room not found"),
352                "{method} error mismatch"
353            );
354        }
355    }
356
357    #[test]
358    fn crdt_unknown_method_returns_none() {
359        let ext = CrdtExtension::new();
360        assert!(ext.handle_message("crdtUnknown", "{}").is_none());
361    }
362
363    #[test]
364    fn crdt_malformed_params_returns_none() {
365        let ext = CrdtExtension::new();
366        assert!(ext.handle_message("crdtJoin", "not json").is_none());
367    }
368
369    #[test]
370    fn crdt_rejects_empty_room_id() {
371        let ext = CrdtExtension::new();
372        for (method, params) in [
373            ("crdtJoin", r#"{"roomId":""}"#),
374            ("crdtApplyUpdate", r#"{"roomId":"","dataBase64":"AA"}"#),
375            ("crdtGetStateVector", r#"{"roomId":""}"#),
376            ("crdtGetFullState", r#"{"roomId":""}"#),
377            ("crdtSetAwareness", r#"{"roomId":"","state":"{}"}"#),
378            ("crdtDestroy", r#"{"roomId":""}"#),
379        ] {
380            let result = ext
381                .handle_message(method, params)
382                .unwrap_or_else(|| panic!("{method} should return Some for empty roomId"));
383            let parsed: serde_json::Value =
384                serde_json::from_str(&result).unwrap_or_else(|_| panic!("{method} not valid JSON"));
385            assert_eq!(
386                parsed["error"].as_str(),
387                Some("roomId is required"),
388                "{method} empty roomId error mismatch"
389            );
390        }
391    }
392
393    #[test]
394    fn crdt_rejects_overlength_room_id() {
395        let ext = CrdtExtension::new();
396        let long_id = "x".repeat(MAX_ROOM_ID_LEN + 1);
397        let params = format!(r#"{{"roomId":"{}"}}"#, long_id);
398        assert!(
399            ext.handle_message("crdtJoin", &params).is_none(),
400            "overlength roomId should return None"
401        );
402    }
403
404    struct EventfulRuntime;
405
406    impl CrdtRuntime for EventfulRuntime {
407        fn join(
408            &self,
409            room_id: &str,
410            _transport: &str,
411        ) -> Result<CrdtJoinResult, CrdtRuntimeError> {
412            Ok(CrdtJoinResult {
413                room_id: room_id.to_string(),
414                transport: "relay".to_string(),
415            })
416        }
417        fn apply_update(&self, _room_id: &str, _data: &str) -> Result<bool, CrdtRuntimeError> {
418            Ok(true)
419        }
420        fn get_state_vector(&self, _room_id: &str) -> Result<String, CrdtRuntimeError> {
421            Ok("0".into())
422        }
423        fn get_full_state(&self, _room_id: &str) -> Result<String, CrdtRuntimeError> {
424            Ok("".into())
425        }
426        fn set_awareness(&self, _room_id: &str, _state: &str) -> Result<bool, CrdtRuntimeError> {
427            Ok(true)
428        }
429        fn destroy(&self, _room_id: &str) -> Result<bool, CrdtRuntimeError> {
430            Ok(true)
431        }
432        fn drain_events(&self) -> Vec<crate::HostPushEvent> {
433            vec![crate::HostPushEvent {
434                event: "crdtRemoteUpdate".into(),
435                payload_json: r#"{"roomId":"doc-1","updateBase64":"AQID"}"#.into(),
436            }]
437        }
438    }
439
440    #[test]
441    fn crdt_drain_events_delegates_to_runtime() {
442        let ext = CrdtExtension::with_runtime(EventfulRuntime);
443        let events = ext.drain_events();
444        assert_eq!(events.len(), 1);
445        assert_eq!(events[0].event, "crdtRemoteUpdate");
446        assert!(events[0].payload_json.contains("doc-1"));
447    }
448
449    struct StaticRuntime;
450
451    impl CrdtRuntime for StaticRuntime {
452        fn join(
453            &self,
454            room_id: &str,
455            _transport: &str,
456        ) -> Result<CrdtJoinResult, CrdtRuntimeError> {
457            Ok(CrdtJoinResult {
458                room_id: room_id.to_string(),
459                transport: "custom".to_string(),
460            })
461        }
462        fn apply_update(&self, _room_id: &str, _data: &str) -> Result<bool, CrdtRuntimeError> {
463            Ok(true)
464        }
465        fn get_state_vector(&self, _room_id: &str) -> Result<String, CrdtRuntimeError> {
466            Ok("static-vector".into())
467        }
468        fn get_full_state(&self, _room_id: &str) -> Result<String, CrdtRuntimeError> {
469            Ok("static-state".into())
470        }
471        fn set_awareness(&self, _room_id: &str, _state: &str) -> Result<bool, CrdtRuntimeError> {
472            Ok(true)
473        }
474        fn destroy(&self, _room_id: &str) -> Result<bool, CrdtRuntimeError> {
475            Ok(true)
476        }
477    }
478
479    #[test]
480    fn crdt_extension_supports_injected_runtime() {
481        let ext = CrdtExtension::with_runtime(StaticRuntime);
482        let result_json = ext
483            .handle_message("crdtJoin", r#"{"roomId":"doc-1"}"#)
484            .expect("join result");
485        let result: CrdtJoinResult = serde_json::from_str(&result_json).expect("parse join result");
486        assert_eq!(result.transport, "custom");
487    }
488}