mockforge_grpc/reflection/mock_proxy/
handlers.rs

1//! Request/response handling logic
2//!
3//! This module provides handlers for processing gRPC requests and responses,
4//! including mock response generation and request validation.
5
6use crate::reflection::mock_proxy::proxy::MockReflectionProxy;
7use prost_reflect::{DynamicMessage, MessageDescriptor};
8use std::sync::{Arc, Mutex};
9use tokio::sync::mpsc;
10use tokio_stream::wrappers::ReceiverStream;
11use tonic::{Request, Response, Status, Streaming};
12use tracing::{debug, info};
13
14impl MockReflectionProxy {
15    /// Handle a unary gRPC request
16    pub async fn handle_unary_request(
17        &self,
18        request: Request<DynamicMessage>,
19    ) -> Result<Response<DynamicMessage>, Status> {
20        let _guard = self.track_connection();
21        self.total_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
22        let (service_name, method_name) = self.extract_service_method_from_request(&request)?;
23
24        debug!("Handling unary request for {}/{}", service_name, method_name);
25
26        // Check if this should be mocked
27        if self.should_mock_service_method(&service_name, &method_name) {
28            return self.generate_mock_response(&service_name, &method_name, request).await;
29        }
30
31        // Forward to real service
32        self.forward_unary_request(request, &service_name, &method_name).await
33    }
34
35    /// Handle a server streaming gRPC request
36    pub async fn handle_server_streaming_request(
37        &self,
38        request: Request<DynamicMessage>,
39    ) -> Result<Response<ReceiverStream<Result<DynamicMessage, Status>>>, Status> {
40        let _guard = self.track_connection();
41        self.total_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
42        let (service_name, method_name) = self.extract_service_method_from_request(&request)?;
43
44        debug!("Handling server streaming request for {}/{}", service_name, method_name);
45
46        // Check if this should be mocked
47        if self.should_mock_service_method(&service_name, &method_name) {
48            return self.generate_mock_stream_response(&service_name, &method_name).await;
49        }
50
51        // Forward to real service
52        self.forward_server_streaming_request(request, &service_name, &method_name)
53            .await
54    }
55
56    /// Handle a client streaming gRPC request
57    pub async fn handle_client_streaming_request(
58        &self,
59        request: Request<Streaming<DynamicMessage>>,
60    ) -> Result<Response<DynamicMessage>, Status> {
61        let _guard = self.track_connection();
62        self.total_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
63        let (service_name, method_name) = self.extract_service_method_from_request(&request)?;
64
65        debug!("Handling client streaming request for {}/{}", service_name, method_name);
66
67        // Check if this should be mocked
68        if self.should_mock_service_method(&service_name, &method_name) {
69            return self
70                .generate_mock_client_stream_response(&service_name, &method_name, request)
71                .await;
72        }
73
74        // Forward to real service
75        self.forward_client_streaming_request(request, &service_name, &method_name)
76            .await
77    }
78
79    /// Handle a bidirectional streaming gRPC request
80    pub async fn handle_bidirectional_streaming_request(
81        &self,
82        request: Request<Streaming<DynamicMessage>>,
83    ) -> Result<Response<ReceiverStream<Result<DynamicMessage, Status>>>, Status> {
84        let _guard = self.track_connection();
85        self.total_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
86        let (service_name, method_name) = self.extract_service_method_from_request(&request)?;
87
88        debug!("Handling bidirectional streaming request for {}/{}", service_name, method_name);
89
90        // Check if this should be mocked
91        if self.should_mock_service_method(&service_name, &method_name) {
92            return self
93                .generate_mock_bidirectional_stream_response(&service_name, &method_name)
94                .await;
95        }
96
97        // Forward to real service
98        self.forward_bidirectional_streaming_request(request, &service_name, &method_name)
99            .await
100    }
101
102    /// Extract service and method names from a request
103    pub fn extract_service_method_from_request<T>(
104        &self,
105        request: &Request<T>,
106    ) -> Result<(String, String), Status> {
107        // Try to get path from metadata (gRPC path header)
108        let path = request
109            .metadata()
110            .get("path")
111            .or_else(|| request.metadata().get(":path"))
112            .and_then(|v| v.to_str().ok())
113            .ok_or_else(|| Status::invalid_argument("Missing path in request"))?;
114
115        if !path.starts_with('/') {
116            return Err(Status::invalid_argument("Invalid request path"));
117        }
118        let parts: Vec<&str> = path[1..].split('/').collect();
119        if parts.len() != 2 {
120            return Err(Status::invalid_argument(
121                "Invalid gRPC path format, expected /Service/Method",
122            ));
123        }
124        Ok((parts[0].to_string(), parts[1].to_string()))
125    }
126
127    /// Generate a mock response for a unary request
128    async fn generate_mock_response(
129        &self,
130        service_name: &str,
131        method_name: &str,
132        _request: Request<DynamicMessage>,
133    ) -> Result<Response<DynamicMessage>, Status> {
134        info!("Generating mock response for {}/{}", service_name, method_name);
135
136        // Get the method descriptor
137        let method_descriptor = self.cache().get_method(service_name, method_name).await?;
138
139        // Generate a mock response message
140        let response_message = self.generate_mock_message(method_descriptor.output())?;
141
142        let mut response = Response::new(response_message);
143
144        // Apply response postprocessing with body transformations
145        self.postprocess_dynamic_response(&mut response, service_name, method_name)
146            .await?;
147
148        Ok(response)
149    }
150
151    /// Generate a mock streaming response
152    async fn generate_mock_stream_response(
153        &self,
154        service_name: &str,
155        method_name: &str,
156    ) -> Result<Response<ReceiverStream<Result<DynamicMessage, Status>>>, Status> {
157        info!("Generating mock stream response for {}/{}", service_name, method_name);
158
159        // Get the method descriptor
160        let method_descriptor = self.cache().get_method(service_name, method_name).await?;
161
162        // Create a channel for streaming responses
163        let (tx, rx) = mpsc::channel(4);
164
165        // Generate mock response messages in a separate task
166        let smart_generator = self.smart_generator().clone();
167        let output_descriptor = method_descriptor.output();
168
169        tokio::spawn(async move {
170            for _i in 0..3 {
171                // Generate a mock response message
172                if let Ok(message) = Self::generate_mock_message_with_generator(
173                    &smart_generator,
174                    output_descriptor.clone(),
175                ) {
176                    if tx.send(Ok(message)).await.is_err() {
177                        break; // Receiver dropped
178                    }
179                }
180
181                // Small delay between messages
182                tokio::time::sleep(std::time::Duration::from_millis(100)).await;
183            }
184        });
185
186        let mut response = Response::new(ReceiverStream::new(rx));
187
188        // Apply response postprocessing for streaming responses
189        self.postprocess_streaming_dynamic_response(&mut response, service_name, method_name)
190            .await?;
191
192        Ok(response)
193    }
194
195    /// Generate a mock client streaming response
196    async fn generate_mock_client_stream_response(
197        &self,
198        service_name: &str,
199        method_name: &str,
200        _request: Request<Streaming<DynamicMessage>>,
201    ) -> Result<Response<DynamicMessage>, Status> {
202        info!("Generating mock client streaming response for {}/{}", service_name, method_name);
203
204        // Get the method descriptor
205        let method_descriptor = self.cache().get_method(service_name, method_name).await?;
206
207        // Generate a mock response message
208        let response_message = self.generate_mock_message(method_descriptor.output())?;
209
210        let mut response = Response::new(response_message);
211
212        // Apply response postprocessing with body transformations
213        self.postprocess_dynamic_response(&mut response, service_name, method_name)
214            .await?;
215
216        Ok(response)
217    }
218
219    /// Generate a mock bidirectional streaming response
220    async fn generate_mock_bidirectional_stream_response(
221        &self,
222        service_name: &str,
223        method_name: &str,
224    ) -> Result<Response<ReceiverStream<Result<DynamicMessage, Status>>>, Status> {
225        info!(
226            "Generating mock bidirectional stream response for {}/{}",
227            service_name, method_name
228        );
229
230        // Get the method descriptor
231        let method_descriptor = self.cache().get_method(service_name, method_name).await?;
232
233        // Create a channel for streaming responses
234        let (tx, rx) = mpsc::channel(4);
235
236        // Generate mock response messages in a separate task
237        let smart_generator = self.smart_generator().clone();
238        let output_descriptor = method_descriptor.output();
239
240        tokio::spawn(async move {
241            for _i in 0..5 {
242                // Generate a mock response message
243                if let Ok(message) = Self::generate_mock_message_with_generator(
244                    &smart_generator,
245                    output_descriptor.clone(),
246                ) {
247                    if tx.send(Ok(message)).await.is_err() {
248                        break; // Receiver dropped
249                    }
250                }
251
252                // Small delay between messages
253                tokio::time::sleep(std::time::Duration::from_millis(200)).await;
254            }
255        });
256
257        let mut response = Response::new(ReceiverStream::new(rx));
258
259        // Apply response postprocessing for streaming responses
260        self.postprocess_streaming_dynamic_response(&mut response, service_name, method_name)
261            .await?;
262
263        Ok(response)
264    }
265
266    /// Forward a unary request to the real service
267    async fn forward_unary_request(
268        &self,
269        _request: Request<DynamicMessage>,
270        _service_name: &str,
271        _method_name: &str,
272    ) -> Result<Response<DynamicMessage>, Status> {
273        if let Some(upstream) = &self.config.upstream_endpoint {
274            // Get channel to upstream
275            let _channel = self.connection_pool.get_channel(upstream).await.map_err(|e| {
276                Status::unavailable(format!("Failed to connect to upstream {}: {}", upstream, e))
277            })?;
278
279            // Generic gRPC forwarding requires generated client stubs for the specific service
280            // Since this is a reflection-based proxy for arbitrary services, forwarding is not supported
281            // To forward, the proxy would need to be configured with specific client implementations
282            Err(Status::unimplemented(
283                "Generic gRPC forwarding not supported - requires service-specific client stubs",
284            ))
285        } else {
286            Err(Status::unimplemented("Upstream endpoint not configured for request forwarding"))
287        }
288    }
289
290    /// Forward a server streaming request to the real service
291    async fn forward_server_streaming_request(
292        &self,
293        _request: Request<DynamicMessage>,
294        _service_name: &str,
295        _method_name: &str,
296    ) -> Result<Response<ReceiverStream<Result<DynamicMessage, Status>>>, Status> {
297        if let Some(upstream) = &self.config.upstream_endpoint {
298            // Get channel to upstream
299            let _channel = self.connection_pool.get_channel(upstream).await.map_err(|e| {
300                Status::unavailable(format!("Failed to connect to upstream {}: {}", upstream, e))
301            })?;
302
303            // Generic gRPC forwarding requires generated client stubs for the specific service
304            // Since this is a reflection-based proxy for arbitrary services, forwarding is not supported
305            // To forward, the proxy would need to be configured with specific client implementations
306            Err(Status::unimplemented(
307                "Generic gRPC forwarding not supported - requires service-specific client stubs",
308            ))
309        } else {
310            Err(Status::unimplemented("Upstream endpoint not configured for request forwarding"))
311        }
312    }
313
314    /// Forward a client streaming request to the real service
315    async fn forward_client_streaming_request(
316        &self,
317        _request: Request<Streaming<DynamicMessage>>,
318        _service_name: &str,
319        _method_name: &str,
320    ) -> Result<Response<DynamicMessage>, Status> {
321        if let Some(upstream) = &self.config.upstream_endpoint {
322            // Get channel to upstream
323            let _channel = self.connection_pool.get_channel(upstream).await.map_err(|e| {
324                Status::unavailable(format!("Failed to connect to upstream {}: {}", upstream, e))
325            })?;
326
327            // Generic gRPC forwarding requires generated client stubs for the specific service
328            // Since this is a reflection-based proxy for arbitrary services, forwarding is not supported
329            // To forward, the proxy would need to be configured with specific client implementations
330            Err(Status::unimplemented(
331                "Generic gRPC forwarding not supported - requires service-specific client stubs",
332            ))
333        } else {
334            Err(Status::unimplemented("Upstream endpoint not configured for request forwarding"))
335        }
336    }
337
338    /// Forward a bidirectional streaming request to the real service
339    async fn forward_bidirectional_streaming_request(
340        &self,
341        _request: Request<Streaming<DynamicMessage>>,
342        _service_name: &str,
343        _method_name: &str,
344    ) -> Result<Response<ReceiverStream<Result<DynamicMessage, Status>>>, Status> {
345        if let Some(upstream) = &self.config.upstream_endpoint {
346            // Get channel to upstream
347            let _channel = self.connection_pool.get_channel(upstream).await.map_err(|e| {
348                Status::unavailable(format!("Failed to connect to upstream {}: {}", upstream, e))
349            })?;
350
351            // Generic gRPC forwarding requires generated client stubs for the specific service
352            // Since this is a reflection-based proxy for arbitrary services, forwarding is not supported
353            // To forward, the proxy would need to be configured with specific client implementations
354            Err(Status::unimplemented(
355                "Generic gRPC forwarding not supported - requires service-specific client stubs",
356            ))
357        } else {
358            Err(Status::unimplemented("Upstream endpoint not configured for request forwarding"))
359        }
360    }
361
362    /// Generate a mock message using the smart generator
363    fn generate_mock_message(
364        &self,
365        descriptor: MessageDescriptor,
366    ) -> Result<DynamicMessage, Status> {
367        let mut smart_generator = self
368            .smart_generator()
369            .lock()
370            .map_err(|_| Status::internal("Failed to acquire lock on smart generator"))?;
371
372        Ok(smart_generator.generate_message(&descriptor))
373    }
374
375    /// Generate a mock message with a specific generator
376    fn generate_mock_message_with_generator(
377        smart_generator: &Arc<Mutex<crate::reflection::smart_mock_generator::SmartMockGenerator>>,
378        descriptor: MessageDescriptor,
379    ) -> Result<DynamicMessage, Status> {
380        let mut smart_generator = smart_generator
381            .lock()
382            .map_err(|_| Status::internal("Failed to acquire lock on smart generator"))?;
383
384        Ok(smart_generator.generate_message(&descriptor))
385    }
386}
387
388#[cfg(test)]
389mod tests {
390
391    #[test]
392    fn test_module_compiles() {}
393}