use std::sync::Arc;
use connectrpc::Router;
use crate::reflector::Reflector;
#[derive(Clone)]
pub struct ReflectionService {
reflector: Arc<Reflector>,
}
impl ReflectionService {
#[must_use]
pub fn new(reflector: Reflector) -> Self {
Self {
reflector: Arc::new(reflector),
}
}
#[must_use]
pub fn from_arc(reflector: Arc<Reflector>) -> Self {
Self { reflector }
}
}
#[must_use]
pub fn install(router: Router, reflector: Reflector) -> Router {
let service = Arc::new(ReflectionService::new(reflector));
let router = crate::connect::grpc::reflection::v1::ServerReflectionExt::register(
Arc::clone(&service),
router,
);
crate::connect::grpc::reflection::v1alpha::ServerReflectionExt::register(service, router)
}
macro_rules! impl_server_reflection {
() => {
impl rpc::ServerReflection for crate::ReflectionService {
async fn server_reflection_info(
&self,
_ctx: ::connectrpc::RequestContext,
requests: ::connectrpc::ServiceStream<
::connectrpc::StreamMessage<pb::ServerReflectionRequest>,
>,
) -> ::connectrpc::ServiceResult<
::connectrpc::ServiceStream<pb::ServerReflectionResponse>,
> {
use futures::StreamExt;
let reflector = ::std::sync::Arc::clone(&self.reflector);
let responses = requests.map(move |request| {
let request = request?.to_owned_message();
respond(&reflector, request)
});
::connectrpc::Response::stream_ok(responses)
}
}
fn respond(
reflector: &$crate::reflector::Reflector,
request: pb::ServerReflectionRequest,
) -> Result<pb::ServerReflectionResponse, ::connectrpc::ConnectError> {
use pb::server_reflection_request::MessageRequest;
use pb::server_reflection_response::MessageResponse;
use $crate::reflector::Answer;
let Some(message_request) = &request.message_request else {
return Err(::connectrpc::ConnectError::invalid_argument(
"ServerReflectionRequest.message_request is not set",
));
};
let answer = match message_request {
MessageRequest::FileByFilename(name) => reflector.file_by_filename(name),
MessageRequest::FileContainingSymbol(symbol) => {
reflector.file_containing_symbol(symbol)
}
MessageRequest::FileContainingExtension(ext) => {
reflector.file_containing_extension(&ext.containing_type, ext.extension_number)
}
MessageRequest::AllExtensionNumbersOfType(name) => {
reflector.all_extension_numbers_of_type(name)
}
MessageRequest::ListServices(_) => reflector.list_services(),
};
let message_response = match answer {
Answer::Files(file_descriptor_proto) => {
MessageResponse::from(pb::FileDescriptorResponse {
file_descriptor_proto,
..Default::default()
})
}
Answer::ExtensionNumbers { base_type, numbers } => {
MessageResponse::from(pb::ExtensionNumberResponse {
base_type_name: base_type,
extension_number: numbers,
..Default::default()
})
}
Answer::Services(names) => MessageResponse::from(pb::ListServiceResponse {
service: names
.into_iter()
.map(|name| pb::ServiceResponse {
name,
..Default::default()
})
.collect(),
..Default::default()
}),
Answer::NotFound(message) => MessageResponse::from(pb::ErrorResponse {
error_code: 5,
error_message: message,
..Default::default()
}),
};
Ok(pb::ServerReflectionResponse {
valid_host: request.host.clone(),
original_request: ::buffa::MessageField::some(request),
message_response: Some(message_response),
..Default::default()
})
}
};
}
mod v1 {
use crate::connect::grpc::reflection::v1 as rpc;
use crate::proto::grpc::reflection::v1 as pb;
impl_server_reflection!();
}
mod v1alpha {
use crate::connect::grpc::reflection::v1alpha as rpc;
use crate::proto::grpc::reflection::v1alpha as pb;
impl_server_reflection!();
}
#[cfg(test)]
mod tests {
use buffa::Message;
use buffa_descriptor::generated::descriptor::{
FileDescriptorProto, FileDescriptorSet, ServiceDescriptorProto,
};
use connectrpc::client::{ClientConfig, HttpClient};
use tokio::net::TcpListener;
use super::*;
use crate::ServerReflectionClient;
use crate::wire::v1::ServerReflectionRequest;
use crate::wire::v1::server_reflection_request::MessageRequest;
use crate::wire::v1::server_reflection_response::MessageResponse;
fn test_set_bytes() -> Vec<u8> {
FileDescriptorSet {
file: vec![FileDescriptorProto {
name: Some("acme/api.proto".into()),
package: Some("acme.api".into()),
service: vec![ServiceDescriptorProto {
name: Some("Search".into()),
..Default::default()
}],
..Default::default()
}],
..Default::default()
}
.encode_to_vec()
}
async fn spawn_reflection_server() -> ServerReflectionClient<HttpClient> {
let reflector = Reflector::from_descriptor_set_bytes(&test_set_bytes()).unwrap();
let router = install(Router::new(), reflector);
let app = router.into_axum_router();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let config = ClientConfig::new(format!("http://{addr}").parse().unwrap());
ServerReflectionClient::new(HttpClient::plaintext(), config)
}
fn request(message_request: MessageRequest) -> ServerReflectionRequest {
ServerReflectionRequest {
host: "test-host".into(),
message_request: Some(message_request),
..Default::default()
}
}
#[tokio::test]
async fn full_stream_round_trip() {
let client = spawn_reflection_server().await;
let mut stream = client.server_reflection_info().await.unwrap();
stream
.send(request(MessageRequest::ListServices(String::new())))
.await
.unwrap();
stream
.send(request(MessageRequest::FileContainingSymbol(
"acme.api.Search".into(),
)))
.await
.unwrap();
stream
.send(request(MessageRequest::FileByFilename("nope.proto".into())))
.await
.unwrap();
stream.close_send();
let resp = stream.message().await.unwrap().unwrap().to_owned_message();
assert_eq!(resp.valid_host, "test-host");
assert!(matches!(
resp.original_request
.as_option()
.and_then(|r| r.message_request.as_ref()),
Some(MessageRequest::ListServices(_))
));
match resp.message_response.unwrap() {
MessageResponse::ListServicesResponse(list) => {
let names: Vec<_> = list.service.iter().map(|s| s.name.as_str()).collect();
assert_eq!(
names,
[
"acme.api.Search",
"grpc.reflection.v1.ServerReflection",
"grpc.reflection.v1alpha.ServerReflection",
]
);
}
other => panic!("expected list_services_response, got {other:?}"),
}
let resp = stream.message().await.unwrap().unwrap().to_owned_message();
match resp.message_response.unwrap() {
MessageResponse::FileDescriptorResponse(fd) => {
assert_eq!(fd.file_descriptor_proto.len(), 1);
let file =
FileDescriptorProto::decode_from_slice(&fd.file_descriptor_proto[0]).unwrap();
assert_eq!(file.name.as_deref(), Some("acme/api.proto"));
}
other => panic!("expected file_descriptor_response, got {other:?}"),
}
let resp = stream.message().await.unwrap().unwrap().to_owned_message();
match resp.message_response.unwrap() {
MessageResponse::ErrorResponse(err) => {
assert_eq!(err.error_code, 5);
assert!(err.error_message.contains("nope.proto"));
}
other => panic!("expected error_response, got {other:?}"),
}
assert!(stream.message().await.unwrap().is_none());
}
#[test]
fn crate_descriptor_set_makes_reflection_self_describing() {
let reflector = Reflector::from_descriptor_set_bytes(crate::FILE_DESCRIPTOR_SET).unwrap();
assert_eq!(
reflector.service_names(),
[
crate::SERVER_REFLECTION_SERVICE_NAME,
crate::SERVER_REFLECTION_V1ALPHA_SERVICE_NAME,
]
);
assert!(matches!(
reflector
.file_containing_symbol("grpc.reflection.v1.ServerReflection.ServerReflectionInfo"),
crate::reflector::Answer::Files(_)
));
}
#[tokio::test]
async fn v1alpha_route_is_served() {
use crate::connect::grpc::reflection::v1alpha::ServerReflectionClient as AlphaClient;
use crate::proto::grpc::reflection::v1alpha::ServerReflectionRequest;
use crate::proto::grpc::reflection::v1alpha::server_reflection_request::MessageRequest as AlphaRequest;
use crate::proto::grpc::reflection::v1alpha::server_reflection_response::MessageResponse as AlphaResponse;
let reflector = Reflector::from_descriptor_set_bytes(&test_set_bytes()).unwrap();
let router = install(Router::new(), reflector);
let app = router.into_axum_router();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let config = ClientConfig::new(format!("http://{addr}").parse().unwrap());
let client = AlphaClient::new(HttpClient::plaintext(), config);
let mut stream = client.server_reflection_info().await.unwrap();
stream
.send(ServerReflectionRequest {
message_request: Some(AlphaRequest::ListServices(String::new())),
..Default::default()
})
.await
.unwrap();
stream.close_send();
let resp = stream.message().await.unwrap().unwrap().to_owned_message();
match resp.message_response.unwrap() {
AlphaResponse::ListServicesResponse(list) => {
assert_eq!(list.service.len(), 3);
}
other => panic!("expected list_services_response, got {other:?}"),
}
}
}