1use crate::reflection::mock_proxy::proxy::MockReflectionProxy;
7use prost_reflect::{DynamicMessage, MessageDescriptor};
8use std::sync::{Arc, Mutex};
9use tokio::sync::mpsc;
10use tokio_stream::wrappers::ReceiverStream;
11use tonic::{Request, Response, Status, Streaming};
12use tracing::{debug, info};
13
14impl MockReflectionProxy {
15 pub async fn handle_unary_request(
17 &self,
18 request: Request<DynamicMessage>,
19 ) -> Result<Response<DynamicMessage>, Status> {
20 let _guard = self.track_connection();
21 self.total_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
22 let (service_name, method_name) = self.extract_service_method_from_request(&request)?;
23
24 debug!("Handling unary request for {}/{}", service_name, method_name);
25
26 if self.should_mock_service_method(&service_name, &method_name) {
28 return self.generate_mock_response(&service_name, &method_name, request).await;
29 }
30
31 self.forward_unary_request(request, &service_name, &method_name).await
33 }
34
35 pub async fn handle_server_streaming_request(
37 &self,
38 request: Request<DynamicMessage>,
39 ) -> Result<Response<ReceiverStream<Result<DynamicMessage, Status>>>, Status> {
40 let _guard = self.track_connection();
41 self.total_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
42 let (service_name, method_name) = self.extract_service_method_from_request(&request)?;
43
44 debug!("Handling server streaming request for {}/{}", service_name, method_name);
45
46 if self.should_mock_service_method(&service_name, &method_name) {
48 return self.generate_mock_stream_response(&service_name, &method_name).await;
49 }
50
51 self.forward_server_streaming_request(request, &service_name, &method_name)
53 .await
54 }
55
56 pub async fn handle_client_streaming_request(
58 &self,
59 request: Request<Streaming<DynamicMessage>>,
60 ) -> Result<Response<DynamicMessage>, Status> {
61 let _guard = self.track_connection();
62 self.total_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
63 let (service_name, method_name) = self.extract_service_method_from_request(&request)?;
64
65 debug!("Handling client streaming request for {}/{}", service_name, method_name);
66
67 if self.should_mock_service_method(&service_name, &method_name) {
69 return self
70 .generate_mock_client_stream_response(&service_name, &method_name, request)
71 .await;
72 }
73
74 self.forward_client_streaming_request(request, &service_name, &method_name)
76 .await
77 }
78
79 pub async fn handle_bidirectional_streaming_request(
81 &self,
82 request: Request<Streaming<DynamicMessage>>,
83 ) -> Result<Response<ReceiverStream<Result<DynamicMessage, Status>>>, Status> {
84 let _guard = self.track_connection();
85 self.total_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
86 let (service_name, method_name) = self.extract_service_method_from_request(&request)?;
87
88 debug!("Handling bidirectional streaming request for {}/{}", service_name, method_name);
89
90 if self.should_mock_service_method(&service_name, &method_name) {
92 return self
93 .generate_mock_bidirectional_stream_response(&service_name, &method_name)
94 .await;
95 }
96
97 self.forward_bidirectional_streaming_request(request, &service_name, &method_name)
99 .await
100 }
101
102 pub fn extract_service_method_from_request<T>(
104 &self,
105 request: &Request<T>,
106 ) -> Result<(String, String), Status> {
107 let path = request
109 .metadata()
110 .get("path")
111 .or_else(|| request.metadata().get(":path"))
112 .and_then(|v| v.to_str().ok())
113 .ok_or_else(|| Status::invalid_argument("Missing path in request"))?;
114
115 if !path.starts_with('/') {
116 return Err(Status::invalid_argument("Invalid request path"));
117 }
118 let parts: Vec<&str> = path[1..].split('/').collect();
119 if parts.len() != 2 {
120 return Err(Status::invalid_argument(
121 "Invalid gRPC path format, expected /Service/Method",
122 ));
123 }
124 Ok((parts[0].to_string(), parts[1].to_string()))
125 }
126
127 async fn generate_mock_response(
129 &self,
130 service_name: &str,
131 method_name: &str,
132 _request: Request<DynamicMessage>,
133 ) -> Result<Response<DynamicMessage>, Status> {
134 info!("Generating mock response for {}/{}", service_name, method_name);
135
136 let method_descriptor = self.cache().get_method(service_name, method_name).await?;
138
139 let response_message = self.generate_mock_message(method_descriptor.output())?;
141
142 let mut response = Response::new(response_message);
143
144 self.postprocess_dynamic_response(&mut response, service_name, method_name)
146 .await?;
147
148 Ok(response)
149 }
150
151 async fn generate_mock_stream_response(
153 &self,
154 service_name: &str,
155 method_name: &str,
156 ) -> Result<Response<ReceiverStream<Result<DynamicMessage, Status>>>, Status> {
157 info!("Generating mock stream response for {}/{}", service_name, method_name);
158
159 let method_descriptor = self.cache().get_method(service_name, method_name).await?;
161
162 let (tx, rx) = mpsc::channel(4);
164
165 let smart_generator = self.smart_generator().clone();
167 let output_descriptor = method_descriptor.output();
168
169 tokio::spawn(async move {
170 for _i in 0..3 {
171 if let Ok(message) = Self::generate_mock_message_with_generator(
173 &smart_generator,
174 output_descriptor.clone(),
175 ) {
176 if tx.send(Ok(message)).await.is_err() {
177 break; }
179 }
180
181 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
183 }
184 });
185
186 let mut response = Response::new(ReceiverStream::new(rx));
187
188 self.postprocess_streaming_dynamic_response(&mut response, service_name, method_name)
190 .await?;
191
192 Ok(response)
193 }
194
195 async fn generate_mock_client_stream_response(
197 &self,
198 service_name: &str,
199 method_name: &str,
200 _request: Request<Streaming<DynamicMessage>>,
201 ) -> Result<Response<DynamicMessage>, Status> {
202 info!("Generating mock client streaming response for {}/{}", service_name, method_name);
203
204 let method_descriptor = self.cache().get_method(service_name, method_name).await?;
206
207 let response_message = self.generate_mock_message(method_descriptor.output())?;
209
210 let mut response = Response::new(response_message);
211
212 self.postprocess_dynamic_response(&mut response, service_name, method_name)
214 .await?;
215
216 Ok(response)
217 }
218
219 async fn generate_mock_bidirectional_stream_response(
221 &self,
222 service_name: &str,
223 method_name: &str,
224 ) -> Result<Response<ReceiverStream<Result<DynamicMessage, Status>>>, Status> {
225 info!(
226 "Generating mock bidirectional stream response for {}/{}",
227 service_name, method_name
228 );
229
230 let method_descriptor = self.cache().get_method(service_name, method_name).await?;
232
233 let (tx, rx) = mpsc::channel(4);
235
236 let smart_generator = self.smart_generator().clone();
238 let output_descriptor = method_descriptor.output();
239
240 tokio::spawn(async move {
241 for _i in 0..5 {
242 if let Ok(message) = Self::generate_mock_message_with_generator(
244 &smart_generator,
245 output_descriptor.clone(),
246 ) {
247 if tx.send(Ok(message)).await.is_err() {
248 break; }
250 }
251
252 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
254 }
255 });
256
257 let mut response = Response::new(ReceiverStream::new(rx));
258
259 self.postprocess_streaming_dynamic_response(&mut response, service_name, method_name)
261 .await?;
262
263 Ok(response)
264 }
265
266 async fn forward_unary_request(
268 &self,
269 request: Request<DynamicMessage>,
270 service_name: &str,
271 method_name: &str,
272 ) -> Result<Response<DynamicMessage>, Status> {
273 if let Some(upstream) = &self.config.upstream_endpoint {
274 let _channel = self.connection_pool.get_channel(upstream).await.map_err(|e| {
276 Status::unavailable(format!("Failed to connect to upstream {}: {}", upstream, e))
277 })?;
278
279 debug!(
280 "Generic upstream forwarding is unavailable for {}/{}, falling back to local mock response",
281 service_name, method_name
282 );
283 self.generate_mock_response(service_name, method_name, request).await
284 } else {
285 debug!(
286 "No upstream endpoint configured for {}/{}, using local mock fallback",
287 service_name, method_name
288 );
289 self.generate_mock_response(service_name, method_name, request).await
290 }
291 }
292
293 async fn forward_server_streaming_request(
295 &self,
296 _request: Request<DynamicMessage>,
297 service_name: &str,
298 method_name: &str,
299 ) -> Result<Response<ReceiverStream<Result<DynamicMessage, Status>>>, Status> {
300 if let Some(upstream) = &self.config.upstream_endpoint {
301 let _channel = self.connection_pool.get_channel(upstream).await.map_err(|e| {
303 Status::unavailable(format!("Failed to connect to upstream {}: {}", upstream, e))
304 })?;
305
306 debug!(
307 "Generic upstream streaming forwarding is unavailable for {}/{}, falling back to local mock stream",
308 service_name, method_name
309 );
310 self.generate_mock_stream_response(service_name, method_name).await
311 } else {
312 debug!(
313 "No upstream endpoint configured for {}/{}, using local mock stream fallback",
314 service_name, method_name
315 );
316 self.generate_mock_stream_response(service_name, method_name).await
317 }
318 }
319
320 async fn forward_client_streaming_request(
322 &self,
323 request: Request<Streaming<DynamicMessage>>,
324 service_name: &str,
325 method_name: &str,
326 ) -> Result<Response<DynamicMessage>, Status> {
327 if let Some(upstream) = &self.config.upstream_endpoint {
328 let _channel = self.connection_pool.get_channel(upstream).await.map_err(|e| {
330 Status::unavailable(format!("Failed to connect to upstream {}: {}", upstream, e))
331 })?;
332
333 debug!(
334 "Generic upstream client-stream forwarding is unavailable for {}/{}, falling back to local mock response",
335 service_name, method_name
336 );
337 self.generate_mock_client_stream_response(service_name, method_name, request)
338 .await
339 } else {
340 debug!(
341 "No upstream endpoint configured for {}/{}, using local mock client-stream fallback",
342 service_name, method_name
343 );
344 self.generate_mock_client_stream_response(service_name, method_name, request)
345 .await
346 }
347 }
348
349 async fn forward_bidirectional_streaming_request(
351 &self,
352 request: Request<Streaming<DynamicMessage>>,
353 service_name: &str,
354 method_name: &str,
355 ) -> Result<Response<ReceiverStream<Result<DynamicMessage, Status>>>, Status> {
356 if let Some(upstream) = &self.config.upstream_endpoint {
357 let _channel = self.connection_pool.get_channel(upstream).await.map_err(|e| {
359 Status::unavailable(format!("Failed to connect to upstream {}: {}", upstream, e))
360 })?;
361
362 debug!(
363 "Generic upstream bidi-stream forwarding is unavailable for {}/{}, falling back to local mock stream",
364 service_name, method_name
365 );
366 let _ = request;
367 self.generate_mock_bidirectional_stream_response(service_name, method_name)
368 .await
369 } else {
370 debug!(
371 "No upstream endpoint configured for {}/{}, using local mock bidi-stream fallback",
372 service_name, method_name
373 );
374 let _ = request;
375 self.generate_mock_bidirectional_stream_response(service_name, method_name)
376 .await
377 }
378 }
379
380 fn generate_mock_message(
382 &self,
383 descriptor: MessageDescriptor,
384 ) -> Result<DynamicMessage, Status> {
385 let mut smart_generator = self
386 .smart_generator()
387 .lock()
388 .map_err(|_| Status::internal("Failed to acquire lock on smart generator"))?;
389
390 Ok(smart_generator.generate_message(&descriptor))
391 }
392
393 fn generate_mock_message_with_generator(
395 smart_generator: &Arc<Mutex<crate::reflection::smart_mock_generator::SmartMockGenerator>>,
396 descriptor: MessageDescriptor,
397 ) -> Result<DynamicMessage, Status> {
398 let mut smart_generator = smart_generator
399 .lock()
400 .map_err(|_| Status::internal("Failed to acquire lock on smart generator"))?;
401
402 Ok(smart_generator.generate_message(&descriptor))
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 #[test]
409 fn test_module_compiles() {
410 }
412}