poem_grpc/
reflection.rs

1use std::{collections::HashMap, sync::Arc};
2
3use futures_util::StreamExt;
4use poem::{IntoEndpoint, endpoint::BoxEndpoint};
5use prost::Message;
6use prost_types::{DescriptorProto, EnumDescriptorProto, FileDescriptorProto, FileDescriptorSet};
7use proto::{
8    server_reflection_request::MessageRequest, server_reflection_response::MessageResponse,
9};
10
11use crate::{Code, Request, Response, Service, Status, Streaming};
12
13#[allow(unreachable_pub)]
14#[allow(clippy::enum_variant_names)]
15#[allow(clippy::derive_partial_eq_without_eq)]
16mod proto {
17    include!(concat!(env!("OUT_DIR"), "/grpc.reflection.v1alpha.rs"));
18}
19
20pub(crate) const FILE_DESCRIPTOR_SET: &[u8] = include_file_descriptor_set!("grpc-reflection.bin");
21
22struct State {
23    service_names: Vec<proto::ServiceResponse>,
24    files: HashMap<String, Arc<FileDescriptorProto>>,
25    symbols: HashMap<String, Arc<FileDescriptorProto>>,
26}
27
28impl State {
29    #[allow(clippy::result_large_err)]
30    fn file_by_filename(&self, filename: &str) -> Result<MessageResponse, Status> {
31        match self.files.get(filename) {
32            None => {
33                Err(Status::new(Code::NotFound)
34                    .with_message(format!("file '{filename}' not found")))
35            }
36            Some(fd) => {
37                let mut encoded_fd = Vec::new();
38                if fd.clone().encode(&mut encoded_fd).is_err() {
39                    return Err(Status::new(Code::Internal).with_message("encoding error"));
40                }
41
42                Ok(MessageResponse::FileDescriptorResponse(
43                    proto::FileDescriptorResponse {
44                        file_descriptor_proto: vec![encoded_fd],
45                    },
46                ))
47            }
48        }
49    }
50
51    #[allow(clippy::result_large_err)]
52    fn symbol_by_name(&self, symbol: &str) -> Result<MessageResponse, Status> {
53        match self.symbols.get(symbol) {
54            None => {
55                Err(Status::new(Code::NotFound)
56                    .with_message(format!("symbol '{symbol}' not found")))
57            }
58            Some(fd) => {
59                let mut encoded_fd = Vec::new();
60                if fd.clone().encode(&mut encoded_fd).is_err() {
61                    return Err(Status::new(Code::Internal).with_message("encoding error"));
62                };
63
64                Ok(MessageResponse::FileDescriptorResponse(
65                    proto::FileDescriptorResponse {
66                        file_descriptor_proto: vec![encoded_fd],
67                    },
68                ))
69            }
70        }
71    }
72
73    fn list_services(&self) -> MessageResponse {
74        MessageResponse::ListServicesResponse(proto::ListServiceResponse {
75            service: self.service_names.clone(),
76        })
77    }
78}
79
80/// A service that serve for reflection
81struct ServerReflectionService {
82    state: Arc<State>,
83}
84
85impl proto::ServerReflection for ServerReflectionService {
86    async fn server_reflection_info(
87        &self,
88        request: Request<Streaming<proto::ServerReflectionRequest>>,
89    ) -> Result<Response<Streaming<proto::ServerReflectionResponse>>, Status> {
90        let mut request_stream = request.into_inner();
91        let state = self.state.clone();
92
93        Ok(Response::new(Streaming::new(async_stream::try_stream! {
94            while let Some(req) = request_stream.next().await.transpose()? {
95                let resp = match &req.message_request {
96                    Some(MessageRequest::FileByFilename(filename)) => state.file_by_filename(filename),
97                    Some(MessageRequest::FileContainingSymbol(symbol)) => state.symbol_by_name(symbol),
98                    Some(MessageRequest::FileContainingExtension(_) | MessageRequest::AllExtensionNumbersOfType(_)) => Err(Status::new(Code::Unimplemented)),
99                    Some(MessageRequest::ListServices(_)) => Ok(state.list_services()),
100                    None => Err(Status::new(Code::InvalidArgument)),
101                }?;
102
103                yield proto::ServerReflectionResponse {
104                    valid_host: req.host.clone(),
105                    original_request: Some(req.clone()),
106                    message_response: Some(resp),
107                };
108            }
109        })))
110    }
111}
112
113/// A builder for creating reflection service
114#[derive(Debug, Default)]
115pub struct Reflection {
116    file_descriptor_sets: Vec<FileDescriptorSet>,
117    service_names: Vec<String>,
118    symbols: HashMap<String, Arc<FileDescriptorProto>>,
119}
120
121impl Reflection {
122    /// Create a `ReflectionBuilder`
123    pub fn new() -> Self {
124        Default::default()
125    }
126
127    /// Add a file descriptor set
128    pub fn add_file_descriptor_set(mut self, data: &[u8]) -> Self {
129        self.file_descriptor_sets
130            .push(FileDescriptorSet::decode(data).expect("valid file descriptor sets"));
131        self
132    }
133
134    /// Build a reflection service
135    pub fn build(
136        self,
137    ) -> impl IntoEndpoint<Endpoint = BoxEndpoint<'static, poem::Response>> + Service {
138        let mut this = self.add_file_descriptor_set(FILE_DESCRIPTOR_SET);
139
140        let fd_iter = std::mem::take(&mut this.file_descriptor_sets)
141            .into_iter()
142            .flat_map(|fds| fds.file.into_iter());
143        let mut files = HashMap::with_capacity(fd_iter.size_hint().0);
144
145        for fd in fd_iter {
146            let fd = Arc::new(fd);
147
148            match fd.name.clone() {
149                Some(filename) => {
150                    files.insert(filename, fd.clone());
151                }
152                None => panic!("missing file name"),
153            }
154
155            let prefix = fd.package.as_deref().unwrap_or_default();
156
157            for proto in &fd.message_type {
158                this.process_message(fd.clone(), prefix, proto);
159            }
160
161            for proto in &fd.enum_type {
162                this.process_enum(fd.clone(), prefix, proto);
163            }
164
165            for service in &fd.service {
166                let service_name = qualified_name(prefix, "service", service.name.as_deref());
167                this.service_names.push(service_name.clone());
168                this.symbols.insert(service_name.clone(), fd.clone());
169
170                for method in &service.method {
171                    let method_name =
172                        qualified_name(&service_name, "method", method.name.as_deref());
173                    this.symbols.insert(method_name, fd.clone());
174                }
175            }
176        }
177
178        proto::ServerReflectionServer::new(ServerReflectionService {
179            state: Arc::new(State {
180                service_names: this
181                    .service_names
182                    .into_iter()
183                    .map(|name| proto::ServiceResponse { name })
184                    .collect(),
185                files,
186                symbols: this.symbols,
187            }),
188        })
189    }
190
191    fn process_message(
192        &mut self,
193        fd: Arc<FileDescriptorProto>,
194        prefix: &str,
195        msg: &DescriptorProto,
196    ) {
197        let message_name = qualified_name(prefix, "message", msg.name.as_deref());
198        self.symbols.insert(message_name.clone(), fd.clone());
199
200        for nested in &msg.nested_type {
201            self.process_message(fd.clone(), &message_name, nested);
202        }
203
204        for e in &msg.enum_type {
205            self.process_enum(fd.clone(), &message_name, e);
206        }
207
208        for field in &msg.field {
209            let field_name = qualified_name(prefix, "field", field.name.as_deref());
210            self.symbols.insert(field_name, fd.clone());
211        }
212
213        for oneof in &msg.oneof_decl {
214            let oneof_name = qualified_name(prefix, "oneof", oneof.name.as_deref());
215            self.symbols.insert(oneof_name, fd.clone());
216        }
217    }
218
219    fn process_enum(
220        &mut self,
221        fd: Arc<FileDescriptorProto>,
222        prefix: &str,
223        e: &EnumDescriptorProto,
224    ) {
225        let enum_name = qualified_name(prefix, "enum", e.name.as_deref());
226        self.symbols.insert(enum_name.clone(), fd.clone());
227
228        for value in &e.value {
229            let value_name = qualified_name(&enum_name, "enum value", value.name.as_deref());
230            self.symbols.insert(value_name, fd.clone());
231        }
232    }
233}
234
235fn qualified_name(prefix: &str, ty: &str, name: Option<&str>) -> String {
236    match name {
237        Some(name) if !prefix.is_empty() => format!("{prefix}.{name}"),
238        Some(name) => name.to_string(),
239        None => panic!("missing {ty} name"),
240    }
241}