1#[cfg(not(target_arch = "wasm32"))]
7use n0_future::stream::Boxed;
8#[cfg(target_arch = "wasm32")]
9use n0_future::stream::BoxedLocal as Boxed;
10use serde::{Deserialize, Serialize};
11use std::error::Error;
12use std::future::Future;
13use std::marker::PhantomData;
14use url::Url;
15
16use crate::cowstr::ToCowStr;
17use crate::error::DecodeError;
18use crate::stream::StreamError;
19use crate::websocket::{WebSocketClient, WebSocketConnection, WsSink, WsStream};
20use crate::{CowStr, Data, IntoStatic, RawData, WsMessage};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum MessageEncoding {
25 Json,
27 DagCbor,
29}
30
31pub trait SubscriptionResp {
39 const NSID: &'static str;
41
42 const ENCODING: MessageEncoding;
44
45 type Message<'de>: Deserialize<'de> + IntoStatic;
47
48 type Error<'de>: Error + Deserialize<'de> + IntoStatic;
50
51 fn decode_message<'de>(bytes: &'de [u8]) -> Result<Self::Message<'de>, DecodeError> {
57 match Self::ENCODING {
58 MessageEncoding::Json => serde_json::from_slice(bytes).map_err(DecodeError::from),
59 MessageEncoding::DagCbor => {
60 serde_ipld_dagcbor::from_slice(bytes).map_err(DecodeError::from)
61 }
62 }
63 }
64}
65
66pub trait XrpcSubscription: Serialize {
73 const NSID: &'static str;
75
76 const ENCODING: MessageEncoding;
78
79 const CUSTOM_PATH: Option<&'static str> = None;
82
83 type Stream: SubscriptionResp;
85
86 fn query_params(&self) -> Vec<(String, String)> {
90 serde_html_form::to_string(self)
92 .ok()
93 .map(|s| {
94 s.split('&')
95 .filter_map(|pair| {
96 let mut parts = pair.splitn(2, '=');
97 Some((parts.next()?.to_string(), parts.next()?.to_string()))
98 })
99 .collect()
100 })
101 .unwrap_or_default()
102 }
103}
104
105#[derive(Debug, serde::Deserialize)]
110pub struct EventHeader {
111 pub op: i64,
113 pub t: smol_str::SmolStr,
115}
116
117pub fn parse_event_header<'a>(bytes: &'a [u8]) -> Result<(EventHeader, &'a [u8]), DecodeError> {
122 let mut cursor = std::io::Cursor::new(bytes);
123 let header: EventHeader = ciborium::de::from_reader(&mut cursor)?;
124 let position = cursor.position() as usize;
125 drop(cursor); Ok((header, &bytes[position..]))
128}
129
130pub fn decode_json_msg<S: SubscriptionResp>(
132 msg_result: Result<crate::websocket::WsMessage, StreamError>,
133) -> Option<Result<StreamMessage<'static, S>, StreamError>>
134where
135 for<'a> StreamMessage<'a, S>: IntoStatic<Output = StreamMessage<'static, S>>,
136{
137 use crate::websocket::WsMessage;
138
139 match msg_result {
140 Ok(WsMessage::Text(text)) => Some(
141 S::decode_message(text.as_ref())
142 .map(|v| v.into_static())
143 .map_err(StreamError::decode),
144 ),
145 Ok(WsMessage::Binary(bytes)) => {
146 #[cfg(feature = "zstd")]
147 {
148 match decompress_zstd(&bytes) {
150 Ok(decompressed) => Some(
151 S::decode_message(&decompressed)
152 .map(|v| v.into_static())
153 .map_err(StreamError::decode),
154 ),
155 Err(_) => {
156 Some(
158 S::decode_message(&bytes)
159 .map(|v| v.into_static())
160 .map_err(StreamError::decode),
161 )
162 }
163 }
164 }
165 #[cfg(not(feature = "zstd"))]
166 {
167 Some(
168 S::decode_message(&bytes)
169 .map(|v| v.into_static())
170 .map_err(StreamError::decode),
171 )
172 }
173 }
174 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
175 Err(e) => Some(Err(e)),
176 }
177}
178
179#[cfg(feature = "zstd")]
180fn decompress_zstd(bytes: &[u8]) -> Result<Vec<u8>, std::io::Error> {
181 use std::sync::OnceLock;
182 use zstd::stream::decode_all;
183
184 static DICTIONARY: OnceLock<Vec<u8>> = OnceLock::new();
185
186 let dict = DICTIONARY.get_or_init(|| include_bytes!("../../zstd_dictionary").to_vec());
187
188 decode_all(std::io::Cursor::new(bytes)).or_else(|_| {
189 let mut decoder = zstd::Decoder::with_dictionary(std::io::Cursor::new(bytes), dict)?;
191 let mut result = Vec::new();
192 std::io::Read::read_to_end(&mut decoder, &mut result)?;
193 Ok(result)
194 })
195}
196
197pub fn decode_cbor_msg<S: SubscriptionResp>(
199 msg_result: Result<crate::websocket::WsMessage, StreamError>,
200) -> Option<Result<StreamMessage<'static, S>, StreamError>>
201where
202 for<'a> StreamMessage<'a, S>: IntoStatic<Output = StreamMessage<'static, S>>,
203{
204 use crate::websocket::WsMessage;
205
206 match msg_result {
207 Ok(WsMessage::Binary(bytes)) => Some(
208 S::decode_message(&bytes)
209 .map(|v| v.into_static())
210 .map_err(StreamError::decode),
211 ),
212 Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format(
213 "expected binary frame for CBOR, got text",
214 ))),
215 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
216 Err(e) => Some(Err(e)),
217 }
218}
219
220pub trait SubscriptionControlMessage: Serialize {
230 type Subscription: XrpcSubscription;
232
233 fn encode(&self) -> Result<WsMessage, StreamError> {
237 Ok(WsMessage::from(
238 serde_json::to_string(&self).map_err(StreamError::encode)?,
239 ))
240 }
241
242 fn decode<'de>(frame: &'de [u8]) -> Result<Self, StreamError>
244 where
245 Self: Deserialize<'de>,
246 {
247 Ok(serde_json::from_slice(frame).map_err(StreamError::decode)?)
248 }
249}
250
251pub struct SubscriptionController<S: SubscriptionControlMessage> {
253 controller: WsSink,
254 _marker: PhantomData<fn() -> S>,
255}
256
257impl<S: SubscriptionControlMessage> SubscriptionController<S> {
258 pub fn new(controller: WsSink) -> Self {
260 Self {
261 controller,
262 _marker: PhantomData,
263 }
264 }
265
266 pub async fn configure(&mut self, params: &S) -> Result<(), StreamError> {
268 let message = params.encode()?;
269
270 n0_future::SinkExt::send(self.controller.get_mut(), message)
271 .await
272 .map_err(StreamError::transport)
273 }
274}
275
276pub struct SubscriptionStream<S: SubscriptionResp> {
281 _marker: PhantomData<fn() -> S>,
282 connection: WebSocketConnection,
283}
284
285impl<S: SubscriptionResp> SubscriptionStream<S> {
286 pub fn new(connection: WebSocketConnection) -> Self {
288 Self {
289 _marker: PhantomData,
290 connection,
291 }
292 }
293
294 pub fn connection(&self) -> &WebSocketConnection {
296 &self.connection
297 }
298
299 pub fn connection_mut(&mut self) -> &mut WebSocketConnection {
301 &mut self.connection
302 }
303
304 pub fn into_stream(
309 self,
310 ) -> (
311 WsSink,
312 Boxed<Result<StreamMessage<'static, S>, StreamError>>,
313 )
314 where
315 for<'a> StreamMessage<'a, S>: IntoStatic<Output = StreamMessage<'static, S>>,
316 {
317 use n0_future::StreamExt as _;
318
319 let (tx, rx) = self.connection.split();
320
321 #[cfg(not(target_arch = "wasm32"))]
322 let stream = match S::ENCODING {
323 MessageEncoding::Json => rx
324 .into_inner()
325 .filter_map(|msg| decode_json_msg::<S>(msg))
326 .boxed(),
327 MessageEncoding::DagCbor => rx
328 .into_inner()
329 .filter_map(|msg| decode_cbor_msg::<S>(msg))
330 .boxed(),
331 };
332
333 #[cfg(target_arch = "wasm32")]
334 let stream = match S::ENCODING {
335 MessageEncoding::Json => rx
336 .into_inner()
337 .filter_map(|msg| decode_json_msg::<S>(msg))
338 .boxed_local(),
339 MessageEncoding::DagCbor => rx
340 .into_inner()
341 .filter_map(|msg| decode_cbor_msg::<S>(msg))
342 .boxed_local(),
343 };
344
345 (tx, stream)
346 }
347
348 pub fn into_raw_data_stream(self) -> (WsSink, Boxed<Result<RawData<'static>, StreamError>>) {
350 use n0_future::StreamExt as _;
351
352 let (tx, rx) = self.connection.split();
353
354 fn parse_msg<'a>(bytes: &'a [u8]) -> Result<RawData<'a>, serde_json::Error> {
355 serde_json::from_slice(bytes)
356 }
357 fn parse_cbor<'a>(
358 bytes: &'a [u8],
359 ) -> Result<RawData<'a>, serde_ipld_dagcbor::DecodeError<std::convert::Infallible>>
360 {
361 serde_ipld_dagcbor::from_slice(bytes)
362 }
363
364 #[cfg(not(target_arch = "wasm32"))]
365 let stream = match S::ENCODING {
366 MessageEncoding::Json => rx
367 .into_inner()
368 .filter_map(|msg_result| match msg_result {
369 Ok(WsMessage::Text(text)) => Some(
370 parse_msg(text.as_ref())
371 .map(|v| v.into_static())
372 .map_err(StreamError::decode),
373 ),
374 Ok(WsMessage::Binary(bytes)) => {
375 #[cfg(feature = "zstd")]
376 {
377 match decompress_zstd(&bytes) {
378 Ok(decompressed) => Some(
379 parse_msg(&decompressed)
380 .map(|v| v.into_static())
381 .map_err(StreamError::decode),
382 ),
383 Err(_) => Some(
384 parse_msg(&bytes)
385 .map(|v| v.into_static())
386 .map_err(StreamError::decode),
387 ),
388 }
389 }
390 #[cfg(not(feature = "zstd"))]
391 {
392 Some(
393 parse_msg(&bytes)
394 .map(|v| v.into_static())
395 .map_err(StreamError::decode),
396 )
397 }
398 }
399 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
400 Err(e) => Some(Err(e)),
401 })
402 .boxed(),
403 MessageEncoding::DagCbor => rx
404 .into_inner()
405 .filter_map(|msg_result| match msg_result {
406 Ok(WsMessage::Binary(bytes)) => Some(
407 parse_cbor(&bytes)
408 .map(|v| v.into_static())
409 .map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))),
410 ),
411 Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format(
412 "expected binary frame for CBOR, got text",
413 ))),
414 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
415 Err(e) => Some(Err(e)),
416 })
417 .boxed(),
418 };
419
420 #[cfg(target_arch = "wasm32")]
421 let stream = match S::ENCODING {
422 MessageEncoding::Json => rx
423 .into_inner()
424 .filter_map(|msg_result| match msg_result {
425 Ok(WsMessage::Text(text)) => Some(
426 parse_msg(text.as_ref())
427 .map(|v| v.into_static())
428 .map_err(StreamError::decode),
429 ),
430 Ok(WsMessage::Binary(bytes)) => {
431 #[cfg(feature = "zstd")]
432 {
433 match decompress_zstd(&bytes) {
434 Ok(decompressed) => Some(
435 parse_msg(&decompressed)
436 .map(|v| v.into_static())
437 .map_err(StreamError::decode),
438 ),
439 Err(_) => Some(
440 parse_msg(&bytes)
441 .map(|v| v.into_static())
442 .map_err(StreamError::decode),
443 ),
444 }
445 }
446 #[cfg(not(feature = "zstd"))]
447 {
448 Some(
449 parse_msg(&bytes)
450 .map(|v| v.into_static())
451 .map_err(StreamError::decode),
452 )
453 }
454 }
455 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
456 Err(e) => Some(Err(e)),
457 })
458 .boxed_local(),
459 MessageEncoding::DagCbor => rx
460 .into_inner()
461 .filter_map(|msg_result| match msg_result {
462 Ok(WsMessage::Binary(bytes)) => Some(
463 parse_cbor(&bytes)
464 .map(|v| v.into_static())
465 .map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))),
466 ),
467 Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format(
468 "expected binary frame for CBOR, got text",
469 ))),
470 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
471 Err(e) => Some(Err(e)),
472 })
473 .boxed_local(),
474 };
475
476 (tx, stream)
477 }
478
479 pub fn into_data_stream(self) -> (WsSink, Boxed<Result<Data<'static>, StreamError>>) {
481 use n0_future::StreamExt as _;
482
483 let (tx, rx) = self.connection.split();
484
485 fn parse_msg<'a>(bytes: &'a [u8]) -> Result<Data<'a>, serde_json::Error> {
486 serde_json::from_slice(bytes)
487 }
488 fn parse_cbor<'a>(
489 bytes: &'a [u8],
490 ) -> Result<Data<'a>, serde_ipld_dagcbor::DecodeError<std::convert::Infallible>> {
491 serde_ipld_dagcbor::from_slice(bytes)
492 }
493
494 #[cfg(not(target_arch = "wasm32"))]
495 let stream = match S::ENCODING {
496 MessageEncoding::Json => rx
497 .into_inner()
498 .filter_map(|msg_result| match msg_result {
499 Ok(WsMessage::Text(text)) => Some(
500 parse_msg(text.as_ref())
501 .map(|v| v.into_static())
502 .map_err(StreamError::decode),
503 ),
504 Ok(WsMessage::Binary(bytes)) => {
505 #[cfg(feature = "zstd")]
506 {
507 match decompress_zstd(&bytes) {
508 Ok(decompressed) => Some(
509 parse_msg(&decompressed)
510 .map(|v| v.into_static())
511 .map_err(StreamError::decode),
512 ),
513 Err(_) => Some(
514 parse_msg(&bytes)
515 .map(|v| v.into_static())
516 .map_err(StreamError::decode),
517 ),
518 }
519 }
520 #[cfg(not(feature = "zstd"))]
521 {
522 Some(
523 parse_msg(&bytes)
524 .map(|v| v.into_static())
525 .map_err(StreamError::decode),
526 )
527 }
528 }
529 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
530 Err(e) => Some(Err(e)),
531 })
532 .boxed(),
533 MessageEncoding::DagCbor => rx
534 .into_inner()
535 .filter_map(|msg_result| match msg_result {
536 Ok(WsMessage::Binary(bytes)) => Some(
537 parse_cbor(&bytes)
538 .map(|v| v.into_static())
539 .map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))),
540 ),
541 Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format(
542 "expected binary frame for CBOR, got text",
543 ))),
544 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
545 Err(e) => Some(Err(e)),
546 })
547 .boxed(),
548 };
549
550 #[cfg(target_arch = "wasm32")]
551 let stream = match S::ENCODING {
552 MessageEncoding::Json => rx
553 .into_inner()
554 .filter_map(|msg_result| match msg_result {
555 Ok(WsMessage::Text(text)) => Some(
556 parse_msg(text.as_ref())
557 .map(|v| v.into_static())
558 .map_err(StreamError::decode),
559 ),
560 Ok(WsMessage::Binary(bytes)) => {
561 #[cfg(feature = "zstd")]
562 {
563 match decompress_zstd(&bytes) {
564 Ok(decompressed) => Some(
565 parse_msg(&decompressed)
566 .map(|v| v.into_static())
567 .map_err(StreamError::decode),
568 ),
569 Err(_) => Some(
570 parse_msg(&bytes)
571 .map(|v| v.into_static())
572 .map_err(StreamError::decode),
573 ),
574 }
575 }
576 #[cfg(not(feature = "zstd"))]
577 {
578 Some(
579 parse_msg(&bytes)
580 .map(|v| v.into_static())
581 .map_err(StreamError::decode),
582 )
583 }
584 }
585 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
586 Err(e) => Some(Err(e)),
587 })
588 .boxed_local(),
589 MessageEncoding::DagCbor => rx
590 .into_inner()
591 .filter_map(|msg_result| match msg_result {
592 Ok(WsMessage::Binary(bytes)) => Some(
593 parse_cbor(&bytes)
594 .map(|v| v.into_static())
595 .map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))),
596 ),
597 Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format(
598 "expected binary frame for CBOR, got text",
599 ))),
600 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
601 Err(e) => Some(Err(e)),
602 })
603 .boxed_local(),
604 };
605
606 (tx, stream)
607 }
608
609 pub fn into_connection(self) -> WebSocketConnection {
611 self.connection
612 }
613
614 pub fn tee(&mut self) -> Boxed<Result<StreamMessage<'static, S>, StreamError>>
620 where
621 for<'a> StreamMessage<'a, S>: IntoStatic<Output = StreamMessage<'static, S>>,
622 {
623 use n0_future::StreamExt as _;
624
625 let rx = self.connection.receiver_mut();
626 let (raw_rx, typed_rx_source) =
627 std::mem::replace(rx, WsStream::new(n0_future::stream::empty())).tee();
628
629 *rx = raw_rx;
631
632 #[cfg(not(target_arch = "wasm32"))]
633 let stream = match S::ENCODING {
634 MessageEncoding::Json => typed_rx_source
635 .into_inner()
636 .filter_map(|msg| decode_json_msg::<S>(msg))
637 .boxed(),
638 MessageEncoding::DagCbor => typed_rx_source
639 .into_inner()
640 .filter_map(|msg| decode_cbor_msg::<S>(msg))
641 .boxed(),
642 };
643
644 #[cfg(target_arch = "wasm32")]
645 let stream = match S::ENCODING {
646 MessageEncoding::Json => typed_rx_source
647 .into_inner()
648 .filter_map(|msg| decode_json_msg::<S>(msg))
649 .boxed_local(),
650 MessageEncoding::DagCbor => typed_rx_source
651 .into_inner()
652 .filter_map(|msg| decode_cbor_msg::<S>(msg))
653 .boxed_local(),
654 };
655 stream
656 }
657}
658
659type StreamMessage<'a, R> = <R as SubscriptionResp>::Message<'a>;
660
661pub trait SubscriptionEndpoint {
669 const PATH: &'static str;
671
672 const ENCODING: MessageEncoding;
674
675 type Params<'de>: XrpcSubscription + Deserialize<'de> + IntoStatic;
677
678 type Stream: SubscriptionResp;
680}
681
682#[derive(Debug, Default, Clone)]
684pub struct SubscriptionOptions<'a> {
685 pub headers: Vec<(CowStr<'a>, CowStr<'a>)>,
687}
688
689impl IntoStatic for SubscriptionOptions<'_> {
690 type Output = SubscriptionOptions<'static>;
691
692 fn into_static(self) -> Self::Output {
693 SubscriptionOptions {
694 headers: self
695 .headers
696 .into_iter()
697 .map(|(k, v)| (k.into_static(), v.into_static()))
698 .collect(),
699 }
700 }
701}
702
703pub trait SubscriptionExt: WebSocketClient {
707 fn subscription<'a>(&'a self, base: Url) -> SubscriptionCall<'a, Self>
709 where
710 Self: Sized,
711 {
712 SubscriptionCall {
713 client: self,
714 base,
715 opts: SubscriptionOptions::default(),
716 }
717 }
718}
719
720impl<T: WebSocketClient> SubscriptionExt for T {}
721
722pub struct SubscriptionCall<'a, C: WebSocketClient> {
726 pub(crate) client: &'a C,
727 pub(crate) base: Url,
728 pub(crate) opts: SubscriptionOptions<'a>,
729}
730
731impl<'a, C: WebSocketClient> SubscriptionCall<'a, C> {
732 pub fn header(mut self, name: impl Into<CowStr<'a>>, value: impl Into<CowStr<'a>>) -> Self {
734 self.opts.headers.push((name.into(), value.into()));
735 self
736 }
737
738 pub fn with_options(mut self, opts: SubscriptionOptions<'a>) -> Self {
740 self.opts = opts;
741 self
742 }
743
744 pub async fn subscribe<Sub>(
750 self,
751 params: &Sub,
752 ) -> Result<SubscriptionStream<Sub::Stream>, C::Error>
753 where
754 Sub: XrpcSubscription,
755 {
756 let mut url = self.base.clone();
757
758 let mut path = url.path().trim_end_matches('/').to_owned();
760 if let Some(custom_path) = Sub::CUSTOM_PATH {
761 path.push_str(custom_path);
762 } else {
763 path.push_str("/xrpc/");
764 path.push_str(Sub::NSID);
765 }
766 url.set_path(&path);
767
768 let query_params = params.query_params();
769 if !query_params.is_empty() {
770 let qs = query_params
771 .iter()
772 .map(|(k, v)| format!("{}={}", k, v))
773 .collect::<Vec<_>>()
774 .join("&");
775 url.set_query(Some(&qs));
776 } else {
777 url.set_query(None);
778 }
779
780 let connection = self
781 .client
782 .connect_with_headers(url, self.opts.headers)
783 .await?;
784
785 Ok(SubscriptionStream::new(connection))
786 }
787}
788
789#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
794pub trait SubscriptionClient: WebSocketClient {
795 fn base_uri(&self) -> impl Future<Output = CowStr<'static>>;
797
798 fn subscription_opts(&self) -> impl Future<Output = SubscriptionOptions<'_>> {
800 async { SubscriptionOptions::default() }
801 }
802
803 #[cfg(not(target_arch = "wasm32"))]
805 fn subscribe<Sub>(
806 &self,
807 params: &Sub,
808 ) -> impl Future<Output = Result<SubscriptionStream<Sub::Stream>, Self::Error>>
809 where
810 Sub: XrpcSubscription + Send + Sync,
811 Self: Sync;
812
813 #[cfg(target_arch = "wasm32")]
815 fn subscribe<Sub>(
816 &self,
817 params: &Sub,
818 ) -> impl Future<Output = Result<SubscriptionStream<Sub::Stream>, Self::Error>>
819 where
820 Sub: XrpcSubscription + Send + Sync;
821
822 #[cfg(not(target_arch = "wasm32"))]
824 fn subscribe_with_opts<Sub>(
825 &self,
826 params: &Sub,
827 opts: SubscriptionOptions<'_>,
828 ) -> impl Future<Output = Result<SubscriptionStream<Sub::Stream>, Self::Error>>
829 where
830 Sub: XrpcSubscription + Send + Sync,
831 Self: Sync;
832
833 #[cfg(target_arch = "wasm32")]
835 fn subscribe_with_opts<Sub>(
836 &self,
837 params: &Sub,
838 opts: SubscriptionOptions<'_>,
839 ) -> impl Future<Output = Result<SubscriptionStream<Sub::Stream>, Self::Error>>
840 where
841 Sub: XrpcSubscription + Send + Sync;
842}
843
844pub struct BasicSubscriptionClient<W: WebSocketClient> {
850 client: W,
851 base_uri: CowStr<'static>,
852 opts: SubscriptionOptions<'static>,
853}
854
855impl<W: WebSocketClient> BasicSubscriptionClient<W> {
856 pub fn new(client: W, base_uri: Url) -> Self {
858 let base_uri = base_uri.as_str().trim_end_matches("/");
859 Self {
860 client,
861 base_uri: base_uri.to_cowstr().into_static(),
862 opts: SubscriptionOptions::default(),
863 }
864 }
865
866 pub fn with_options(mut self, opts: SubscriptionOptions<'_>) -> Self {
868 self.opts = opts.into_static();
869 self
870 }
871
872 pub fn inner(&self) -> &W {
874 &self.client
875 }
876}
877
878impl<W: WebSocketClient> WebSocketClient for BasicSubscriptionClient<W> {
879 type Error = W::Error;
880
881 async fn connect(&self, url: Url) -> Result<WebSocketConnection, Self::Error> {
882 self.client.connect(url).await
883 }
884
885 async fn connect_with_headers(
886 &self,
887 url: Url,
888 headers: Vec<(CowStr<'_>, CowStr<'_>)>,
889 ) -> Result<WebSocketConnection, Self::Error> {
890 self.client.connect_with_headers(url, headers).await
891 }
892}
893
894impl<W: WebSocketClient> SubscriptionClient for BasicSubscriptionClient<W> {
895 async fn base_uri(&self) -> CowStr<'static> {
896 self.base_uri.clone()
897 }
898
899 async fn subscription_opts(&self) -> SubscriptionOptions<'_> {
900 self.opts.clone()
901 }
902
903 #[cfg(not(target_arch = "wasm32"))]
904 async fn subscribe<Sub>(
905 &self,
906 params: &Sub,
907 ) -> Result<SubscriptionStream<Sub::Stream>, Self::Error>
908 where
909 Sub: XrpcSubscription + Send + Sync,
910 Self: Sync,
911 {
912 let opts = self.subscription_opts().await;
913 self.subscribe_with_opts(params, opts).await
914 }
915
916 #[cfg(target_arch = "wasm32")]
917 async fn subscribe<Sub>(
918 &self,
919 params: &Sub,
920 ) -> Result<SubscriptionStream<Sub::Stream>, Self::Error>
921 where
922 Sub: XrpcSubscription + Send + Sync,
923 {
924 let opts = self.subscription_opts().await;
925 self.subscribe_with_opts(params, opts).await
926 }
927
928 #[cfg(not(target_arch = "wasm32"))]
929 async fn subscribe_with_opts<Sub>(
930 &self,
931 params: &Sub,
932 opts: SubscriptionOptions<'_>,
933 ) -> Result<SubscriptionStream<Sub::Stream>, Self::Error>
934 where
935 Sub: XrpcSubscription + Send + Sync,
936 Self: Sync,
937 {
938 let base = self.base_uri().await;
939 let base = Url::parse(&base).expect("Failed to parse base URL");
940 self.subscription(base)
941 .with_options(opts)
942 .subscribe(params)
943 .await
944 }
945
946 #[cfg(target_arch = "wasm32")]
947 async fn subscribe_with_opts<Sub>(
948 &self,
949 params: &Sub,
950 opts: SubscriptionOptions<'_>,
951 ) -> Result<SubscriptionStream<Sub::Stream>, Self::Error>
952 where
953 Sub: XrpcSubscription + Send + Sync,
954 {
955 let base = self.base_uri().await;
956 let base = Url::parse(&base).expect("Failed to parse base URL");
957 self.subscription(base)
958 .with_options(opts)
959 .subscribe(params)
960 .await
961 }
962}
963
964pub type TungsteniteSubscriptionClient =
983 BasicSubscriptionClient<crate::websocket::tungstenite_client::TungsteniteClient>;
984
985impl TungsteniteSubscriptionClient {
986 pub fn from_base_uri(base_uri: Url) -> Self {
988 let client = crate::websocket::tungstenite_client::TungsteniteClient::new();
989 BasicSubscriptionClient::new(client, base_uri)
990 }
991}