1use std::sync::Arc;
7
8use serde::{Deserialize, Serialize};
9
10use crate::HostExtension;
11
12use super::crdt_runtime::{CrdtRuntime, CrdtRuntimeError, InMemoryCrdtRuntime};
13
14const MAX_ROOM_ID_LEN: usize = 256;
16
17const HOST_EXT_CRDT_SCRIPT: &str = r#"
18(function(){
19 if (!window.host || !window.host.ext) return;
20 window.host.ext.crdt = Object.freeze({
21 join: function(roomId, opts) {
22 return window.__hostCall('hostBridge', 'crdtJoin', {
23 roomId: roomId,
24 transport: (opts && opts.transport) || 'relay'
25 });
26 },
27 applyUpdate: function(roomId, dataBase64) {
28 return window.__hostCall('hostBridge', 'crdtApplyUpdate', {
29 roomId: roomId,
30 dataBase64: dataBase64
31 });
32 },
33 getStateVector: function(roomId) {
34 return window.__hostCall('hostBridge', 'crdtGetStateVector', {
35 roomId: roomId
36 });
37 },
38 getFullState: function(roomId) {
39 return window.__hostCall('hostBridge', 'crdtGetFullState', {
40 roomId: roomId
41 });
42 },
43 setAwareness: function(roomId, state) {
44 return window.__hostCall('hostBridge', 'crdtSetAwareness', {
45 roomId: roomId,
46 state: JSON.stringify(state)
47 });
48 },
49 destroy: function(roomId) {
50 return window.__hostCall('hostBridge', 'crdtDestroy', {
51 roomId: roomId
52 });
53 }
54 });
55})();
56"#;
57
58pub struct CrdtExtension {
59 runtime: Arc<dyn CrdtRuntime>,
60}
61
62impl Default for CrdtExtension {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68impl CrdtExtension {
69 pub fn new() -> Self {
70 Self::with_runtime(InMemoryCrdtRuntime::new())
71 }
72
73 pub fn with_runtime<R>(runtime: R) -> Self
74 where
75 R: CrdtRuntime,
76 {
77 Self::with_shared_runtime(Arc::new(runtime))
78 }
79
80 pub fn with_shared_runtime(runtime: Arc<dyn CrdtRuntime>) -> Self {
81 Self { runtime }
82 }
83}
84
85fn check_room_id(room_id: &str) -> Option<Option<String>> {
88 if room_id.is_empty() {
89 return Some(error_json(&CrdtRuntimeError::RoomIdRequired));
90 }
91 if room_id.len() > MAX_ROOM_ID_LEN {
92 return Some(None);
93 }
94 None
95}
96
97impl HostExtension for CrdtExtension {
98 fn namespace(&self) -> &str {
99 "crdt"
100 }
101
102 fn channel(&self) -> &str {
103 "hostBridge"
104 }
105
106 fn inject_script(&self) -> &str {
107 HOST_EXT_CRDT_SCRIPT
108 }
109
110 fn handle_message(&self, method: &str, params: &str) -> Option<String> {
111 match method {
112 "crdtJoin" => {
113 let req: JoinPayload = serde_json::from_str(params).ok()?;
114 if let Some(early) = check_room_id(&req.room_id) {
115 return early;
116 }
117 match self.runtime.join(&req.room_id, &req.transport) {
118 Ok(result) => serde_json::to_string(&result).ok(),
119 Err(e) => error_json(&e),
120 }
121 }
122 "crdtApplyUpdate" => {
123 let req: UpdatePayload = serde_json::from_str(params).ok()?;
124 if let Some(early) = check_room_id(&req.room_id) {
125 return early;
126 }
127 match self.runtime.apply_update(&req.room_id, &req.data_base64) {
128 Ok(v) => serde_json::to_string(&v).ok(),
129 Err(e) => error_json(&e),
130 }
131 }
132 "crdtGetStateVector" => {
133 let req: RoomIdPayload = serde_json::from_str(params).ok()?;
134 if let Some(early) = check_room_id(&req.room_id) {
135 return early;
136 }
137 match self.runtime.get_state_vector(&req.room_id) {
138 Ok(v) => serde_json::to_string(&v).ok(),
139 Err(e) => error_json(&e),
140 }
141 }
142 "crdtGetFullState" => {
143 let req: RoomIdPayload = serde_json::from_str(params).ok()?;
144 if let Some(early) = check_room_id(&req.room_id) {
145 return early;
146 }
147 match self.runtime.get_full_state(&req.room_id) {
148 Ok(v) => serde_json::to_string(&v).ok(),
149 Err(e) => error_json(&e),
150 }
151 }
152 "crdtSetAwareness" => {
153 let req: AwarenessPayload = serde_json::from_str(params).ok()?;
154 if let Some(early) = check_room_id(&req.room_id) {
155 return early;
156 }
157 match self.runtime.set_awareness(&req.room_id, &req.state) {
158 Ok(v) => serde_json::to_string(&v).ok(),
159 Err(e) => error_json(&e),
160 }
161 }
162 "crdtDestroy" => {
163 let req: RoomIdPayload = serde_json::from_str(params).ok()?;
164 if let Some(early) = check_room_id(&req.room_id) {
165 return early;
166 }
167 match self.runtime.destroy(&req.room_id) {
168 Ok(v) => serde_json::to_string(&v).ok(),
169 Err(e) => error_json(&e),
170 }
171 }
172 _ => None,
173 }
174 }
175
176 fn drain_events(&self) -> Vec<crate::HostPushEvent> {
177 self.runtime.drain_events()
178 }
179}
180
181#[derive(Debug, Serialize)]
187struct CrdtErrorEnvelope {
188 error: String,
189}
190
191fn error_json(err: &CrdtRuntimeError) -> Option<String> {
192 serde_json::to_string(&CrdtErrorEnvelope {
193 error: err.to_string(),
194 })
195 .ok()
196}
197
198#[derive(Debug, Deserialize)]
199#[serde(rename_all = "camelCase")]
200struct JoinPayload {
201 room_id: String,
202 #[serde(default = "default_transport")]
203 transport: String,
204}
205
206fn default_transport() -> String {
207 "relay".to_string()
208}
209
210#[derive(Debug, Deserialize)]
211#[serde(rename_all = "camelCase")]
212struct UpdatePayload {
213 room_id: String,
214 data_base64: String,
215}
216
217#[derive(Debug, Deserialize)]
218#[serde(rename_all = "camelCase")]
219struct RoomIdPayload {
220 room_id: String,
221}
222
223#[derive(Debug, Deserialize)]
224#[serde(rename_all = "camelCase")]
225struct AwarenessPayload {
226 room_id: String,
227 state: String,
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use crate::executor_contract::CrdtJoinResult;
234
235 #[test]
236 fn crdt_extension_basics() {
237 let ext = CrdtExtension::new();
238 assert_eq!(ext.namespace(), "crdt");
239 assert_eq!(ext.channel(), "hostBridge");
240 assert!(ext.inject_script().contains("window.host.ext.crdt"));
241 assert!(ext.inject_script().contains("crdtJoin"));
242 assert!(ext.inject_script().contains("crdtApplyUpdate"));
243 assert!(ext.inject_script().contains("crdtGetStateVector"));
244 assert!(ext.inject_script().contains("crdtGetFullState"));
245 assert!(ext.inject_script().contains("crdtSetAwareness"));
246 assert!(ext.inject_script().contains("crdtDestroy"));
247 }
248
249 #[test]
250 fn crdt_join_returns_result() {
251 let ext = CrdtExtension::new();
252 let result_json = ext
253 .handle_message("crdtJoin", r#"{"roomId":"doc-1"}"#)
254 .expect("join result");
255 let result: CrdtJoinResult = serde_json::from_str(&result_json).expect("parse join result");
256 assert_eq!(result.room_id, "doc-1");
257 assert_eq!(result.transport, "relay");
258 }
259
260 #[test]
261 fn crdt_join_with_transport() {
262 let ext = CrdtExtension::new();
263 let result_json = ext
264 .handle_message("crdtJoin", r#"{"roomId":"doc-1","transport":"p2p"}"#)
265 .expect("join result");
266 let result: CrdtJoinResult = serde_json::from_str(&result_json).expect("parse join result");
267 assert_eq!(result.transport, "p2p");
268 }
269
270 #[test]
271 fn crdt_apply_update_succeeds() {
272 let ext = CrdtExtension::new();
273 ext.handle_message("crdtJoin", r#"{"roomId":"doc-1"}"#);
274 let result = ext
275 .handle_message(
276 "crdtApplyUpdate",
277 r#"{"roomId":"doc-1","dataBase64":"AQID"}"#,
278 )
279 .expect("apply result");
280 assert_eq!(result, "true");
281 }
282
283 #[test]
284 fn crdt_get_state_vector_succeeds() {
285 let ext = CrdtExtension::new();
286 ext.handle_message("crdtJoin", r#"{"roomId":"doc-1"}"#);
287 let result = ext
288 .handle_message("crdtGetStateVector", r#"{"roomId":"doc-1"}"#)
289 .expect("state vector");
290 assert_eq!(result, r#""0""#);
291 }
292
293 #[test]
294 fn crdt_get_full_state_succeeds() {
295 let ext = CrdtExtension::new();
296 ext.handle_message("crdtJoin", r#"{"roomId":"doc-1"}"#);
297 ext.handle_message(
298 "crdtApplyUpdate",
299 r#"{"roomId":"doc-1","dataBase64":"AQID"}"#,
300 );
301 let result = ext
302 .handle_message("crdtGetFullState", r#"{"roomId":"doc-1"}"#)
303 .expect("full state");
304 assert_eq!(result, r#""AQID""#);
305 }
306
307 #[test]
308 fn crdt_set_awareness_succeeds() {
309 let ext = CrdtExtension::new();
310 ext.handle_message("crdtJoin", r#"{"roomId":"doc-1"}"#);
311 let result = ext
312 .handle_message(
313 "crdtSetAwareness",
314 r#"{"roomId":"doc-1","state":"{\"cursor\":5}"}"#,
315 )
316 .expect("awareness result");
317 assert_eq!(result, "true");
318 }
319
320 #[test]
321 fn crdt_destroy_succeeds() {
322 let ext = CrdtExtension::new();
323 ext.handle_message("crdtJoin", r#"{"roomId":"doc-1"}"#);
324 let result = ext
325 .handle_message("crdtDestroy", r#"{"roomId":"doc-1"}"#)
326 .expect("destroy result");
327 assert_eq!(result, "true");
328 }
329
330 #[test]
331 fn crdt_errors_are_json_safe() {
332 let ext = CrdtExtension::new();
333 for (method, params) in [
335 (
336 "crdtApplyUpdate",
337 r#"{"roomId":"missing","dataBase64":"AA"}"#,
338 ),
339 ("crdtGetStateVector", r#"{"roomId":"missing"}"#),
340 ("crdtGetFullState", r#"{"roomId":"missing"}"#),
341 ("crdtSetAwareness", r#"{"roomId":"missing","state":"{}"}"#),
342 ("crdtDestroy", r#"{"roomId":"missing"}"#),
343 ] {
344 let result = ext
345 .handle_message(method, params)
346 .unwrap_or_else(|| panic!("{method} should return Some"));
347 let parsed: serde_json::Value =
348 serde_json::from_str(&result).unwrap_or_else(|_| panic!("{method} not valid JSON"));
349 assert_eq!(
350 parsed["error"].as_str(),
351 Some("crdt room not found"),
352 "{method} error mismatch"
353 );
354 }
355 }
356
357 #[test]
358 fn crdt_unknown_method_returns_none() {
359 let ext = CrdtExtension::new();
360 assert!(ext.handle_message("crdtUnknown", "{}").is_none());
361 }
362
363 #[test]
364 fn crdt_malformed_params_returns_none() {
365 let ext = CrdtExtension::new();
366 assert!(ext.handle_message("crdtJoin", "not json").is_none());
367 }
368
369 #[test]
370 fn crdt_rejects_empty_room_id() {
371 let ext = CrdtExtension::new();
372 for (method, params) in [
373 ("crdtJoin", r#"{"roomId":""}"#),
374 ("crdtApplyUpdate", r#"{"roomId":"","dataBase64":"AA"}"#),
375 ("crdtGetStateVector", r#"{"roomId":""}"#),
376 ("crdtGetFullState", r#"{"roomId":""}"#),
377 ("crdtSetAwareness", r#"{"roomId":"","state":"{}"}"#),
378 ("crdtDestroy", r#"{"roomId":""}"#),
379 ] {
380 let result = ext
381 .handle_message(method, params)
382 .unwrap_or_else(|| panic!("{method} should return Some for empty roomId"));
383 let parsed: serde_json::Value =
384 serde_json::from_str(&result).unwrap_or_else(|_| panic!("{method} not valid JSON"));
385 assert_eq!(
386 parsed["error"].as_str(),
387 Some("roomId is required"),
388 "{method} empty roomId error mismatch"
389 );
390 }
391 }
392
393 #[test]
394 fn crdt_rejects_overlength_room_id() {
395 let ext = CrdtExtension::new();
396 let long_id = "x".repeat(MAX_ROOM_ID_LEN + 1);
397 let params = format!(r#"{{"roomId":"{}"}}"#, long_id);
398 assert!(
399 ext.handle_message("crdtJoin", ¶ms).is_none(),
400 "overlength roomId should return None"
401 );
402 }
403
404 struct EventfulRuntime;
405
406 impl CrdtRuntime for EventfulRuntime {
407 fn join(
408 &self,
409 room_id: &str,
410 _transport: &str,
411 ) -> Result<CrdtJoinResult, CrdtRuntimeError> {
412 Ok(CrdtJoinResult {
413 room_id: room_id.to_string(),
414 transport: "relay".to_string(),
415 })
416 }
417 fn apply_update(&self, _room_id: &str, _data: &str) -> Result<bool, CrdtRuntimeError> {
418 Ok(true)
419 }
420 fn get_state_vector(&self, _room_id: &str) -> Result<String, CrdtRuntimeError> {
421 Ok("0".into())
422 }
423 fn get_full_state(&self, _room_id: &str) -> Result<String, CrdtRuntimeError> {
424 Ok("".into())
425 }
426 fn set_awareness(&self, _room_id: &str, _state: &str) -> Result<bool, CrdtRuntimeError> {
427 Ok(true)
428 }
429 fn destroy(&self, _room_id: &str) -> Result<bool, CrdtRuntimeError> {
430 Ok(true)
431 }
432 fn drain_events(&self) -> Vec<crate::HostPushEvent> {
433 vec![crate::HostPushEvent {
434 event: "crdtRemoteUpdate".into(),
435 payload_json: r#"{"roomId":"doc-1","updateBase64":"AQID"}"#.into(),
436 }]
437 }
438 }
439
440 #[test]
441 fn crdt_drain_events_delegates_to_runtime() {
442 let ext = CrdtExtension::with_runtime(EventfulRuntime);
443 let events = ext.drain_events();
444 assert_eq!(events.len(), 1);
445 assert_eq!(events[0].event, "crdtRemoteUpdate");
446 assert!(events[0].payload_json.contains("doc-1"));
447 }
448
449 struct StaticRuntime;
450
451 impl CrdtRuntime for StaticRuntime {
452 fn join(
453 &self,
454 room_id: &str,
455 _transport: &str,
456 ) -> Result<CrdtJoinResult, CrdtRuntimeError> {
457 Ok(CrdtJoinResult {
458 room_id: room_id.to_string(),
459 transport: "custom".to_string(),
460 })
461 }
462 fn apply_update(&self, _room_id: &str, _data: &str) -> Result<bool, CrdtRuntimeError> {
463 Ok(true)
464 }
465 fn get_state_vector(&self, _room_id: &str) -> Result<String, CrdtRuntimeError> {
466 Ok("static-vector".into())
467 }
468 fn get_full_state(&self, _room_id: &str) -> Result<String, CrdtRuntimeError> {
469 Ok("static-state".into())
470 }
471 fn set_awareness(&self, _room_id: &str, _state: &str) -> Result<bool, CrdtRuntimeError> {
472 Ok(true)
473 }
474 fn destroy(&self, _room_id: &str) -> Result<bool, CrdtRuntimeError> {
475 Ok(true)
476 }
477 }
478
479 #[test]
480 fn crdt_extension_supports_injected_runtime() {
481 let ext = CrdtExtension::with_runtime(StaticRuntime);
482 let result_json = ext
483 .handle_message("crdtJoin", r#"{"roomId":"doc-1"}"#)
484 .expect("join result");
485 let result: CrdtJoinResult = serde_json::from_str(&result_json).expect("parse join result");
486 assert_eq!(result.transport, "custom");
487 }
488}