grpc_ease/
reflection.rs

1use crate::service_info::{MethodInfo, ServiceInfo};
2use prost::Message;
3use std::error::Error;
4use tokio_stream::StreamExt;
5use tonic::{
6    transport::{Channel, Endpoint},
7    Request,
8};
9use tonic_reflection::pb::{
10    server_reflection_client::ServerReflectionClient, server_reflection_request::MessageRequest,
11    server_reflection_response::MessageResponse, ServerReflectionRequest,
12};
13
14pub struct ReflectionClient {
15    client: ServerReflectionClient<Channel>,
16}
17
18impl ReflectionClient {
19    pub async fn new(endpoint: String) -> Result<Self, Box<dyn Error>> {
20        let channel = Channel::from_shared(endpoint)?.connect().await?;
21        Ok(Self {
22            client: ServerReflectionClient::new(channel),
23        })
24    }
25
26    pub async fn connect(addr: &str) -> Result<Self, Box<dyn Error>> {
27        let endpoint = Endpoint::new(addr.to_string())?.connect().await?;
28        Ok(Self {
29            client: ServerReflectionClient::new(endpoint),
30        })
31    }
32
33    async fn make_request(
34        &mut self,
35        request: ServerReflectionRequest,
36    ) -> Result<MessageResponse, Box<dyn Error>> {
37        let request = Request::new(tokio_stream::once(request));
38        let mut inbound = self
39            .client
40            .server_reflection_info(request)
41            .await?
42            .into_inner();
43
44        if let Some(response) = inbound.next().await {
45            return Ok(response?.message_response.expect("some MessageResponse"));
46        }
47
48        Err("No response received".into())
49    }
50
51    pub async fn list_services(&mut self) -> Result<Vec<ServiceInfo>, Box<dyn Error>> {
52        let response = self
53            .make_request(ServerReflectionRequest {
54                host: "".to_string(),
55                message_request: Some(MessageRequest::ListServices(String::new())),
56            })
57            .await?;
58
59        if let MessageResponse::ListServicesResponse(services_response) = response {
60            let mut services_info = Vec::new();
61
62            for service in services_response.service {
63                let descriptors = self.get_file_descriptor(service.name.clone()).await?;
64
65                for file_descriptor in descriptors {
66                    for service in file_descriptor.service {
67                        let methods: Vec<MethodInfo> = service
68                            .method
69                            .into_iter()
70                            .map(|method| {
71                                method
72                                    .name
73                                    .ok_or_else(|| {
74                                        format!(
75                                            "Method name is missing for service {:?}",
76                                            service.name
77                                        )
78                                    })
79                                    .map(|name| MethodInfo { name })
80                            })
81                            .collect::<Result<Vec<MethodInfo>, _>>()?;
82
83                        let package = file_descriptor.package.clone().ok_or_else(|| {
84                            format!("Package name is missing for service {:?}", service.name)
85                        })?;
86
87                        let service_name = service.name.ok_or_else(|| {
88                            format!("Service name is missing for package {}", package)
89                        })?;
90
91                        services_info.push(ServiceInfo {
92                            package,
93                            service: service_name,
94                            methods,
95                        });
96                    }
97                }
98            }
99
100            Ok(services_info)
101        } else {
102            Err("Expected a ListServicesResponse variant".into())
103        }
104    }
105
106    async fn get_file_descriptor(
107        &mut self,
108        symbol: String,
109    ) -> Result<Vec<prost_types::FileDescriptorProto>, Box<dyn Error>> {
110        let response = self
111            .make_request(ServerReflectionRequest {
112                host: "".to_string(),
113                message_request: Some(MessageRequest::FileContainingSymbol(symbol)),
114            })
115            .await?;
116
117        if let MessageResponse::FileDescriptorResponse(descriptor_response) = response {
118            let mut descriptors = Vec::new();
119            for file_descriptor_proto in descriptor_response.file_descriptor_proto {
120                let file_descriptor =
121                    prost_types::FileDescriptorProto::decode(&file_descriptor_proto[..])?;
122                descriptors.push(file_descriptor);
123            }
124            Ok(descriptors)
125        } else {
126            Err("Expected a FileDescriptorResponse variant".into())
127        }
128    }
129}