1use crate::dynamic::proto_parser::{ProtoMethod, ProtoParser, ProtoService};
6use crate::reflection::smart_mock_generator::{SmartMockConfig, SmartMockGenerator};
7use mockforge_core::latency::LatencyInjector;
8use prost_reflect::DescriptorPool;
9use prost_types::Any;
10use std::collections::HashMap;
11use std::sync::{Arc, Mutex};
12use std::time::Duration;
13use tokio::sync::mpsc;
14use tokio_stream::wrappers::ReceiverStream;
15use tonic::{Request, Response, Status, Streaming};
16use tracing::{debug, info, warn};
17
18pub struct EnhancedServiceFactory;
20
21impl EnhancedServiceFactory {
22 pub async fn create_services_from_proto_dir(
24 proto_dir: &str,
25 latency_injector: Option<LatencyInjector>,
26 smart_config: SmartMockConfig,
27 ) -> Result<Vec<DynamicGrpcService>, Box<dyn std::error::Error + Send + Sync>> {
28 info!("Creating enhanced services from proto directory: {}", proto_dir);
29
30 let mut parser = ProtoParser::new();
32 parser.parse_directory(proto_dir).await?;
33
34 let mut services = Vec::new();
35
36 let services_info: Vec<(String, ProtoService)> = parser
38 .services()
39 .iter()
40 .map(|(name, service)| (name.clone(), service.clone()))
41 .collect();
42
43 for (service_name, proto_service) in services_info {
45 debug!("Creating enhanced service: {}", service_name);
46
47 let mut service_parser = ProtoParser::new();
49 let _ = service_parser.parse_directory(proto_dir).await; let service = DynamicGrpcService::new_enhanced(
52 proto_service,
53 latency_injector.clone(),
54 Some(service_parser),
55 smart_config.clone(),
56 );
57
58 services.push(service);
59 }
60
61 info!("Created {} enhanced services", services.len());
62 Ok(services)
63 }
64
65 pub fn create_service_from_proto(
67 proto_service: ProtoService,
68 latency_injector: Option<LatencyInjector>,
69 proto_parser: Option<ProtoParser>,
70 smart_config: SmartMockConfig,
71 ) -> DynamicGrpcService {
72 if proto_parser.is_some() {
73 info!("Creating enhanced service: {}", proto_service.name);
74 DynamicGrpcService::new_enhanced(
75 proto_service,
76 latency_injector,
77 proto_parser,
78 smart_config,
79 )
80 } else {
81 info!("Creating basic service: {}", proto_service.name);
82 DynamicGrpcService::new(proto_service, latency_injector)
83 }
84 }
85}
86
87pub struct DynamicGrpcService {
89 service: ProtoService,
91 latency_injector: Option<LatencyInjector>,
93 mock_responses: HashMap<String, MockResponse>,
95 proto_parser: Option<ProtoParser>,
97 smart_generator: Arc<Mutex<SmartMockGenerator>>,
99}
100
101#[derive(Debug, Clone)]
103pub struct MockResponse {
104 pub response_json: String,
106 pub simulate_error: bool,
108 pub error_message: Option<String>,
110 pub error_code: Option<i32>,
112}
113
114impl DynamicGrpcService {
115 pub fn new(service: ProtoService, latency_injector: Option<LatencyInjector>) -> Self {
117 let mut mock_responses = HashMap::new();
118
119 for method in &service.methods {
121 let response = Self::generate_mock_response(&method.name, &method.output_type);
122 mock_responses.insert(method.name.clone(), response);
123 }
124
125 Self {
126 service,
127 latency_injector,
128 mock_responses,
129 proto_parser: None,
130 smart_generator: Arc::new(Mutex::new(SmartMockGenerator::new(
131 SmartMockConfig::default(),
132 ))),
133 }
134 }
135
136 pub fn new_enhanced(
138 service: ProtoService,
139 latency_injector: Option<LatencyInjector>,
140 proto_parser: Option<ProtoParser>,
141 smart_config: SmartMockConfig,
142 ) -> Self {
143 let mut mock_responses = HashMap::new();
144 let smart_generator = Arc::new(Mutex::new(SmartMockGenerator::new(smart_config)));
145
146 for method in &service.methods {
148 let response = if proto_parser.is_some() {
149 Self::generate_enhanced_mock_response(
150 &method.name,
151 &method.output_type,
152 &service.name,
153 &smart_generator,
154 )
155 } else {
156 Self::generate_mock_response(&method.name, &method.output_type)
157 };
158 mock_responses.insert(method.name.clone(), response);
159 }
160
161 Self {
162 service,
163 latency_injector,
164 mock_responses,
165 proto_parser,
166 smart_generator,
167 }
168 }
169
170 fn generate_mock_response(method_name: &str, output_type: &str) -> MockResponse {
172 let response_json = match method_name {
174 "SayHello" | "SayHelloStream" | "SayHelloClientStream" | "Chat" => {
175 r#"{"message": "Hello from MockForge!"}"#.to_string()
176 }
177 _ => {
178 format!(
180 r#"{{"result": "Mock response for {}", "type": "{}"}}"#,
181 method_name, output_type
182 )
183 }
184 };
185
186 MockResponse {
187 response_json,
188 simulate_error: false,
189 error_message: None,
190 error_code: None,
191 }
192 }
193
194 fn generate_enhanced_mock_response(
196 method_name: &str,
197 output_type: &str,
198 service_name: &str,
199 smart_generator: &Arc<Mutex<SmartMockGenerator>>,
200 ) -> MockResponse {
201 debug!("Generating enhanced mock response for {}.{}", service_name, method_name);
202
203 let response_json = if let Ok(mut generator) = smart_generator.lock() {
205 let mut fields = HashMap::new();
207
208 match method_name.to_lowercase().as_str() {
210 name if name.contains("hello") || name.contains("greet") => {
211 fields.insert("message".to_string(), "greeting".to_string());
212 fields.insert("name".to_string(), "user_name".to_string());
213 fields.insert("timestamp".to_string(), "timestamp".to_string());
214 }
215 name if name.contains("list") || name.contains("get") => {
216 fields.insert("id".to_string(), "identifier".to_string());
217 fields.insert("data".to_string(), "response_data".to_string());
218 fields.insert("count".to_string(), "total_count".to_string());
219 }
220 name if name.contains("create") || name.contains("add") => {
221 fields.insert("id".to_string(), "new_id".to_string());
222 fields.insert("status".to_string(), "status".to_string());
223 fields.insert("message".to_string(), "success_message".to_string());
224 }
225 name if name.contains("update") || name.contains("modify") => {
226 fields.insert("updated".to_string(), "updated_fields".to_string());
227 fields.insert("version".to_string(), "version_number".to_string());
228 fields.insert("status".to_string(), "status".to_string());
229 }
230 name if name.contains("delete") || name.contains("remove") => {
231 fields.insert("deleted".to_string(), "deleted_status".to_string());
232 fields.insert("message".to_string(), "confirmation_message".to_string());
233 }
234 _ => {
235 fields.insert("result".to_string(), "result_data".to_string());
237 fields.insert("status".to_string(), "status".to_string());
238 fields.insert("message".to_string(), "response_message".to_string());
239 }
240 }
241
242 let mut json_parts = Vec::new();
244 for (field_name, field_type) in fields {
245 let mock_value = match field_type.as_str() {
246 "greeting" => {
247 format!("\"Hello from enhanced MockForge service {}!\"", service_name)
248 }
249 "user_name" => "\"MockForge User\"".to_string(),
250 "timestamp" => format!(
251 "\"{}\"",
252 std::time::SystemTime::now()
253 .duration_since(std::time::UNIX_EPOCH)
254 .unwrap_or_default()
255 .as_secs()
256 ),
257 "identifier" | "new_id" => format!("{}", generator.next_sequence()),
258 "total_count" => "42".to_string(),
259 "status" => "\"success\"".to_string(),
260 "success_message" => {
261 format!("\"Successfully processed {} request\"", method_name)
262 }
263 "confirmation_message" => {
264 format!("\"Operation {} completed successfully\"", method_name)
265 }
266 "version_number" => "\"1.0.0\"".to_string(),
267 "updated_status" | "deleted_status" => "true".to_string(),
268 _ => format!("\"Enhanced mock data for {}\"", field_type),
269 };
270 json_parts.push(format!("\"{}\": {}", field_name, mock_value));
271 }
272
273 format!("{{{}}}", json_parts.join(", "))
274 } else {
275 format!(
277 r#"{{"result": "Enhanced mock response for {}", "type": "{}"}}"#,
278 method_name, output_type
279 )
280 };
281
282 MockResponse {
283 response_json,
284 simulate_error: false,
285 error_message: None,
286 error_code: None,
287 }
288 }
289
290 pub fn descriptor_pool(&self) -> Option<&DescriptorPool> {
292 self.proto_parser.as_ref().map(|parser| parser.pool())
293 }
294
295 pub fn smart_generator(&self) -> &Arc<Mutex<SmartMockGenerator>> {
297 &self.smart_generator
298 }
299
300 pub fn service(&self) -> &ProtoService {
302 &self.service
303 }
304
305 pub async fn handle_unary(
307 &self,
308 method_name: &str,
309 _request: Request<Any>,
310 ) -> Result<Response<Any>, Status> {
311 debug!("Handling unary request for method: {}", method_name);
312
313 if let Some(ref injector) = self.latency_injector {
315 let _ = injector.inject_latency(&[]).await;
316 }
317
318 let mock_response = self
320 .mock_responses
321 .get(method_name)
322 .ok_or_else(|| Status::not_found(format!("Method {} not found", method_name)))?;
323
324 if mock_response.simulate_error {
326 let error_code = mock_response.error_code.unwrap_or(2); let error_message = mock_response
328 .error_message
329 .as_deref()
330 .unwrap_or("Simulated error from MockForge");
331 return Err(Status::new(tonic::Code::from_i32(error_code), error_message));
332 }
333
334 let response = Any {
336 type_url: format!("type.googleapis.com/{}", self.get_output_type(method_name)),
337 value: mock_response.response_json.as_bytes().to_vec(),
338 };
339
340 Ok(Response::new(response))
341 }
342
343 pub async fn handle_server_streaming(
345 &self,
346 method_name: &str,
347 request: Request<Any>,
348 ) -> Result<Response<ReceiverStream<Result<Any, Status>>>, Status> {
349 debug!("Handling server streaming request for method: {}", method_name);
350
351 if let Some(ref injector) = self.latency_injector {
353 let _ = injector.inject_latency(&[]).await;
354 }
355
356 let mock_response = self
358 .mock_responses
359 .get(method_name)
360 .ok_or_else(|| Status::not_found(format!("Method {} not found", method_name)))?;
361
362 if mock_response.simulate_error {
364 let error_code = mock_response.error_code.unwrap_or(2); let error_message = mock_response
366 .error_message
367 .as_deref()
368 .unwrap_or("Simulated error from MockForge");
369 return Err(Status::new(tonic::Code::from_i32(error_code), error_message));
370 }
371
372 let stream = self
374 .create_server_stream(method_name, &request.into_inner(), mock_response)
375 .await?;
376 Ok(Response::new(stream))
377 }
378
379 async fn create_server_stream(
381 &self,
382 method_name: &str,
383 _request: &Any,
384 mock_response: &MockResponse,
385 ) -> Result<ReceiverStream<Result<Any, Status>>, Status> {
386 debug!("Creating server stream for method: {}", method_name);
387
388 let (tx, rx) = mpsc::channel(10);
389 let method_name = method_name.to_string();
390 let output_type = self.get_output_type(&method_name);
391 let response_json = mock_response.response_json.clone();
392
393 tokio::spawn(async move {
395 let message_count = 3 + (method_name.len() % 3); for i in 0..message_count {
399 let stream_response = Self::create_stream_response_message(
401 &method_name,
402 &output_type,
403 &response_json,
404 i,
405 message_count,
406 );
407
408 if tx.send(Ok(stream_response)).await.is_err() {
409 debug!("Stream receiver dropped for method: {}", method_name);
410 break; }
412
413 let delay = Duration::from_millis(100 + (i as u64 * 50)); tokio::time::sleep(delay).await;
416 }
417
418 info!(
419 "Completed server streaming for method: {} with {} messages",
420 method_name, message_count
421 );
422 });
423
424 Ok(ReceiverStream::new(rx))
425 }
426
427 fn create_stream_response_message(
429 method_name: &str,
430 output_type: &str,
431 base_response: &str,
432 index: usize,
433 total: usize,
434 ) -> Any {
435 let stream_response = if base_response.starts_with('{') && base_response.ends_with('}') {
437 let mut response = base_response.trim_end_matches('}').to_string();
439 response.push_str(&format!(
440 r#", "stream_index": {}, "stream_total": {}, "is_final": {}, "timestamp": "{}""#,
441 index,
442 total,
443 index == total - 1,
444 std::time::SystemTime::now()
445 .duration_since(std::time::UNIX_EPOCH)
446 .unwrap_or_default()
447 .as_secs()
448 ));
449 response.push('}');
450 response
451 } else {
452 format!(
454 r#"{{"message": "{}", "stream_index": {}, "stream_total": {}, "is_final": {}, "method": "{}"}}"#,
455 base_response.replace('"', r#"\""#), index,
457 total,
458 index == total - 1,
459 method_name
460 )
461 };
462
463 Any {
464 type_url: format!("type.googleapis.com/{}", output_type),
465 value: stream_response.as_bytes().to_vec(),
466 }
467 }
468
469 pub async fn handle_client_streaming(
471 &self,
472 method_name: &str,
473 mut request: Request<Streaming<Any>>,
474 ) -> Result<Response<Any>, Status> {
475 debug!("Handling client streaming request for method: {}", method_name);
476
477 if let Some(ref injector) = self.latency_injector {
479 let _ = injector.inject_latency(&[]).await;
480 }
481
482 let mut messages = Vec::new();
484 while let Ok(Some(message)) = request.get_mut().message().await {
485 messages.push(message);
486 }
487
488 debug!("Received {} client messages", messages.len());
489
490 let mock_response = self
492 .mock_responses
493 .get(method_name)
494 .ok_or_else(|| Status::not_found(format!("Method {} not found", method_name)))?;
495
496 if mock_response.simulate_error {
498 let error_code = mock_response.error_code.unwrap_or(2); let error_message = mock_response
500 .error_message
501 .as_deref()
502 .unwrap_or("Simulated error from MockForge");
503 return Err(Status::new(tonic::Code::from_i32(error_code), error_message));
504 }
505
506 let response = Any {
508 type_url: format!("type.googleapis.com/{}", self.get_output_type(method_name)),
509 value: format!(
510 r#"{{"message": "Processed {} messages from MockForge!"}}"#,
511 messages.len()
512 )
513 .as_bytes()
514 .to_vec(),
515 };
516
517 Ok(Response::new(response))
518 }
519
520 pub async fn handle_bidirectional_streaming(
522 &self,
523 method_name: &str,
524 request: Request<Streaming<Any>>,
525 ) -> Result<Response<ReceiverStream<Result<Any, Status>>>, Status> {
526 debug!("Handling bidirectional streaming request for method: {}", method_name);
527
528 if let Some(ref injector) = self.latency_injector {
530 let _ = injector.inject_latency(&[]).await;
531 }
532
533 let mock_response = self
535 .mock_responses
536 .get(method_name)
537 .ok_or_else(|| Status::not_found(format!("Method {} not found", method_name)))?;
538
539 if mock_response.simulate_error {
541 let error_code = mock_response.error_code.unwrap_or(2); let error_message = mock_response
543 .error_message
544 .as_deref()
545 .unwrap_or("Simulated error from MockForge");
546 return Err(Status::new(tonic::Code::from_i32(error_code), error_message));
547 }
548
549 let stream = self.create_bidirectional_stream(method_name, request, mock_response).await?;
551 Ok(Response::new(stream))
552 }
553
554 async fn create_bidirectional_stream(
556 &self,
557 method_name: &str,
558 mut request: Request<Streaming<Any>>,
559 mock_response: &MockResponse,
560 ) -> Result<ReceiverStream<Result<Any, Status>>, Status> {
561 debug!("Creating bidirectional stream for method: {}", method_name);
562
563 let (tx, rx) = mpsc::channel(10);
564 let method_name = method_name.to_string();
565 let output_type = self.get_output_type(&method_name);
566 let response_json = mock_response.response_json.clone();
567
568 tokio::spawn(async move {
570 let mut input_count = 0;
571 let mut output_count = 0;
572
573 while let Ok(Some(input_message)) = request.get_mut().message().await {
575 input_count += 1;
576 debug!(
577 "Received bidirectional input message {} for method: {}",
578 input_count, method_name
579 );
580
581 let responses_per_input = if input_count % 3 == 0 { 2 } else { 1 };
583
584 for response_idx in 0..responses_per_input {
585 output_count += 1;
586
587 let response_message = Self::create_bidirectional_response_message(
589 &method_name,
590 &output_type,
591 &response_json,
592 &input_message,
593 input_count,
594 output_count,
595 response_idx,
596 );
597
598 if tx.send(Ok(response_message)).await.is_err() {
599 debug!("Bidirectional stream receiver dropped for method: {}", method_name);
600 return;
601 }
602
603 tokio::time::sleep(Duration::from_millis(50)).await;
605 }
606
607 if input_count >= 100 {
609 warn!(
610 "Reached maximum input message limit (100) for bidirectional method: {}",
611 method_name
612 );
613 break;
614 }
615 }
616
617 info!("Bidirectional streaming completed for method: {}: processed {} inputs, sent {} outputs",
618 method_name, input_count, output_count);
619 });
620
621 Ok(ReceiverStream::new(rx))
622 }
623
624 fn create_bidirectional_response_message(
626 method_name: &str,
627 output_type: &str,
628 base_response: &str,
629 input_message: &Any,
630 input_sequence: usize,
631 output_sequence: usize,
632 response_index: usize,
633 ) -> Any {
634 let input_context = if let Ok(input_str) = String::from_utf8(input_message.value.clone()) {
636 if input_str.len() < 200 {
637 input_str
639 } else {
640 format!("Large input ({} bytes)", input_message.value.len())
641 }
642 } else {
643 format!("Binary input ({} bytes)", input_message.value.len())
644 };
645
646 let response_json = if base_response.starts_with('{') && base_response.ends_with('}') {
648 let mut response = base_response.trim_end_matches('}').to_string();
650 response.push_str(&format!(
651 r#", "input_sequence": {}, "output_sequence": {}, "response_index": {}, "input_context": "{}", "is_final": {}, "timestamp": "{}""#,
652 input_sequence,
653 output_sequence,
654 response_index,
655 input_context.replace('"', r#"\""#), response_index > 0, std::time::SystemTime::now()
658 .duration_since(std::time::UNIX_EPOCH)
659 .unwrap_or_default()
660 .as_secs()
661 ));
662 response.push('}');
663 response
664 } else {
665 format!(
667 r#"{{"message": "{}", "input_sequence": {}, "output_sequence": {}, "response_index": {}, "input_context": "{}", "method": "{}"}}"#,
668 base_response.replace('"', r#"\""#), input_sequence,
670 output_sequence,
671 response_index,
672 input_context.replace('"', r#"\""#), method_name
674 )
675 };
676
677 Any {
678 type_url: format!("type.googleapis.com/{}", output_type),
679 value: response_json.as_bytes().to_vec(),
680 }
681 }
682
683 fn get_output_type(&self, method_name: &str) -> String {
685 self.service
686 .methods
687 .iter()
688 .find(|m| m.name == method_name)
689 .map(|m| m.output_type.clone())
690 .unwrap_or_else(|| "google.protobuf.Any".to_string())
691 }
692
693 pub fn service_name(&self) -> &str {
695 &self.service.name
696 }
697
698 pub fn set_mock_response(&mut self, method_name: &str, response: MockResponse) {
700 self.mock_responses.insert(method_name.to_string(), response);
701 }
702
703 pub fn set_error_simulation(
705 &mut self,
706 method_name: &str,
707 error_message: &str,
708 error_code: i32,
709 ) {
710 if let Some(mock_response) = self.mock_responses.get_mut(method_name) {
711 mock_response.simulate_error = true;
712 mock_response.error_message = Some(error_message.to_string());
713 mock_response.error_code = Some(error_code);
714 }
715 }
716
717 pub fn methods(&self) -> &Vec<ProtoMethod> {
719 &self.service.methods
720 }
721
722 pub fn package(&self) -> &str {
724 &self.service.package
725 }
726}
727
728#[cfg(test)]
729mod tests {
730
731 #[test]
732 fn test_module_compiles() {}
733}