1use crate::service_info::{MethodInfo, ServiceInfo};
2use prost::Message;
3use std::error::Error;
4use tokio_stream::StreamExt;
5use tonic::{
6 transport::{Channel, Endpoint},
7 Request,
8};
9use tonic_reflection::pb::{
10 server_reflection_client::ServerReflectionClient, server_reflection_request::MessageRequest,
11 server_reflection_response::MessageResponse, ServerReflectionRequest,
12};
13
14pub struct ReflectionClient {
15 client: ServerReflectionClient<Channel>,
16}
17
18impl ReflectionClient {
19 pub async fn new(endpoint: String) -> Result<Self, Box<dyn Error>> {
20 let channel = Channel::from_shared(endpoint)?.connect().await?;
21 Ok(Self {
22 client: ServerReflectionClient::new(channel),
23 })
24 }
25
26 pub async fn connect(addr: &str) -> Result<Self, Box<dyn Error>> {
27 let endpoint = Endpoint::new(addr.to_string())?.connect().await?;
28 Ok(Self {
29 client: ServerReflectionClient::new(endpoint),
30 })
31 }
32
33 async fn make_request(
34 &mut self,
35 request: ServerReflectionRequest,
36 ) -> Result<MessageResponse, Box<dyn Error>> {
37 let request = Request::new(tokio_stream::once(request));
38 let mut inbound = self
39 .client
40 .server_reflection_info(request)
41 .await?
42 .into_inner();
43
44 if let Some(response) = inbound.next().await {
45 return Ok(response?.message_response.expect("some MessageResponse"));
46 }
47
48 Err("No response received".into())
49 }
50
51 pub async fn list_services(&mut self) -> Result<Vec<ServiceInfo>, Box<dyn Error>> {
52 let response = self
53 .make_request(ServerReflectionRequest {
54 host: "".to_string(),
55 message_request: Some(MessageRequest::ListServices(String::new())),
56 })
57 .await?;
58
59 if let MessageResponse::ListServicesResponse(services_response) = response {
60 let mut services_info = Vec::new();
61
62 for service in services_response.service {
63 let descriptors = self.get_file_descriptor(service.name.clone()).await?;
64
65 for file_descriptor in descriptors {
66 for service in file_descriptor.service {
67 let methods: Vec<MethodInfo> = service
68 .method
69 .into_iter()
70 .map(|method| {
71 method
72 .name
73 .ok_or_else(|| {
74 format!(
75 "Method name is missing for service {:?}",
76 service.name
77 )
78 })
79 .map(|name| MethodInfo { name })
80 })
81 .collect::<Result<Vec<MethodInfo>, _>>()?;
82
83 let package = file_descriptor.package.clone().ok_or_else(|| {
84 format!("Package name is missing for service {:?}", service.name)
85 })?;
86
87 let service_name = service.name.ok_or_else(|| {
88 format!("Service name is missing for package {}", package)
89 })?;
90
91 services_info.push(ServiceInfo {
92 package,
93 service: service_name,
94 methods,
95 });
96 }
97 }
98 }
99
100 Ok(services_info)
101 } else {
102 Err("Expected a ListServicesResponse variant".into())
103 }
104 }
105
106 async fn get_file_descriptor(
107 &mut self,
108 symbol: String,
109 ) -> Result<Vec<prost_types::FileDescriptorProto>, Box<dyn Error>> {
110 let response = self
111 .make_request(ServerReflectionRequest {
112 host: "".to_string(),
113 message_request: Some(MessageRequest::FileContainingSymbol(symbol)),
114 })
115 .await?;
116
117 if let MessageResponse::FileDescriptorResponse(descriptor_response) = response {
118 let mut descriptors = Vec::new();
119 for file_descriptor_proto in descriptor_response.file_descriptor_proto {
120 let file_descriptor =
121 prost_types::FileDescriptorProto::decode(&file_descriptor_proto[..])?;
122 descriptors.push(file_descriptor);
123 }
124 Ok(descriptors)
125 } else {
126 Err("Expected a FileDescriptorResponse variant".into())
127 }
128 }
129}