Skip to main content

mockforge_grpc/reflection/
proxy.rs

1//! Main reflection proxy implementation
2
3use crate::reflection::{
4    cache::DescriptorCache, client::ReflectionClient, config::ProxyConfig,
5    connection_pool::ConnectionPool,
6};
7use futures_util::Stream;
8#[cfg(feature = "data-faker")]
9use mockforge_data::{DataConfig, DataGenerator, SchemaDefinition};
10use prost_reflect::{DynamicMessage, ReflectMessage};
11use std::pin::Pin;
12use std::time::Duration;
13use tokio::sync::mpsc;
14use tokio_stream::wrappers::ReceiverStream;
15use tokio_stream::StreamExt;
16use tonic::{transport::Endpoint, Request, Response, Status, Streaming};
17use tracing::{debug, warn};
18
19/// A reflection-based gRPC proxy that can forward requests to arbitrary services
20pub struct ReflectionProxy {
21    /// The reflection client for discovering services
22    _client: ReflectionClient,
23    /// Cache of service and method descriptors
24    cache: DescriptorCache,
25    /// Proxy configuration
26    config: ProxyConfig,
27    /// Timeout for requests
28    timeout_duration: Duration,
29    /// Connection pool for gRPC channels
30    #[allow(dead_code)]
31    connection_pool: ConnectionPool,
32}
33
34impl ReflectionProxy {
35    /// Create a new reflection proxy
36    pub async fn new(endpoint: Endpoint, config: ProxyConfig) -> Result<Self, Status> {
37        debug!("Creating reflection proxy for endpoint: {:?}", endpoint.uri());
38
39        let client = ReflectionClient::new(endpoint).await?;
40        let cache = DescriptorCache::new();
41
42        // Populate cache from the client's descriptor pool
43        cache.populate_from_pool(Some(client.pool())).await;
44
45        Ok(Self {
46            _client: client,
47            cache,
48            config,
49            timeout_duration: Duration::from_secs(30),
50            connection_pool: ConnectionPool::new(),
51        })
52    }
53
54    /// Set the request timeout
55    pub fn with_timeout(mut self, timeout: Duration) -> Self {
56        self.timeout_duration = timeout;
57        self
58    }
59
60    /// Forward a unary request to the target service
61    pub async fn forward_unary(
62        &self,
63        service_name: &str,
64        method_name: &str,
65        request: Request<DynamicMessage>,
66    ) -> Result<Response<DynamicMessage>, Status> {
67        // Check if service is allowed
68        if !self.config.is_service_allowed(service_name) {
69            return Err(Status::permission_denied(format!(
70                "Service {} is not allowed",
71                service_name
72            )));
73        }
74
75        // Get the method descriptor
76        let method = self.cache.get_method(service_name, method_name).await?;
77
78        // Check if it's actually a unary method
79        if !method.is_server_streaming() && !method.is_client_streaming() {
80            self.forward_unary_impl(method, request).await
81        } else {
82            Err(Status::invalid_argument(format!(
83                "Method {}::{} is not a unary method",
84                service_name, method_name
85            )))
86        }
87    }
88
89    /// Forward a server-streaming request to the target service
90    pub async fn forward_server_streaming(
91        &self,
92        service_name: &str,
93        method_name: &str,
94        request: Request<DynamicMessage>,
95    ) -> Result<Response<Pin<Box<dyn Stream<Item = Result<DynamicMessage, Status>> + Send>>>, Status>
96    {
97        // Check if service is allowed
98        if !self.config.is_service_allowed(service_name) {
99            return Err(Status::permission_denied(format!(
100                "Service {} is not allowed",
101                service_name
102            )));
103        }
104
105        // Get the method descriptor
106        let method = self.cache.get_method(service_name, method_name).await?;
107
108        // Check if it's actually a server streaming method
109        if method.is_server_streaming() && !method.is_client_streaming() {
110            self.forward_server_streaming_impl(method, request).await
111        } else {
112            Err(Status::invalid_argument(format!(
113                "Method {}::{} is not a server streaming method",
114                service_name, method_name
115            )))
116        }
117    }
118
119    /// Forward a client-streaming request to the target service
120    pub async fn forward_client_streaming(
121        &self,
122        service_name: &str,
123        method_name: &str,
124        request: Request<Streaming<DynamicMessage>>,
125    ) -> Result<Response<DynamicMessage>, Status> {
126        // Check if service is allowed
127        if !self.config.is_service_allowed(service_name) {
128            return Err(Status::permission_denied(format!(
129                "Service {} is not allowed",
130                service_name
131            )));
132        }
133
134        // Get the method descriptor
135        let method = self.cache.get_method(service_name, method_name).await?;
136
137        // Check if it's actually a client streaming method
138        if method.is_client_streaming() && !method.is_server_streaming() {
139            self.forward_client_streaming_impl(method, request).await
140        } else {
141            Err(Status::invalid_argument(format!(
142                "Method {}::{} is not a client streaming method",
143                service_name, method_name
144            )))
145        }
146    }
147
148    /// Forward a bidirectional streaming request to the target service
149    pub async fn forward_bidirectional_streaming(
150        &self,
151        service_name: &str,
152        method_name: &str,
153        request: Request<Streaming<DynamicMessage>>,
154    ) -> Result<Response<Pin<Box<dyn Stream<Item = Result<DynamicMessage, Status>> + Send>>>, Status>
155    {
156        // Check if service is allowed
157        if !self.config.is_service_allowed(service_name) {
158            return Err(Status::permission_denied(format!(
159                "Service {} is not allowed",
160                service_name
161            )));
162        }
163
164        // Get the method descriptor
165        let method = self.cache.get_method(service_name, method_name).await?;
166
167        // Check if it's actually a bidirectional streaming method
168        if method.is_client_streaming() && method.is_server_streaming() {
169            self.forward_bidirectional_streaming_impl(method, request).await
170        } else {
171            Err(Status::invalid_argument(format!(
172                "Method {}::{} is not a bidirectional streaming method",
173                service_name, method_name
174            )))
175        }
176    }
177
178    /// Implementation for forwarding unary requests
179    async fn forward_unary_impl(
180        &self,
181        method: prost_reflect::MethodDescriptor,
182        request: Request<DynamicMessage>,
183    ) -> Result<Response<DynamicMessage>, Status> {
184        // Real implementation for mock server:
185        // 1. Look up mock responses based on the service/method
186        // 2. Apply any configured latency or error simulation
187        // 3. Return the appropriate mock response with preserved metadata
188        // 4. Preserve all metadata from the original request in the response
189
190        debug!("Generating mock response for method: {}", method.name());
191
192        // Extract service name from method descriptor
193        let service_name = method.parent_service().name();
194        let method_name = method.name();
195
196        // Create a mock response based on the method
197        let mock_response = self.generate_mock_response(service_name, method_name, &method).await?;
198
199        // Create response with mock data and preserve metadata
200        let mut response = Response::new(mock_response);
201
202        // Preserve original request metadata in the response (ASCII only for simplicity)
203        let request_metadata = request.metadata();
204        for entry in request_metadata.iter() {
205            if let tonic::metadata::KeyAndValueRef::Ascii(key, value) = entry {
206                // Only preserve certain metadata keys, avoiding system headers
207                if !key.as_str().starts_with(':')
208                    && !key.as_str().starts_with("grpc-")
209                    && !key.as_str().starts_with("te")
210                    && !key.as_str().starts_with("content-type")
211                {
212                    response.metadata_mut().insert(key.clone(), value.clone());
213                }
214            }
215        }
216
217        // Add mock-specific metadata
218        response
219            .metadata_mut()
220            .insert("x-mockforge-service", service_name.parse().unwrap());
221        response
222            .metadata_mut()
223            .insert("x-mockforge-method", method_name.parse().unwrap());
224        response
225            .metadata_mut()
226            .insert("x-mockforge-timestamp", chrono::Utc::now().to_rfc3339().parse().unwrap());
227
228        Ok(response)
229    }
230
231    /// Generate a mock response for a given service and method
232    async fn generate_mock_response(
233        &self,
234        service_name: &str,
235        method_name: &str,
236        method_descriptor: &prost_reflect::MethodDescriptor,
237    ) -> Result<DynamicMessage, Status> {
238        debug!("Generating mock response for {}.{}", service_name, method_name);
239
240        // Get the output message descriptor
241        let output_descriptor = method_descriptor.output();
242
243        // Create a new dynamic message with the output descriptor
244        let mut response = DynamicMessage::new(output_descriptor.clone());
245
246        // Generate mock data dynamically based on the proto structure
247        self.populate_dynamic_mock_response(
248            &mut response,
249            service_name,
250            method_name,
251            &output_descriptor,
252        )?;
253
254        Ok(response)
255    }
256
257    /// Populate a dynamic mock response based on the proto structure
258    fn populate_dynamic_mock_response(
259        &self,
260        response: &mut DynamicMessage,
261        service_name: &str,
262        method_name: &str,
263        output_descriptor: &prost_reflect::MessageDescriptor,
264    ) -> Result<(), Status> {
265        debug!("Generating dynamic mock response for {}.{}", service_name, method_name);
266
267        // Get all fields from the output message descriptor
268        for field in output_descriptor.fields() {
269            let field_name = field.name();
270            let field_type = field.kind();
271
272            debug!("Processing field: {} of type: {:?}", field_name, field_type);
273
274            // Generate appropriate mock values based on field type
275            let mock_value = self.generate_mock_value_for_field(&field, service_name, method_name);
276
277            // Try to set the field (ignore errors if field doesn't exist or is wrong type)
278            response.set_field(&field, mock_value);
279        }
280
281        // Always try to add some common metadata fields if they don't exist
282        let metadata_fields = vec![
283            ("mockforge_service", prost_reflect::Value::String(service_name.to_string())),
284            ("mockforge_method", prost_reflect::Value::String(method_name.to_string())),
285            (
286                "mockforge_timestamp",
287                prost_reflect::Value::String(chrono::Utc::now().to_rfc3339()),
288            ),
289            (
290                "mockforge_source",
291                prost_reflect::Value::String("MockForge Reflection Proxy".to_string()),
292            ),
293        ];
294
295        for (field_name, value) in metadata_fields {
296            response.set_field_by_name(field_name, value);
297        }
298
299        Ok(())
300    }
301
302    /// Generate a mock value for a specific field based on its type
303    fn generate_mock_value_for_field(
304        &self,
305        field: &prost_reflect::FieldDescriptor,
306        service_name: &str,
307        method_name: &str,
308    ) -> prost_reflect::Value {
309        self.generate_mock_value_for_field_with_depth(field, service_name, method_name, 0)
310    }
311
312    /// Generate a mock value for a specific field with recursion depth limit
313    fn generate_mock_value_for_field_with_depth(
314        &self,
315        field: &prost_reflect::FieldDescriptor,
316        service_name: &str,
317        method_name: &str,
318        depth: usize,
319    ) -> prost_reflect::Value {
320        // Prevent infinite recursion with a reasonable depth limit
321        const MAX_DEPTH: usize = 5;
322        if depth >= MAX_DEPTH {
323            return prost_reflect::Value::String(format!("max_depth_reached_{}", field.name()));
324        }
325
326        // Handle repeated fields (arrays)
327        if field.is_list() {
328            let mut list_values = Vec::new();
329            // Generate 1-3 mock values for the list
330            let field_name_lower = field.name().to_lowercase();
331            let num_items =
332                if field_name_lower.contains("list") || field_name_lower.contains("items") {
333                    3
334                } else {
335                    1
336                };
337
338            for _ in 0..num_items {
339                let item_value =
340                    self.generate_single_field_value(field, service_name, method_name, depth);
341                list_values.push(item_value);
342            }
343
344            return prost_reflect::Value::List(list_values);
345        }
346
347        self.generate_single_field_value(field, service_name, method_name, depth)
348    }
349
350    /// Generate a mock value for a single (non-repeated) field
351    fn generate_single_field_value(
352        &self,
353        field: &prost_reflect::FieldDescriptor,
354        service_name: &str,
355        method_name: &str,
356        depth: usize,
357    ) -> prost_reflect::Value {
358        let field_name = field.name().to_lowercase();
359        let field_type = field.kind();
360
361        // Generate contextual mock data based on field name patterns
362        if field_name.contains("message")
363            || field_name.contains("text")
364            || field_name.contains("content")
365        {
366            return prost_reflect::Value::String(format!(
367                "Mock response from {} for method {} at {}",
368                service_name,
369                method_name,
370                chrono::Utc::now().format("%Y-%m-%d %H:%M:%S UTC")
371            ));
372        }
373
374        if field_name.contains("id") {
375            return prost_reflect::Value::String(format!(
376                "mock_{}",
377                chrono::Utc::now().timestamp()
378            ));
379        }
380
381        if field_name.contains("status") || field_name.contains("state") {
382            return prost_reflect::Value::String("success".to_string());
383        }
384
385        if field_name.contains("count") || field_name.contains("number") {
386            return prost_reflect::Value::I64(42);
387        }
388
389        if field_name.contains("timestamp") || field_name.contains("time") {
390            return prost_reflect::Value::String(chrono::Utc::now().to_rfc3339());
391        }
392
393        if field_name.contains("enabled") || field_name.contains("active") {
394            return prost_reflect::Value::Bool(true);
395        }
396
397        // Default mock values based on field type
398        match field_type {
399            prost_reflect::Kind::String => {
400                prost_reflect::Value::String(format!("mock_{}_{}", service_name, method_name))
401            }
402            prost_reflect::Kind::Int32 => prost_reflect::Value::I32(42),
403            prost_reflect::Kind::Int64 => prost_reflect::Value::I64(42),
404            prost_reflect::Kind::Float => prost_reflect::Value::F32(std::f32::consts::PI),
405            prost_reflect::Kind::Double => prost_reflect::Value::F64(std::f64::consts::PI),
406            prost_reflect::Kind::Bool => prost_reflect::Value::Bool(true),
407            prost_reflect::Kind::Bytes => prost_reflect::Value::Bytes(b"mock_data".to_vec().into()),
408            prost_reflect::Kind::Enum(enum_descriptor) => {
409                // Try to get the first enum value, or use a default
410                if let Some(first_value) = enum_descriptor.values().next() {
411                    // Use the first enum value as the default
412                    prost_reflect::Value::EnumNumber(first_value.number())
413                } else {
414                    // Fallback if no enum values are defined
415                    prost_reflect::Value::EnumNumber(0)
416                }
417            }
418            prost_reflect::Kind::Message(message_descriptor) => {
419                // Recursively generate a mock message for nested types
420                let mut nested_message = DynamicMessage::new(message_descriptor.clone());
421
422                // Populate the nested message with mock values
423                for nested_field in message_descriptor.fields() {
424                    let mock_value = self.generate_mock_value_for_field_with_depth(
425                        &nested_field,
426                        service_name,
427                        method_name,
428                        depth + 1,
429                    );
430                    nested_message.set_field(&nested_field, mock_value);
431                }
432
433                prost_reflect::Value::Message(nested_message)
434            }
435            _ => prost_reflect::Value::String("mock_value".to_string()),
436        }
437    }
438
439    /// Implementation for forwarding server streaming requests
440    async fn forward_server_streaming_impl(
441        &self,
442        method: prost_reflect::MethodDescriptor,
443        request: Request<DynamicMessage>,
444    ) -> Result<Response<Pin<Box<dyn Stream<Item = Result<DynamicMessage, Status>> + Send>>>, Status>
445    {
446        // Extract metadata from the original request
447        let metadata = request.metadata();
448        debug!(
449            "Forwarding server streaming request for method: {} with {} metadata entries",
450            method.name(),
451            metadata.len()
452        );
453
454        #[cfg(feature = "data-faker")]
455        {
456            // Generate mock streaming responses
457            let output_descriptor = method.output();
458            let messages = self.generate_mock_stream_messages(&output_descriptor, 5).await?;
459
460            // Create a proper streaming response using ReceiverStream
461            let (tx, rx) = mpsc::channel(32);
462            let stream = Box::pin(ReceiverStream::new(rx))
463                as Pin<Box<dyn Stream<Item = Result<DynamicMessage, Status>> + Send>>;
464
465            // Spawn a task to send messages
466            tokio::spawn(async move {
467                for message in messages {
468                    if tx.send(Ok(message)).await.is_err() {
469                        break;
470                    }
471                }
472            });
473
474            // Preserve original request metadata in the response
475            let mut response = Response::new(stream);
476
477            // Copy relevant metadata from the original request to the response (ASCII only for simplicity)
478            for entry in metadata.iter() {
479                if let tonic::metadata::KeyAndValueRef::Ascii(key, value) = entry {
480                    // Only preserve certain metadata keys, avoiding system headers
481                    if !key.as_str().starts_with(':')
482                        && !key.as_str().starts_with("grpc-")
483                        && !key.as_str().starts_with("te")
484                        && !key.as_str().starts_with("content-type")
485                    {
486                        response.metadata_mut().insert(key.clone(), value.clone());
487                    }
488                }
489            }
490
491            // Add mock-specific metadata
492            response
493                .metadata_mut()
494                .insert("x-mockforge-service", method.parent_service().name().parse().unwrap());
495            response
496                .metadata_mut()
497                .insert("x-mockforge-method", method.name().parse().unwrap());
498            response
499                .metadata_mut()
500                .insert("x-mockforge-timestamp", chrono::Utc::now().to_rfc3339().parse().unwrap());
501            response.metadata_mut().insert("x-mockforge-stream-count", "5".parse().unwrap());
502
503            debug!("Generated server streaming response with {} messages", 5);
504            Ok(response)
505        }
506
507        #[cfg(not(feature = "data-faker"))]
508        {
509            debug!("Data faker feature not enabled, using built-in mock stream generation");
510
511            let service_name = method.parent_service().name().to_string();
512            let method_name = method.name().to_string();
513
514            let (tx, rx) = mpsc::channel(32);
515            let stream = Box::pin(ReceiverStream::new(rx))
516                as Pin<Box<dyn Stream<Item = Result<DynamicMessage, Status>> + Send>>;
517
518            let proxy = self;
519            let method_for_task = method.clone();
520            tokio::spawn(async move {
521                for _ in 0..5 {
522                    let message_result = proxy
523                        .generate_mock_response(&service_name, &method_name, &method_for_task)
524                        .await;
525                    if tx.send(message_result).await.is_err() {
526                        break;
527                    }
528                }
529            });
530
531            let mut response = Response::new(stream);
532            for entry in metadata.iter() {
533                if let tonic::metadata::KeyAndValueRef::Ascii(key, value) = entry {
534                    if !key.as_str().starts_with(':')
535                        && !key.as_str().starts_with("grpc-")
536                        && !key.as_str().starts_with("te")
537                        && !key.as_str().starts_with("content-type")
538                    {
539                        response.metadata_mut().insert(key.clone(), value.clone());
540                    }
541                }
542            }
543
544            response
545                .metadata_mut()
546                .insert("x-mockforge-service", method.parent_service().name().parse().unwrap());
547            response
548                .metadata_mut()
549                .insert("x-mockforge-method", method.name().parse().unwrap());
550            response
551                .metadata_mut()
552                .insert("x-mockforge-timestamp", chrono::Utc::now().to_rfc3339().parse().unwrap());
553            response.metadata_mut().insert("x-mockforge-stream-count", "5".parse().unwrap());
554
555            Ok(response)
556        }
557    }
558
559    /// Implementation for forwarding client streaming requests
560    async fn forward_client_streaming_impl(
561        &self,
562        method: prost_reflect::MethodDescriptor,
563        request: Request<Streaming<DynamicMessage>>,
564    ) -> Result<Response<DynamicMessage>, Status> {
565        debug!("Forwarding client streaming request for method: {}", method.name());
566
567        #[cfg(feature = "data-faker")]
568        {
569            // Extract metadata from the original request before consuming it
570            let request_metadata = request.metadata().clone();
571
572            // Process the streaming request and extract message data
573            let mut stream = request.into_inner();
574            let mut message_count = 0;
575            let mut processed_names = Vec::new();
576            let mut user_ids = Vec::new();
577            let mut all_tags = Vec::new();
578
579            while let Some(message_result) = stream.next().await {
580                match message_result {
581                    Ok(message) => {
582                        message_count += 1;
583                        debug!(
584                            "Processing client streaming message {} for method: {}",
585                            message_count,
586                            method.name()
587                        );
588
589                        // Extract data from the HelloRequest message
590                        let input_descriptor = method.input();
591
592                        // Extract the 'name' field
593                        if let Some(name_field) = input_descriptor.get_field_by_name("name") {
594                            let field_value = message.get_field(&name_field);
595                            if let prost_reflect::Value::String(name) = field_value.into_owned() {
596                                processed_names.push(name.clone());
597                                debug!("  - Name: {}", name);
598                            }
599                        }
600
601                        // Extract the 'user_info' field (nested message)
602                        if let Some(user_info_field) =
603                            input_descriptor.get_field_by_name("user_info")
604                        {
605                            let field_value = message.get_field(&user_info_field);
606                            if let prost_reflect::Value::Message(user_info_msg) =
607                                field_value.into_owned()
608                            {
609                                // Extract user_id from user_info
610                                if let Some(user_id_field) =
611                                    user_info_msg.descriptor().get_field_by_name("user_id")
612                                {
613                                    let user_id_value = user_info_msg.get_field(&user_id_field);
614                                    if let prost_reflect::Value::String(user_id) =
615                                        user_id_value.into_owned()
616                                    {
617                                        user_ids.push(user_id.clone());
618                                        debug!("  - User ID: {}", user_id);
619                                    }
620                                }
621                            }
622                        }
623
624                        // Extract the 'tags' field (repeated string)
625                        if let Some(tags_field) = input_descriptor.get_field_by_name("tags") {
626                            let field_value = message.get_field(&tags_field);
627                            if let prost_reflect::Value::List(tags_list) = field_value.into_owned()
628                            {
629                                for tag_value in tags_list {
630                                    if let prost_reflect::Value::String(tag) = tag_value {
631                                        all_tags.push(tag.clone());
632                                        debug!("  - Tag: {}", tag);
633                                    }
634                                }
635                            }
636                        }
637                    }
638                    Err(e) => {
639                        warn!("Error receiving client streaming message: {}", e);
640                        return Err(Status::internal(format!(
641                            "Error processing streaming request: {}",
642                            e
643                        )));
644                    }
645                }
646            }
647
648            debug!("Processed {} messages in client streaming request", message_count);
649            debug!(
650                "Collected data - Names: {:?}, User IDs: {:?}, Tags: {:?}",
651                processed_names, user_ids, all_tags
652            );
653
654            // Generate a mock response based on the output descriptor, but enhance it with processed data
655            let output_descriptor = method.output();
656            let mut mock_response = self.generate_mock_message(&output_descriptor).await?;
657
658            // Enhance the response message with aggregated data from the stream
659            if let Some(message_field) = output_descriptor.get_field_by_name("message") {
660                // Create a personalized message based on the processed data
661                let personalized_message = if !processed_names.is_empty() {
662                    format!("Hello to all {} senders! Processed names: {}, with {} unique tags from {} users",
663                           message_count, processed_names.join(", "), all_tags.len(), user_ids.len())
664                } else {
665                    format!(
666                        "Hello! Processed {} messages with {} tags",
667                        message_count,
668                        all_tags.len()
669                    )
670                };
671
672                // Update the message field in the response
673                mock_response
674                    .set_field(&message_field, prost_reflect::Value::String(personalized_message));
675            }
676
677            // Preserve original request metadata in the response
678            let mut response = Response::new(mock_response);
679
680            // Copy relevant metadata from the original request to the response (ASCII only for simplicity)
681            for entry in request_metadata.iter() {
682                if let tonic::metadata::KeyAndValueRef::Ascii(key, value) = entry {
683                    // Only preserve certain metadata keys, avoiding system headers
684                    if !key.as_str().starts_with(':')
685                        && !key.as_str().starts_with("grpc-")
686                        && !key.as_str().starts_with("te")
687                        && !key.as_str().starts_with("content-type")
688                    {
689                        response.metadata_mut().insert(key.clone(), value.clone());
690                    }
691                }
692            }
693
694            // Add mock-specific metadata
695            response
696                .metadata_mut()
697                .insert("x-mockforge-service", method.parent_service().name().parse().unwrap());
698            response
699                .metadata_mut()
700                .insert("x-mockforge-method", method.name().parse().unwrap());
701            response
702                .metadata_mut()
703                .insert("x-mockforge-timestamp", chrono::Utc::now().to_rfc3339().parse().unwrap());
704            response
705                .metadata_mut()
706                .insert("x-mockforge-message-count", message_count.to_string().parse().unwrap());
707
708            let response = response;
709
710            debug!(
711                "Generated enhanced client streaming response with {} processed messages",
712                message_count
713            );
714            Ok(response)
715        }
716
717        #[cfg(not(feature = "data-faker"))]
718        {
719            debug!("Data faker feature not enabled, using built-in mock client-stream response");
720
721            let request_metadata = request.metadata().clone();
722            let mut stream = request.into_inner();
723            let mut message_count = 0usize;
724
725            while let Some(message_result) = stream.next().await {
726                match message_result {
727                    Ok(_) => {
728                        message_count += 1;
729                    }
730                    Err(e) => {
731                        warn!("Error receiving client streaming message: {}", e);
732                        return Err(Status::internal(format!(
733                            "Error processing streaming request: {}",
734                            e
735                        )));
736                    }
737                }
738            }
739
740            let service_name = method.parent_service().name().to_string();
741            let method_name = method.name().to_string();
742            let mock_response =
743                self.generate_mock_response(&service_name, &method_name, &method).await?;
744            let mut response = Response::new(mock_response);
745
746            for entry in request_metadata.iter() {
747                if let tonic::metadata::KeyAndValueRef::Ascii(key, value) = entry {
748                    if !key.as_str().starts_with(':')
749                        && !key.as_str().starts_with("grpc-")
750                        && !key.as_str().starts_with("te")
751                        && !key.as_str().starts_with("content-type")
752                    {
753                        response.metadata_mut().insert(key.clone(), value.clone());
754                    }
755                }
756            }
757
758            response
759                .metadata_mut()
760                .insert("x-mockforge-service", method.parent_service().name().parse().unwrap());
761            response
762                .metadata_mut()
763                .insert("x-mockforge-method", method.name().parse().unwrap());
764            response
765                .metadata_mut()
766                .insert("x-mockforge-timestamp", chrono::Utc::now().to_rfc3339().parse().unwrap());
767            response
768                .metadata_mut()
769                .insert("x-mockforge-message-count", message_count.to_string().parse().unwrap());
770
771            Ok(response)
772        }
773    }
774
775    /// Implementation for forwarding bidirectional streaming requests
776    async fn forward_bidirectional_streaming_impl(
777        &self,
778        method: prost_reflect::MethodDescriptor,
779        request: Request<Streaming<DynamicMessage>>,
780    ) -> Result<Response<Pin<Box<dyn Stream<Item = Result<DynamicMessage, Status>> + Send>>>, Status>
781    {
782        debug!("Forwarding bidirectional streaming request for method: {}", method.name());
783
784        #[cfg(feature = "data-faker")]
785        {
786            // Extract metadata from the original request before consuming it
787            let metadata = request.metadata();
788            debug!("Forwarding bidirectional streaming request for method: {} with {} metadata entries",
789                   method.name(), metadata.len());
790
791            // Generate mock bidirectional streaming responses
792            let output_descriptor = method.output();
793            let messages = self.generate_mock_stream_messages(&output_descriptor, 10).await?;
794
795            // Create streaming response using ReceiverStream
796            let (tx, rx) = mpsc::channel(32);
797            let stream = Box::pin(ReceiverStream::new(rx))
798                as Pin<Box<dyn Stream<Item = Result<DynamicMessage, Status>> + Send>>;
799
800            // Spawn a task to send messages
801            tokio::spawn(async move {
802                for message in messages {
803                    if tx.send(Ok(message)).await.is_err() {
804                        break;
805                    }
806                }
807            });
808
809            // Preserve original request metadata in the response
810            let mut response = Response::new(stream);
811
812            // Copy relevant metadata from the original request to the response (ASCII only for simplicity)
813            for entry in metadata.iter() {
814                if let tonic::metadata::KeyAndValueRef::Ascii(key, value) = entry {
815                    // Only preserve certain metadata keys, avoiding system headers
816                    if !key.as_str().starts_with(':')
817                        && !key.as_str().starts_with("grpc-")
818                        && !key.as_str().starts_with("te")
819                        && !key.as_str().starts_with("content-type")
820                    {
821                        response.metadata_mut().insert(key.clone(), value.clone());
822                    }
823                }
824            }
825
826            // Add mock-specific metadata
827            response
828                .metadata_mut()
829                .insert("x-mockforge-service", method.parent_service().name().parse().unwrap());
830            response
831                .metadata_mut()
832                .insert("x-mockforge-method", method.name().parse().unwrap());
833            response
834                .metadata_mut()
835                .insert("x-mockforge-timestamp", chrono::Utc::now().to_rfc3339().parse().unwrap());
836            response
837                .metadata_mut()
838                .insert("x-mockforge-stream-count", "10".parse().unwrap());
839
840            // Process incoming stream concurrently
841            let mut incoming_stream = request.into_inner();
842            tokio::spawn(async move {
843                let mut count = 0;
844                while let Some(message_result) = incoming_stream.next().await {
845                    match message_result {
846                        Ok(_) => {
847                            count += 1;
848                            debug!(
849                                "Processed bidirectional message {} for method: {}",
850                                count,
851                                method.name()
852                            );
853                        }
854                        Err(e) => {
855                            warn!("Error processing bidirectional message: {}", e);
856                            break;
857                        }
858                    }
859                }
860                debug!("Finished processing {} bidirectional messages", count);
861            });
862
863            debug!("Generated bidirectional streaming response with {} messages", 10);
864            Ok(response)
865        }
866
867        #[cfg(not(feature = "data-faker"))]
868        {
869            debug!(
870                "Data faker feature not enabled, using built-in mock bidirectional stream generation"
871            );
872
873            let metadata = request.metadata().clone();
874            let service_name = method.parent_service().name().to_string();
875            let method_name = method.name().to_string();
876            let method_for_task = method.clone();
877
878            let (tx, rx) = mpsc::channel(32);
879            let stream = Box::pin(ReceiverStream::new(rx))
880                as Pin<Box<dyn Stream<Item = Result<DynamicMessage, Status>> + Send>>;
881
882            let proxy = self;
883            tokio::spawn(async move {
884                for _ in 0..10 {
885                    let message_result = proxy
886                        .generate_mock_response(&service_name, &method_name, &method_for_task)
887                        .await;
888                    if tx.send(message_result).await.is_err() {
889                        break;
890                    }
891                }
892            });
893
894            let mut response = Response::new(stream);
895            for entry in metadata.iter() {
896                if let tonic::metadata::KeyAndValueRef::Ascii(key, value) = entry {
897                    if !key.as_str().starts_with(':')
898                        && !key.as_str().starts_with("grpc-")
899                        && !key.as_str().starts_with("te")
900                        && !key.as_str().starts_with("content-type")
901                    {
902                        response.metadata_mut().insert(key.clone(), value.clone());
903                    }
904                }
905            }
906
907            response
908                .metadata_mut()
909                .insert("x-mockforge-service", method.parent_service().name().parse().unwrap());
910            response
911                .metadata_mut()
912                .insert("x-mockforge-method", method.name().parse().unwrap());
913            response
914                .metadata_mut()
915                .insert("x-mockforge-timestamp", chrono::Utc::now().to_rfc3339().parse().unwrap());
916            response
917                .metadata_mut()
918                .insert("x-mockforge-stream-count", "10".parse().unwrap());
919
920            let mut incoming_stream = request.into_inner();
921            let method_name_for_log = method.name().to_string();
922            tokio::spawn(async move {
923                let mut count = 0usize;
924                while let Some(message_result) = incoming_stream.next().await {
925                    match message_result {
926                        Ok(_) => count += 1,
927                        Err(e) => {
928                            warn!("Error processing bidirectional message: {}", e);
929                            break;
930                        }
931                    }
932                }
933                debug!(
934                    "Finished processing {} bidirectional request messages for method {}",
935                    count, method_name_for_log
936                );
937            });
938
939            Ok(response)
940        }
941    }
942
943    /// Generate a single mock message for the given descriptor
944    #[cfg(feature = "data-faker")]
945    async fn generate_mock_message(
946        &self,
947        descriptor: &prost_reflect::MessageDescriptor,
948    ) -> Result<DynamicMessage, Status> {
949        // Create a basic schema from the descriptor for mock generation
950        let schema_def = self.create_schema_from_protobuf_descriptor(descriptor);
951
952        let config = DataConfig {
953            rows: 1,
954            ..Default::default()
955        };
956
957        let mut generator = DataGenerator::new(schema_def, config)
958            .map_err(|e| Status::internal(format!("Failed to create data generator: {}", e)))?;
959
960        let result = generator
961            .generate()
962            .await
963            .map_err(|e| Status::internal(format!("Failed to generate mock data: {}", e)))?;
964
965        if let Some(data) = result.data.first() {
966            // Convert the generated JSON to a DynamicMessage
967            self.json_to_dynamic_message(descriptor, data)
968        } else {
969            Err(Status::internal("No mock data generated"))
970        }
971    }
972
973    /// Generate multiple mock messages for streaming
974    #[cfg(feature = "data-faker")]
975    async fn generate_mock_stream_messages(
976        &self,
977        descriptor: &prost_reflect::MessageDescriptor,
978        count: usize,
979    ) -> Result<Vec<DynamicMessage>, Status> {
980        let schema_def = self.create_schema_from_protobuf_descriptor(descriptor);
981
982        let config = DataConfig {
983            rows: count,
984            ..Default::default()
985        };
986
987        let mut generator = DataGenerator::new(schema_def, config)
988            .map_err(|e| Status::internal(format!("Failed to create data generator: {}", e)))?;
989
990        let result = generator
991            .generate()
992            .await
993            .map_err(|e| Status::internal(format!("Failed to generate mock data: {}", e)))?;
994
995        result
996            .data
997            .iter()
998            .map(|data| self.json_to_dynamic_message(descriptor, data))
999            .collect()
1000    }
1001
1002    /// Convert JSON data to a DynamicMessage
1003    #[cfg(feature = "data-faker")]
1004    fn json_to_dynamic_message(
1005        &self,
1006        descriptor: &prost_reflect::MessageDescriptor,
1007        json_data: &serde_json::Value,
1008    ) -> Result<DynamicMessage, Status> {
1009        let mut message = DynamicMessage::new(descriptor.clone());
1010
1011        if let serde_json::Value::Object(obj) = json_data {
1012            for (key, value) in obj {
1013                if let Some(field) = descriptor.get_field_by_name(key) {
1014                    let field_value = self.convert_json_value_to_protobuf_value(&field, value)?;
1015                    message.set_field(&field, field_value);
1016                }
1017            }
1018        }
1019
1020        Ok(message)
1021    }
1022
1023    /// Convert a JSON value to a protobuf Value based on the field descriptor
1024    #[cfg(feature = "data-faker")]
1025    fn convert_json_value_to_protobuf_value(
1026        &self,
1027        field: &prost_reflect::FieldDescriptor,
1028        json_value: &serde_json::Value,
1029    ) -> Result<prost_reflect::Value, Status> {
1030        use prost_reflect::Kind;
1031
1032        match json_value {
1033            serde_json::Value::Null => {
1034                // Return default value for the field type
1035                match field.kind() {
1036                    Kind::Message(message_descriptor) => Ok(prost_reflect::Value::Message(
1037                        DynamicMessage::new(message_descriptor.clone()),
1038                    )),
1039                    Kind::Enum(enum_descriptor) => {
1040                        // Try to get the first enum value, or use 0 as default
1041                        if let Some(first_value) = enum_descriptor.values().next() {
1042                            Ok(prost_reflect::Value::EnumNumber(first_value.number()))
1043                        } else {
1044                            Ok(prost_reflect::Value::EnumNumber(0))
1045                        }
1046                    }
1047                    Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => Ok(prost_reflect::Value::I32(0)),
1048                    Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 => Ok(prost_reflect::Value::I64(0)),
1049                    Kind::Uint32 | Kind::Fixed32 => Ok(prost_reflect::Value::U32(0)),
1050                    Kind::Uint64 | Kind::Fixed64 => Ok(prost_reflect::Value::U64(0)),
1051                    Kind::Float => Ok(prost_reflect::Value::F32(0.0)),
1052                    Kind::Double => Ok(prost_reflect::Value::F64(0.0)),
1053                    Kind::Bool => Ok(prost_reflect::Value::Bool(false)),
1054                    Kind::String => Ok(prost_reflect::Value::String(String::new())),
1055                    Kind::Bytes => Ok(prost_reflect::Value::Bytes(b"".to_vec().into())),
1056                }
1057            }
1058            serde_json::Value::Bool(b) => Ok(prost_reflect::Value::Bool(*b)),
1059            serde_json::Value::Number(n) => {
1060                match field.kind() {
1061                    Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => {
1062                        if let Some(i) = n.as_i64() {
1063                            Ok(prost_reflect::Value::I32(i as i32))
1064                        } else {
1065                            Err(Status::invalid_argument(format!(
1066                                "Cannot convert number {} to int32",
1067                                n
1068                            )))
1069                        }
1070                    }
1071                    Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 => {
1072                        if let Some(i) = n.as_i64() {
1073                            Ok(prost_reflect::Value::I64(i))
1074                        } else {
1075                            Err(Status::invalid_argument(format!(
1076                                "Cannot convert number {} to int64",
1077                                n
1078                            )))
1079                        }
1080                    }
1081                    Kind::Uint32 | Kind::Fixed32 => {
1082                        if let Some(i) = n.as_u64() {
1083                            Ok(prost_reflect::Value::U32(i as u32))
1084                        } else {
1085                            Err(Status::invalid_argument(format!(
1086                                "Cannot convert number {} to uint32",
1087                                n
1088                            )))
1089                        }
1090                    }
1091                    Kind::Uint64 | Kind::Fixed64 => {
1092                        if let Some(i) = n.as_u64() {
1093                            Ok(prost_reflect::Value::U64(i))
1094                        } else {
1095                            Err(Status::invalid_argument(format!(
1096                                "Cannot convert number {} to uint64",
1097                                n
1098                            )))
1099                        }
1100                    }
1101                    Kind::Float => {
1102                        if let Some(f) = n.as_f64() {
1103                            Ok(prost_reflect::Value::F32(f as f32))
1104                        } else {
1105                            Err(Status::invalid_argument(format!(
1106                                "Cannot convert number {} to float",
1107                                n
1108                            )))
1109                        }
1110                    }
1111                    Kind::Double => {
1112                        if let Some(f) = n.as_f64() {
1113                            Ok(prost_reflect::Value::F64(f))
1114                        } else {
1115                            Err(Status::invalid_argument(format!(
1116                                "Cannot convert number {} to double",
1117                                n
1118                            )))
1119                        }
1120                    }
1121                    _ => {
1122                        // Fallback to int64 for unknown numeric types
1123                        if let Some(i) = n.as_i64() {
1124                            Ok(prost_reflect::Value::I64(i))
1125                        } else {
1126                            Err(Status::invalid_argument(format!(
1127                                "Cannot convert number {} to numeric type",
1128                                n
1129                            )))
1130                        }
1131                    }
1132                }
1133            }
1134            serde_json::Value::String(s) => {
1135                match field.kind() {
1136                    Kind::String => Ok(prost_reflect::Value::String(s.clone())),
1137                    Kind::Bytes => Ok(prost_reflect::Value::Bytes(s.as_bytes().to_vec().into())),
1138                    Kind::Enum(enum_descriptor) => {
1139                        // Try to convert string to enum value
1140                        if let Some(enum_value) = enum_descriptor.get_value_by_name(s) {
1141                            Ok(prost_reflect::Value::EnumNumber(enum_value.number()))
1142                        } else {
1143                            // Try to parse as number
1144                            if let Ok(num) = s.parse::<i32>() {
1145                                Ok(prost_reflect::Value::EnumNumber(num))
1146                            } else {
1147                                warn!(
1148                                    "Unknown enum value '{}' for field '{}', using default",
1149                                    s,
1150                                    field.name()
1151                                );
1152                                Ok(prost_reflect::Value::EnumNumber(0))
1153                            }
1154                        }
1155                    }
1156                    _ => {
1157                        // For other types, treat string as string
1158                        Ok(prost_reflect::Value::String(s.clone()))
1159                    }
1160                }
1161            }
1162            serde_json::Value::Array(arr) => {
1163                let mut list_values = Vec::new();
1164
1165                for item in arr {
1166                    let item_value = self.convert_json_value_to_protobuf_value(field, item)?;
1167                    list_values.push(item_value);
1168                }
1169
1170                Ok(prost_reflect::Value::List(list_values))
1171            }
1172            serde_json::Value::Object(_obj) => match field.kind() {
1173                Kind::Message(message_descriptor) => self
1174                    .json_to_dynamic_message(&message_descriptor, json_value)
1175                    .map(prost_reflect::Value::Message),
1176                _ => Err(Status::invalid_argument(format!(
1177                    "Cannot convert object to field {} of type {:?}",
1178                    field.name(),
1179                    field.kind()
1180                ))),
1181            },
1182        }
1183    }
1184
1185    /// Create a basic schema definition from a protobuf message descriptor
1186    #[cfg(feature = "data-faker")]
1187    fn create_schema_from_protobuf_descriptor(
1188        &self,
1189        descriptor: &prost_reflect::MessageDescriptor,
1190    ) -> SchemaDefinition {
1191        use mockforge_data::schema::FieldDefinition;
1192
1193        let mut schema = SchemaDefinition::new(descriptor.name().to_string());
1194
1195        for field in descriptor.fields() {
1196            let field_name = field.name().to_string();
1197            let field_type = match field.kind() {
1198                prost_reflect::Kind::Message(_) => {
1199                    // For nested messages, use a generic object type
1200                    "object".to_string()
1201                }
1202                prost_reflect::Kind::Enum(_) => "string".to_string(),
1203                prost_reflect::Kind::Bool => "boolean".to_string(),
1204                prost_reflect::Kind::Int32
1205                | prost_reflect::Kind::Sint32
1206                | prost_reflect::Kind::Sfixed32
1207                | prost_reflect::Kind::Uint32
1208                | prost_reflect::Kind::Fixed32
1209                | prost_reflect::Kind::Int64
1210                | prost_reflect::Kind::Sint64
1211                | prost_reflect::Kind::Sfixed64
1212                | prost_reflect::Kind::Uint64
1213                | prost_reflect::Kind::Fixed64 => "integer".to_string(),
1214                prost_reflect::Kind::Float | prost_reflect::Kind::Double => "number".to_string(),
1215                prost_reflect::Kind::String => "string".to_string(),
1216                prost_reflect::Kind::Bytes => "string".to_string(),
1217            };
1218
1219            let mut field_def = FieldDefinition::new(field_name, field_type);
1220
1221            // Check if field is optional based on protobuf field properties
1222            // In proto3, all non-repeated fields are effectively optional
1223            // In proto2, only explicitly optional or required fields exist
1224            if field.supports_presence() && !field.is_list() {
1225                // Field supports presence detection and is not repeated, so it's optional
1226                field_def = field_def.optional();
1227            }
1228
1229            schema = schema.with_field(field_def);
1230        }
1231
1232        schema
1233    }
1234}
1235
1236#[cfg(test)]
1237mod tests {
1238    #[test]
1239    fn test_module_compiles() {
1240        // Verify this module's types and imports are valid
1241    }
1242}