use crate::deps::fluent_uri::{
ParseError, Uri,
pct_enc::{
EString,
encoder::{Data as EncData, Query},
},
};
use crate::error::DecodeError;
use crate::stream::StreamError;
use crate::websocket::{WebSocketClient, WebSocketConnection, WsSink, WsStream};
use crate::{CowStr, Data, IntoStatic, RawData, WsMessage};
use alloc::borrow::ToOwned;
use alloc::string::String;
use alloc::string::ToString;
use alloc::vec::Vec;
use core::error::Error;
use core::future::Future;
use core::marker::PhantomData;
#[cfg(not(target_arch = "wasm32"))]
use n0_future::stream::Boxed;
#[cfg(target_arch = "wasm32")]
use n0_future::stream::BoxedLocal as Boxed;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MessageEncoding {
Json,
DagCbor,
}
pub trait SubscriptionResp {
const NSID: &'static str;
const ENCODING: MessageEncoding;
type Message<'de>: Deserialize<'de> + IntoStatic;
type Error<'de>: Error + Deserialize<'de> + IntoStatic;
fn decode_message<'de>(bytes: &'de [u8]) -> Result<Self::Message<'de>, DecodeError> {
match Self::ENCODING {
MessageEncoding::Json => serde_json::from_slice(bytes).map_err(DecodeError::from),
MessageEncoding::DagCbor => {
serde_ipld_dagcbor::from_slice(bytes).map_err(DecodeError::from)
}
}
}
}
pub trait XrpcSubscription: Serialize {
const NSID: &'static str;
const ENCODING: MessageEncoding;
const CUSTOM_PATH: Option<&'static str> = None;
type Stream: SubscriptionResp;
fn query_params(&self) -> Vec<(String, String)> {
serde_html_form::to_string(self)
.ok()
.map(|s| {
s.split('&')
.filter_map(|pair| {
let mut parts = pair.splitn(2, '=');
Some((parts.next()?.to_string(), parts.next()?.to_string()))
})
.collect()
})
.unwrap_or_default()
}
}
#[derive(Debug, serde::Deserialize)]
pub struct EventHeader {
pub op: i64,
pub t: smol_str::SmolStr,
}
#[cfg(not(feature = "std"))]
struct SliceCursor<'a> {
slice: &'a [u8],
position: usize,
}
#[cfg(not(feature = "std"))]
impl<'a> SliceCursor<'a> {
fn new(slice: &'a [u8]) -> Self {
Self { slice, position: 0 }
}
fn position(&self) -> usize {
self.position
}
}
#[cfg(not(feature = "std"))]
impl ciborium_io::Read for SliceCursor<'_> {
type Error = core::convert::Infallible;
fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), Self::Error> {
let end = self.position + buf.len();
buf.copy_from_slice(&self.slice[self.position..end]);
self.position = end;
Ok(())
}
}
#[cfg(feature = "std")]
pub fn parse_event_header<'a>(bytes: &'a [u8]) -> Result<(EventHeader, &'a [u8]), DecodeError> {
let mut cursor = std::io::Cursor::new(bytes);
let header: EventHeader = ciborium::de::from_reader(&mut cursor)?;
let position = cursor.position() as usize;
drop(cursor);
Ok((header, &bytes[position..]))
}
#[cfg(not(feature = "std"))]
pub fn parse_event_header<'a>(bytes: &'a [u8]) -> Result<(EventHeader, &'a [u8]), DecodeError> {
let mut cursor = SliceCursor::new(bytes);
let header: EventHeader = ciborium::de::from_reader(&mut cursor)?;
let position = cursor.position();
Ok((header, &bytes[position..]))
}
pub fn decode_json_msg<S: SubscriptionResp>(
msg_result: Result<crate::websocket::WsMessage, StreamError>,
) -> Option<Result<StreamMessage<'static, S>, StreamError>>
where
for<'a> StreamMessage<'a, S>: IntoStatic<Output = StreamMessage<'static, S>>,
{
use crate::websocket::WsMessage;
match msg_result {
Ok(WsMessage::Text(text)) => Some(
S::decode_message(text.as_ref())
.map(|v| v.into_static())
.map_err(StreamError::decode),
),
Ok(WsMessage::Binary(bytes)) => {
#[cfg(feature = "zstd")]
{
match decompress_zstd(&bytes) {
Ok(decompressed) => Some(
S::decode_message(&decompressed)
.map(|v| v.into_static())
.map_err(StreamError::decode),
),
Err(_) => {
Some(
S::decode_message(&bytes)
.map(|v| v.into_static())
.map_err(StreamError::decode),
)
}
}
}
#[cfg(not(feature = "zstd"))]
{
Some(
S::decode_message(&bytes)
.map(|v| v.into_static())
.map_err(StreamError::decode),
)
}
}
Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
Err(e) => Some(Err(e)),
}
}
#[cfg(feature = "zstd")]
fn decompress_zstd(bytes: &[u8]) -> Result<Vec<u8>, std::io::Error> {
use std::sync::OnceLock;
use zstd::stream::decode_all;
static DICTIONARY: OnceLock<Vec<u8>> = OnceLock::new();
let dict = DICTIONARY.get_or_init(|| include_bytes!("../../zstd_dictionary").to_vec());
decode_all(std::io::Cursor::new(bytes)).or_else(|_| {
let mut decoder = zstd::Decoder::with_dictionary(std::io::Cursor::new(bytes), dict)?;
let mut result = Vec::new();
std::io::Read::read_to_end(&mut decoder, &mut result)?;
Ok(result)
})
}
pub fn decode_cbor_msg<S: SubscriptionResp>(
msg_result: Result<crate::websocket::WsMessage, StreamError>,
) -> Option<Result<StreamMessage<'static, S>, StreamError>>
where
for<'a> StreamMessage<'a, S>: IntoStatic<Output = StreamMessage<'static, S>>,
{
use crate::websocket::WsMessage;
match msg_result {
Ok(WsMessage::Binary(bytes)) => Some(
S::decode_message(&bytes)
.map(|v| v.into_static())
.map_err(StreamError::decode),
),
Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format(
"expected binary frame for CBOR, got text",
))),
Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
Err(e) => Some(Err(e)),
}
}
pub trait SubscriptionControlMessage: Serialize {
type Subscription: XrpcSubscription;
fn encode(&self) -> Result<WsMessage, StreamError> {
Ok(WsMessage::from(
serde_json::to_string(&self).map_err(StreamError::encode)?,
))
}
fn decode<'de>(frame: &'de [u8]) -> Result<Self, StreamError>
where
Self: Deserialize<'de>,
{
Ok(serde_json::from_slice(frame).map_err(StreamError::decode)?)
}
}
pub struct SubscriptionController<S: SubscriptionControlMessage> {
controller: WsSink,
_marker: PhantomData<fn() -> S>,
}
impl<S: SubscriptionControlMessage> SubscriptionController<S> {
pub fn new(controller: WsSink) -> Self {
Self {
controller,
_marker: PhantomData,
}
}
pub async fn configure(&mut self, params: &S) -> Result<(), StreamError> {
let message = params.encode()?;
n0_future::SinkExt::send(self.controller.get_mut(), message)
.await
.map_err(StreamError::transport)
}
}
pub struct SubscriptionStream<S: SubscriptionResp> {
_marker: PhantomData<fn() -> S>,
connection: WebSocketConnection,
}
impl<S: SubscriptionResp> SubscriptionStream<S> {
pub fn new(connection: WebSocketConnection) -> Self {
Self {
_marker: PhantomData,
connection,
}
}
pub fn connection(&self) -> &WebSocketConnection {
&self.connection
}
pub fn connection_mut(&mut self) -> &mut WebSocketConnection {
&mut self.connection
}
pub fn into_stream(
self,
) -> (
WsSink,
Boxed<Result<StreamMessage<'static, S>, StreamError>>,
)
where
for<'a> StreamMessage<'a, S>: IntoStatic<Output = StreamMessage<'static, S>>,
{
use n0_future::StreamExt as _;
let (tx, rx) = self.connection.split();
#[cfg(not(target_arch = "wasm32"))]
let stream = match S::ENCODING {
MessageEncoding::Json => rx
.into_inner()
.filter_map(|msg| decode_json_msg::<S>(msg))
.boxed(),
MessageEncoding::DagCbor => rx
.into_inner()
.filter_map(|msg| decode_cbor_msg::<S>(msg))
.boxed(),
};
#[cfg(target_arch = "wasm32")]
let stream = match S::ENCODING {
MessageEncoding::Json => rx
.into_inner()
.filter_map(|msg| decode_json_msg::<S>(msg))
.boxed_local(),
MessageEncoding::DagCbor => rx
.into_inner()
.filter_map(|msg| decode_cbor_msg::<S>(msg))
.boxed_local(),
};
(tx, stream)
}
pub fn into_raw_data_stream(self) -> (WsSink, Boxed<Result<RawData<'static>, StreamError>>) {
use n0_future::StreamExt as _;
let (tx, rx) = self.connection.split();
fn parse_msg<'a>(bytes: &'a [u8]) -> Result<RawData<'a>, serde_json::Error> {
serde_json::from_slice(bytes)
}
fn parse_cbor<'a>(
bytes: &'a [u8],
) -> Result<RawData<'a>, serde_ipld_dagcbor::DecodeError<core::convert::Infallible>>
{
serde_ipld_dagcbor::from_slice(bytes)
}
#[cfg(not(target_arch = "wasm32"))]
let stream = match S::ENCODING {
MessageEncoding::Json => rx
.into_inner()
.filter_map(|msg_result| match msg_result {
Ok(WsMessage::Text(text)) => Some(
parse_msg(text.as_ref())
.map(|v| v.into_static())
.map_err(StreamError::decode),
),
Ok(WsMessage::Binary(bytes)) => {
#[cfg(feature = "zstd")]
{
match decompress_zstd(&bytes) {
Ok(decompressed) => Some(
parse_msg(&decompressed)
.map(|v| v.into_static())
.map_err(StreamError::decode),
),
Err(_) => Some(
parse_msg(&bytes)
.map(|v| v.into_static())
.map_err(StreamError::decode),
),
}
}
#[cfg(not(feature = "zstd"))]
{
Some(
parse_msg(&bytes)
.map(|v| v.into_static())
.map_err(StreamError::decode),
)
}
}
Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
Err(e) => Some(Err(e)),
})
.boxed(),
MessageEncoding::DagCbor => rx
.into_inner()
.filter_map(|msg_result| match msg_result {
Ok(WsMessage::Binary(bytes)) => Some(
parse_cbor(&bytes)
.map(|v| v.into_static())
.map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))),
),
Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format(
"expected binary frame for CBOR, got text",
))),
Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
Err(e) => Some(Err(e)),
})
.boxed(),
};
#[cfg(target_arch = "wasm32")]
let stream = match S::ENCODING {
MessageEncoding::Json => rx
.into_inner()
.filter_map(|msg_result| match msg_result {
Ok(WsMessage::Text(text)) => Some(
parse_msg(text.as_ref())
.map(|v| v.into_static())
.map_err(StreamError::decode),
),
Ok(WsMessage::Binary(bytes)) => {
#[cfg(feature = "zstd")]
{
match decompress_zstd(&bytes) {
Ok(decompressed) => Some(
parse_msg(&decompressed)
.map(|v| v.into_static())
.map_err(StreamError::decode),
),
Err(_) => Some(
parse_msg(&bytes)
.map(|v| v.into_static())
.map_err(StreamError::decode),
),
}
}
#[cfg(not(feature = "zstd"))]
{
Some(
parse_msg(&bytes)
.map(|v| v.into_static())
.map_err(StreamError::decode),
)
}
}
Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
Err(e) => Some(Err(e)),
})
.boxed_local(),
MessageEncoding::DagCbor => rx
.into_inner()
.filter_map(|msg_result| match msg_result {
Ok(WsMessage::Binary(bytes)) => Some(
parse_cbor(&bytes)
.map(|v| v.into_static())
.map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))),
),
Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format(
"expected binary frame for CBOR, got text",
))),
Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
Err(e) => Some(Err(e)),
})
.boxed_local(),
};
(tx, stream)
}
pub fn into_data_stream(self) -> (WsSink, Boxed<Result<Data<'static>, StreamError>>) {
use n0_future::StreamExt as _;
let (tx, rx) = self.connection.split();
fn parse_msg<'a>(bytes: &'a [u8]) -> Result<Data<'a>, serde_json::Error> {
serde_json::from_slice(bytes)
}
fn parse_cbor<'a>(
bytes: &'a [u8],
) -> Result<Data<'a>, serde_ipld_dagcbor::DecodeError<core::convert::Infallible>> {
serde_ipld_dagcbor::from_slice(bytes)
}
#[cfg(not(target_arch = "wasm32"))]
let stream = match S::ENCODING {
MessageEncoding::Json => rx
.into_inner()
.filter_map(|msg_result| match msg_result {
Ok(WsMessage::Text(text)) => Some(
parse_msg(text.as_ref())
.map(|v| v.into_static())
.map_err(StreamError::decode),
),
Ok(WsMessage::Binary(bytes)) => {
#[cfg(feature = "zstd")]
{
match decompress_zstd(&bytes) {
Ok(decompressed) => Some(
parse_msg(&decompressed)
.map(|v| v.into_static())
.map_err(StreamError::decode),
),
Err(_) => Some(
parse_msg(&bytes)
.map(|v| v.into_static())
.map_err(StreamError::decode),
),
}
}
#[cfg(not(feature = "zstd"))]
{
Some(
parse_msg(&bytes)
.map(|v| v.into_static())
.map_err(StreamError::decode),
)
}
}
Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
Err(e) => Some(Err(e)),
})
.boxed(),
MessageEncoding::DagCbor => rx
.into_inner()
.filter_map(|msg_result| match msg_result {
Ok(WsMessage::Binary(bytes)) => Some(
parse_cbor(&bytes)
.map(|v| v.into_static())
.map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))),
),
Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format(
"expected binary frame for CBOR, got text",
))),
Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
Err(e) => Some(Err(e)),
})
.boxed(),
};
#[cfg(target_arch = "wasm32")]
let stream = match S::ENCODING {
MessageEncoding::Json => rx
.into_inner()
.filter_map(|msg_result| match msg_result {
Ok(WsMessage::Text(text)) => Some(
parse_msg(text.as_ref())
.map(|v| v.into_static())
.map_err(StreamError::decode),
),
Ok(WsMessage::Binary(bytes)) => {
#[cfg(feature = "zstd")]
{
match decompress_zstd(&bytes) {
Ok(decompressed) => Some(
parse_msg(&decompressed)
.map(|v| v.into_static())
.map_err(StreamError::decode),
),
Err(_) => Some(
parse_msg(&bytes)
.map(|v| v.into_static())
.map_err(StreamError::decode),
),
}
}
#[cfg(not(feature = "zstd"))]
{
Some(
parse_msg(&bytes)
.map(|v| v.into_static())
.map_err(StreamError::decode),
)
}
}
Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
Err(e) => Some(Err(e)),
})
.boxed_local(),
MessageEncoding::DagCbor => rx
.into_inner()
.filter_map(|msg_result| match msg_result {
Ok(WsMessage::Binary(bytes)) => Some(
parse_cbor(&bytes)
.map(|v| v.into_static())
.map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))),
),
Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format(
"expected binary frame for CBOR, got text",
))),
Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
Err(e) => Some(Err(e)),
})
.boxed_local(),
};
(tx, stream)
}
pub fn into_connection(self) -> WebSocketConnection {
self.connection
}
pub fn tee(&mut self) -> Boxed<Result<StreamMessage<'static, S>, StreamError>>
where
for<'a> StreamMessage<'a, S>: IntoStatic<Output = StreamMessage<'static, S>>,
{
use n0_future::StreamExt as _;
let rx = self.connection.receiver_mut();
let (raw_rx, typed_rx_source) =
core::mem::replace(rx, WsStream::new(n0_future::stream::empty())).tee();
*rx = raw_rx;
#[cfg(not(target_arch = "wasm32"))]
let stream = match S::ENCODING {
MessageEncoding::Json => typed_rx_source
.into_inner()
.filter_map(|msg| decode_json_msg::<S>(msg))
.boxed(),
MessageEncoding::DagCbor => typed_rx_source
.into_inner()
.filter_map(|msg| decode_cbor_msg::<S>(msg))
.boxed(),
};
#[cfg(target_arch = "wasm32")]
let stream = match S::ENCODING {
MessageEncoding::Json => typed_rx_source
.into_inner()
.filter_map(|msg| decode_json_msg::<S>(msg))
.boxed_local(),
MessageEncoding::DagCbor => typed_rx_source
.into_inner()
.filter_map(|msg| decode_cbor_msg::<S>(msg))
.boxed_local(),
};
stream
}
}
type StreamMessage<'a, R> = <R as SubscriptionResp>::Message<'a>;
pub trait SubscriptionEndpoint {
const PATH: &'static str;
const ENCODING: MessageEncoding;
type Params<'de>: XrpcSubscription + Deserialize<'de> + IntoStatic;
type Stream: SubscriptionResp;
}
#[derive(Debug, Default, Clone)]
pub struct SubscriptionOptions<'a> {
pub headers: Vec<(CowStr<'a>, CowStr<'a>)>,
}
impl IntoStatic for SubscriptionOptions<'_> {
type Output = SubscriptionOptions<'static>;
fn into_static(self) -> Self::Output {
SubscriptionOptions {
headers: self
.headers
.into_iter()
.map(|(k, v)| (k.into_static(), v.into_static()))
.collect(),
}
}
}
pub trait SubscriptionExt: WebSocketClient {
fn subscription<'a>(&'a self, base: Uri<String>) -> SubscriptionCall<'a, Self>
where
Self: Sized,
{
SubscriptionCall {
client: self,
base,
opts: SubscriptionOptions::default(),
}
}
}
impl<T: WebSocketClient> SubscriptionExt for T {}
fn build_subscription_uri(
base: &Uri<String>,
nsid: &str,
custom_path: Option<&str>,
query_params: &[(String, String)],
) -> Result<Uri<String>, ParseError> {
let base_path = base.path().as_str().trim_end_matches('/');
let mut path = String::with_capacity(base_path.len() + 50);
path.push_str(base_path);
if let Some(custom_path) = custom_path {
path.push_str(custom_path);
} else {
path.push_str("/xrpc/");
path.push_str(nsid);
}
let query_str = if !query_params.is_empty() {
query_params
.iter()
.map(|(k, v)| {
let mut enc_k = EString::<Query>::new();
enc_k.encode_str::<EncData>(k.as_str());
let mut enc_v = EString::<Query>::new();
enc_v.encode_str::<EncData>(v.as_str());
alloc::format!("{}={}", enc_k, enc_v)
})
.collect::<Vec<_>>()
.join("&")
} else {
String::new()
};
let capacity = base.scheme().as_str().len()
+ 3 + base.authority().map(|a| a.as_str().len()).unwrap_or(0)
+ path.len()
+ query_str.len()
+ if !query_str.is_empty() { 1 } else { 0 };
let mut uri_str = String::with_capacity(capacity);
uri_str.push_str(base.scheme().as_str());
uri_str.push_str("://");
if let Some(authority) = base.authority() {
uri_str.push_str(authority.as_str());
}
uri_str.push_str(&path);
if !query_str.is_empty() {
uri_str.push('?');
uri_str.push_str(&query_str);
}
Uri::parse(uri_str)
.map(|u| u.to_owned())
.map_err(|(e, _)| e)
}
pub struct SubscriptionCall<'a, C: WebSocketClient> {
pub(crate) client: &'a C,
pub(crate) base: Uri<String>,
pub(crate) opts: SubscriptionOptions<'a>,
}
impl<'a, C: WebSocketClient> SubscriptionCall<'a, C> {
pub fn header(mut self, name: impl Into<CowStr<'a>>, value: impl Into<CowStr<'a>>) -> Self {
self.opts.headers.push((name.into(), value.into()));
self
}
pub fn with_options(mut self, opts: SubscriptionOptions<'a>) -> Self {
self.opts = opts;
self
}
pub async fn subscribe<Sub>(
self,
params: &Sub,
) -> Result<SubscriptionStream<Sub::Stream>, C::Error>
where
Sub: XrpcSubscription,
{
let query_params = params.query_params();
let uri = build_subscription_uri(&self.base, Sub::NSID, Sub::CUSTOM_PATH, &query_params)
.expect("subscription URI must be valid (base_uri + path always yields a valid URI)");
let connection = self
.client
.connect_with_headers(uri.borrow(), self.opts.headers)
.await?;
Ok(SubscriptionStream::new(connection))
}
}
#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
pub trait SubscriptionClient: WebSocketClient {
fn base_uri(&self) -> impl Future<Output = Uri<String>>;
fn subscription_opts(&self) -> impl Future<Output = SubscriptionOptions<'_>> {
async { SubscriptionOptions::default() }
}
#[cfg(not(target_arch = "wasm32"))]
fn subscribe<Sub>(
&self,
params: &Sub,
) -> impl Future<Output = Result<SubscriptionStream<Sub::Stream>, Self::Error>>
where
Sub: XrpcSubscription + Send + Sync,
Self: Sync;
#[cfg(target_arch = "wasm32")]
fn subscribe<Sub>(
&self,
params: &Sub,
) -> impl Future<Output = Result<SubscriptionStream<Sub::Stream>, Self::Error>>
where
Sub: XrpcSubscription + Send + Sync;
#[cfg(not(target_arch = "wasm32"))]
fn subscribe_with_opts<Sub>(
&self,
params: &Sub,
opts: SubscriptionOptions<'_>,
) -> impl Future<Output = Result<SubscriptionStream<Sub::Stream>, Self::Error>>
where
Sub: XrpcSubscription + Send + Sync,
Self: Sync;
#[cfg(target_arch = "wasm32")]
fn subscribe_with_opts<Sub>(
&self,
params: &Sub,
opts: SubscriptionOptions<'_>,
) -> impl Future<Output = Result<SubscriptionStream<Sub::Stream>, Self::Error>>
where
Sub: XrpcSubscription + Send + Sync;
}
pub struct BasicSubscriptionClient<W: WebSocketClient> {
client: W,
base_uri: Uri<String>,
opts: SubscriptionOptions<'static>,
}
impl<W: WebSocketClient> BasicSubscriptionClient<W> {
pub fn new(client: W, base_uri: Uri<String>) -> Self {
Self {
client,
base_uri,
opts: SubscriptionOptions::default(),
}
}
pub fn with_options(mut self, opts: SubscriptionOptions<'_>) -> Self {
self.opts = opts.into_static();
self
}
pub fn inner(&self) -> &W {
&self.client
}
}
impl<W: WebSocketClient> WebSocketClient for BasicSubscriptionClient<W> {
type Error = W::Error;
async fn connect(&self, uri: Uri<&str>) -> Result<WebSocketConnection, Self::Error> {
self.client.connect(uri).await
}
async fn connect_with_headers(
&self,
uri: Uri<&str>,
headers: Vec<(CowStr<'_>, CowStr<'_>)>,
) -> Result<WebSocketConnection, Self::Error> {
self.client.connect_with_headers(uri, headers).await
}
}
impl<W: WebSocketClient> SubscriptionClient for BasicSubscriptionClient<W> {
async fn base_uri(&self) -> Uri<String> {
self.base_uri.clone()
}
async fn subscription_opts(&self) -> SubscriptionOptions<'_> {
self.opts.clone()
}
#[cfg(not(target_arch = "wasm32"))]
async fn subscribe<Sub>(
&self,
params: &Sub,
) -> Result<SubscriptionStream<Sub::Stream>, Self::Error>
where
Sub: XrpcSubscription + Send + Sync,
Self: Sync,
{
let opts = self.subscription_opts().await;
self.subscribe_with_opts(params, opts).await
}
#[cfg(target_arch = "wasm32")]
async fn subscribe<Sub>(
&self,
params: &Sub,
) -> Result<SubscriptionStream<Sub::Stream>, Self::Error>
where
Sub: XrpcSubscription + Send + Sync,
{
let opts = self.subscription_opts().await;
self.subscribe_with_opts(params, opts).await
}
#[cfg(not(target_arch = "wasm32"))]
async fn subscribe_with_opts<Sub>(
&self,
params: &Sub,
opts: SubscriptionOptions<'_>,
) -> Result<SubscriptionStream<Sub::Stream>, Self::Error>
where
Sub: XrpcSubscription + Send + Sync,
Self: Sync,
{
let base = self.base_uri().await;
self.subscription(base)
.with_options(opts)
.subscribe(params)
.await
}
#[cfg(target_arch = "wasm32")]
async fn subscribe_with_opts<Sub>(
&self,
params: &Sub,
opts: SubscriptionOptions<'_>,
) -> Result<SubscriptionStream<Sub::Stream>, Self::Error>
where
Sub: XrpcSubscription + Send + Sync,
{
let base = self.base_uri().await;
self.subscription(base)
.with_options(opts)
.subscribe(params)
.await
}
}
pub type TungsteniteSubscriptionClient =
BasicSubscriptionClient<crate::websocket::tungstenite_client::TungsteniteClient>;
impl TungsteniteSubscriptionClient {
pub fn from_base_uri(base_uri: Uri<String>) -> Self {
let client = crate::websocket::tungstenite_client::TungsteniteClient::new();
BasicSubscriptionClient::new(client, base_uri)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_subscription_uri_with_nsid_path() {
let base_uri = Uri::parse("wss://bsky.social/xrpc").unwrap().to_owned();
let nsid = "com.example.subscribe";
let query_params = vec![
("cursor".to_string(), "abc123".to_string()),
("filter".to_string(), "like".to_string()),
];
let uri = build_subscription_uri(&base_uri, nsid, None, &query_params)
.expect("valid base uri and path should produce valid uri");
let uri_str = uri.as_str();
assert!(uri_str.contains("/xrpc/com.example.subscribe"));
assert!(uri_str.contains("cursor=abc123"));
assert!(uri_str.contains("filter=like"));
assert!(!uri_str.contains("//xrpc"));
}
#[test]
fn test_subscription_uri_with_custom_path() {
let base_uri = Uri::parse("wss://jetstream.example.com")
.unwrap()
.to_owned();
let custom_path = "/subscribe";
let uri = build_subscription_uri(&base_uri, "com.example.sub", Some(custom_path), &[])
.expect("valid base uri and path should produce valid uri");
let uri_str = uri.as_str();
assert!(uri_str.contains("/subscribe"));
assert!(!uri_str.contains("/xrpc/"));
}
#[test]
fn test_subscription_uri_scheme_and_authority() {
let base_uri = Uri::parse("wss://example.com:8080/path")
.unwrap()
.to_owned();
let nsid = "com.example.test";
let uri = build_subscription_uri(&base_uri, nsid, None, &[])
.expect("valid base uri and path should produce valid uri");
let uri_str = uri.as_str();
assert!(uri_str.starts_with("wss://example.com:8080"));
assert!(uri_str.contains("/path/xrpc/com.example.test"));
}
#[test]
fn test_query_parameters_encoding() {
let base_uri = Uri::parse("wss://example.com").unwrap().to_owned();
let params = vec![
("cursor".to_string(), "abc123".to_string()),
("filter".to_string(), "like".to_string()),
];
let uri = build_subscription_uri(&base_uri, "com.test", None, ¶ms)
.expect("valid base uri and path should produce valid uri");
let uri_str = uri.as_str();
assert!(uri_str.contains("?"));
assert!(uri_str.contains("cursor=abc123"));
assert!(uri_str.contains("filter=like"));
assert!(uri_str.contains("&"));
}
#[test]
fn test_uri_trailing_slash_handling() {
let base_uri = Uri::parse("wss://example.com/xrpc/").unwrap().to_owned();
let uri = build_subscription_uri(&base_uri, "com.example.test", None, &[])
.expect("valid base uri and path should produce valid uri");
let uri_str = uri.as_str();
assert!(!uri_str.contains("//xrpc"));
assert!(uri_str.contains("/xrpc/com.example.test"));
}
#[test]
fn test_empty_query_parameters() {
let base_uri = Uri::parse("wss://example.com").unwrap().to_owned();
let uri = build_subscription_uri(&base_uri, "com.example.test", None, &[])
.expect("valid base uri and path should produce valid uri");
let uri_str = uri.as_str();
assert!(!uri_str.contains("?"));
assert!(uri_str.ends_with("com.example.test"));
}
}