grr_plugin/
grpc_broker.rs

1// Because of course something using Golang and gRPC has to be overtly complex in new and innovative ways.
2// The secondary streams brokered by GRPC Broker are JSON-RPC 2.0, wouldn't you know?
3use super::unique_port::UniquePort;
4use super::unix::{incoming_from_path, TempSocket};
5use super::Error;
6use super::ServiceId;
7use super::{ConnInfo, Status};
8use anyhow::anyhow;
9use anyhow::{Context, Result};
10use async_recursion::async_recursion;
11use futures::stream::StreamExt;
12use hyper::{Body, Request, Response};
13use std::collections::HashMap;
14use std::collections::HashSet;
15use std::sync::Arc;
16use std::time::Duration;
17use tokio::net::UnixStream;
18use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
19use tokio::sync::Mutex;
20use tokio::time::sleep;
21use tonic::body::BoxBody;
22use tonic::transport::NamedService;
23use tonic::transport::{Channel, Endpoint, Uri};
24use tonic::Streaming;
25use tower::service_fn as tower_service_fn;
26use tower::Service;
27
28// Brokers connections by service_id
29// Not necessarily threadsafe, so caller should Arc<RwLock<>> this,
30// but I don't know how you'd get multiple mutable references without that anyway.
31pub struct GRpcBroker {
32    unique_port: UniquePort,
33    used_ids: HashSet<u32>,
34    next_id: u32,
35
36    listener: triggered::Listener,
37
38    // Send the client information on new services and their endpoints
39    outgoing_conninfo_sender: UnboundedSender<Result<ConnInfo, Status>>,
40
41    // Services on the host-side that we've been informed of
42    // The optional ConnInfo entry allows us to park a None
43    // against a ServiceId that was previously used, so it won't get
44    // reused.
45    host_services: Arc<Mutex<HashMap<ServiceId, Option<ConnInfo>>>>,
46}
47
48impl GRpcBroker {
49    pub fn new(
50        unique_port: UniquePort,
51        outgoing_conninfo_sender: UnboundedSender<Result<ConnInfo, Status>>,
52        mut incoming_conninfo_stream_receiver_receiver: UnboundedReceiver<Streaming<ConnInfo>>,
53        listener: triggered::Listener,
54    ) -> Self {
55        log::info!("Creating new GrpcBroker");
56        let host_services = Arc::new(Mutex::new(HashMap::new()));
57
58        log::trace!("spawning a process to receive the stream of incoming ConnInfo's, and then the ConnInfo's themselves from host side...");
59        let host_services_for_closure = host_services.clone();
60        tokio::spawn(async move {
61            log::trace!(
62                "Inside spawn'd process. Waiting for the stream of ConnInfo's to be available...."
63            );
64            let incoming_conninfo_stream = match incoming_conninfo_stream_receiver_receiver
65                .recv()
66                .await
67            {
68                Some(incoming_conninfo_stream) => incoming_conninfo_stream,
69                None => {
70                    log::error!("inside spawn'd process to wait for a Stream of ConnInfo's, the stream was None, which is unexpected, since it is expected instead to block indefinitely until such a stream is available.");
71                    return;
72                }
73            };
74
75            Self::blocking_incoming_conn(incoming_conninfo_stream, host_services_for_closure).await
76        });
77
78        Self {
79            next_id: 1, // start next id at a number where it won't conflict with other services
80            used_ids: HashSet::new(),
81            unique_port,
82            outgoing_conninfo_sender,
83            host_services,
84            listener,
85        }
86    }
87
88    pub async fn new_grpc_server<S>(&mut self, plugin: S) -> Result<ServiceId, Error>
89    where
90        S: Service<Request<Body>, Response = Response<BoxBody>>
91            + NamedService
92            + Clone
93            + Send
94            + 'static,
95        <S as Service<http::Request<hyper::Body>>>::Future: Send + 'static,
96        <S as Service<http::Request<hyper::Body>>>::Error:
97            Into<Box<dyn std::error::Error + Send + Sync>> + Send,
98    {
99        log::info!("called");
100
101        // get next service_id, increment the underlying value, and release lock in the block
102        let service_id = self.get_unused_service_id();
103        log::info!("newServer - obtained an unused service_id: {}", service_id);
104
105        self.new_grpc_server_with_service_id(service_id, plugin)
106            .await
107    }
108
109    pub async fn new_grpc_server_with_service_id<S>(
110        &mut self,
111        service_id: ServiceId,
112        plugin: S,
113    ) -> Result<ServiceId, Error>
114    where
115        S: Service<Request<Body>, Response = Response<BoxBody>>
116            + NamedService
117            + Clone
118            + Send
119            + 'static,
120        <S as Service<http::Request<hyper::Body>>>::Future: Send + 'static,
121        <S as Service<http::Request<hyper::Body>>>::Error:
122            Into<Box<dyn std::error::Error + Send + Sync>> + Send,
123    {
124        log::info!("called");
125
126        if self.used_ids.contains(&service_id) {
127            return Err(Error::Other(anyhow!("In GrpcBroker, the service_id {} was provided to open a new server with, but it was found to exist already in the used set.", service_id)));
128        }
129
130        // reserve current service_id
131        self.used_ids.insert(service_id);
132
133        let temp_socket = TempSocket::new()
134        .with_context(|| format!("newServer({}) Failed to create a new TempSocket for opening a new JSON-RPC 2.0 server", service_id))?;
135        let socket_path = temp_socket.socket_filename()
136            .with_context(|| format!("newServer({}) Failed to get a temporary socket filename from the temp socket for opening a new JSON-RPC 2.0 server", service_id))?;
137        log::info!(
138            "newServer({}) Created a temp socket path: {}",
139            service_id,
140            socket_path
141        );
142
143        let listener = self.listener.clone();
144
145        tokio::spawn(async move {
146            log::debug!(
147                "newServer({}) - spawned into separate task to wait for this server to complete...",
148                service_id
149            );
150
151            let socket_path = temp_socket.socket_filename()
152            .with_context(|| format!("newServer({}) Inside spawned grpc server, failed to get a temporary socket filename from the temp socket for opening a new JSON-RPC 2.0 server", service_id)).unwrap();
153
154            // create incoming stream from unix socket above...
155            let incoming_stream_from_socket = incoming_from_path(socket_path.as_str()).await
156                .with_context(|| format!("newServer({}) Inside spawned grpc server, unable to open incoming UnixStream from socket {}", service_id, socket_path.as_str())).unwrap();
157            log::trace!("newServer({}) Inside spawned grpc server, created Incoming unix stream from the socket", service_id);
158
159            log::info!(
160                "newServer({}) Inside spawned grpc server, starting a new grpc service...",
161                service_id
162            );
163            let grpc_service_future = tonic::transport::Server::builder()
164                .add_service(plugin)
165                .serve_with_incoming_shutdown(incoming_stream_from_socket, async {
166                    listener.await
167                });
168
169            if let Err(err) = grpc_service_future.await.with_context(|| {
170                format!(
171                    "newServer({}) Inside spawned grpc server, service future failed",
172                    service_id
173                )
174            }) {
175                log::error!(
176                    "newServer({}) Inside spawned grpc server, it errored: {}",
177                    service_id,
178                    err
179                );
180            }
181
182            log::info!(
183                "newServer({}) Inside spawned grpc server, exiting task. Service has ended.",
184                service_id
185            );
186        });
187
188        log::debug!(
189            "newServer({}) - Creating ConnInfo for this service to send to the client-side broker.",
190            service_id
191        );
192        let conn_info = ConnInfo {
193            network: "unix".to_string(),
194            address: socket_path,
195            service_id,
196        };
197
198        log::debug!(
199            "newServer({}) - Created ConnInfo for this service: {:?}",
200            service_id,
201            conn_info
202        );
203
204        self.outgoing_conninfo_sender
205            .send(Ok(conn_info.clone()))
206            .with_context(|| {
207                format!(
208                    "Failed to send ConnInfo {:?} to the client/host/consumer of this plugin.",
209                    conn_info
210                )
211            })?;
212        log::info!(
213            "newServer({}) - Sent ConnInfo to client-side broker",
214            service_id
215        );
216
217        Ok(service_id)
218    }
219
220    pub fn get_unused_service_id(&mut self) -> u32 {
221        // keep incrementing next_id so long as it has already been used.
222        while self.used_ids.contains(&self.next_id) {
223            self.next_id += 1;
224        }
225
226        // return service_id that is not yet used.
227        self.next_id
228    }
229
230    pub fn get_unused_port(&mut self) -> Option<u16> {
231        self.unique_port.get_unused_port()
232    }
233
234    pub async fn dial_to_host_service(&mut self, service_id: ServiceId) -> Result<Channel, Error> {
235        let conn_info = self.get_incoming_conninfo_retry(service_id, 5).await?;
236
237        let channel = match conn_info.network.as_str() {
238            "tcp" => Endpoint::try_from(conn_info.address)?.connect().await?,
239            "unix" => {
240                // Copied from: https://github.com/hyperium/tonic/blob/master/examples/src/uds/client.rs
241                Endpoint::try_from("http://[::]:50051")?
242                    .connect_with_connector(tower_service_fn(move |_: Uri| {
243                        // Connect to a Uds socket
244                        // The clone ensures this closure doesn't consume the environment.
245                        UnixStream::connect(conn_info.address.clone())
246                    }))
247                    .await?
248            }
249            s => return Err(Error::NetworkTypeUnknown(s.to_string())),
250        };
251
252        Ok(channel)
253    }
254
255    #[async_recursion]
256    async fn get_incoming_conninfo_retry(
257        &mut self,
258        service_id: ServiceId,
259        retry_count: usize,
260    ) -> Result<ConnInfo, Error> {
261        match self.get_incoming_conninfo(service_id).await {
262            None => match retry_count {
263                0 => Err(Error::ServiceIdDoesNotExist(service_id)),
264                _c => {
265                    sleep(Duration::from_secs(1)).await;
266                    self.get_incoming_conninfo_retry(service_id, retry_count - 1)
267                        .await
268                }
269            },
270            Some(conn_info) => Ok(conn_info),
271        }
272    }
273
274    //https://github.com/hashicorp/go-plugin/blob/master/grpc_broker.go#L371
275    async fn get_incoming_conninfo(&mut self, service_id: ServiceId) -> Option<ConnInfo> {
276        // hold lock for duration of this function, so we can atomically park a None
277        // in case we pulled a ConnInfo out.
278        let mut hs = self.host_services.lock().await;
279
280        match hs.remove(&service_id) {
281            None | Some(None) => None,
282            Some(Some(conn_info)) => {
283                // if some conn_info existed, replace it with None before exiting
284                hs.insert(service_id, None);
285                Some(conn_info)
286            }
287        }
288    }
289
290    // This function will run forever. tokio::spawn this!
291    async fn blocking_incoming_conn(
292        mut stream: Streaming<ConnInfo>,
293        host_services: Arc<Mutex<HashMap<ServiceId, Option<ConnInfo>>>>,
294    ) {
295        log::info!("blocking_incoming_conn - perpetually listening for incoming ConnInfo's",);
296        while let Some(conn_info_result) = stream.next().await {
297            match conn_info_result {
298                Err(e) => {
299                    log::error!(
300                        "blocking_incoming_conn - an error occurred reading from the stream: {:?}",
301                        e
302                    );
303                    break; //out of the while loop
304                }
305                Ok(conn_info) => {
306                    log::info!("Received conn_info: {:?}", conn_info);
307
308                    let mut hs = host_services.lock().await;
309                    log::trace!("Write-locked the host services to add the new ConnInfo",);
310
311                    log::trace!(
312                        "Only creating a new entry if one doesn't exist for this ServiceId: {}",
313                        conn_info.service_id
314                    );
315                    hs.entry(conn_info.service_id)
316                        .or_insert_with(|| Some(conn_info));
317                }
318            }
319        }
320        log::info!("blocking_incoming_conn - exiting due to stream returning None or an error",);
321    }
322}
323
324#[cfg(test)]
325mod test {
326    use super::*;
327    use crate::unique_port;
328    use tokio::sync::mpsc::unbounded_channel;
329
330    #[tokio::test]
331    async fn test_service_id_increment() {
332        let (_t, l) = triggered::trigger();
333        let (t1, _r1) = unbounded_channel::<Result<ConnInfo, Status>>();
334        let (_t2, r2) = unbounded_channel::<Streaming<ConnInfo>>();
335        let mut g = GRpcBroker::new(unique_port::UniquePort::new(), t1, r2, l);
336
337        g.used_ids.insert(5);
338
339        assert_eq!(1, g.get_unused_service_id());
340        // still unuused
341        assert_eq!(1, g.get_unused_service_id());
342        g.used_ids.insert(1);
343        assert_eq!(2, g.get_unused_service_id());
344        g.used_ids.insert(2);
345        assert_eq!(3, g.get_unused_service_id());
346        g.used_ids.insert(3);
347        assert_eq!(4, g.get_unused_service_id());
348        g.used_ids.insert(4);
349
350        // skip 5 which was pre-inserted
351        assert_eq!(6, g.get_unused_service_id());
352    }
353}