1use std::{collections::HashMap, sync::Arc};
2
3use futures_util::StreamExt;
4use poem::{IntoEndpoint, endpoint::BoxEndpoint};
5use prost::Message;
6use prost_types::{DescriptorProto, EnumDescriptorProto, FileDescriptorProto, FileDescriptorSet};
7use proto::{
8 server_reflection_request::MessageRequest, server_reflection_response::MessageResponse,
9};
10
11use crate::{Code, Request, Response, Service, Status, Streaming};
12
13#[allow(unreachable_pub)]
14#[allow(clippy::enum_variant_names)]
15#[allow(clippy::derive_partial_eq_without_eq)]
16mod proto {
17 include!(concat!(env!("OUT_DIR"), "/grpc.reflection.v1alpha.rs"));
18}
19
20pub(crate) const FILE_DESCRIPTOR_SET: &[u8] = include_file_descriptor_set!("grpc-reflection.bin");
21
22struct State {
23 service_names: Vec<proto::ServiceResponse>,
24 files: HashMap<String, Arc<FileDescriptorProto>>,
25 symbols: HashMap<String, Arc<FileDescriptorProto>>,
26}
27
28impl State {
29 #[allow(clippy::result_large_err)]
30 fn file_by_filename(&self, filename: &str) -> Result<MessageResponse, Status> {
31 match self.files.get(filename) {
32 None => {
33 Err(Status::new(Code::NotFound)
34 .with_message(format!("file '{filename}' not found")))
35 }
36 Some(fd) => {
37 let mut encoded_fd = Vec::new();
38 if fd.clone().encode(&mut encoded_fd).is_err() {
39 return Err(Status::new(Code::Internal).with_message("encoding error"));
40 }
41
42 Ok(MessageResponse::FileDescriptorResponse(
43 proto::FileDescriptorResponse {
44 file_descriptor_proto: vec![encoded_fd],
45 },
46 ))
47 }
48 }
49 }
50
51 #[allow(clippy::result_large_err)]
52 fn symbol_by_name(&self, symbol: &str) -> Result<MessageResponse, Status> {
53 match self.symbols.get(symbol) {
54 None => {
55 Err(Status::new(Code::NotFound)
56 .with_message(format!("symbol '{symbol}' not found")))
57 }
58 Some(fd) => {
59 let mut encoded_fd = Vec::new();
60 if fd.clone().encode(&mut encoded_fd).is_err() {
61 return Err(Status::new(Code::Internal).with_message("encoding error"));
62 };
63
64 Ok(MessageResponse::FileDescriptorResponse(
65 proto::FileDescriptorResponse {
66 file_descriptor_proto: vec![encoded_fd],
67 },
68 ))
69 }
70 }
71 }
72
73 fn list_services(&self) -> MessageResponse {
74 MessageResponse::ListServicesResponse(proto::ListServiceResponse {
75 service: self.service_names.clone(),
76 })
77 }
78}
79
80struct ServerReflectionService {
82 state: Arc<State>,
83}
84
85impl proto::ServerReflection for ServerReflectionService {
86 async fn server_reflection_info(
87 &self,
88 request: Request<Streaming<proto::ServerReflectionRequest>>,
89 ) -> Result<Response<Streaming<proto::ServerReflectionResponse>>, Status> {
90 let mut request_stream = request.into_inner();
91 let state = self.state.clone();
92
93 Ok(Response::new(Streaming::new(async_stream::try_stream! {
94 while let Some(req) = request_stream.next().await.transpose()? {
95 let resp = match &req.message_request {
96 Some(MessageRequest::FileByFilename(filename)) => state.file_by_filename(filename),
97 Some(MessageRequest::FileContainingSymbol(symbol)) => state.symbol_by_name(symbol),
98 Some(MessageRequest::FileContainingExtension(_) | MessageRequest::AllExtensionNumbersOfType(_)) => Err(Status::new(Code::Unimplemented)),
99 Some(MessageRequest::ListServices(_)) => Ok(state.list_services()),
100 None => Err(Status::new(Code::InvalidArgument)),
101 }?;
102
103 yield proto::ServerReflectionResponse {
104 valid_host: req.host.clone(),
105 original_request: Some(req.clone()),
106 message_response: Some(resp),
107 };
108 }
109 })))
110 }
111}
112
113#[derive(Debug, Default)]
115pub struct Reflection {
116 file_descriptor_sets: Vec<FileDescriptorSet>,
117 service_names: Vec<String>,
118 symbols: HashMap<String, Arc<FileDescriptorProto>>,
119}
120
121impl Reflection {
122 pub fn new() -> Self {
124 Default::default()
125 }
126
127 pub fn add_file_descriptor_set(mut self, data: &[u8]) -> Self {
129 self.file_descriptor_sets
130 .push(FileDescriptorSet::decode(data).expect("valid file descriptor sets"));
131 self
132 }
133
134 pub fn build(
136 self,
137 ) -> impl IntoEndpoint<Endpoint = BoxEndpoint<'static, poem::Response>> + Service {
138 let mut this = self.add_file_descriptor_set(FILE_DESCRIPTOR_SET);
139
140 let fd_iter = std::mem::take(&mut this.file_descriptor_sets)
141 .into_iter()
142 .flat_map(|fds| fds.file.into_iter());
143 let mut files = HashMap::with_capacity(fd_iter.size_hint().0);
144
145 for fd in fd_iter {
146 let fd = Arc::new(fd);
147
148 match fd.name.clone() {
149 Some(filename) => {
150 files.insert(filename, fd.clone());
151 }
152 None => panic!("missing file name"),
153 }
154
155 let prefix = fd.package.as_deref().unwrap_or_default();
156
157 for proto in &fd.message_type {
158 this.process_message(fd.clone(), prefix, proto);
159 }
160
161 for proto in &fd.enum_type {
162 this.process_enum(fd.clone(), prefix, proto);
163 }
164
165 for service in &fd.service {
166 let service_name = qualified_name(prefix, "service", service.name.as_deref());
167 this.service_names.push(service_name.clone());
168 this.symbols.insert(service_name.clone(), fd.clone());
169
170 for method in &service.method {
171 let method_name =
172 qualified_name(&service_name, "method", method.name.as_deref());
173 this.symbols.insert(method_name, fd.clone());
174 }
175 }
176 }
177
178 proto::ServerReflectionServer::new(ServerReflectionService {
179 state: Arc::new(State {
180 service_names: this
181 .service_names
182 .into_iter()
183 .map(|name| proto::ServiceResponse { name })
184 .collect(),
185 files,
186 symbols: this.symbols,
187 }),
188 })
189 }
190
191 fn process_message(
192 &mut self,
193 fd: Arc<FileDescriptorProto>,
194 prefix: &str,
195 msg: &DescriptorProto,
196 ) {
197 let message_name = qualified_name(prefix, "message", msg.name.as_deref());
198 self.symbols.insert(message_name.clone(), fd.clone());
199
200 for nested in &msg.nested_type {
201 self.process_message(fd.clone(), &message_name, nested);
202 }
203
204 for e in &msg.enum_type {
205 self.process_enum(fd.clone(), &message_name, e);
206 }
207
208 for field in &msg.field {
209 let field_name = qualified_name(prefix, "field", field.name.as_deref());
210 self.symbols.insert(field_name, fd.clone());
211 }
212
213 for oneof in &msg.oneof_decl {
214 let oneof_name = qualified_name(prefix, "oneof", oneof.name.as_deref());
215 self.symbols.insert(oneof_name, fd.clone());
216 }
217 }
218
219 fn process_enum(
220 &mut self,
221 fd: Arc<FileDescriptorProto>,
222 prefix: &str,
223 e: &EnumDescriptorProto,
224 ) {
225 let enum_name = qualified_name(prefix, "enum", e.name.as_deref());
226 self.symbols.insert(enum_name.clone(), fd.clone());
227
228 for value in &e.value {
229 let value_name = qualified_name(&enum_name, "enum value", value.name.as_deref());
230 self.symbols.insert(value_name, fd.clone());
231 }
232 }
233}
234
235fn qualified_name(prefix: &str, ty: &str, name: Option<&str>) -> String {
236 match name {
237 Some(name) if !prefix.is_empty() => format!("{prefix}.{name}"),
238 Some(name) => name.to_string(),
239 None => panic!("missing {ty} name"),
240 }
241}