mockforge_grpc/reflection/
client.rs

1//! gRPC reflection client for dynamically discovering services and methods.
2
3use prost_reflect::{prost::Message, prost_types, DescriptorPool, ServiceDescriptor};
4use tonic::{
5    transport::{Channel, Endpoint},
6    Status,
7};
8use tonic_reflection::pb::v1::{
9    server_reflection_client::ServerReflectionClient, server_reflection_response::MessageResponse,
10    ServerReflectionRequest,
11};
12use tracing::{debug, error, trace};
13
14/// A client that uses gRPC reflection to discover services and methods
15pub struct ReflectionClient {
16    /// The gRPC channel to the target server
17    channel: Channel,
18    /// The descriptor pool containing all discovered services
19    pool: DescriptorPool,
20}
21
22impl ReflectionClient {
23    /// Create a new reflection client
24    pub async fn new(endpoint: Endpoint) -> Result<Self, Status> {
25        let channel = endpoint.connect().await.map_err(|e| {
26            error!("Failed to connect to endpoint: {}", e);
27            Status::unavailable(format!("Failed to connect to endpoint: {}", e))
28        })?;
29
30        let mut pool = DescriptorPool::new();
31
32        // Create a reflection client
33        let mut client = ServerReflectionClient::new(channel.clone());
34
35        // Get the list of services
36        let request = tonic::Request::new(futures_util::stream::iter(vec![
37            ServerReflectionRequest {
38                host: "".to_string(),
39                message_request: Some(
40                    tonic_reflection::pb::v1::server_reflection_request::MessageRequest::ListServices(
41                        "*".to_string(),
42                    ),
43                ),
44            }
45        ]));
46
47        let mut service_names = Vec::new();
48
49        match client.server_reflection_info(request).await {
50            Ok(response) => {
51                let mut stream = response.into_inner();
52                while let Some(reply) = stream.message().await.map_err(|e| {
53                    error!("Failed to read reflection response: {}", e);
54                    Status::internal(format!("Failed to read reflection response: {}", e))
55                })? {
56                    if let Some(MessageResponse::ListServicesResponse(services)) =
57                        reply.message_response
58                    {
59                        trace!("Found {} services", services.service.len());
60                        for service in services.service {
61                            debug!("Found service: {}", service.name);
62                            service_names.push(service.name.clone());
63                        }
64                    }
65                }
66            }
67            Err(e) => {
68                error!("Failed to get service list: {}", e);
69                return Err(Status::internal(format!("Failed to get service list: {}", e)));
70            }
71        }
72
73        // For each service, get its file descriptor
74        for service_name in &service_names {
75            Self::get_file_descriptor_for_service(&mut client, &mut pool, service_name).await?;
76        }
77
78        debug!(
79            "Created reflection client for endpoint with {} services",
80            pool.services().count()
81        );
82
83        Ok(Self { channel, pool })
84    }
85
86    /// Get file descriptor for a service
87    async fn get_file_descriptor_for_service(
88        client: &mut ServerReflectionClient<Channel>,
89        pool: &mut DescriptorPool,
90        service_name: &str,
91    ) -> Result<(), Status> {
92        trace!("Getting file descriptor for service: {}", service_name);
93
94        let request = tonic::Request::new(futures_util::stream::iter(vec![
95            ServerReflectionRequest {
96                host: "".to_string(),
97                message_request: Some(
98                    tonic_reflection::pb::v1::server_reflection_request::MessageRequest::FileContainingSymbol(
99                        service_name.to_string(),
100                    ),
101                ),
102            }
103        ]));
104
105        match client.server_reflection_info(request).await {
106            Ok(response) => {
107                let mut stream = response.into_inner();
108                while let Some(reply) = stream.message().await.map_err(|e| {
109                    error!("Failed to read reflection response: {}", e);
110                    Status::internal(format!("Failed to read reflection response: {}", e))
111                })? {
112                    if let Some(MessageResponse::FileDescriptorResponse(descriptor_response)) =
113                        reply.message_response
114                    {
115                        trace!(
116                            "Found {} file descriptors for service {}",
117                            descriptor_response.file_descriptor_proto.len(),
118                            service_name
119                        );
120                        for file_descriptor_proto in descriptor_response.file_descriptor_proto {
121                            match prost_types::FileDescriptorProto::decode(&*file_descriptor_proto)
122                            {
123                                Ok(file_descriptor) => {
124                                    if let Err(e) = pool.add_file_descriptor_proto(file_descriptor)
125                                    {
126                                        error!(
127                                            "Failed to register file descriptor for service {}: {}",
128                                            service_name, e
129                                        );
130                                        return Err(Status::internal(format!(
131                                            "Failed to register file descriptor for service {}: {}",
132                                            service_name, e
133                                        )));
134                                    } else {
135                                        debug!(
136                                            "Registered file descriptor for service: {}",
137                                            service_name
138                                        );
139                                    }
140                                }
141                                Err(e) => {
142                                    error!(
143                                        "Failed to decode file descriptor for service {}: {}",
144                                        service_name, e
145                                    );
146                                    return Err(Status::data_loss(format!(
147                                        "Failed to decode file descriptor for service {}: {}",
148                                        service_name, e
149                                    )));
150                                }
151                            }
152                        }
153                    }
154                }
155            }
156            Err(e) => {
157                error!("Failed to get file descriptor for service {}: {}", service_name, e);
158                return Err(Status::internal(format!(
159                    "Failed to get file descriptor for service {}: {}",
160                    service_name, e
161                )));
162            }
163        }
164
165        Ok(())
166    }
167
168    /// Get a service descriptor by name
169    pub fn get_service(&self, service_name: &str) -> Option<ServiceDescriptor> {
170        self.pool.get_service_by_name(service_name)
171    }
172
173    /// Get the underlying channel
174    pub fn channel(&self) -> Channel {
175        self.channel.clone()
176    }
177
178    /// Get a reference to the descriptor pool
179    pub fn pool(&self) -> &DescriptorPool {
180        &self.pool
181    }
182}
183
184#[cfg(test)]
185mod tests {
186
187    #[test]
188    fn test_module_compiles() {}
189}