1use futures_util::Sink;
2use futures_util::stream::{SplitSink, SplitStream};
3use rocket_ws::Message;
4use rocket_ws::stream::DuplexStream;
5use yrs_tokio::signaling::Message as SignalingMessage;
6use yrs_tokio::{
7 YrsExchange, YrsSink, YrsStream, impl_yrs_signal_stream, to_signaling_message, yrs_common_sink,
8};
9
10#[derive(YrsStream)]
11pub struct YrsStream(SplitStream<DuplexStream>);
12#[derive(YrsExchange)]
13pub struct YrsSignalStream(SplitStream<DuplexStream>);
14
15impl_yrs_signal_stream!(YrsSignalStream, item => to_signaling_message!(item, custom Message::Frame(frame) => SignalingMessage::Binary(frame.into_data())));
16#[derive(YrsSink)]
17pub struct YrsSink(SplitSink<DuplexStream, Message>);
18#[yrs_common_sink]
19impl Sink<SignalingMessage> for YrsSink {}
20
21#[cfg(test)]
22mod test {
23 use crate::{YrsSink, YrsStream};
24 use futures_util::{SinkExt, ready};
25 use rocket::{State, get, routes};
26 use rocket_ws::stream::DuplexStream;
27 use rocket_ws::{Channel, WebSocket};
28 use std::net::SocketAddr;
29 use std::str::FromStr;
30 use std::sync::Arc;
31 use tokio::sync::Mutex;
32 use tokio::task::JoinHandle;
33 use yrs::updates::encoder::Encode;
34 use yrs::{GetString, Text, Transact};
35 use yrs_tokio::broadcast::BroadcastGroup;
36 use yrs_tokio::yrs_common_test;
37
38 #[get("/my-room")]
39 fn ws_handler(ws: WebSocket, bcast: &State<Arc<BroadcastGroup>>) -> Channel<'_> {
40 let bcast = bcast.inner();
41
42 ws.channel(move |stream| {
43 Box::pin(async move {
44 peer(stream, bcast).await;
45
46 Ok(())
47 })
48 })
49 }
50
51 async fn peer(stream: DuplexStream, bcast: &Arc<BroadcastGroup>) {
52 use rocket::futures::StreamExt;
53 let (sink, stream) = stream.split();
54 let sink = Arc::new(Mutex::new(YrsSink::from(sink)));
55 let stream = YrsStream::from(stream);
56
57 let sub = bcast.subscribe(sink, stream);
58 match sub.completed().await {
59 Ok(_) => println!("broadcasting for channel finished successfully"),
60 Err(e) => eprintln!("broadcasting for channel finished abruptly: {}", e),
61 }
62 }
63
64 #[yrs_common_test]
65 async fn start_server(
66 addr: &str,
67 bcast: Arc<BroadcastGroup>,
68 ) -> Result<JoinHandle<()>, Box<dyn std::error::Error>> {
69 let addr = SocketAddr::from_str(addr)?;
70
71 let rocket_handle = tokio::spawn(async move {
72 let _rocket = rocket::build()
73 .configure(
74 rocket::config::Config::figment()
75 .merge(("address", addr.ip().to_string()))
76 .merge(("port", addr.port())),
77 )
78 .manage(bcast.clone()) .mount("/", routes![ws_handler])
80 .launch()
81 .await;
82 });
83
84 Ok(rocket_handle)
85 }
86}