Skip to main content

mockforge_grpc/reflection/mock_proxy/
handlers.rs

1//! Request/response handling logic
2//!
3//! This module provides handlers for processing gRPC requests and responses,
4//! including mock response generation and request validation.
5
6use crate::reflection::mock_proxy::proxy::MockReflectionProxy;
7use prost_reflect::{DynamicMessage, MessageDescriptor};
8use std::sync::{Arc, Mutex};
9use tokio::sync::mpsc;
10use tokio_stream::wrappers::ReceiverStream;
11use tonic::{Request, Response, Status, Streaming};
12use tracing::{debug, info};
13
14impl MockReflectionProxy {
15    /// Handle a unary gRPC request
16    pub async fn handle_unary_request(
17        &self,
18        request: Request<DynamicMessage>,
19    ) -> Result<Response<DynamicMessage>, Status> {
20        let _guard = self.track_connection();
21        self.total_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
22        let (service_name, method_name) = self.extract_service_method_from_request(&request)?;
23
24        debug!("Handling unary request for {}/{}", service_name, method_name);
25
26        // Check if this should be mocked
27        if self.should_mock_service_method(&service_name, &method_name) {
28            return self.generate_mock_response(&service_name, &method_name, request).await;
29        }
30
31        // Forward to real service
32        self.forward_unary_request(request, &service_name, &method_name).await
33    }
34
35    /// Handle a server streaming gRPC request
36    pub async fn handle_server_streaming_request(
37        &self,
38        request: Request<DynamicMessage>,
39    ) -> Result<Response<ReceiverStream<Result<DynamicMessage, Status>>>, Status> {
40        let _guard = self.track_connection();
41        self.total_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
42        let (service_name, method_name) = self.extract_service_method_from_request(&request)?;
43
44        debug!("Handling server streaming request for {}/{}", service_name, method_name);
45
46        // Check if this should be mocked
47        if self.should_mock_service_method(&service_name, &method_name) {
48            return self.generate_mock_stream_response(&service_name, &method_name).await;
49        }
50
51        // Forward to real service
52        self.forward_server_streaming_request(request, &service_name, &method_name)
53            .await
54    }
55
56    /// Handle a client streaming gRPC request
57    pub async fn handle_client_streaming_request(
58        &self,
59        request: Request<Streaming<DynamicMessage>>,
60    ) -> Result<Response<DynamicMessage>, Status> {
61        let _guard = self.track_connection();
62        self.total_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
63        let (service_name, method_name) = self.extract_service_method_from_request(&request)?;
64
65        debug!("Handling client streaming request for {}/{}", service_name, method_name);
66
67        // Check if this should be mocked
68        if self.should_mock_service_method(&service_name, &method_name) {
69            return self
70                .generate_mock_client_stream_response(&service_name, &method_name, request)
71                .await;
72        }
73
74        // Forward to real service
75        self.forward_client_streaming_request(request, &service_name, &method_name)
76            .await
77    }
78
79    /// Handle a bidirectional streaming gRPC request
80    pub async fn handle_bidirectional_streaming_request(
81        &self,
82        request: Request<Streaming<DynamicMessage>>,
83    ) -> Result<Response<ReceiverStream<Result<DynamicMessage, Status>>>, Status> {
84        let _guard = self.track_connection();
85        self.total_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
86        let (service_name, method_name) = self.extract_service_method_from_request(&request)?;
87
88        debug!("Handling bidirectional streaming request for {}/{}", service_name, method_name);
89
90        // Check if this should be mocked
91        if self.should_mock_service_method(&service_name, &method_name) {
92            return self
93                .generate_mock_bidirectional_stream_response(&service_name, &method_name)
94                .await;
95        }
96
97        // Forward to real service
98        self.forward_bidirectional_streaming_request(request, &service_name, &method_name)
99            .await
100    }
101
102    /// Extract service and method names from a request
103    pub fn extract_service_method_from_request<T>(
104        &self,
105        request: &Request<T>,
106    ) -> Result<(String, String), Status> {
107        // Try to get path from metadata (gRPC path header)
108        let path = request
109            .metadata()
110            .get("path")
111            .or_else(|| request.metadata().get(":path"))
112            .and_then(|v| v.to_str().ok())
113            .ok_or_else(|| Status::invalid_argument("Missing path in request"))?;
114
115        if !path.starts_with('/') {
116            return Err(Status::invalid_argument("Invalid request path"));
117        }
118        let parts: Vec<&str> = path[1..].split('/').collect();
119        if parts.len() != 2 {
120            return Err(Status::invalid_argument(
121                "Invalid gRPC path format, expected /Service/Method",
122            ));
123        }
124        Ok((parts[0].to_string(), parts[1].to_string()))
125    }
126
127    /// Generate a mock response for a unary request
128    async fn generate_mock_response(
129        &self,
130        service_name: &str,
131        method_name: &str,
132        _request: Request<DynamicMessage>,
133    ) -> Result<Response<DynamicMessage>, Status> {
134        info!("Generating mock response for {}/{}", service_name, method_name);
135
136        // Get the method descriptor
137        let method_descriptor = self.cache().get_method(service_name, method_name).await?;
138
139        // Generate a mock response message
140        let response_message = self.generate_mock_message(method_descriptor.output())?;
141
142        let mut response = Response::new(response_message);
143
144        // Apply response postprocessing with body transformations
145        self.postprocess_dynamic_response(&mut response, service_name, method_name)
146            .await?;
147
148        Ok(response)
149    }
150
151    /// Generate a mock streaming response
152    async fn generate_mock_stream_response(
153        &self,
154        service_name: &str,
155        method_name: &str,
156    ) -> Result<Response<ReceiverStream<Result<DynamicMessage, Status>>>, Status> {
157        info!("Generating mock stream response for {}/{}", service_name, method_name);
158
159        // Get the method descriptor
160        let method_descriptor = self.cache().get_method(service_name, method_name).await?;
161
162        // Create a channel for streaming responses
163        let (tx, rx) = mpsc::channel(4);
164
165        // Generate mock response messages in a separate task
166        let smart_generator = self.smart_generator().clone();
167        let output_descriptor = method_descriptor.output();
168
169        tokio::spawn(async move {
170            for _i in 0..3 {
171                // Generate a mock response message
172                if let Ok(message) = Self::generate_mock_message_with_generator(
173                    &smart_generator,
174                    output_descriptor.clone(),
175                ) {
176                    if tx.send(Ok(message)).await.is_err() {
177                        break; // Receiver dropped
178                    }
179                }
180
181                // Small delay between messages
182                tokio::time::sleep(std::time::Duration::from_millis(100)).await;
183            }
184        });
185
186        let mut response = Response::new(ReceiverStream::new(rx));
187
188        // Apply response postprocessing for streaming responses
189        self.postprocess_streaming_dynamic_response(&mut response, service_name, method_name)
190            .await?;
191
192        Ok(response)
193    }
194
195    /// Generate a mock client streaming response
196    async fn generate_mock_client_stream_response(
197        &self,
198        service_name: &str,
199        method_name: &str,
200        _request: Request<Streaming<DynamicMessage>>,
201    ) -> Result<Response<DynamicMessage>, Status> {
202        info!("Generating mock client streaming response for {}/{}", service_name, method_name);
203
204        // Get the method descriptor
205        let method_descriptor = self.cache().get_method(service_name, method_name).await?;
206
207        // Generate a mock response message
208        let response_message = self.generate_mock_message(method_descriptor.output())?;
209
210        let mut response = Response::new(response_message);
211
212        // Apply response postprocessing with body transformations
213        self.postprocess_dynamic_response(&mut response, service_name, method_name)
214            .await?;
215
216        Ok(response)
217    }
218
219    /// Generate a mock bidirectional streaming response
220    async fn generate_mock_bidirectional_stream_response(
221        &self,
222        service_name: &str,
223        method_name: &str,
224    ) -> Result<Response<ReceiverStream<Result<DynamicMessage, Status>>>, Status> {
225        info!(
226            "Generating mock bidirectional stream response for {}/{}",
227            service_name, method_name
228        );
229
230        // Get the method descriptor
231        let method_descriptor = self.cache().get_method(service_name, method_name).await?;
232
233        // Create a channel for streaming responses
234        let (tx, rx) = mpsc::channel(4);
235
236        // Generate mock response messages in a separate task
237        let smart_generator = self.smart_generator().clone();
238        let output_descriptor = method_descriptor.output();
239
240        tokio::spawn(async move {
241            for _i in 0..5 {
242                // Generate a mock response message
243                if let Ok(message) = Self::generate_mock_message_with_generator(
244                    &smart_generator,
245                    output_descriptor.clone(),
246                ) {
247                    if tx.send(Ok(message)).await.is_err() {
248                        break; // Receiver dropped
249                    }
250                }
251
252                // Small delay between messages
253                tokio::time::sleep(std::time::Duration::from_millis(200)).await;
254            }
255        });
256
257        let mut response = Response::new(ReceiverStream::new(rx));
258
259        // Apply response postprocessing for streaming responses
260        self.postprocess_streaming_dynamic_response(&mut response, service_name, method_name)
261            .await?;
262
263        Ok(response)
264    }
265
266    /// Forward a unary request to the real service
267    async fn forward_unary_request(
268        &self,
269        request: Request<DynamicMessage>,
270        service_name: &str,
271        method_name: &str,
272    ) -> Result<Response<DynamicMessage>, Status> {
273        if let Some(upstream) = &self.config.upstream_endpoint {
274            // Get channel to upstream
275            let _channel = self.connection_pool.get_channel(upstream).await.map_err(|e| {
276                Status::unavailable(format!("Failed to connect to upstream {}: {}", upstream, e))
277            })?;
278
279            debug!(
280                "Generic upstream forwarding is unavailable for {}/{}, falling back to local mock response",
281                service_name, method_name
282            );
283            self.generate_mock_response(service_name, method_name, request).await
284        } else {
285            debug!(
286                "No upstream endpoint configured for {}/{}, using local mock fallback",
287                service_name, method_name
288            );
289            self.generate_mock_response(service_name, method_name, request).await
290        }
291    }
292
293    /// Forward a server streaming request to the real service
294    async fn forward_server_streaming_request(
295        &self,
296        _request: Request<DynamicMessage>,
297        service_name: &str,
298        method_name: &str,
299    ) -> Result<Response<ReceiverStream<Result<DynamicMessage, Status>>>, Status> {
300        if let Some(upstream) = &self.config.upstream_endpoint {
301            // Get channel to upstream
302            let _channel = self.connection_pool.get_channel(upstream).await.map_err(|e| {
303                Status::unavailable(format!("Failed to connect to upstream {}: {}", upstream, e))
304            })?;
305
306            debug!(
307                "Generic upstream streaming forwarding is unavailable for {}/{}, falling back to local mock stream",
308                service_name, method_name
309            );
310            self.generate_mock_stream_response(service_name, method_name).await
311        } else {
312            debug!(
313                "No upstream endpoint configured for {}/{}, using local mock stream fallback",
314                service_name, method_name
315            );
316            self.generate_mock_stream_response(service_name, method_name).await
317        }
318    }
319
320    /// Forward a client streaming request to the real service
321    async fn forward_client_streaming_request(
322        &self,
323        request: Request<Streaming<DynamicMessage>>,
324        service_name: &str,
325        method_name: &str,
326    ) -> Result<Response<DynamicMessage>, Status> {
327        if let Some(upstream) = &self.config.upstream_endpoint {
328            // Get channel to upstream
329            let _channel = self.connection_pool.get_channel(upstream).await.map_err(|e| {
330                Status::unavailable(format!("Failed to connect to upstream {}: {}", upstream, e))
331            })?;
332
333            debug!(
334                "Generic upstream client-stream forwarding is unavailable for {}/{}, falling back to local mock response",
335                service_name, method_name
336            );
337            self.generate_mock_client_stream_response(service_name, method_name, request)
338                .await
339        } else {
340            debug!(
341                "No upstream endpoint configured for {}/{}, using local mock client-stream fallback",
342                service_name, method_name
343            );
344            self.generate_mock_client_stream_response(service_name, method_name, request)
345                .await
346        }
347    }
348
349    /// Forward a bidirectional streaming request to the real service
350    async fn forward_bidirectional_streaming_request(
351        &self,
352        request: Request<Streaming<DynamicMessage>>,
353        service_name: &str,
354        method_name: &str,
355    ) -> Result<Response<ReceiverStream<Result<DynamicMessage, Status>>>, Status> {
356        if let Some(upstream) = &self.config.upstream_endpoint {
357            // Get channel to upstream
358            let _channel = self.connection_pool.get_channel(upstream).await.map_err(|e| {
359                Status::unavailable(format!("Failed to connect to upstream {}: {}", upstream, e))
360            })?;
361
362            debug!(
363                "Generic upstream bidi-stream forwarding is unavailable for {}/{}, falling back to local mock stream",
364                service_name, method_name
365            );
366            let _ = request;
367            self.generate_mock_bidirectional_stream_response(service_name, method_name)
368                .await
369        } else {
370            debug!(
371                "No upstream endpoint configured for {}/{}, using local mock bidi-stream fallback",
372                service_name, method_name
373            );
374            let _ = request;
375            self.generate_mock_bidirectional_stream_response(service_name, method_name)
376                .await
377        }
378    }
379
380    /// Generate a mock message using the smart generator
381    fn generate_mock_message(
382        &self,
383        descriptor: MessageDescriptor,
384    ) -> Result<DynamicMessage, Status> {
385        let mut smart_generator = self
386            .smart_generator()
387            .lock()
388            .map_err(|_| Status::internal("Failed to acquire lock on smart generator"))?;
389
390        Ok(smart_generator.generate_message(&descriptor))
391    }
392
393    /// Generate a mock message with a specific generator
394    fn generate_mock_message_with_generator(
395        smart_generator: &Arc<Mutex<crate::reflection::smart_mock_generator::SmartMockGenerator>>,
396        descriptor: MessageDescriptor,
397    ) -> Result<DynamicMessage, Status> {
398        let mut smart_generator = smart_generator
399            .lock()
400            .map_err(|_| Status::internal("Failed to acquire lock on smart generator"))?;
401
402        Ok(smart_generator.generate_message(&descriptor))
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    #[test]
409    fn test_module_compiles() {
410        // Verify this module's types and imports are valid
411    }
412}