vtubestudio/service/
api.rs

1use crate::data::{RequestEnvelope, RequestId, ResponseEnvelope};
2use crate::error::{BoxError, Error};
3use crate::transport::buffered::BufferedApiTransport;
4use crate::transport::event::{EventStream, EventlessApiTransport};
5
6use futures_core::TryStream;
7use futures_sink::Sink;
8use std::fmt::Write;
9use std::future::Future;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12use tokio_tower::multiplex::{Client as MultiplexClient, MultiplexTransport, TagStore};
13use tower::Service;
14
15crate::cfg_feature! {
16    #![feature = "tokio-tungstenite"]
17    use crate::transport::TungsteniteApiTransport;
18
19    /// Type alias for an [`ApiService`] wrapping a [`TungsteniteApiTransport`].
20    pub type TungsteniteApiService = ApiService<TungsteniteApiTransport>;
21}
22
23/// Struct describing how to tag [`RequestEnvelope`]s and extract tags from [`ResponseEnvelope`]s.
24#[derive(Debug)]
25pub struct IdTagger {
26    next: usize,
27    buffer: String,
28}
29
30impl TagStore<RequestEnvelope, ResponseEnvelope> for IdTagger {
31    type Tag = RequestId;
32
33    fn assign_tag(mut self: Pin<&mut Self>, request: &mut RequestEnvelope) -> Self::Tag {
34        // If request already has an ID, use it. Otherwise generate a new one.
35        if let Some(id) = &request.request_id {
36            return id.clone();
37        }
38
39        let id = self.next;
40        if write!(self.buffer, "{}", id).is_err() {
41            // We don't expect this to happen, but recover just in case
42            self.buffer = id.to_string();
43        }
44
45        let id = RequestId::from(self.buffer.as_str());
46        request.request_id = Some(id.clone());
47
48        self.next += 1;
49        self.buffer.clear();
50        id
51    }
52
53    fn finish_tag(self: Pin<&mut Self>, response: &ResponseEnvelope) -> Self::Tag {
54        response.request_id.clone()
55    }
56}
57
58type ServiceInner<T> = MultiplexClient<
59    MultiplexTransport<BufferedApiTransport<EventlessApiTransport<T>>, IdTagger>,
60    Error,
61    RequestEnvelope,
62>;
63
64/// A [`Service`] that assigns request IDs to [`RequestEnvelope`]s and matches them to incoming
65/// [`ResponseEnvelope`]s.
66///
67/// This uses [`tokio_tower::multiplex`] to wrap an underlying transport.
68#[derive(Debug)]
69pub struct ApiService<T>
70where
71    T: Sink<RequestEnvelope> + TryStream<Ok = ResponseEnvelope>,
72{
73    service: ServiceInner<T>,
74}
75
76impl<T> ApiService<T>
77where
78    T: Sink<RequestEnvelope> + TryStream<Ok = ResponseEnvelope> + Send + 'static,
79    <T as TryStream>::Error: Send,
80    BoxError: From<<T as Sink<RequestEnvelope>>::Error>,
81    BoxError: From<<T as TryStream>::Error>,
82{
83    /// Create a new [`ApiService`] and corresponding [`EventStream`].
84    pub fn new(transport: T, buffer_size: usize) -> (Self, EventStream<T>) {
85        Self::with_error_handler(
86            transport,
87            buffer_size,
88            |error| tracing::error!(%error, "Transport error"),
89        )
90    }
91
92    /// Create a new [`ApiService`] with an internal handler for transport errors.
93    pub fn with_error_handler<F>(
94        transport: T,
95        buffer_size: usize,
96        on_service_error: F,
97    ) -> (Self, EventStream<T>)
98    where
99        F: FnOnce(Error) + Send + 'static,
100    {
101        let tagger = IdTagger {
102            next: 0,
103            buffer: String::new(),
104        };
105
106        let (eventless_transport, event_stream) = EventlessApiTransport::new(transport);
107        let buffered_transport = BufferedApiTransport::new(eventless_transport, buffer_size);
108
109        let multiplex_transport = MultiplexTransport::new(buffered_transport, tagger);
110        let service = MultiplexClient::with_error_handler(multiplex_transport, on_service_error);
111
112        (Self { service }, event_stream)
113    }
114}
115
116impl<T> Service<RequestEnvelope> for ApiService<T>
117where
118    T: Sink<RequestEnvelope> + TryStream<Ok = ResponseEnvelope> + 'static,
119    BoxError: From<<T as Sink<RequestEnvelope>>::Error>,
120    BoxError: From<<T as TryStream>::Error>,
121{
122    type Response = ResponseEnvelope;
123    type Error = Error;
124    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
125
126    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
127        self.service.poll_ready(cx)
128    }
129
130    fn call(&mut self, req: RequestEnvelope) -> Self::Future {
131        self.service.call(req)
132    }
133}