pdk_classy/grpc/
call.rs

1// Copyright (c) 2025, Salesforce, Inc.,
2// All rights reserved.
3// For full license text, see the LICENSE.txt file
4
5use std::future::Future;
6use std::pin::Pin;
7use std::rc::Rc;
8use std::task::{Context, Poll, Waker};
9
10use crate::grpc::error::{GrpcCallError, GrpcClientError};
11use crate::grpc::transport::GrpcCallId;
12use crate::host::grpc::GrpcHost;
13use crate::reactor::root::RootReactor;
14use crate::types::Cid;
15
16use super::codec::Decoder;
17use super::{GrpcResponse, GrpcStatus, GrpcStatusCode};
18
19/// Represents an async oneshot gRPC request to the upstream.
20pub struct GrpcCall<D> {
21    reactor: Rc<RootReactor>,
22    host: Rc<dyn GrpcHost>,
23    cid_and_waker: Option<(Cid, Waker)>,
24    call_id: Option<GrpcCallId>,
25    error: Option<GrpcCallError>,
26    ready: bool,
27    decoder: D,
28}
29
30impl<D> GrpcCall<D> {
31    pub(super) fn new(
32        reactor: Rc<RootReactor>,
33        host: Rc<dyn GrpcHost>,
34        decoder: D,
35        result: Result<GrpcCallId, GrpcCallError>,
36    ) -> Self {
37        let (call_id, error) = match result {
38            Ok(call_id) => (Some(call_id), None),
39            Err(error) => (None, Some(error)),
40        };
41        Self {
42            reactor,
43            host,
44            decoder,
45            cid_and_waker: None,
46            call_id,
47            error,
48            ready: false,
49        }
50    }
51
52    /// Turns this request into a result to detect call errors before awaiting the remote response.
53    pub fn into_result(mut self) -> Result<Self, GrpcCallError> {
54        if self.error.is_some() {
55            // It is infallible to unwrap here
56            return Err(self.error.take().unwrap());
57        }
58        Ok(self)
59    }
60}
61
62impl<R> Drop for GrpcCall<R> {
63    fn drop(&mut self) {
64        if let Some(call_id) = self.call_id.as_ref() {
65            let call_id = *call_id;
66
67            // Just cancel if the call is not ready
68            if !self.ready {
69                let _ = self.host.cancel_grpc_call(call_id.token());
70            }
71
72            // Ensure that all related objects were removed
73            self.reactor.remove_grpc_response(call_id);
74            self.reactor.remove_grpc_client(call_id);
75        }
76    }
77}
78
79impl<D> Future for GrpcCall<D>
80where
81    D: Decoder + Unpin,
82{
83    type Output = Result<GrpcResponse<D::Output>, GrpcClientError>;
84
85    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
86        let Some(call_id) = self.call_id else {
87            self.ready = true;
88
89            // It is infallible to unwrap here
90            let error = self.error.take().unwrap();
91
92            return Poll::Ready(Err(GrpcClientError::Call(error)));
93        };
94
95        if let Some(response_parts) = self.reactor.remove_grpc_response(call_id) {
96            self.ready = true;
97            let status = GrpcStatusCode::from_u32(response_parts.event.status_code);
98            if status != GrpcStatusCode::Ok {
99                return Poll::Ready(Err(GrpcClientError::Status(GrpcStatus::new(
100                    status,
101                    response_parts.status,
102                ))));
103            }
104            let content = response_parts.content.unwrap_or_default();
105
106            self.decoder
107                .decode(content)
108                .map(GrpcResponse::new)
109                .map_err(|e| GrpcClientError::Decode(e.into()))
110                .into()
111        } else {
112            let this = &mut *self.as_mut();
113            match this.cid_and_waker.as_ref() {
114                None => {
115                    let cid = this.reactor.active_cid();
116
117                    // Register the waker in the reactor.
118                    this.reactor.insert_grpc_client(call_id, cx.waker().clone());
119                    this.reactor.set_paused(cid, true);
120                    this.cid_and_waker = Some((cid, cx.waker().clone()));
121                }
122                Some((cid, waker)) if !waker.will_wake(cx.waker()) => {
123                    // Deregister the waker from the reactor to remove the old waker.
124                    let _ = this
125                        .reactor
126                        .remove_grpc_client(call_id)
127                        // It should be infallible to unwrap here
128                        .expect("stored client");
129
130                    // Register the waker in the reactor with the new waker.
131                    this.reactor.insert_grpc_client(call_id, cx.waker().clone());
132                    this.cid_and_waker = Some((*cid, cx.waker().clone()));
133                }
134                Some(_) => {}
135            }
136            Poll::Pending
137        }
138    }
139}