Skip to main content

trillium_grpc/server/
streaming.rs

1//! Borrowed streaming primitives handed to service methods.
2//!
3//! [`RequestStream`] decodes inbound request messages from a boxed reader — the
4//! request body during `run()` (via [`GrpcServerConn::requests`]), and continues
5//! against the same retained body after a bidi upgrade. [`Channel`] is the
6//! bidirectional read+write surface a bidi responder drives over the upgraded
7//! transport.
8//!
9//! Codec is type-erased into `fn` pointers so these user-facing types carry no
10//! codec parameter.
11//!
12//! [`GrpcServerConn::requests`]: crate::server::GrpcServerConn::requests
13
14use crate::{
15    Encoding, Status,
16    encoding::DEFAULT_MAX_MESSAGE_SIZE,
17    frame::{
18        reader::{ReadState, poll_read_message},
19        writer::encode_payload,
20    },
21};
22use bytes::Bytes;
23use futures_lite::{AsyncRead, AsyncWriteExt, Stream};
24use std::{
25    future::poll_fn,
26    pin::Pin,
27    task::{Context, Poll},
28};
29use trillium::{Headers, Upgrade};
30
31/// A stream of decoded request messages.
32///
33/// Produced by [`GrpcServerConn::requests`](crate::server::GrpcServerConn::requests). Read
34/// it with [`recv`](Self::recv) (or as a [`Stream`]); `recv` yields `Ok(None)`
35/// on clean end-of-stream and `Err` on a decode or transport error.
36pub struct RequestStream<'a, T> {
37    reader: Pin<Box<dyn AsyncRead + Send + 'a>>,
38    state: ReadState,
39    decode: fn(&[u8]) -> Result<T, Status>,
40    encoding: Encoding,
41    max_message_size: usize,
42}
43
44impl<'a, T> RequestStream<'a, T> {
45    pub(crate) fn new(
46        reader: Pin<Box<dyn AsyncRead + Send + 'a>>,
47        decode: fn(&[u8]) -> Result<T, Status>,
48        encoding: Encoding,
49    ) -> Self {
50        Self {
51            reader,
52            state: ReadState::new(),
53            decode,
54            encoding,
55            max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
56        }
57    }
58
59    fn poll_message(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<T, Status>>> {
60        poll_read_message(
61            self.reader.as_mut(),
62            &mut self.state,
63            cx,
64            self.decode,
65            self.encoding,
66            self.max_message_size,
67        )
68    }
69
70    /// The next decoded request message: `Ok(None)` on clean end-of-stream,
71    /// `Err` on a per-message decode error or transport failure.
72    pub async fn recv(&mut self) -> Result<Option<T>, Status> {
73        poll_fn(|cx| match self.poll_message(cx) {
74            Poll::Ready(Some(Ok(t))) => Poll::Ready(Ok(Some(t))),
75            Poll::Ready(Some(Err(e))) => Poll::Ready(Err(e)),
76            Poll::Ready(None) => Poll::Ready(Ok(None)),
77            Poll::Pending => Poll::Pending,
78        })
79        .await
80    }
81}
82
83impl<T: 'static> Stream for RequestStream<'_, T> {
84    type Item = Result<T, Status>;
85    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
86        self.get_mut().poll_message(cx)
87    }
88}
89
90/// Bidirectional channel: read decoded requests, write framed responses, over
91/// the upgraded transport.
92///
93/// Handed to a [`BidiResponder`](crate::BidiResponder). Turn-taking by
94/// construction — both [`recv`](Self::recv) and [`send`](Self::send) take
95/// `&mut self`, so the underlying `&mut Upgrade` is never aliased. For producers
96/// that need to run concurrently with the request loop, spawn a task on the
97/// runtime.
98///
99/// The response *initial* metadata was committed before the upgrade (by the
100/// prologue), so there is no way to mutate it here. The *trailing* metadata,
101/// emitted after the loop alongside `grpc-status`, is still open: write it
102/// through [`response_trailers_mut`](Self::response_trailers_mut) (the bag was
103/// seeded with whatever the prologue set).
104pub struct Channel<'a, Req, Resp> {
105    upgrade: &'a mut Upgrade,
106    response_trailers: &'a mut Headers,
107    state: ReadState,
108    decode: fn(&[u8]) -> Result<Req, Status>,
109    encode: fn(&Resp) -> Result<Bytes, Status>,
110    inbound_encoding: Encoding,
111    outbound_encoding: Encoding,
112    max_message_size: usize,
113}
114
115impl<'a, Req, Resp> Channel<'a, Req, Resp> {
116    pub(crate) fn new(
117        upgrade: &'a mut Upgrade,
118        response_trailers: &'a mut Headers,
119        decode: fn(&[u8]) -> Result<Req, Status>,
120        encode: fn(&Resp) -> Result<Bytes, Status>,
121        inbound_encoding: Encoding,
122        outbound_encoding: Encoding,
123    ) -> Self {
124        Self {
125            upgrade,
126            response_trailers,
127            state: ReadState::new(),
128            decode,
129            encode,
130            inbound_encoding,
131            outbound_encoding,
132            max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
133        }
134    }
135
136    /// The response's trailing metadata, emitted alongside `grpc-status` once
137    /// the loop ends. Seeded with whatever the prologue set; write to it to add
138    /// trailing metadata (including `grpc-status-details-bin` error details)
139    /// from inside the loop.
140    pub fn response_trailers_mut(&mut self) -> &mut Headers {
141        self.response_trailers
142    }
143
144    /// Read the next decoded request. `None` on clean EOF (client closed the
145    /// request side); `Some(Err(_))` ends the read side.
146    pub async fn recv(&mut self) -> Option<Result<Req, Status>> {
147        let upgrade = &mut *self.upgrade;
148        let state = &mut self.state;
149        let decode = self.decode;
150        let encoding = self.inbound_encoding;
151        let max = self.max_message_size;
152        poll_fn(|cx| poll_read_message(Pin::new(&mut *upgrade), state, cx, decode, encoding, max))
153            .await
154    }
155
156    /// Frame and write one response message.
157    pub async fn send(&mut self, value: Resp) -> Result<(), Status> {
158        let payload = (self.encode)(&value)?;
159        let frame = encode_payload(&payload, self.outbound_encoding)?;
160        self.upgrade
161            .write_all(&frame)
162            .await
163            .map_err(|e| Status::unavailable(format!("write error: {e}")))
164    }
165}