jacquard_common/xrpc/
subscription.rs

1//! WebSocket subscription support for XRPC
2//!
3//! This module defines traits and types for typed WebSocket subscriptions,
4//! mirroring the request/response pattern used for HTTP XRPC endpoints.
5
6#[cfg(not(target_arch = "wasm32"))]
7use n0_future::stream::Boxed;
8#[cfg(target_arch = "wasm32")]
9use n0_future::stream::BoxedLocal as Boxed;
10use serde::{Deserialize, Serialize};
11use std::error::Error;
12use std::future::Future;
13use std::marker::PhantomData;
14use url::Url;
15
16use crate::cowstr::ToCowStr;
17use crate::error::DecodeError;
18use crate::stream::StreamError;
19use crate::websocket::{WebSocketClient, WebSocketConnection, WsSink, WsStream};
20use crate::{CowStr, Data, IntoStatic, RawData, WsMessage};
21
22/// Encoding format for subscription messages
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum MessageEncoding {
25    /// JSON text frames
26    Json,
27    /// DAG-CBOR binary frames
28    DagCbor,
29}
30
31/// XRPC subscription stream response trait
32///
33/// Analogous to `XrpcResp` but for WebSocket subscriptions.
34/// Defines the message and error types for a subscription stream.
35///
36/// This trait is implemented on a marker struct to keep it lifetime-free
37/// while using GATs for the message/error types.
38pub trait SubscriptionResp {
39    /// The NSID for this subscription
40    const NSID: &'static str;
41
42    /// Message encoding (JSON or DAG-CBOR)
43    const ENCODING: MessageEncoding;
44
45    /// Message union type
46    type Message<'de>: Deserialize<'de> + IntoStatic;
47
48    /// Error union type
49    type Error<'de>: Error + Deserialize<'de> + IntoStatic;
50
51    /// Decode a message from bytes.
52    ///
53    /// Default implementation uses simple deserialization via serde.
54    /// Subscriptions that use framed encoding (header + body) can override
55    /// this to do two-stage deserialization.
56    fn decode_message<'de>(bytes: &'de [u8]) -> Result<Self::Message<'de>, DecodeError> {
57        match Self::ENCODING {
58            MessageEncoding::Json => serde_json::from_slice(bytes).map_err(DecodeError::from),
59            MessageEncoding::DagCbor => {
60                serde_ipld_dagcbor::from_slice(bytes).map_err(DecodeError::from)
61            }
62        }
63    }
64}
65
66/// XRPC subscription (WebSocket)
67///
68/// This trait is analogous to `XrpcRequest` but for WebSocket subscriptions.
69/// It defines the NSID and associated stream response type.
70///
71/// The trait is implemented on the subscription parameters type.
72pub trait XrpcSubscription: Serialize {
73    /// The NSID for this XRPC subscription
74    const NSID: &'static str;
75
76    /// Message encoding (JSON or DAG-CBOR)
77    const ENCODING: MessageEncoding;
78
79    /// Custom path override (e.g., "/subscribe" for Jetstream).
80    /// If None, defaults to "/xrpc/{NSID}"
81    const CUSTOM_PATH: Option<&'static str> = None;
82
83    /// Stream response type (marker struct)
84    type Stream: SubscriptionResp;
85
86    /// Encode query params for WebSocket URL
87    ///
88    /// Default implementation uses serde_html_form to encode the struct as query parameters.
89    fn query_params(&self) -> Vec<(String, String)> {
90        // Default: use serde_html_form to encode self
91        serde_html_form::to_string(self)
92            .ok()
93            .map(|s| {
94                s.split('&')
95                    .filter_map(|pair| {
96                        let mut parts = pair.splitn(2, '=');
97                        Some((parts.next()?.to_string(), parts.next()?.to_string()))
98                    })
99                    .collect()
100            })
101            .unwrap_or_default()
102    }
103}
104
105/// Header for framed DAG-CBOR subscription messages.
106///
107/// Used in ATProto subscription streams where each message has a CBOR-encoded header
108/// followed by the message body.
109#[derive(Debug, serde::Deserialize)]
110pub struct EventHeader {
111    /// Operation code
112    pub op: i64,
113    /// Event type discriminator (e.g., "#commit", "#identity")
114    pub t: smol_str::SmolStr,
115}
116
117/// Parse a framed DAG-CBOR message header and return the header plus remaining body bytes.
118///
119/// Used for two-stage deserialization of subscription messages in formats like
120/// `com.atproto.sync.subscribeRepos`.
121pub fn parse_event_header<'a>(bytes: &'a [u8]) -> Result<(EventHeader, &'a [u8]), DecodeError> {
122    let mut cursor = std::io::Cursor::new(bytes);
123    let header: EventHeader = ciborium::de::from_reader(&mut cursor)?;
124    let position = cursor.position() as usize;
125    drop(cursor); // explicit drop before reborrowing bytes
126
127    Ok((header, &bytes[position..]))
128}
129
130/// Decode JSON messages from a WebSocket stream
131pub fn decode_json_msg<S: SubscriptionResp>(
132    msg_result: Result<crate::websocket::WsMessage, StreamError>,
133) -> Option<Result<StreamMessage<'static, S>, StreamError>>
134where
135    for<'a> StreamMessage<'a, S>: IntoStatic<Output = StreamMessage<'static, S>>,
136{
137    use crate::websocket::WsMessage;
138
139    match msg_result {
140        Ok(WsMessage::Text(text)) => Some(
141            S::decode_message(text.as_ref())
142                .map(|v| v.into_static())
143                .map_err(StreamError::decode),
144        ),
145        Ok(WsMessage::Binary(bytes)) => {
146            #[cfg(feature = "zstd")]
147            {
148                // Try to decompress with zstd first (Jetstream uses zstd compression)
149                match decompress_zstd(&bytes) {
150                    Ok(decompressed) => Some(
151                        S::decode_message(&decompressed)
152                            .map(|v| v.into_static())
153                            .map_err(StreamError::decode),
154                    ),
155                    Err(_) => {
156                        // Not zstd-compressed, try direct decode
157                        Some(
158                            S::decode_message(&bytes)
159                                .map(|v| v.into_static())
160                                .map_err(StreamError::decode),
161                        )
162                    }
163                }
164            }
165            #[cfg(not(feature = "zstd"))]
166            {
167                Some(
168                    S::decode_message(&bytes)
169                        .map(|v| v.into_static())
170                        .map_err(StreamError::decode),
171                )
172            }
173        }
174        Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
175        Err(e) => Some(Err(e)),
176    }
177}
178
179#[cfg(feature = "zstd")]
180fn decompress_zstd(bytes: &[u8]) -> Result<Vec<u8>, std::io::Error> {
181    use std::sync::OnceLock;
182    use zstd::stream::decode_all;
183
184    static DICTIONARY: OnceLock<Vec<u8>> = OnceLock::new();
185
186    let dict = DICTIONARY.get_or_init(|| include_bytes!("../../zstd_dictionary").to_vec());
187
188    decode_all(std::io::Cursor::new(bytes)).or_else(|_| {
189        // Try with dictionary
190        let mut decoder = zstd::Decoder::with_dictionary(std::io::Cursor::new(bytes), dict)?;
191        let mut result = Vec::new();
192        std::io::Read::read_to_end(&mut decoder, &mut result)?;
193        Ok(result)
194    })
195}
196
197/// Decode CBOR messages from a WebSocket stream
198pub fn decode_cbor_msg<S: SubscriptionResp>(
199    msg_result: Result<crate::websocket::WsMessage, StreamError>,
200) -> Option<Result<StreamMessage<'static, S>, StreamError>>
201where
202    for<'a> StreamMessage<'a, S>: IntoStatic<Output = StreamMessage<'static, S>>,
203{
204    use crate::websocket::WsMessage;
205
206    match msg_result {
207        Ok(WsMessage::Binary(bytes)) => Some(
208            S::decode_message(&bytes)
209                .map(|v| v.into_static())
210                .map_err(StreamError::decode),
211        ),
212        Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format(
213            "expected binary frame for CBOR, got text",
214        ))),
215        Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
216        Err(e) => Some(Err(e)),
217    }
218}
219
220/// Websocket subscriber-sent control message
221///
222/// Note: this is not meaningful for atproto event stream endpoints as
223/// those do not support control after the fact. Jetstream does, however.
224///
225/// If you wish to control an ongoing Jetstream connection, wrap the [`WsSink`]
226/// returned from one of the `into_*` methods of the [`SubscriptionStream`]
227/// in a [`SubscriptionController`] with the corresponding message implementing
228/// this trait as a generic parameter.
229pub trait SubscriptionControlMessage: Serialize {
230    /// The subscription this is associated with
231    type Subscription: XrpcSubscription;
232
233    /// Encode the control message for transmission
234    ///
235    /// Defaults to json text (matches Jetstream)
236    fn encode(&self) -> Result<WsMessage, StreamError> {
237        Ok(WsMessage::from(
238            serde_json::to_string(&self).map_err(StreamError::encode)?,
239        ))
240    }
241
242    /// Decode the control message
243    fn decode<'de>(frame: &'de [u8]) -> Result<Self, StreamError>
244    where
245        Self: Deserialize<'de>,
246    {
247        Ok(serde_json::from_slice(frame).map_err(StreamError::decode)?)
248    }
249}
250
251/// Control a websocket stream with a given subscription control message
252pub struct SubscriptionController<S: SubscriptionControlMessage> {
253    controller: WsSink,
254    _marker: PhantomData<fn() -> S>,
255}
256
257impl<S: SubscriptionControlMessage> SubscriptionController<S> {
258    /// Create a new subscription controller from a WebSocket sink.
259    pub fn new(controller: WsSink) -> Self {
260        Self {
261            controller,
262            _marker: PhantomData,
263        }
264    }
265
266    /// Configure the upstream connection via the websocket
267    pub async fn configure(&mut self, params: &S) -> Result<(), StreamError> {
268        let message = params.encode()?;
269
270        n0_future::SinkExt::send(self.controller.get_mut(), message)
271            .await
272            .map_err(StreamError::transport)
273    }
274}
275
276/// Typed subscription stream wrapping a WebSocket connection.
277///
278/// Analogous to `Response<R>` for XRPC but for subscription streams.
279/// Automatically decodes messages based on the subscription's encoding format.
280pub struct SubscriptionStream<S: SubscriptionResp> {
281    _marker: PhantomData<fn() -> S>,
282    connection: WebSocketConnection,
283}
284
285impl<S: SubscriptionResp> SubscriptionStream<S> {
286    /// Create a new subscription stream from a WebSocket connection.
287    pub fn new(connection: WebSocketConnection) -> Self {
288        Self {
289            _marker: PhantomData,
290            connection,
291        }
292    }
293
294    /// Get a reference to the underlying WebSocket connection.
295    pub fn connection(&self) -> &WebSocketConnection {
296        &self.connection
297    }
298
299    /// Get a mutable reference to the underlying WebSocket connection.
300    pub fn connection_mut(&mut self) -> &mut WebSocketConnection {
301        &mut self.connection
302    }
303
304    /// Split the connection and decode messages into a typed stream.
305    ///
306    /// Returns a tuple of (sender, typed message stream).
307    /// Messages are decoded according to the subscription's ENCODING.
308    pub fn into_stream(
309        self,
310    ) -> (
311        WsSink,
312        Boxed<Result<StreamMessage<'static, S>, StreamError>>,
313    )
314    where
315        for<'a> StreamMessage<'a, S>: IntoStatic<Output = StreamMessage<'static, S>>,
316    {
317        use n0_future::StreamExt as _;
318
319        let (tx, rx) = self.connection.split();
320
321        #[cfg(not(target_arch = "wasm32"))]
322        let stream = match S::ENCODING {
323            MessageEncoding::Json => rx
324                .into_inner()
325                .filter_map(|msg| decode_json_msg::<S>(msg))
326                .boxed(),
327            MessageEncoding::DagCbor => rx
328                .into_inner()
329                .filter_map(|msg| decode_cbor_msg::<S>(msg))
330                .boxed(),
331        };
332
333        #[cfg(target_arch = "wasm32")]
334        let stream = match S::ENCODING {
335            MessageEncoding::Json => rx
336                .into_inner()
337                .filter_map(|msg| decode_json_msg::<S>(msg))
338                .boxed_local(),
339            MessageEncoding::DagCbor => rx
340                .into_inner()
341                .filter_map(|msg| decode_cbor_msg::<S>(msg))
342                .boxed_local(),
343        };
344
345        (tx, stream)
346    }
347
348    /// Converts the subscription into a stream of raw atproto data.
349    pub fn into_raw_data_stream(self) -> (WsSink, Boxed<Result<RawData<'static>, StreamError>>) {
350        use n0_future::StreamExt as _;
351
352        let (tx, rx) = self.connection.split();
353
354        fn parse_msg<'a>(bytes: &'a [u8]) -> Result<RawData<'a>, serde_json::Error> {
355            serde_json::from_slice(bytes)
356        }
357        fn parse_cbor<'a>(
358            bytes: &'a [u8],
359        ) -> Result<RawData<'a>, serde_ipld_dagcbor::DecodeError<std::convert::Infallible>>
360        {
361            serde_ipld_dagcbor::from_slice(bytes)
362        }
363
364        #[cfg(not(target_arch = "wasm32"))]
365        let stream = match S::ENCODING {
366            MessageEncoding::Json => rx
367                .into_inner()
368                .filter_map(|msg_result| match msg_result {
369                    Ok(WsMessage::Text(text)) => Some(
370                        parse_msg(text.as_ref())
371                            .map(|v| v.into_static())
372                            .map_err(StreamError::decode),
373                    ),
374                    Ok(WsMessage::Binary(bytes)) => {
375                        #[cfg(feature = "zstd")]
376                        {
377                            match decompress_zstd(&bytes) {
378                                Ok(decompressed) => Some(
379                                    parse_msg(&decompressed)
380                                        .map(|v| v.into_static())
381                                        .map_err(StreamError::decode),
382                                ),
383                                Err(_) => Some(
384                                    parse_msg(&bytes)
385                                        .map(|v| v.into_static())
386                                        .map_err(StreamError::decode),
387                                ),
388                            }
389                        }
390                        #[cfg(not(feature = "zstd"))]
391                        {
392                            Some(
393                                parse_msg(&bytes)
394                                    .map(|v| v.into_static())
395                                    .map_err(StreamError::decode),
396                            )
397                        }
398                    }
399                    Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
400                    Err(e) => Some(Err(e)),
401                })
402                .boxed(),
403            MessageEncoding::DagCbor => rx
404                .into_inner()
405                .filter_map(|msg_result| match msg_result {
406                    Ok(WsMessage::Binary(bytes)) => Some(
407                        parse_cbor(&bytes)
408                            .map(|v| v.into_static())
409                            .map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))),
410                    ),
411                    Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format(
412                        "expected binary frame for CBOR, got text",
413                    ))),
414                    Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
415                    Err(e) => Some(Err(e)),
416                })
417                .boxed(),
418        };
419
420        #[cfg(target_arch = "wasm32")]
421        let stream = match S::ENCODING {
422            MessageEncoding::Json => rx
423                .into_inner()
424                .filter_map(|msg_result| match msg_result {
425                    Ok(WsMessage::Text(text)) => Some(
426                        parse_msg(text.as_ref())
427                            .map(|v| v.into_static())
428                            .map_err(StreamError::decode),
429                    ),
430                    Ok(WsMessage::Binary(bytes)) => {
431                        #[cfg(feature = "zstd")]
432                        {
433                            match decompress_zstd(&bytes) {
434                                Ok(decompressed) => Some(
435                                    parse_msg(&decompressed)
436                                        .map(|v| v.into_static())
437                                        .map_err(StreamError::decode),
438                                ),
439                                Err(_) => Some(
440                                    parse_msg(&bytes)
441                                        .map(|v| v.into_static())
442                                        .map_err(StreamError::decode),
443                                ),
444                            }
445                        }
446                        #[cfg(not(feature = "zstd"))]
447                        {
448                            Some(
449                                parse_msg(&bytes)
450                                    .map(|v| v.into_static())
451                                    .map_err(StreamError::decode),
452                            )
453                        }
454                    }
455                    Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
456                    Err(e) => Some(Err(e)),
457                })
458                .boxed_local(),
459            MessageEncoding::DagCbor => rx
460                .into_inner()
461                .filter_map(|msg_result| match msg_result {
462                    Ok(WsMessage::Binary(bytes)) => Some(
463                        parse_cbor(&bytes)
464                            .map(|v| v.into_static())
465                            .map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))),
466                    ),
467                    Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format(
468                        "expected binary frame for CBOR, got text",
469                    ))),
470                    Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
471                    Err(e) => Some(Err(e)),
472                })
473                .boxed_local(),
474        };
475
476        (tx, stream)
477    }
478
479    /// Converts the subscription into a stream of loosely-typed atproto data.
480    pub fn into_data_stream(self) -> (WsSink, Boxed<Result<Data<'static>, StreamError>>) {
481        use n0_future::StreamExt as _;
482
483        let (tx, rx) = self.connection.split();
484
485        fn parse_msg<'a>(bytes: &'a [u8]) -> Result<Data<'a>, serde_json::Error> {
486            serde_json::from_slice(bytes)
487        }
488        fn parse_cbor<'a>(
489            bytes: &'a [u8],
490        ) -> Result<Data<'a>, serde_ipld_dagcbor::DecodeError<std::convert::Infallible>> {
491            serde_ipld_dagcbor::from_slice(bytes)
492        }
493
494        #[cfg(not(target_arch = "wasm32"))]
495        let stream = match S::ENCODING {
496            MessageEncoding::Json => rx
497                .into_inner()
498                .filter_map(|msg_result| match msg_result {
499                    Ok(WsMessage::Text(text)) => Some(
500                        parse_msg(text.as_ref())
501                            .map(|v| v.into_static())
502                            .map_err(StreamError::decode),
503                    ),
504                    Ok(WsMessage::Binary(bytes)) => {
505                        #[cfg(feature = "zstd")]
506                        {
507                            match decompress_zstd(&bytes) {
508                                Ok(decompressed) => Some(
509                                    parse_msg(&decompressed)
510                                        .map(|v| v.into_static())
511                                        .map_err(StreamError::decode),
512                                ),
513                                Err(_) => Some(
514                                    parse_msg(&bytes)
515                                        .map(|v| v.into_static())
516                                        .map_err(StreamError::decode),
517                                ),
518                            }
519                        }
520                        #[cfg(not(feature = "zstd"))]
521                        {
522                            Some(
523                                parse_msg(&bytes)
524                                    .map(|v| v.into_static())
525                                    .map_err(StreamError::decode),
526                            )
527                        }
528                    }
529                    Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
530                    Err(e) => Some(Err(e)),
531                })
532                .boxed(),
533            MessageEncoding::DagCbor => rx
534                .into_inner()
535                .filter_map(|msg_result| match msg_result {
536                    Ok(WsMessage::Binary(bytes)) => Some(
537                        parse_cbor(&bytes)
538                            .map(|v| v.into_static())
539                            .map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))),
540                    ),
541                    Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format(
542                        "expected binary frame for CBOR, got text",
543                    ))),
544                    Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
545                    Err(e) => Some(Err(e)),
546                })
547                .boxed(),
548        };
549
550        #[cfg(target_arch = "wasm32")]
551        let stream = match S::ENCODING {
552            MessageEncoding::Json => rx
553                .into_inner()
554                .filter_map(|msg_result| match msg_result {
555                    Ok(WsMessage::Text(text)) => Some(
556                        parse_msg(text.as_ref())
557                            .map(|v| v.into_static())
558                            .map_err(StreamError::decode),
559                    ),
560                    Ok(WsMessage::Binary(bytes)) => {
561                        #[cfg(feature = "zstd")]
562                        {
563                            match decompress_zstd(&bytes) {
564                                Ok(decompressed) => Some(
565                                    parse_msg(&decompressed)
566                                        .map(|v| v.into_static())
567                                        .map_err(StreamError::decode),
568                                ),
569                                Err(_) => Some(
570                                    parse_msg(&bytes)
571                                        .map(|v| v.into_static())
572                                        .map_err(StreamError::decode),
573                                ),
574                            }
575                        }
576                        #[cfg(not(feature = "zstd"))]
577                        {
578                            Some(
579                                parse_msg(&bytes)
580                                    .map(|v| v.into_static())
581                                    .map_err(StreamError::decode),
582                            )
583                        }
584                    }
585                    Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
586                    Err(e) => Some(Err(e)),
587                })
588                .boxed_local(),
589            MessageEncoding::DagCbor => rx
590                .into_inner()
591                .filter_map(|msg_result| match msg_result {
592                    Ok(WsMessage::Binary(bytes)) => Some(
593                        parse_cbor(&bytes)
594                            .map(|v| v.into_static())
595                            .map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))),
596                    ),
597                    Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format(
598                        "expected binary frame for CBOR, got text",
599                    ))),
600                    Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
601                    Err(e) => Some(Err(e)),
602                })
603                .boxed_local(),
604        };
605
606        (tx, stream)
607    }
608
609    /// Consume the stream and return the underlying connection.
610    pub fn into_connection(self) -> WebSocketConnection {
611        self.connection
612    }
613
614    /// Tee the stream, keeping the raw stream in self and returning a typed stream.
615    ///
616    /// Replaces the internal WebSocket stream with one copy and returns a typed decoded
617    /// stream. Both streams receive all messages. Useful for observing raw messages
618    /// while also processing typed messages.
619    pub fn tee(&mut self) -> Boxed<Result<StreamMessage<'static, S>, StreamError>>
620    where
621        for<'a> StreamMessage<'a, S>: IntoStatic<Output = StreamMessage<'static, S>>,
622    {
623        use n0_future::StreamExt as _;
624
625        let rx = self.connection.receiver_mut();
626        let (raw_rx, typed_rx_source) =
627            std::mem::replace(rx, WsStream::new(n0_future::stream::empty())).tee();
628
629        // Put the raw stream back
630        *rx = raw_rx;
631
632        #[cfg(not(target_arch = "wasm32"))]
633        let stream = match S::ENCODING {
634            MessageEncoding::Json => typed_rx_source
635                .into_inner()
636                .filter_map(|msg| decode_json_msg::<S>(msg))
637                .boxed(),
638            MessageEncoding::DagCbor => typed_rx_source
639                .into_inner()
640                .filter_map(|msg| decode_cbor_msg::<S>(msg))
641                .boxed(),
642        };
643
644        #[cfg(target_arch = "wasm32")]
645        let stream = match S::ENCODING {
646            MessageEncoding::Json => typed_rx_source
647                .into_inner()
648                .filter_map(|msg| decode_json_msg::<S>(msg))
649                .boxed_local(),
650            MessageEncoding::DagCbor => typed_rx_source
651                .into_inner()
652                .filter_map(|msg| decode_cbor_msg::<S>(msg))
653                .boxed_local(),
654        };
655        stream
656    }
657}
658
659type StreamMessage<'a, R> = <R as SubscriptionResp>::Message<'a>;
660
661/// XRPC subscription endpoint trait (server-side)
662///
663/// Analogous to `XrpcEndpoint` but for WebSocket subscriptions.
664/// Defines the fully-qualified path and associated parameter/stream types.
665///
666/// This exists primarily for server-side frameworks (like Axum) to extract
667/// typed subscription parameters without lifetime issues.
668pub trait SubscriptionEndpoint {
669    /// Fully-qualified path ('/xrpc/[nsid]') where this subscription endpoint lives
670    const PATH: &'static str;
671
672    /// Message encoding (JSON or DAG-CBOR)
673    const ENCODING: MessageEncoding;
674
675    /// Subscription parameters type
676    type Params<'de>: XrpcSubscription + Deserialize<'de> + IntoStatic;
677
678    /// Stream response type
679    type Stream: SubscriptionResp;
680}
681
682/// Per-subscription options for WebSocket subscriptions.
683#[derive(Debug, Default, Clone)]
684pub struct SubscriptionOptions<'a> {
685    /// Extra headers to attach to this subscription (e.g., Authorization).
686    pub headers: Vec<(CowStr<'a>, CowStr<'a>)>,
687}
688
689impl IntoStatic for SubscriptionOptions<'_> {
690    type Output = SubscriptionOptions<'static>;
691
692    fn into_static(self) -> Self::Output {
693        SubscriptionOptions {
694            headers: self
695                .headers
696                .into_iter()
697                .map(|(k, v)| (k.into_static(), v.into_static()))
698                .collect(),
699        }
700    }
701}
702
703/// Extension for stateless subscription calls on any `WebSocketClient`.
704///
705/// Provides a builder pattern for establishing WebSocket subscriptions with custom options.
706pub trait SubscriptionExt: WebSocketClient {
707    /// Start building a subscription call for the given base URL.
708    fn subscription<'a>(&'a self, base: Url) -> SubscriptionCall<'a, Self>
709    where
710        Self: Sized,
711    {
712        SubscriptionCall {
713            client: self,
714            base,
715            opts: SubscriptionOptions::default(),
716        }
717    }
718}
719
720impl<T: WebSocketClient> SubscriptionExt for T {}
721
722/// Stateless subscription call builder.
723///
724/// Provides methods for adding headers and establishing typed subscriptions.
725pub struct SubscriptionCall<'a, C: WebSocketClient> {
726    pub(crate) client: &'a C,
727    pub(crate) base: Url,
728    pub(crate) opts: SubscriptionOptions<'a>,
729}
730
731impl<'a, C: WebSocketClient> SubscriptionCall<'a, C> {
732    /// Add an extra header.
733    pub fn header(mut self, name: impl Into<CowStr<'a>>, value: impl Into<CowStr<'a>>) -> Self {
734        self.opts.headers.push((name.into(), value.into()));
735        self
736    }
737
738    /// Replace the builder's options entirely.
739    pub fn with_options(mut self, opts: SubscriptionOptions<'a>) -> Self {
740        self.opts = opts;
741        self
742    }
743
744    /// Subscribe to the given XRPC subscription endpoint.
745    ///
746    /// Builds a WebSocket URL from the base, appends the NSID path,
747    /// encodes query parameters from the subscription type, and connects.
748    /// Returns a typed SubscriptionStream that automatically decodes messages.
749    pub async fn subscribe<Sub>(
750        self,
751        params: &Sub,
752    ) -> Result<SubscriptionStream<Sub::Stream>, C::Error>
753    where
754        Sub: XrpcSubscription,
755    {
756        let mut url = self.base.clone();
757
758        // Use custom path if provided, otherwise construct from NSID
759        let mut path = url.path().trim_end_matches('/').to_owned();
760        if let Some(custom_path) = Sub::CUSTOM_PATH {
761            path.push_str(custom_path);
762        } else {
763            path.push_str("/xrpc/");
764            path.push_str(Sub::NSID);
765        }
766        url.set_path(&path);
767
768        let query_params = params.query_params();
769        if !query_params.is_empty() {
770            let qs = query_params
771                .iter()
772                .map(|(k, v)| format!("{}={}", k, v))
773                .collect::<Vec<_>>()
774                .join("&");
775            url.set_query(Some(&qs));
776        } else {
777            url.set_query(None);
778        }
779
780        let connection = self
781            .client
782            .connect_with_headers(url, self.opts.headers)
783            .await?;
784
785        Ok(SubscriptionStream::new(connection))
786    }
787}
788
789/// Stateful subscription client trait.
790///
791/// Analogous to `XrpcClient` but for WebSocket subscriptions.
792/// Provides a stateful interface for subscribing with configured base URI and options.
793#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
794pub trait SubscriptionClient: WebSocketClient {
795    /// Get the base URI for the client.
796    fn base_uri(&self) -> impl Future<Output = CowStr<'static>>;
797
798    /// Get the subscription options for the client.
799    fn subscription_opts(&self) -> impl Future<Output = SubscriptionOptions<'_>> {
800        async { SubscriptionOptions::default() }
801    }
802
803    /// Subscribe to an XRPC subscription endpoint using the client's base URI and options.
804    #[cfg(not(target_arch = "wasm32"))]
805    fn subscribe<Sub>(
806        &self,
807        params: &Sub,
808    ) -> impl Future<Output = Result<SubscriptionStream<Sub::Stream>, Self::Error>>
809    where
810        Sub: XrpcSubscription + Send + Sync,
811        Self: Sync;
812
813    /// Subscribe to an XRPC subscription endpoint using the client's base URI and options.
814    #[cfg(target_arch = "wasm32")]
815    fn subscribe<Sub>(
816        &self,
817        params: &Sub,
818    ) -> impl Future<Output = Result<SubscriptionStream<Sub::Stream>, Self::Error>>
819    where
820        Sub: XrpcSubscription + Send + Sync;
821
822    /// Subscribe with custom options.
823    #[cfg(not(target_arch = "wasm32"))]
824    fn subscribe_with_opts<Sub>(
825        &self,
826        params: &Sub,
827        opts: SubscriptionOptions<'_>,
828    ) -> impl Future<Output = Result<SubscriptionStream<Sub::Stream>, Self::Error>>
829    where
830        Sub: XrpcSubscription + Send + Sync,
831        Self: Sync;
832
833    /// Subscribe with custom options.
834    #[cfg(target_arch = "wasm32")]
835    fn subscribe_with_opts<Sub>(
836        &self,
837        params: &Sub,
838        opts: SubscriptionOptions<'_>,
839    ) -> impl Future<Output = Result<SubscriptionStream<Sub::Stream>, Self::Error>>
840    where
841        Sub: XrpcSubscription + Send + Sync;
842}
843
844/// Simple stateless subscription client wrapping a WebSocketClient.
845///
846/// Analogous to a basic HTTP client but for WebSocket subscriptions.
847/// Does not manage sessions or authentication - useful for public subscriptions
848/// or when you want to handle auth manually via headers.
849pub struct BasicSubscriptionClient<W: WebSocketClient> {
850    client: W,
851    base_uri: CowStr<'static>,
852    opts: SubscriptionOptions<'static>,
853}
854
855impl<W: WebSocketClient> BasicSubscriptionClient<W> {
856    /// Create a new basic subscription client with the given WebSocket client and base URI.
857    pub fn new(client: W, base_uri: Url) -> Self {
858        let base_uri = base_uri.as_str().trim_end_matches("/");
859        Self {
860            client,
861            base_uri: base_uri.to_cowstr().into_static(),
862            opts: SubscriptionOptions::default(),
863        }
864    }
865
866    /// Create with default options.
867    pub fn with_options(mut self, opts: SubscriptionOptions<'_>) -> Self {
868        self.opts = opts.into_static();
869        self
870    }
871
872    /// Get a reference to the inner WebSocket client.
873    pub fn inner(&self) -> &W {
874        &self.client
875    }
876}
877
878impl<W: WebSocketClient> WebSocketClient for BasicSubscriptionClient<W> {
879    type Error = W::Error;
880
881    async fn connect(&self, url: Url) -> Result<WebSocketConnection, Self::Error> {
882        self.client.connect(url).await
883    }
884
885    async fn connect_with_headers(
886        &self,
887        url: Url,
888        headers: Vec<(CowStr<'_>, CowStr<'_>)>,
889    ) -> Result<WebSocketConnection, Self::Error> {
890        self.client.connect_with_headers(url, headers).await
891    }
892}
893
894impl<W: WebSocketClient> SubscriptionClient for BasicSubscriptionClient<W> {
895    async fn base_uri(&self) -> CowStr<'static> {
896        self.base_uri.clone()
897    }
898
899    async fn subscription_opts(&self) -> SubscriptionOptions<'_> {
900        self.opts.clone()
901    }
902
903    #[cfg(not(target_arch = "wasm32"))]
904    async fn subscribe<Sub>(
905        &self,
906        params: &Sub,
907    ) -> Result<SubscriptionStream<Sub::Stream>, Self::Error>
908    where
909        Sub: XrpcSubscription + Send + Sync,
910        Self: Sync,
911    {
912        let opts = self.subscription_opts().await;
913        self.subscribe_with_opts(params, opts).await
914    }
915
916    #[cfg(target_arch = "wasm32")]
917    async fn subscribe<Sub>(
918        &self,
919        params: &Sub,
920    ) -> Result<SubscriptionStream<Sub::Stream>, Self::Error>
921    where
922        Sub: XrpcSubscription + Send + Sync,
923    {
924        let opts = self.subscription_opts().await;
925        self.subscribe_with_opts(params, opts).await
926    }
927
928    #[cfg(not(target_arch = "wasm32"))]
929    async fn subscribe_with_opts<Sub>(
930        &self,
931        params: &Sub,
932        opts: SubscriptionOptions<'_>,
933    ) -> Result<SubscriptionStream<Sub::Stream>, Self::Error>
934    where
935        Sub: XrpcSubscription + Send + Sync,
936        Self: Sync,
937    {
938        let base = self.base_uri().await;
939        let base = Url::parse(&base).expect("Failed to parse base URL");
940        self.subscription(base)
941            .with_options(opts)
942            .subscribe(params)
943            .await
944    }
945
946    #[cfg(target_arch = "wasm32")]
947    async fn subscribe_with_opts<Sub>(
948        &self,
949        params: &Sub,
950        opts: SubscriptionOptions<'_>,
951    ) -> Result<SubscriptionStream<Sub::Stream>, Self::Error>
952    where
953        Sub: XrpcSubscription + Send + Sync,
954    {
955        let base = self.base_uri().await;
956        let base = Url::parse(&base).expect("Failed to parse base URL");
957        self.subscription(base)
958            .with_options(opts)
959            .subscribe(params)
960            .await
961    }
962}
963
964/// Type alias for a basic subscription client using the default TungsteniteClient.
965///
966/// Provides a simple, stateless WebSocket subscription client without session management.
967/// Useful for public subscriptions or when handling authentication manually.
968///
969/// # Example
970///
971/// ```no_run
972/// # use jacquard_common::xrpc::{TungsteniteSubscriptionClient, SubscriptionClient};
973/// # use url::Url;
974/// # #[tokio::main]
975/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
976/// let base = Url::parse("wss://bsky.network")?;
977/// let client = TungsteniteSubscriptionClient::from_base_uri(base);
978/// // let conn = client.subscribe(&params).await?;
979/// # Ok(())
980/// # }
981/// ```
982pub type TungsteniteSubscriptionClient =
983    BasicSubscriptionClient<crate::websocket::tungstenite_client::TungsteniteClient>;
984
985impl TungsteniteSubscriptionClient {
986    /// Create a new Tungstenite-backed subscription client with the given base URI.
987    pub fn from_base_uri(base_uri: Url) -> Self {
988        let client = crate::websocket::tungstenite_client::TungsteniteClient::new();
989        BasicSubscriptionClient::new(client, base_uri)
990    }
991}