mockforge_grpc/dynamic/
service_generator.rs

1//! Dynamic gRPC service generation
2//!
3//! This module generates actual gRPC service implementations from parsed proto definitions.
4
5use crate::dynamic::proto_parser::{ProtoMethod, ProtoParser, ProtoService};
6use crate::reflection::smart_mock_generator::{SmartMockConfig, SmartMockGenerator};
7use mockforge_core::latency::LatencyInjector;
8use prost_reflect::DescriptorPool;
9use prost_types::Any;
10use std::collections::HashMap;
11use std::sync::{Arc, Mutex};
12use std::time::Duration;
13use tokio::sync::mpsc;
14use tokio_stream::wrappers::ReceiverStream;
15use tonic::{Request, Response, Status, Streaming};
16use tracing::{debug, info, warn};
17
18/// Service factory for creating enhanced gRPC services from proto files
19pub struct EnhancedServiceFactory;
20
21impl EnhancedServiceFactory {
22    /// Create services from a proto directory with enhanced capabilities
23    pub async fn create_services_from_proto_dir(
24        proto_dir: &str,
25        latency_injector: Option<LatencyInjector>,
26        smart_config: SmartMockConfig,
27    ) -> Result<Vec<DynamicGrpcService>, Box<dyn std::error::Error + Send + Sync>> {
28        info!("Creating enhanced services from proto directory: {}", proto_dir);
29
30        // Parse proto files with full protoc support
31        let mut parser = ProtoParser::new();
32        parser.parse_directory(proto_dir).await?;
33
34        let mut services = Vec::new();
35
36        // Store services info before consuming parser
37        let services_info: Vec<(String, ProtoService)> = parser
38            .services()
39            .iter()
40            .map(|(name, service)| (name.clone(), service.clone()))
41            .collect();
42
43        // Create enhanced services for each parsed service
44        for (service_name, proto_service) in services_info {
45            debug!("Creating enhanced service: {}", service_name);
46
47            // Create a new parser instance for each service (we'll improve this later)
48            let mut service_parser = ProtoParser::new();
49            let _ = service_parser.parse_directory(proto_dir).await; // Re-parse for now
50
51            let service = DynamicGrpcService::new_enhanced(
52                proto_service,
53                latency_injector.clone(),
54                Some(service_parser),
55                smart_config.clone(),
56            );
57
58            services.push(service);
59        }
60
61        info!("Created {} enhanced services", services.len());
62        Ok(services)
63    }
64
65    /// Create a single service from proto service definition
66    pub fn create_service_from_proto(
67        proto_service: ProtoService,
68        latency_injector: Option<LatencyInjector>,
69        proto_parser: Option<ProtoParser>,
70        smart_config: SmartMockConfig,
71    ) -> DynamicGrpcService {
72        if proto_parser.is_some() {
73            info!("Creating enhanced service: {}", proto_service.name);
74            DynamicGrpcService::new_enhanced(
75                proto_service,
76                latency_injector,
77                proto_parser,
78                smart_config,
79            )
80        } else {
81            info!("Creating basic service: {}", proto_service.name);
82            DynamicGrpcService::new(proto_service, latency_injector)
83        }
84    }
85}
86
87/// A dynamically generated gRPC service
88pub struct DynamicGrpcService {
89    /// The service definition
90    service: ProtoService,
91    /// Latency injector for simulating delays
92    latency_injector: Option<LatencyInjector>,
93    /// Mock responses for each method
94    mock_responses: HashMap<String, MockResponse>,
95    /// Proto parser with descriptor pool for advanced type support
96    proto_parser: Option<ProtoParser>,
97    /// Smart mock generator for intelligent data generation
98    smart_generator: Arc<Mutex<SmartMockGenerator>>,
99}
100
101/// Configuration for mock responses
102#[derive(Debug, Clone)]
103pub struct MockResponse {
104    /// The response message as JSON
105    pub response_json: String,
106    /// Whether to simulate an error
107    pub simulate_error: bool,
108    /// Error message if simulating an error
109    pub error_message: Option<String>,
110    /// Error code if simulating an error
111    pub error_code: Option<i32>,
112}
113
114impl DynamicGrpcService {
115    /// Create a new dynamic gRPC service
116    pub fn new(service: ProtoService, latency_injector: Option<LatencyInjector>) -> Self {
117        let mut mock_responses = HashMap::new();
118
119        // Generate default mock responses for each method
120        for method in &service.methods {
121            let response = Self::generate_mock_response(&method.name, &method.output_type);
122            mock_responses.insert(method.name.clone(), response);
123        }
124
125        Self {
126            service,
127            latency_injector,
128            mock_responses,
129            proto_parser: None,
130            smart_generator: Arc::new(Mutex::new(SmartMockGenerator::new(
131                SmartMockConfig::default(),
132            ))),
133        }
134    }
135
136    /// Create a new enhanced dynamic gRPC service with proto parser and smart generator
137    pub fn new_enhanced(
138        service: ProtoService,
139        latency_injector: Option<LatencyInjector>,
140        proto_parser: Option<ProtoParser>,
141        smart_config: SmartMockConfig,
142    ) -> Self {
143        let mut mock_responses = HashMap::new();
144        let smart_generator = Arc::new(Mutex::new(SmartMockGenerator::new(smart_config)));
145
146        // Generate enhanced mock responses for each method using smart generator
147        for method in &service.methods {
148            let response = if proto_parser.is_some() {
149                Self::generate_enhanced_mock_response(
150                    &method.name,
151                    &method.output_type,
152                    &service.name,
153                    &smart_generator,
154                )
155            } else {
156                Self::generate_mock_response(&method.name, &method.output_type)
157            };
158            mock_responses.insert(method.name.clone(), response);
159        }
160
161        Self {
162            service,
163            latency_injector,
164            mock_responses,
165            proto_parser,
166            smart_generator,
167        }
168    }
169
170    /// Generate a mock response for a method
171    fn generate_mock_response(method_name: &str, output_type: &str) -> MockResponse {
172        // Generate different responses based on method name
173        let response_json = match method_name {
174            "SayHello" | "SayHelloStream" | "SayHelloClientStream" | "Chat" => {
175                r#"{"message": "Hello from MockForge!"}"#.to_string()
176            }
177            _ => {
178                // Generic response for unknown methods
179                format!(
180                    r#"{{"result": "Mock response for {}", "type": "{}"}}"#,
181                    method_name, output_type
182                )
183            }
184        };
185
186        MockResponse {
187            response_json,
188            simulate_error: false,
189            error_message: None,
190            error_code: None,
191        }
192    }
193
194    /// Generate an enhanced mock response using smart generator
195    fn generate_enhanced_mock_response(
196        method_name: &str,
197        output_type: &str,
198        service_name: &str,
199        smart_generator: &Arc<Mutex<SmartMockGenerator>>,
200    ) -> MockResponse {
201        debug!("Generating enhanced mock response for {}.{}", service_name, method_name);
202
203        // Use smart generator for more realistic responses
204        let response_json = if let Ok(mut generator) = smart_generator.lock() {
205            // Create sample fields based on common gRPC response patterns
206            let mut fields = HashMap::new();
207
208            // Add common response fields based on method name
209            match method_name.to_lowercase().as_str() {
210                name if name.contains("hello") || name.contains("greet") => {
211                    fields.insert("message".to_string(), "greeting".to_string());
212                    fields.insert("name".to_string(), "user_name".to_string());
213                    fields.insert("timestamp".to_string(), "timestamp".to_string());
214                }
215                name if name.contains("list") || name.contains("get") => {
216                    fields.insert("id".to_string(), "identifier".to_string());
217                    fields.insert("data".to_string(), "response_data".to_string());
218                    fields.insert("count".to_string(), "total_count".to_string());
219                }
220                name if name.contains("create") || name.contains("add") => {
221                    fields.insert("id".to_string(), "new_id".to_string());
222                    fields.insert("status".to_string(), "status".to_string());
223                    fields.insert("message".to_string(), "success_message".to_string());
224                }
225                name if name.contains("update") || name.contains("modify") => {
226                    fields.insert("updated".to_string(), "updated_fields".to_string());
227                    fields.insert("version".to_string(), "version_number".to_string());
228                    fields.insert("status".to_string(), "status".to_string());
229                }
230                name if name.contains("delete") || name.contains("remove") => {
231                    fields.insert("deleted".to_string(), "deleted_status".to_string());
232                    fields.insert("message".to_string(), "confirmation_message".to_string());
233                }
234                _ => {
235                    // Generic response structure
236                    fields.insert("result".to_string(), "result_data".to_string());
237                    fields.insert("status".to_string(), "status".to_string());
238                    fields.insert("message".to_string(), "response_message".to_string());
239                }
240            }
241
242            // Generate JSON response using field patterns
243            let mut json_parts = Vec::new();
244            for (field_name, field_type) in fields {
245                let mock_value = match field_type.as_str() {
246                    "greeting" => {
247                        format!("\"Hello from enhanced MockForge service {}!\"", service_name)
248                    }
249                    "user_name" => "\"MockForge User\"".to_string(),
250                    "timestamp" => format!(
251                        "\"{}\"",
252                        std::time::SystemTime::now()
253                            .duration_since(std::time::UNIX_EPOCH)
254                            .unwrap_or_default()
255                            .as_secs()
256                    ),
257                    "identifier" | "new_id" => format!("{}", generator.next_sequence()),
258                    "total_count" => "42".to_string(),
259                    "status" => "\"success\"".to_string(),
260                    "success_message" => {
261                        format!("\"Successfully processed {} request\"", method_name)
262                    }
263                    "confirmation_message" => {
264                        format!("\"Operation {} completed successfully\"", method_name)
265                    }
266                    "version_number" => "\"1.0.0\"".to_string(),
267                    "updated_status" | "deleted_status" => "true".to_string(),
268                    _ => format!("\"Enhanced mock data for {}\"", field_type),
269                };
270                json_parts.push(format!("\"{}\": {}", field_name, mock_value));
271            }
272
273            format!("{{{}}}", json_parts.join(", "))
274        } else {
275            // Fallback to basic response if generator lock fails
276            format!(
277                r#"{{"result": "Enhanced mock response for {}", "type": "{}"}}"#,
278                method_name, output_type
279            )
280        };
281
282        MockResponse {
283            response_json,
284            simulate_error: false,
285            error_message: None,
286            error_code: None,
287        }
288    }
289
290    /// Get the descriptor pool if available
291    pub fn descriptor_pool(&self) -> Option<&DescriptorPool> {
292        self.proto_parser.as_ref().map(|parser| parser.pool())
293    }
294
295    /// Get the smart generator for external use
296    pub fn smart_generator(&self) -> &Arc<Mutex<SmartMockGenerator>> {
297        &self.smart_generator
298    }
299
300    /// Get the service definition
301    pub fn service(&self) -> &ProtoService {
302        &self.service
303    }
304
305    /// Handle a unary request
306    pub async fn handle_unary(
307        &self,
308        method_name: &str,
309        _request: Request<Any>,
310    ) -> Result<Response<Any>, Status> {
311        debug!("Handling unary request for method: {}", method_name);
312
313        // Inject latency if configured
314        if let Some(ref injector) = self.latency_injector {
315            let _ = injector.inject_latency(&[]).await;
316        }
317
318        // Get mock response for this method
319        let mock_response = self
320            .mock_responses
321            .get(method_name)
322            .ok_or_else(|| Status::not_found(format!("Method {} not found", method_name)))?;
323
324        // Check if we should simulate an error
325        if mock_response.simulate_error {
326            let error_code = mock_response.error_code.unwrap_or(2); // UNKNOWN
327            let error_message = mock_response
328                .error_message
329                .as_deref()
330                .unwrap_or("Simulated error from MockForge");
331            return Err(Status::new(tonic::Code::from_i32(error_code), error_message));
332        }
333
334        // Create response
335        let response = Any {
336            type_url: format!("type.googleapis.com/{}", self.get_output_type(method_name)),
337            value: mock_response.response_json.as_bytes().to_vec(),
338        };
339
340        Ok(Response::new(response))
341    }
342
343    /// Handle a server streaming request
344    pub async fn handle_server_streaming(
345        &self,
346        method_name: &str,
347        request: Request<Any>,
348    ) -> Result<Response<ReceiverStream<Result<Any, Status>>>, Status> {
349        debug!("Handling server streaming request for method: {}", method_name);
350
351        // Inject latency if configured
352        if let Some(ref injector) = self.latency_injector {
353            let _ = injector.inject_latency(&[]).await;
354        }
355
356        // Get mock response for this method
357        let mock_response = self
358            .mock_responses
359            .get(method_name)
360            .ok_or_else(|| Status::not_found(format!("Method {} not found", method_name)))?;
361
362        // Check if we should simulate an error
363        if mock_response.simulate_error {
364            let error_code = mock_response.error_code.unwrap_or(2); // UNKNOWN
365            let error_message = mock_response
366                .error_message
367                .as_deref()
368                .unwrap_or("Simulated error from MockForge");
369            return Err(Status::new(tonic::Code::from_i32(error_code), error_message));
370        }
371
372        // Create a streaming response
373        let stream = self
374            .create_server_stream(method_name, &request.into_inner(), mock_response)
375            .await?;
376        Ok(Response::new(stream))
377    }
378
379    /// Create a server streaming response
380    async fn create_server_stream(
381        &self,
382        method_name: &str,
383        _request: &Any,
384        mock_response: &MockResponse,
385    ) -> Result<ReceiverStream<Result<Any, Status>>, Status> {
386        debug!("Creating server stream for method: {}", method_name);
387
388        let (tx, rx) = mpsc::channel(10);
389        let method_name = method_name.to_string();
390        let output_type = self.get_output_type(&method_name);
391        let response_json = mock_response.response_json.clone();
392
393        // Spawn a task to generate stream messages
394        tokio::spawn(async move {
395            // Generate multiple stream messages (3-5 messages per stream)
396            let message_count = 3 + (method_name.len() % 3); // 3-5 messages based on method name
397
398            for i in 0..message_count {
399                // Create a mock response message
400                let stream_response = Self::create_stream_response_message(
401                    &method_name,
402                    &output_type,
403                    &response_json,
404                    i,
405                    message_count,
406                );
407
408                if tx.send(Ok(stream_response)).await.is_err() {
409                    debug!("Stream receiver dropped for method: {}", method_name);
410                    break; // Receiver dropped
411                }
412
413                // Add delay between messages to simulate realistic streaming
414                let delay = Duration::from_millis(100 + (i as u64 * 50)); // Progressive delay
415                tokio::time::sleep(delay).await;
416            }
417
418            info!(
419                "Completed server streaming for method: {} with {} messages",
420                method_name, message_count
421            );
422        });
423
424        Ok(ReceiverStream::new(rx))
425    }
426
427    /// Create a single stream response message
428    fn create_stream_response_message(
429        method_name: &str,
430        output_type: &str,
431        base_response: &str,
432        index: usize,
433        total: usize,
434    ) -> Any {
435        // Create a streaming-specific response by modifying the base response
436        let stream_response = if base_response.starts_with('{') && base_response.ends_with('}') {
437            // It's JSON, add streaming fields
438            let mut response = base_response.trim_end_matches('}').to_string();
439            response.push_str(&format!(
440                r#", "stream_index": {}, "stream_total": {}, "is_final": {}, "timestamp": "{}""#,
441                index,
442                total,
443                index == total - 1,
444                std::time::SystemTime::now()
445                    .duration_since(std::time::UNIX_EPOCH)
446                    .unwrap_or_default()
447                    .as_secs()
448            ));
449            response.push('}');
450            response
451        } else {
452            // It's a simple string, create a structured response
453            format!(
454                r#"{{"message": "{}", "stream_index": {}, "stream_total": {}, "is_final": {}, "method": "{}"}}"#,
455                base_response.replace('"', r#"\""#), // Escape quotes
456                index,
457                total,
458                index == total - 1,
459                method_name
460            )
461        };
462
463        Any {
464            type_url: format!("type.googleapis.com/{}", output_type),
465            value: stream_response.as_bytes().to_vec(),
466        }
467    }
468
469    /// Handle a client streaming request
470    pub async fn handle_client_streaming(
471        &self,
472        method_name: &str,
473        mut request: Request<Streaming<Any>>,
474    ) -> Result<Response<Any>, Status> {
475        debug!("Handling client streaming request for method: {}", method_name);
476
477        // Inject latency if configured
478        if let Some(ref injector) = self.latency_injector {
479            let _ = injector.inject_latency(&[]).await;
480        }
481
482        // Collect all client messages
483        let mut messages = Vec::new();
484        while let Ok(Some(message)) = request.get_mut().message().await {
485            messages.push(message);
486        }
487
488        debug!("Received {} client messages", messages.len());
489
490        // Get mock response for this method
491        let mock_response = self
492            .mock_responses
493            .get(method_name)
494            .ok_or_else(|| Status::not_found(format!("Method {} not found", method_name)))?;
495
496        // Check if we should simulate an error
497        if mock_response.simulate_error {
498            let error_code = mock_response.error_code.unwrap_or(2); // UNKNOWN
499            let error_message = mock_response
500                .error_message
501                .as_deref()
502                .unwrap_or("Simulated error from MockForge");
503            return Err(Status::new(tonic::Code::from_i32(error_code), error_message));
504        }
505
506        // Create response based on collected messages
507        let response = Any {
508            type_url: format!("type.googleapis.com/{}", self.get_output_type(method_name)),
509            value: format!(
510                r#"{{"message": "Processed {} messages from MockForge!"}}"#,
511                messages.len()
512            )
513            .as_bytes()
514            .to_vec(),
515        };
516
517        Ok(Response::new(response))
518    }
519
520    /// Handle a bidirectional streaming request
521    pub async fn handle_bidirectional_streaming(
522        &self,
523        method_name: &str,
524        request: Request<Streaming<Any>>,
525    ) -> Result<Response<ReceiverStream<Result<Any, Status>>>, Status> {
526        debug!("Handling bidirectional streaming request for method: {}", method_name);
527
528        // Inject latency if configured
529        if let Some(ref injector) = self.latency_injector {
530            let _ = injector.inject_latency(&[]).await;
531        }
532
533        // Get mock response for this method
534        let mock_response = self
535            .mock_responses
536            .get(method_name)
537            .ok_or_else(|| Status::not_found(format!("Method {} not found", method_name)))?;
538
539        // Check if we should simulate an error
540        if mock_response.simulate_error {
541            let error_code = mock_response.error_code.unwrap_or(2); // UNKNOWN
542            let error_message = mock_response
543                .error_message
544                .as_deref()
545                .unwrap_or("Simulated error from MockForge");
546            return Err(Status::new(tonic::Code::from_i32(error_code), error_message));
547        }
548
549        // Create a bidirectional streaming response
550        let stream = self.create_bidirectional_stream(method_name, request, mock_response).await?;
551        Ok(Response::new(stream))
552    }
553
554    /// Create a bidirectional streaming response
555    async fn create_bidirectional_stream(
556        &self,
557        method_name: &str,
558        mut request: Request<Streaming<Any>>,
559        mock_response: &MockResponse,
560    ) -> Result<ReceiverStream<Result<Any, Status>>, Status> {
561        debug!("Creating bidirectional stream for method: {}", method_name);
562
563        let (tx, rx) = mpsc::channel(10);
564        let method_name = method_name.to_string();
565        let output_type = self.get_output_type(&method_name);
566        let response_json = mock_response.response_json.clone();
567
568        // Spawn a task to handle bidirectional streaming
569        tokio::spawn(async move {
570            let mut input_count = 0;
571            let mut output_count = 0;
572
573            // Read from input stream and respond to each message
574            while let Ok(Some(input_message)) = request.get_mut().message().await {
575                input_count += 1;
576                debug!(
577                    "Received bidirectional input message {} for method: {}",
578                    input_count, method_name
579                );
580
581                // For each input message, generate 1-2 response messages
582                let responses_per_input = if input_count % 3 == 0 { 2 } else { 1 };
583
584                for response_idx in 0..responses_per_input {
585                    output_count += 1;
586
587                    // Create a bidirectional response message
588                    let response_message = Self::create_bidirectional_response_message(
589                        &method_name,
590                        &output_type,
591                        &response_json,
592                        &input_message,
593                        input_count,
594                        output_count,
595                        response_idx,
596                    );
597
598                    if tx.send(Ok(response_message)).await.is_err() {
599                        debug!("Bidirectional stream receiver dropped for method: {}", method_name);
600                        return;
601                    }
602
603                    // Add small delay between responses
604                    tokio::time::sleep(Duration::from_millis(50)).await;
605                }
606
607                // Limit the number of messages we process to prevent infinite loops
608                if input_count >= 100 {
609                    warn!(
610                        "Reached maximum input message limit (100) for bidirectional method: {}",
611                        method_name
612                    );
613                    break;
614                }
615            }
616
617            info!("Bidirectional streaming completed for method: {}: processed {} inputs, sent {} outputs",
618                  method_name, input_count, output_count);
619        });
620
621        Ok(ReceiverStream::new(rx))
622    }
623
624    /// Create a single bidirectional response message
625    fn create_bidirectional_response_message(
626        method_name: &str,
627        output_type: &str,
628        base_response: &str,
629        input_message: &Any,
630        input_sequence: usize,
631        output_sequence: usize,
632        response_index: usize,
633    ) -> Any {
634        // Try to extract some context from the input message
635        let input_context = if let Ok(input_str) = String::from_utf8(input_message.value.clone()) {
636            if input_str.len() < 200 {
637                // Reasonable length limit
638                input_str
639            } else {
640                format!("Large input ({} bytes)", input_message.value.len())
641            }
642        } else {
643            format!("Binary input ({} bytes)", input_message.value.len())
644        };
645
646        // Create a bidirectional response
647        let response_json = if base_response.starts_with('{') && base_response.ends_with('}') {
648            // It's JSON, add bidirectional fields
649            let mut response = base_response.trim_end_matches('}').to_string();
650            response.push_str(&format!(
651                r#", "input_sequence": {}, "output_sequence": {}, "response_index": {}, "input_context": "{}", "is_final": {}, "timestamp": "{}""#,
652                input_sequence,
653                output_sequence,
654                response_index,
655                input_context.replace('"', r#"\""#), // Escape quotes
656                response_index > 0, // Mark as final if this is the second response
657                std::time::SystemTime::now()
658                    .duration_since(std::time::UNIX_EPOCH)
659                    .unwrap_or_default()
660                    .as_secs()
661            ));
662            response.push('}');
663            response
664        } else {
665            // It's a simple string, create a structured response
666            format!(
667                r#"{{"message": "{}", "input_sequence": {}, "output_sequence": {}, "response_index": {}, "input_context": "{}", "method": "{}"}}"#,
668                base_response.replace('"', r#"\""#), // Escape quotes
669                input_sequence,
670                output_sequence,
671                response_index,
672                input_context.replace('"', r#"\""#), // Escape quotes
673                method_name
674            )
675        };
676
677        Any {
678            type_url: format!("type.googleapis.com/{}", output_type),
679            value: response_json.as_bytes().to_vec(),
680        }
681    }
682
683    /// Get the output type for a method
684    fn get_output_type(&self, method_name: &str) -> String {
685        self.service
686            .methods
687            .iter()
688            .find(|m| m.name == method_name)
689            .map(|m| m.output_type.clone())
690            .unwrap_or_else(|| "google.protobuf.Any".to_string())
691    }
692
693    /// Get the service name
694    pub fn service_name(&self) -> &str {
695        &self.service.name
696    }
697
698    /// Set a custom mock response for a method
699    pub fn set_mock_response(&mut self, method_name: &str, response: MockResponse) {
700        self.mock_responses.insert(method_name.to_string(), response);
701    }
702
703    /// Set error simulation for a method
704    pub fn set_error_simulation(
705        &mut self,
706        method_name: &str,
707        error_message: &str,
708        error_code: i32,
709    ) {
710        if let Some(mock_response) = self.mock_responses.get_mut(method_name) {
711            mock_response.simulate_error = true;
712            mock_response.error_message = Some(error_message.to_string());
713            mock_response.error_code = Some(error_code);
714        }
715    }
716
717    /// Get the service methods
718    pub fn methods(&self) -> &Vec<ProtoMethod> {
719        &self.service.methods
720    }
721
722    /// Get the service package
723    pub fn package(&self) -> &str {
724        &self.service.package
725    }
726}
727
728#[cfg(test)]
729mod tests {
730
731    #[test]
732    fn test_module_compiles() {}
733}