use prost_reflect::{prost::Message, prost_types, DescriptorPool, ServiceDescriptor};
use tonic::{
transport::{Channel, Endpoint},
Status,
};
use tonic_reflection::pb::v1::{
server_reflection_client::ServerReflectionClient, server_reflection_response::MessageResponse,
ServerReflectionRequest,
};
use tracing::{debug, error, trace};
pub struct ReflectionClient {
channel: Channel,
pool: DescriptorPool,
}
impl ReflectionClient {
pub async fn new(endpoint: Endpoint) -> Result<Self, Status> {
let channel = endpoint.connect().await.map_err(|e| {
error!("Failed to connect to endpoint: {}", e);
Status::unavailable(format!("Failed to connect to endpoint: {}", e))
})?;
let mut pool = DescriptorPool::new();
let mut client = ServerReflectionClient::new(channel.clone());
let request = tonic::Request::new(futures_util::stream::iter(vec![
ServerReflectionRequest {
host: "".to_string(),
message_request: Some(
tonic_reflection::pb::v1::server_reflection_request::MessageRequest::ListServices(
"*".to_string(),
),
),
}
]));
let mut service_names = Vec::new();
match client.server_reflection_info(request).await {
Ok(response) => {
let mut stream = response.into_inner();
while let Some(reply) = stream.message().await.map_err(|e| {
error!("Failed to read reflection response: {}", e);
Status::internal(format!("Failed to read reflection response: {}", e))
})? {
if let Some(MessageResponse::ListServicesResponse(services)) =
reply.message_response
{
trace!("Found {} services", services.service.len());
for service in services.service {
debug!("Found service: {}", service.name);
service_names.push(service.name.clone());
}
}
}
}
Err(e) => {
error!("Failed to get service list: {}", e);
return Err(Status::internal(format!("Failed to get service list: {}", e)));
}
}
for service_name in &service_names {
Self::get_file_descriptor_for_service(&mut client, &mut pool, service_name).await?;
}
debug!(
"Created reflection client for endpoint with {} services",
pool.services().count()
);
Ok(Self { channel, pool })
}
async fn get_file_descriptor_for_service(
client: &mut ServerReflectionClient<Channel>,
pool: &mut DescriptorPool,
service_name: &str,
) -> Result<(), Status> {
trace!("Getting file descriptor for service: {}", service_name);
let request = tonic::Request::new(futures_util::stream::iter(vec![
ServerReflectionRequest {
host: "".to_string(),
message_request: Some(
tonic_reflection::pb::v1::server_reflection_request::MessageRequest::FileContainingSymbol(
service_name.to_string(),
),
),
}
]));
match client.server_reflection_info(request).await {
Ok(response) => {
let mut stream = response.into_inner();
while let Some(reply) = stream.message().await.map_err(|e| {
error!("Failed to read reflection response: {}", e);
Status::internal(format!("Failed to read reflection response: {}", e))
})? {
if let Some(MessageResponse::FileDescriptorResponse(descriptor_response)) =
reply.message_response
{
trace!(
"Found {} file descriptors for service {}",
descriptor_response.file_descriptor_proto.len(),
service_name
);
for file_descriptor_proto in descriptor_response.file_descriptor_proto {
match prost_types::FileDescriptorProto::decode(&*file_descriptor_proto)
{
Ok(file_descriptor) => {
if let Err(e) = pool.add_file_descriptor_proto(file_descriptor)
{
error!(
"Failed to register file descriptor for service {}: {}",
service_name, e
);
return Err(Status::internal(format!(
"Failed to register file descriptor for service {}: {}",
service_name, e
)));
} else {
debug!(
"Registered file descriptor for service: {}",
service_name
);
}
}
Err(e) => {
error!(
"Failed to decode file descriptor for service {}: {}",
service_name, e
);
return Err(Status::data_loss(format!(
"Failed to decode file descriptor for service {}: {}",
service_name, e
)));
}
}
}
}
}
}
Err(e) => {
error!("Failed to get file descriptor for service {}: {}", service_name, e);
return Err(Status::internal(format!(
"Failed to get file descriptor for service {}: {}",
service_name, e
)));
}
}
Ok(())
}
pub fn get_service(&self, service_name: &str) -> Option<ServiceDescriptor> {
self.pool.get_service_by_name(service_name)
}
pub fn channel(&self) -> Channel {
self.channel.clone()
}
pub fn pool(&self) -> &DescriptorPool {
&self.pool
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_module_compiles() {
}
}