mockforge_grpc/reflection/
client.rs1use prost_reflect::{prost::Message, prost_types, DescriptorPool, ServiceDescriptor};
4use tonic::{
5 transport::{Channel, Endpoint},
6 Status,
7};
8use tonic_reflection::pb::v1::{
9 server_reflection_client::ServerReflectionClient, server_reflection_response::MessageResponse,
10 ServerReflectionRequest,
11};
12use tracing::{debug, error, trace};
13
14pub struct ReflectionClient {
16 channel: Channel,
18 pool: DescriptorPool,
20}
21
22impl ReflectionClient {
23 pub async fn new(endpoint: Endpoint) -> Result<Self, Status> {
25 let channel = endpoint.connect().await.map_err(|e| {
26 error!("Failed to connect to endpoint: {}", e);
27 Status::unavailable(format!("Failed to connect to endpoint: {}", e))
28 })?;
29
30 let mut pool = DescriptorPool::new();
31
32 let mut client = ServerReflectionClient::new(channel.clone());
34
35 let request = tonic::Request::new(futures_util::stream::iter(vec![
37 ServerReflectionRequest {
38 host: "".to_string(),
39 message_request: Some(
40 tonic_reflection::pb::v1::server_reflection_request::MessageRequest::ListServices(
41 "*".to_string(),
42 ),
43 ),
44 }
45 ]));
46
47 let mut service_names = Vec::new();
48
49 match client.server_reflection_info(request).await {
50 Ok(response) => {
51 let mut stream = response.into_inner();
52 while let Some(reply) = stream.message().await.map_err(|e| {
53 error!("Failed to read reflection response: {}", e);
54 Status::internal(format!("Failed to read reflection response: {}", e))
55 })? {
56 if let Some(MessageResponse::ListServicesResponse(services)) =
57 reply.message_response
58 {
59 trace!("Found {} services", services.service.len());
60 for service in services.service {
61 debug!("Found service: {}", service.name);
62 service_names.push(service.name.clone());
63 }
64 }
65 }
66 }
67 Err(e) => {
68 error!("Failed to get service list: {}", e);
69 return Err(Status::internal(format!("Failed to get service list: {}", e)));
70 }
71 }
72
73 for service_name in &service_names {
75 Self::get_file_descriptor_for_service(&mut client, &mut pool, service_name).await?;
76 }
77
78 debug!(
79 "Created reflection client for endpoint with {} services",
80 pool.services().count()
81 );
82
83 Ok(Self { channel, pool })
84 }
85
86 async fn get_file_descriptor_for_service(
88 client: &mut ServerReflectionClient<Channel>,
89 pool: &mut DescriptorPool,
90 service_name: &str,
91 ) -> Result<(), Status> {
92 trace!("Getting file descriptor for service: {}", service_name);
93
94 let request = tonic::Request::new(futures_util::stream::iter(vec![
95 ServerReflectionRequest {
96 host: "".to_string(),
97 message_request: Some(
98 tonic_reflection::pb::v1::server_reflection_request::MessageRequest::FileContainingSymbol(
99 service_name.to_string(),
100 ),
101 ),
102 }
103 ]));
104
105 match client.server_reflection_info(request).await {
106 Ok(response) => {
107 let mut stream = response.into_inner();
108 while let Some(reply) = stream.message().await.map_err(|e| {
109 error!("Failed to read reflection response: {}", e);
110 Status::internal(format!("Failed to read reflection response: {}", e))
111 })? {
112 if let Some(MessageResponse::FileDescriptorResponse(descriptor_response)) =
113 reply.message_response
114 {
115 trace!(
116 "Found {} file descriptors for service {}",
117 descriptor_response.file_descriptor_proto.len(),
118 service_name
119 );
120 for file_descriptor_proto in descriptor_response.file_descriptor_proto {
121 match prost_types::FileDescriptorProto::decode(&*file_descriptor_proto)
122 {
123 Ok(file_descriptor) => {
124 if let Err(e) = pool.add_file_descriptor_proto(file_descriptor)
125 {
126 error!(
127 "Failed to register file descriptor for service {}: {}",
128 service_name, e
129 );
130 return Err(Status::internal(format!(
131 "Failed to register file descriptor for service {}: {}",
132 service_name, e
133 )));
134 } else {
135 debug!(
136 "Registered file descriptor for service: {}",
137 service_name
138 );
139 }
140 }
141 Err(e) => {
142 error!(
143 "Failed to decode file descriptor for service {}: {}",
144 service_name, e
145 );
146 return Err(Status::data_loss(format!(
147 "Failed to decode file descriptor for service {}: {}",
148 service_name, e
149 )));
150 }
151 }
152 }
153 }
154 }
155 }
156 Err(e) => {
157 error!("Failed to get file descriptor for service {}: {}", service_name, e);
158 return Err(Status::internal(format!(
159 "Failed to get file descriptor for service {}: {}",
160 service_name, e
161 )));
162 }
163 }
164
165 Ok(())
166 }
167
168 pub fn get_service(&self, service_name: &str) -> Option<ServiceDescriptor> {
170 self.pool.get_service_by_name(service_name)
171 }
172
173 pub fn channel(&self) -> Channel {
175 self.channel.clone()
176 }
177
178 pub fn pool(&self) -> &DescriptorPool {
180 &self.pool
181 }
182}
183
184#[cfg(test)]
185mod tests {
186
187 #[test]
188 fn test_module_compiles() {}
189}