connectrpc_reflection/
service.rs1use std::sync::Arc;
6
7use connectrpc::Router;
8
9use crate::reflector::Reflector;
10
11#[derive(Clone)]
29pub struct ReflectionService {
30 reflector: Arc<Reflector>,
31}
32
33impl ReflectionService {
34 #[must_use]
36 pub fn new(reflector: Reflector) -> Self {
37 Self {
38 reflector: Arc::new(reflector),
39 }
40 }
41
42 #[must_use]
44 pub fn from_arc(reflector: Arc<Reflector>) -> Self {
45 Self { reflector }
46 }
47}
48
49#[must_use]
59pub fn install(router: Router, reflector: Reflector) -> Router {
60 let service = Arc::new(ReflectionService::new(reflector));
61 let router = crate::connect::grpc::reflection::v1::ServerReflectionExt::register(
62 Arc::clone(&service),
63 router,
64 );
65 crate::connect::grpc::reflection::v1alpha::ServerReflectionExt::register(service, router)
66}
67
68macro_rules! impl_server_reflection {
74 () => {
75 impl rpc::ServerReflection for crate::ReflectionService {
76 async fn server_reflection_info(
77 &self,
78 _ctx: ::connectrpc::RequestContext,
79 requests: ::connectrpc::ServiceStream<
80 ::connectrpc::StreamMessage<pb::ServerReflectionRequest>,
81 >,
82 ) -> ::connectrpc::ServiceResult<
83 ::connectrpc::ServiceStream<pb::ServerReflectionResponse>,
84 > {
85 use futures::StreamExt;
86 let reflector = ::std::sync::Arc::clone(&self.reflector);
87 let responses = requests.map(move |request| {
88 let request = request?.to_owned_message();
89 respond(&reflector, request)
90 });
91 ::connectrpc::Response::stream_ok(responses)
92 }
93 }
94
95 fn respond(
100 reflector: &$crate::reflector::Reflector,
101 request: pb::ServerReflectionRequest,
102 ) -> Result<pb::ServerReflectionResponse, ::connectrpc::ConnectError> {
103 use pb::server_reflection_request::MessageRequest;
104 use pb::server_reflection_response::MessageResponse;
105 use $crate::reflector::Answer;
106
107 let Some(message_request) = &request.message_request else {
108 return Err(::connectrpc::ConnectError::invalid_argument(
109 "ServerReflectionRequest.message_request is not set",
110 ));
111 };
112
113 let answer = match message_request {
114 MessageRequest::FileByFilename(name) => reflector.file_by_filename(name),
115 MessageRequest::FileContainingSymbol(symbol) => {
116 reflector.file_containing_symbol(symbol)
117 }
118 MessageRequest::FileContainingExtension(ext) => {
119 reflector.file_containing_extension(&ext.containing_type, ext.extension_number)
120 }
121 MessageRequest::AllExtensionNumbersOfType(name) => {
122 reflector.all_extension_numbers_of_type(name)
123 }
124 MessageRequest::ListServices(_) => reflector.list_services(),
125 };
126
127 let message_response = match answer {
128 Answer::Files(file_descriptor_proto) => {
129 MessageResponse::from(pb::FileDescriptorResponse {
130 file_descriptor_proto,
131 ..Default::default()
132 })
133 }
134 Answer::ExtensionNumbers { base_type, numbers } => {
135 MessageResponse::from(pb::ExtensionNumberResponse {
136 base_type_name: base_type,
137 extension_number: numbers,
138 ..Default::default()
139 })
140 }
141 Answer::Services(names) => MessageResponse::from(pb::ListServiceResponse {
142 service: names
143 .into_iter()
144 .map(|name| pb::ServiceResponse {
145 name,
146 ..Default::default()
147 })
148 .collect(),
149 ..Default::default()
150 }),
151 Answer::NotFound(message) => MessageResponse::from(pb::ErrorResponse {
152 error_code: 5,
155 error_message: message,
156 ..Default::default()
157 }),
158 };
159
160 Ok(pb::ServerReflectionResponse {
161 valid_host: request.host.clone(),
162 original_request: ::buffa::MessageField::some(request),
163 message_response: Some(message_response),
164 ..Default::default()
165 })
166 }
167 };
168}
169
170mod v1 {
171 use crate::connect::grpc::reflection::v1 as rpc;
172 use crate::proto::grpc::reflection::v1 as pb;
173
174 impl_server_reflection!();
175}
176
177mod v1alpha {
178 use crate::connect::grpc::reflection::v1alpha as rpc;
179 use crate::proto::grpc::reflection::v1alpha as pb;
180
181 impl_server_reflection!();
182}
183
184#[cfg(test)]
185mod tests {
186 use buffa::Message;
187 use buffa_descriptor::generated::descriptor::{
188 FileDescriptorProto, FileDescriptorSet, ServiceDescriptorProto,
189 };
190 use connectrpc::client::{ClientConfig, HttpClient};
191 use tokio::net::TcpListener;
192
193 use super::*;
194 use crate::ServerReflectionClient;
198 use crate::wire::v1::ServerReflectionRequest;
199 use crate::wire::v1::server_reflection_request::MessageRequest;
200 use crate::wire::v1::server_reflection_response::MessageResponse;
201
202 fn test_set_bytes() -> Vec<u8> {
203 FileDescriptorSet {
204 file: vec![FileDescriptorProto {
205 name: Some("acme/api.proto".into()),
206 package: Some("acme.api".into()),
207 service: vec![ServiceDescriptorProto {
208 name: Some("Search".into()),
209 ..Default::default()
210 }],
211 ..Default::default()
212 }],
213 ..Default::default()
214 }
215 .encode_to_vec()
216 }
217
218 async fn spawn_reflection_server() -> ServerReflectionClient<HttpClient> {
221 let reflector = Reflector::from_descriptor_set_bytes(&test_set_bytes()).unwrap();
222 let router = install(Router::new(), reflector);
223 let app = router.into_axum_router();
224 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
225 let addr = listener.local_addr().unwrap();
226 tokio::spawn(async move {
227 axum::serve(listener, app).await.unwrap();
228 });
229 let config = ClientConfig::new(format!("http://{addr}").parse().unwrap());
230 ServerReflectionClient::new(HttpClient::plaintext(), config)
231 }
232
233 fn request(message_request: MessageRequest) -> ServerReflectionRequest {
234 ServerReflectionRequest {
235 host: "test-host".into(),
236 message_request: Some(message_request),
237 ..Default::default()
238 }
239 }
240
241 #[tokio::test]
242 async fn full_stream_round_trip() {
243 let client = spawn_reflection_server().await;
244 let mut stream = client.server_reflection_info().await.unwrap();
245
246 stream
247 .send(request(MessageRequest::ListServices(String::new())))
248 .await
249 .unwrap();
250 stream
251 .send(request(MessageRequest::FileContainingSymbol(
252 "acme.api.Search".into(),
253 )))
254 .await
255 .unwrap();
256 stream
257 .send(request(MessageRequest::FileByFilename("nope.proto".into())))
258 .await
259 .unwrap();
260 stream.close_send();
261
262 let resp = stream.message().await.unwrap().unwrap().to_owned_message();
264 assert_eq!(resp.valid_host, "test-host");
265 assert!(matches!(
266 resp.original_request
267 .as_option()
268 .and_then(|r| r.message_request.as_ref()),
269 Some(MessageRequest::ListServices(_))
270 ));
271 match resp.message_response.unwrap() {
272 MessageResponse::ListServicesResponse(list) => {
273 let names: Vec<_> = list.service.iter().map(|s| s.name.as_str()).collect();
274 assert_eq!(
275 names,
276 [
277 "acme.api.Search",
278 "grpc.reflection.v1.ServerReflection",
279 "grpc.reflection.v1alpha.ServerReflection",
280 ]
281 );
282 }
283 other => panic!("expected list_services_response, got {other:?}"),
284 }
285
286 let resp = stream.message().await.unwrap().unwrap().to_owned_message();
288 match resp.message_response.unwrap() {
289 MessageResponse::FileDescriptorResponse(fd) => {
290 assert_eq!(fd.file_descriptor_proto.len(), 1);
291 let file =
292 FileDescriptorProto::decode_from_slice(&fd.file_descriptor_proto[0]).unwrap();
293 assert_eq!(file.name.as_deref(), Some("acme/api.proto"));
294 }
295 other => panic!("expected file_descriptor_response, got {other:?}"),
296 }
297
298 let resp = stream.message().await.unwrap().unwrap().to_owned_message();
300 match resp.message_response.unwrap() {
301 MessageResponse::ErrorResponse(err) => {
302 assert_eq!(err.error_code, 5);
303 assert!(err.error_message.contains("nope.proto"));
304 }
305 other => panic!("expected error_response, got {other:?}"),
306 }
307
308 assert!(stream.message().await.unwrap().is_none());
309 }
310
311 #[test]
312 fn crate_descriptor_set_makes_reflection_self_describing() {
313 let reflector = Reflector::from_descriptor_set_bytes(crate::FILE_DESCRIPTOR_SET).unwrap();
314 assert_eq!(
315 reflector.service_names(),
316 [
317 crate::SERVER_REFLECTION_SERVICE_NAME,
318 crate::SERVER_REFLECTION_V1ALPHA_SERVICE_NAME,
319 ]
320 );
321 assert!(matches!(
322 reflector
323 .file_containing_symbol("grpc.reflection.v1.ServerReflection.ServerReflectionInfo"),
324 crate::reflector::Answer::Files(_)
325 ));
326 }
327
328 #[tokio::test]
329 async fn v1alpha_route_is_served() {
330 use crate::connect::grpc::reflection::v1alpha::ServerReflectionClient as AlphaClient;
335 use crate::proto::grpc::reflection::v1alpha::ServerReflectionRequest;
336 use crate::proto::grpc::reflection::v1alpha::server_reflection_request::MessageRequest as AlphaRequest;
337 use crate::proto::grpc::reflection::v1alpha::server_reflection_response::MessageResponse as AlphaResponse;
338
339 let reflector = Reflector::from_descriptor_set_bytes(&test_set_bytes()).unwrap();
340 let router = install(Router::new(), reflector);
341 let app = router.into_axum_router();
342 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
343 let addr = listener.local_addr().unwrap();
344 tokio::spawn(async move {
345 axum::serve(listener, app).await.unwrap();
346 });
347 let config = ClientConfig::new(format!("http://{addr}").parse().unwrap());
348 let client = AlphaClient::new(HttpClient::plaintext(), config);
349
350 let mut stream = client.server_reflection_info().await.unwrap();
351 stream
352 .send(ServerReflectionRequest {
353 message_request: Some(AlphaRequest::ListServices(String::new())),
354 ..Default::default()
355 })
356 .await
357 .unwrap();
358 stream.close_send();
359
360 let resp = stream.message().await.unwrap().unwrap().to_owned_message();
361 match resp.message_response.unwrap() {
362 AlphaResponse::ListServicesResponse(list) => {
363 assert_eq!(list.service.len(), 3);
364 }
365 other => panic!("expected list_services_response, got {other:?}"),
366 }
367 }
368}