mockforge_grpc/reflection/mock_proxy/
handlers.rs1use crate::reflection::mock_proxy::proxy::MockReflectionProxy;
7use prost_reflect::{DynamicMessage, MessageDescriptor};
8use std::sync::{Arc, Mutex};
9use tokio::sync::mpsc;
10use tokio_stream::wrappers::ReceiverStream;
11use tonic::{Request, Response, Status, Streaming};
12use tracing::{debug, info};
13
14impl MockReflectionProxy {
15 pub async fn handle_unary_request(
17 &self,
18 request: Request<DynamicMessage>,
19 ) -> Result<Response<DynamicMessage>, Status> {
20 let _guard = self.track_connection();
21 self.total_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
22 let (service_name, method_name) = self.extract_service_method_from_request(&request)?;
23
24 debug!("Handling unary request for {}/{}", service_name, method_name);
25
26 if self.should_mock_service_method(&service_name, &method_name) {
28 return self.generate_mock_response(&service_name, &method_name, request).await;
29 }
30
31 self.forward_unary_request(request, &service_name, &method_name).await
33 }
34
35 pub async fn handle_server_streaming_request(
37 &self,
38 request: Request<DynamicMessage>,
39 ) -> Result<Response<ReceiverStream<Result<DynamicMessage, Status>>>, Status> {
40 let _guard = self.track_connection();
41 self.total_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
42 let (service_name, method_name) = self.extract_service_method_from_request(&request)?;
43
44 debug!("Handling server streaming request for {}/{}", service_name, method_name);
45
46 if self.should_mock_service_method(&service_name, &method_name) {
48 return self.generate_mock_stream_response(&service_name, &method_name).await;
49 }
50
51 self.forward_server_streaming_request(request, &service_name, &method_name)
53 .await
54 }
55
56 pub async fn handle_client_streaming_request(
58 &self,
59 request: Request<Streaming<DynamicMessage>>,
60 ) -> Result<Response<DynamicMessage>, Status> {
61 let _guard = self.track_connection();
62 self.total_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
63 let (service_name, method_name) = self.extract_service_method_from_request(&request)?;
64
65 debug!("Handling client streaming request for {}/{}", service_name, method_name);
66
67 if self.should_mock_service_method(&service_name, &method_name) {
69 return self
70 .generate_mock_client_stream_response(&service_name, &method_name, request)
71 .await;
72 }
73
74 self.forward_client_streaming_request(request, &service_name, &method_name)
76 .await
77 }
78
79 pub async fn handle_bidirectional_streaming_request(
81 &self,
82 request: Request<Streaming<DynamicMessage>>,
83 ) -> Result<Response<ReceiverStream<Result<DynamicMessage, Status>>>, Status> {
84 let _guard = self.track_connection();
85 self.total_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
86 let (service_name, method_name) = self.extract_service_method_from_request(&request)?;
87
88 debug!("Handling bidirectional streaming request for {}/{}", service_name, method_name);
89
90 if self.should_mock_service_method(&service_name, &method_name) {
92 return self
93 .generate_mock_bidirectional_stream_response(&service_name, &method_name)
94 .await;
95 }
96
97 self.forward_bidirectional_streaming_request(request, &service_name, &method_name)
99 .await
100 }
101
102 pub fn extract_service_method_from_request<T>(
104 &self,
105 request: &Request<T>,
106 ) -> Result<(String, String), Status> {
107 let path = request
109 .metadata()
110 .get("path")
111 .or_else(|| request.metadata().get(":path"))
112 .and_then(|v| v.to_str().ok())
113 .ok_or_else(|| Status::invalid_argument("Missing path in request"))?;
114
115 if !path.starts_with('/') {
116 return Err(Status::invalid_argument("Invalid request path"));
117 }
118 let parts: Vec<&str> = path[1..].split('/').collect();
119 if parts.len() != 2 {
120 return Err(Status::invalid_argument(
121 "Invalid gRPC path format, expected /Service/Method",
122 ));
123 }
124 Ok((parts[0].to_string(), parts[1].to_string()))
125 }
126
127 async fn generate_mock_response(
129 &self,
130 service_name: &str,
131 method_name: &str,
132 _request: Request<DynamicMessage>,
133 ) -> Result<Response<DynamicMessage>, Status> {
134 info!("Generating mock response for {}/{}", service_name, method_name);
135
136 let method_descriptor = self.cache().get_method(service_name, method_name).await?;
138
139 let response_message = self.generate_mock_message(method_descriptor.output())?;
141
142 let mut response = Response::new(response_message);
143
144 self.postprocess_dynamic_response(&mut response, service_name, method_name)
146 .await?;
147
148 Ok(response)
149 }
150
151 async fn generate_mock_stream_response(
153 &self,
154 service_name: &str,
155 method_name: &str,
156 ) -> Result<Response<ReceiverStream<Result<DynamicMessage, Status>>>, Status> {
157 info!("Generating mock stream response for {}/{}", service_name, method_name);
158
159 let method_descriptor = self.cache().get_method(service_name, method_name).await?;
161
162 let (tx, rx) = mpsc::channel(4);
164
165 let smart_generator = self.smart_generator().clone();
167 let output_descriptor = method_descriptor.output();
168
169 tokio::spawn(async move {
170 for _i in 0..3 {
171 if let Ok(message) = Self::generate_mock_message_with_generator(
173 &smart_generator,
174 output_descriptor.clone(),
175 ) {
176 if tx.send(Ok(message)).await.is_err() {
177 break; }
179 }
180
181 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
183 }
184 });
185
186 let mut response = Response::new(ReceiverStream::new(rx));
187
188 self.postprocess_streaming_dynamic_response(&mut response, service_name, method_name)
190 .await?;
191
192 Ok(response)
193 }
194
195 async fn generate_mock_client_stream_response(
197 &self,
198 service_name: &str,
199 method_name: &str,
200 _request: Request<Streaming<DynamicMessage>>,
201 ) -> Result<Response<DynamicMessage>, Status> {
202 info!("Generating mock client streaming response for {}/{}", service_name, method_name);
203
204 let method_descriptor = self.cache().get_method(service_name, method_name).await?;
206
207 let response_message = self.generate_mock_message(method_descriptor.output())?;
209
210 let mut response = Response::new(response_message);
211
212 self.postprocess_dynamic_response(&mut response, service_name, method_name)
214 .await?;
215
216 Ok(response)
217 }
218
219 async fn generate_mock_bidirectional_stream_response(
221 &self,
222 service_name: &str,
223 method_name: &str,
224 ) -> Result<Response<ReceiverStream<Result<DynamicMessage, Status>>>, Status> {
225 info!(
226 "Generating mock bidirectional stream response for {}/{}",
227 service_name, method_name
228 );
229
230 let method_descriptor = self.cache().get_method(service_name, method_name).await?;
232
233 let (tx, rx) = mpsc::channel(4);
235
236 let smart_generator = self.smart_generator().clone();
238 let output_descriptor = method_descriptor.output();
239
240 tokio::spawn(async move {
241 for _i in 0..5 {
242 if let Ok(message) = Self::generate_mock_message_with_generator(
244 &smart_generator,
245 output_descriptor.clone(),
246 ) {
247 if tx.send(Ok(message)).await.is_err() {
248 break; }
250 }
251
252 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
254 }
255 });
256
257 let mut response = Response::new(ReceiverStream::new(rx));
258
259 self.postprocess_streaming_dynamic_response(&mut response, service_name, method_name)
261 .await?;
262
263 Ok(response)
264 }
265
266 async fn forward_unary_request(
268 &self,
269 _request: Request<DynamicMessage>,
270 _service_name: &str,
271 _method_name: &str,
272 ) -> Result<Response<DynamicMessage>, Status> {
273 if let Some(upstream) = &self.config.upstream_endpoint {
274 let _channel = self.connection_pool.get_channel(upstream).await.map_err(|e| {
276 Status::unavailable(format!("Failed to connect to upstream {}: {}", upstream, e))
277 })?;
278
279 Err(Status::unimplemented(
283 "Generic gRPC forwarding not supported - requires service-specific client stubs",
284 ))
285 } else {
286 Err(Status::unimplemented("Upstream endpoint not configured for request forwarding"))
287 }
288 }
289
290 async fn forward_server_streaming_request(
292 &self,
293 _request: Request<DynamicMessage>,
294 _service_name: &str,
295 _method_name: &str,
296 ) -> Result<Response<ReceiverStream<Result<DynamicMessage, Status>>>, Status> {
297 if let Some(upstream) = &self.config.upstream_endpoint {
298 let _channel = self.connection_pool.get_channel(upstream).await.map_err(|e| {
300 Status::unavailable(format!("Failed to connect to upstream {}: {}", upstream, e))
301 })?;
302
303 Err(Status::unimplemented(
307 "Generic gRPC forwarding not supported - requires service-specific client stubs",
308 ))
309 } else {
310 Err(Status::unimplemented("Upstream endpoint not configured for request forwarding"))
311 }
312 }
313
314 async fn forward_client_streaming_request(
316 &self,
317 _request: Request<Streaming<DynamicMessage>>,
318 _service_name: &str,
319 _method_name: &str,
320 ) -> Result<Response<DynamicMessage>, Status> {
321 if let Some(upstream) = &self.config.upstream_endpoint {
322 let _channel = self.connection_pool.get_channel(upstream).await.map_err(|e| {
324 Status::unavailable(format!("Failed to connect to upstream {}: {}", upstream, e))
325 })?;
326
327 Err(Status::unimplemented(
331 "Generic gRPC forwarding not supported - requires service-specific client stubs",
332 ))
333 } else {
334 Err(Status::unimplemented("Upstream endpoint not configured for request forwarding"))
335 }
336 }
337
338 async fn forward_bidirectional_streaming_request(
340 &self,
341 _request: Request<Streaming<DynamicMessage>>,
342 _service_name: &str,
343 _method_name: &str,
344 ) -> Result<Response<ReceiverStream<Result<DynamicMessage, Status>>>, Status> {
345 if let Some(upstream) = &self.config.upstream_endpoint {
346 let _channel = self.connection_pool.get_channel(upstream).await.map_err(|e| {
348 Status::unavailable(format!("Failed to connect to upstream {}: {}", upstream, e))
349 })?;
350
351 Err(Status::unimplemented(
355 "Generic gRPC forwarding not supported - requires service-specific client stubs",
356 ))
357 } else {
358 Err(Status::unimplemented("Upstream endpoint not configured for request forwarding"))
359 }
360 }
361
362 fn generate_mock_message(
364 &self,
365 descriptor: MessageDescriptor,
366 ) -> Result<DynamicMessage, Status> {
367 let mut smart_generator = self
368 .smart_generator()
369 .lock()
370 .map_err(|_| Status::internal("Failed to acquire lock on smart generator"))?;
371
372 Ok(smart_generator.generate_message(&descriptor))
373 }
374
375 fn generate_mock_message_with_generator(
377 smart_generator: &Arc<Mutex<crate::reflection::smart_mock_generator::SmartMockGenerator>>,
378 descriptor: MessageDescriptor,
379 ) -> Result<DynamicMessage, Status> {
380 let mut smart_generator = smart_generator
381 .lock()
382 .map_err(|_| Status::internal("Failed to acquire lock on smart generator"))?;
383
384 Ok(smart_generator.generate_message(&descriptor))
385 }
386}
387
388#[cfg(test)]
389mod tests {
390
391 #[test]
392 fn test_module_compiles() {}
393}