proxy_sdk/
grpc_stream.rs

1use std::{
2    fmt,
3    ops::{Bound, RangeBounds},
4};
5
6use derive_builder::Builder;
7
8use crate::{
9    downcast_box::DowncastBox,
10    grpc_call::GrpcCode,
11    hostcalls::{self, BufferType},
12    log_concern, RootContext, Status, Upstream,
13};
14
15#[cfg(feature = "stream-metadata")]
16use crate::hostcalls::MapType;
17
18/// Outbound GRPC stream (bidirectional)
19#[derive(Builder)]
20#[builder(setter(into))]
21#[builder(pattern = "owned")]
22#[allow(clippy::type_complexity)]
23pub struct GrpcStream<'a> {
24    /// Upstream cluster to send the request to.
25    pub cluster: Upstream<'a>,
26    /// The GRPC service to call.
27    pub service: &'a str,
28    /// The GRPC service method to call.
29    pub method: &'a str,
30    /// Initial GRPC metadata to send with the request.
31    #[builder(setter(each(name = "metadata")), default)]
32    pub initial_metadata: Vec<(&'a str, &'a [u8])>,
33    /// Callback to call when the server sends initial metadata.
34    #[cfg(feature = "stream-metadata")]
35    #[builder(setter(custom), default)]
36    pub on_initial_metadata: Option<
37        Box<
38            dyn FnMut(
39                &mut DowncastBox<dyn RootContext>,
40                GrpcStreamHandle,
41                &GrpcStreamInitialMetadata,
42            ),
43        >,
44    >,
45    /// Callback to call when the server sends a stream message.
46    #[builder(setter(custom), default)]
47    pub on_message: Option<
48        Box<dyn FnMut(&mut DowncastBox<dyn RootContext>, GrpcStreamHandle, &GrpcStreamMessage)>,
49    >,
50    /// Callback to call when the server sends trailing metadata.
51    #[cfg(feature = "stream-metadata")]
52    #[builder(setter(custom), default)]
53    pub on_trailing_metadata: Option<
54        Box<
55            dyn FnMut(
56                &mut DowncastBox<dyn RootContext>,
57                GrpcStreamHandle,
58                &GrpcStreamTrailingMetadata,
59            ),
60        >,
61    >,
62    /// Callback to call when the stream closes.
63    #[builder(setter(custom), default)]
64    pub on_close: Option<Box<dyn FnOnce(&mut DowncastBox<dyn RootContext>, &GrpcStreamClose)>>,
65}
66
67impl<'a> GrpcStreamBuilder<'a> {
68    /// Set an initial metadata callback
69    #[cfg(feature = "stream-metadata")]
70    pub fn on_initial_metadata<R: RootContext + 'static>(
71        mut self,
72        mut callback: impl FnMut(&mut R, GrpcStreamHandle, &GrpcStreamInitialMetadata) + 'static,
73    ) -> Self {
74        self.on_initial_metadata = Some(Some(Box::new(move |root, handle, metadata| {
75            callback(
76                root.as_any_mut().downcast_mut().expect("invalid root type"),
77                handle,
78                metadata,
79            )
80        })));
81        self
82    }
83
84    /// Set a stream message callback
85    pub fn on_message<R: RootContext + 'static>(
86        mut self,
87        mut callback: impl FnMut(&mut R, GrpcStreamHandle, &GrpcStreamMessage) + 'static,
88    ) -> Self {
89        self.on_message = Some(Some(Box::new(move |root, handle, message| {
90            callback(
91                root.as_any_mut().downcast_mut().expect("invalid root type"),
92                handle,
93                message,
94            )
95        })));
96        self
97    }
98
99    /// Set a trailing metadata callback
100    #[cfg(feature = "stream-metadata")]
101    pub fn on_trailing_metadata<R: RootContext + 'static>(
102        mut self,
103        mut callback: impl FnMut(&mut R, GrpcStreamHandle, &GrpcStreamTrailingMetadata) + 'static,
104    ) -> Self {
105        self.on_trailing_metadata = Some(Some(Box::new(move |root, handle, metadata| {
106            callback(
107                root.as_any_mut().downcast_mut().expect("invalid root type"),
108                handle,
109                metadata,
110            )
111        })));
112        self
113    }
114
115    /// Set a stream close callback
116    pub fn on_close<R: RootContext + 'static>(
117        mut self,
118        callback: impl FnOnce(&mut R, &GrpcStreamClose) + 'static,
119    ) -> Self {
120        self.on_close = Some(Some(Box::new(move |root, close| {
121            callback(
122                root.as_any_mut().downcast_mut().expect("invalid root type"),
123                close,
124            )
125        })));
126        self
127    }
128}
129
130/// GRPC stream handle to cancel, close, or send a message over a GRPC stream.
131#[derive(Clone, Copy, PartialEq, Eq)]
132pub struct GrpcStreamHandle(pub(crate) u32);
133
134impl<'a> GrpcStream<'a> {
135    /// Open a new outbound GRPC stream.
136    pub fn open(self) -> Result<GrpcStreamHandle, Status> {
137        let token = hostcalls::open_grpc_stream(
138            &self.cluster.0,
139            self.service,
140            self.method,
141            &self.initial_metadata,
142        )?;
143
144        #[cfg(feature = "stream-metadata")]
145        if let Some(callback) = self.on_initial_metadata {
146            crate::dispatcher::register_grpc_stream_initial_meta(token, callback);
147        }
148        if let Some(callback) = self.on_message {
149            crate::dispatcher::register_grpc_stream_message(token, callback);
150        }
151        #[cfg(feature = "stream-metadata")]
152        if let Some(callback) = self.on_trailing_metadata {
153            crate::dispatcher::register_grpc_stream_trailing_metadata(token, callback);
154        }
155        if let Some(callback) = self.on_close {
156            crate::dispatcher::register_grpc_stream_close(token, callback);
157        }
158
159        Ok(GrpcStreamHandle(token))
160    }
161}
162
163impl GrpcStreamHandle {
164    /// Attempts to cancel the GRPC stream
165    pub fn cancel(&self) {
166        hostcalls::cancel_grpc_stream(self.0).ok();
167    }
168
169    /// Closes the GRPC stream
170    pub fn close(&self) {
171        hostcalls::close_grpc_stream(self.0).ok();
172    }
173
174    /// Sends a message over the GRPC stream
175    pub fn send(&self, message: Option<impl AsRef<[u8]>>, end_stream: bool) -> Result<(), Status> {
176        hostcalls::send_grpc_stream_message(
177            self.0,
178            message.as_ref().map(|x| x.as_ref()),
179            end_stream,
180        )
181    }
182}
183
184impl PartialEq<u32> for GrpcStreamHandle {
185    fn eq(&self, other: &u32) -> bool {
186        self.0 == *other
187    }
188}
189
190impl PartialEq<GrpcStreamHandle> for u32 {
191    fn eq(&self, other: &GrpcStreamHandle) -> bool {
192        other == self
193    }
194}
195
196impl fmt::Display for GrpcStreamHandle {
197    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198        self.0.fmt(f)
199    }
200}
201
202/// Response type for [`GrpcStream::on_initial_metadata`]
203#[cfg(feature = "stream-metadata")]
204pub struct GrpcStreamInitialMetadata {
205    num_elements: usize,
206}
207
208#[cfg(feature = "stream-metadata")]
209impl GrpcStreamInitialMetadata {
210    pub(crate) fn new(num_elements: usize) -> Self {
211        Self { num_elements }
212    }
213
214    /// Number of metadata elements
215    pub fn num_elements(&self) -> usize {
216        self.num_elements
217    }
218
219    /// Get all metadata elements
220    pub fn all(&self) -> Vec<(String, Vec<u8>)> {
221        log_concern(
222            "grpc-stream-metadata-all",
223            hostcalls::get_map(MapType::GrpcReceiveInitialMetadata),
224        )
225        .unwrap_or_default()
226    }
227
228    /// Get a specific metadata element
229    pub fn value(&self, name: impl AsRef<str>) -> Option<Vec<u8>> {
230        log_concern(
231            "grpc-stream-metadata",
232            hostcalls::get_map_value(MapType::GrpcReceiveInitialMetadata, name.as_ref()),
233        )
234    }
235}
236
237/// Response type for [`GrpcStream::on_message`]
238pub struct GrpcStreamMessage {
239    status_code: GrpcCode,
240    body_size: usize,
241    message: Option<String>,
242}
243
244impl GrpcStreamMessage {
245    pub(crate) fn new(status_code: GrpcCode, message: Option<String>, body_size: usize) -> Self {
246        Self {
247            status_code,
248            body_size,
249            message,
250        }
251    }
252
253    /// GRPC status code of the message
254    pub fn status_code(&self) -> GrpcCode {
255        self.status_code
256    }
257
258    /// Optional GRPC status message of the message
259    pub fn status_message(&self) -> Option<&str> {
260        self.message.as_deref()
261    }
262
263    /// Total size of the message body
264    pub fn body_size(&self) -> usize {
265        self.body_size
266    }
267
268    /// Get a range of the message body
269    pub fn body(&self, range: impl RangeBounds<usize>) -> Option<Vec<u8>> {
270        let start = match range.start_bound() {
271            Bound::Included(x) => *x,
272            Bound::Excluded(x) => x.saturating_sub(1),
273            Bound::Unbounded => 0,
274        };
275        let size = match range.end_bound() {
276            Bound::Included(x) => *x + 1,
277            Bound::Excluded(x) => *x,
278            Bound::Unbounded => self.body_size,
279        }
280        .min(self.body_size)
281        .saturating_sub(start);
282        log_concern(
283            "grpc-stream-message-body",
284            hostcalls::get_buffer(BufferType::GrpcReceiveBuffer, start, size),
285        )
286    }
287
288    /// Get the entire message body
289    pub fn full_body(&self) -> Option<Vec<u8>> {
290        self.body(..self.body_size)
291    }
292}
293
294/// Response type for [`GrpcStream::on_trailing_metadata`]
295#[cfg(feature = "stream-metadata")]
296pub struct GrpcStreamTrailingMetadata {
297    num_elements: usize,
298}
299
300#[cfg(feature = "stream-metadata")]
301impl GrpcStreamTrailingMetadata {
302    pub(crate) fn new(num_elements: usize) -> Self {
303        Self { num_elements }
304    }
305
306    /// Number of metadata elements
307    pub fn num_elements(&self) -> usize {
308        self.num_elements
309    }
310
311    /// Get all metadata elements
312    pub fn all(&self) -> Vec<(String, Vec<u8>)> {
313        log_concern(
314            "grpc-stream-trailing-metadata-all",
315            hostcalls::get_map(MapType::GrpcReceiveTrailingMetadata),
316        )
317        .unwrap_or_default()
318    }
319
320    /// Get a specific metadata element
321    pub fn value(&self, name: impl AsRef<str>) -> Option<Vec<u8>> {
322        log_concern(
323            "grpc-stream-trailing-metadata",
324            hostcalls::get_map_value(MapType::GrpcReceiveTrailingMetadata, name.as_ref()),
325        )
326    }
327}
328
329/// Response type for [`GrpcStream::on_close`]
330pub struct GrpcStreamClose {
331    handle_id: u32,
332    status_code: GrpcCode,
333    message: Option<String>,
334}
335
336impl GrpcStreamClose {
337    pub(crate) fn new(token_id: u32, status_code: GrpcCode, message: Option<String>) -> Self {
338        Self {
339            handle_id: token_id,
340            status_code,
341            message,
342        }
343    }
344
345    /// GRPC handle ID of the message
346    pub fn handle_id(&self) -> u32 {
347        self.handle_id
348    }
349
350    /// GRPC status code of the message
351    pub fn status_code(&self) -> GrpcCode {
352        self.status_code
353    }
354
355    /// Optional GRPC status message of the message
356    pub fn status_message(&self) -> Option<&str> {
357        self.message.as_deref()
358    }
359}