Skip to main content

tonic_reflection/server/
v1alpha.rs

1use std::{fmt, sync::Arc};
2
3use tokio::sync::mpsc;
4use tokio_stream::{Stream, StreamExt};
5use tonic::{Request, Response, Status, Streaming};
6
7use super::ReflectionServiceState;
8use crate::pb::v1alpha::server_reflection_request::MessageRequest;
9use crate::pb::v1alpha::server_reflection_response::MessageResponse;
10pub use crate::pb::v1alpha::server_reflection_server::{ServerReflection, ServerReflectionServer};
11use crate::pb::v1alpha::{
12    ExtensionNumberResponse, FileDescriptorResponse, ListServiceResponse, ServerReflectionRequest,
13    ServerReflectionResponse, ServiceResponse,
14};
15
16/// An implementation for `ServerReflection`.
17#[derive(Debug)]
18pub struct ReflectionService {
19    state: Arc<ReflectionServiceState>,
20}
21
22#[tonic::async_trait]
23impl ServerReflection for ReflectionService {
24    type ServerReflectionInfoStream = ServerReflectionInfoStream;
25
26    async fn server_reflection_info(
27        &self,
28        req: Request<Streaming<ServerReflectionRequest>>,
29    ) -> Result<Response<Self::ServerReflectionInfoStream>, Status> {
30        let mut req_rx = req.into_inner();
31        let (resp_tx, resp_rx) = mpsc::channel::<Result<ServerReflectionResponse, Status>>(1);
32
33        let state = self.state.clone();
34
35        tokio::spawn(async move {
36            while let Some(req) = req_rx.next().await {
37                let Ok(req) = req else {
38                    return;
39                };
40
41                let resp_msg = match req.message_request.clone() {
42                    None => Err(Status::invalid_argument("invalid MessageRequest")),
43                    Some(msg) => match msg {
44                        MessageRequest::FileByFilename(s) => state.file_by_filename(&s).map(|fd| {
45                            MessageResponse::FileDescriptorResponse(FileDescriptorResponse {
46                                file_descriptor_proto: vec![fd],
47                            })
48                        }),
49                        MessageRequest::FileContainingSymbol(s) => {
50                            state.symbol_by_name(&s).map(|fd| {
51                                MessageResponse::FileDescriptorResponse(FileDescriptorResponse {
52                                    file_descriptor_proto: vec![fd],
53                                })
54                            })
55                        }
56                        MessageRequest::FileContainingExtension(_) => {
57                            Err(Status::not_found("extensions are not supported"))
58                        }
59                        MessageRequest::AllExtensionNumbersOfType(_) => {
60                            // NOTE: Workaround. Some grpc clients (e.g. grpcurl) expect this method not to fail.
61                            // https://github.com/hyperium/tonic/issues/1077
62                            Ok(MessageResponse::AllExtensionNumbersResponse(
63                                ExtensionNumberResponse::default(),
64                            ))
65                        }
66                        MessageRequest::ListServices(_) => {
67                            Ok(MessageResponse::ListServicesResponse(ListServiceResponse {
68                                service: state
69                                    .list_services()
70                                    .iter()
71                                    .map(|s| ServiceResponse { name: s.clone() })
72                                    .collect(),
73                            }))
74                        }
75                    },
76                };
77
78                match resp_msg {
79                    Ok(resp_msg) => {
80                        let resp = ServerReflectionResponse {
81                            valid_host: req.host.clone(),
82                            original_request: Some(req.clone()),
83                            message_response: Some(resp_msg),
84                        };
85                        if resp_tx.send(Ok(resp)).await.is_err() {
86                            return;
87                        }
88                    }
89                    Err(status) => {
90                        let _ = resp_tx.send(Err(status)).await;
91                        return;
92                    }
93                }
94            }
95        });
96
97        Ok(Response::new(ServerReflectionInfoStream::new(resp_rx)))
98    }
99}
100
101impl From<ReflectionServiceState> for ReflectionService {
102    fn from(state: ReflectionServiceState) -> Self {
103        Self {
104            state: Arc::new(state),
105        }
106    }
107}
108
109/// A response stream.
110pub struct ServerReflectionInfoStream {
111    inner: tokio_stream::wrappers::ReceiverStream<Result<ServerReflectionResponse, Status>>,
112}
113
114impl ServerReflectionInfoStream {
115    fn new(resp_rx: mpsc::Receiver<Result<ServerReflectionResponse, Status>>) -> Self {
116        let inner = tokio_stream::wrappers::ReceiverStream::new(resp_rx);
117        Self { inner }
118    }
119}
120
121impl Stream for ServerReflectionInfoStream {
122    type Item = Result<ServerReflectionResponse, Status>;
123
124    fn poll_next(
125        mut self: std::pin::Pin<&mut Self>,
126        cx: &mut std::task::Context<'_>,
127    ) -> std::task::Poll<Option<Self::Item>> {
128        std::pin::Pin::new(&mut self.inner).poll_next(cx)
129    }
130
131    fn size_hint(&self) -> (usize, Option<usize>) {
132        self.inner.size_hint()
133    }
134}
135
136impl fmt::Debug for ServerReflectionInfoStream {
137    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138        f.debug_tuple("ServerReflectionInfoStream").finish()
139    }
140}