1use super::{BridgeResponse, HttpBridgeConfig};
7use axum::response::{IntoResponse, Sse};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::mpsc;
13use tokio_stream::{wrappers::ReceiverStream, StreamExt};
14use tonic::Request;
15use tracing::warn;
16
17pub struct StreamHandler;
19
20#[derive(Debug, Serialize, Deserialize)]
22pub struct StreamingMessage {
23 pub event_type: String,
25 pub data: Value,
27 pub metadata: std::collections::HashMap<String, String>,
29}
30
31impl StreamHandler {
32 pub async fn create_sse_stream(
34 _config: HttpBridgeConfig,
35 service_name: String,
36 method_name: String,
37 ) -> impl IntoResponse {
38 let (tx, rx) = mpsc::channel::<Result<axum::response::sse::Event, axum::BoxError>>(32);
39
40 tokio::spawn(async move {
42 let init_msg = StreamingMessage {
44 event_type: "stream_init".to_string(),
45 data: serde_json::json!({
46 "service": service_name,
47 "method": method_name,
48 "message": "Stream initialized for bidirectional communication"
49 }),
50 metadata: std::collections::HashMap::new(),
51 };
52
53 if let Ok(json_str) = serde_json::to_string(&init_msg) {
54 let _ = tx
55 .send(Ok(axum::response::sse::Event::default().event("message").data(json_str)))
56 .await;
57 }
58
59 let mut counter = 0;
61 while counter < 10 {
62 tokio::time::sleep(Duration::from_millis(500)).await;
63
64 let stream_msg = StreamingMessage {
65 event_type: "data".to_string(),
66 data: serde_json::json!({
67 "counter": counter,
68 "message": format!("Streaming message #{}", counter),
69 "timestamp": chrono::Utc::now().to_rfc3339()
70 }),
71 metadata: vec![("sequence".to_string(), counter.to_string())]
72 .into_iter()
73 .collect(),
74 };
75
76 if let Ok(json_str) = serde_json::to_string(&stream_msg) {
77 let event_type = if counter % 3 == 0 {
78 "heartbeat"
79 } else {
80 "data"
81 };
82 let _ = tx
83 .send(Ok(axum::response::sse::Event::default()
84 .event(event_type)
85 .data(json_str)))
86 .await;
87 }
88
89 counter += 1;
90
91 if counter == 7 {
93 let error_msg = StreamingMessage {
94 event_type: "error".to_string(),
95 data: serde_json::json!({
96 "error": "Simulated network error",
97 "code": "NETWORK_ERROR"
98 }),
99 metadata: vec![("error_code".to_string(), "123".to_string())]
100 .into_iter()
101 .collect(),
102 };
103
104 if let Ok(json_str) = serde_json::to_string(&error_msg) {
105 let _ = tx
106 .send(Ok(axum::response::sse::Event::default()
107 .event("error")
108 .data(json_str)))
109 .await;
110 }
111 }
112 }
113
114 let complete_msg = StreamingMessage {
116 event_type: "stream_complete".to_string(),
117 data: serde_json::json!({
118 "message": "Streaming session completed",
119 "total_messages": counter
120 }),
121 metadata: vec![("session_id".to_string(), "demo-123".to_string())]
122 .into_iter()
123 .collect(),
124 };
125
126 if let Ok(json_str) = serde_json::to_string(&complete_msg) {
127 let _ = tx
128 .send(Ok(axum::response::sse::Event::default()
129 .event("complete")
130 .data(json_str)))
131 .await;
132 }
133 });
134
135 let stream = ReceiverStream::new(rx).map(|result: Result<axum::response::sse::Event, axum::BoxError>| -> Result<axum::response::sse::Event, axum::BoxError> {
136 match result {
137 Ok(event) => Ok(event),
138 Err(e) => Ok(axum::response::sse::Event::default()
139 .event("error")
140 .data(format!("Stream error: {}", e))),
141 }
142 });
143
144 Sse::new(stream).keep_alive(
145 axum::response::sse::KeepAlive::new()
146 .interval(Duration::from_secs(30))
147 .text("keep-alive"),
148 )
149 }
150
151 pub async fn create_grpc_stream_stream(
153 proxy: Arc<super::MockReflectionProxy>,
154 service_name: &str,
155 method_name: &str,
156 initial_request: Value,
157 ) -> impl IntoResponse {
158 let (tx, rx) = mpsc::channel(32);
159
160 let service_name = service_name.to_string();
162 let method_name = method_name.to_string();
163
164 let result = Self::handle_grpc_bidirectional_streaming(
165 proxy,
166 &service_name,
167 &method_name,
168 initial_request,
169 tx.clone(),
170 )
171 .await;
172
173 tokio::spawn(async move {
174 match result {
175 Ok(_) => {
176 let _ = tx
177 .send(Ok(axum::response::sse::Event::default()
178 .event("complete")
179 .data("Stream completed successfully")))
180 .await;
181 }
182 Err(e) => {
183 let _ = tx
184 .send(Ok(axum::response::sse::Event::default()
185 .event("error")
186 .data(format!("Stream error: {}", e))))
187 .await;
188 }
189 }
190 });
191
192 let stream = ReceiverStream::new(rx).map(|result: Result<axum::response::sse::Event, axum::BoxError>| -> Result<axum::response::sse::Event, axum::BoxError> {
193 match result {
194 Ok(event) => Ok(event),
195 Err(e) => Ok(axum::response::sse::Event::default()
196 .event("error")
197 .data(format!("Stream error: {}", e))),
198 }
199 });
200
201 Sse::new(stream).keep_alive(
202 axum::response::sse::KeepAlive::new()
203 .interval(Duration::from_secs(30))
204 .text("keep-alive"),
205 )
206 }
207
208 async fn handle_grpc_bidirectional_streaming(
210 proxy: Arc<super::MockReflectionProxy>,
211 service_name: &str,
212 method_name: &str,
213 initial_request: Value,
214 tx: mpsc::Sender<Result<axum::response::sse::Event, axum::BoxError>>,
215 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
216 let registry = proxy.service_registry();
218 let service = registry
219 .get(service_name)
220 .ok_or_else(|| format!("Service '{}' not found", service_name))?;
221
222 let method_info = service
223 .service()
224 .methods
225 .iter()
226 .find(|m| m.name == method_name)
227 .ok_or_else(|| {
228 format!("Method '{}' not found in service '{}'", method_name, service_name)
229 })?;
230 let input_descriptor = registry
231 .descriptor_pool()
232 .get_message_by_name(&method_info.input_type)
233 .ok_or_else(|| format!("Input type '{}' not found", method_info.input_type))?;
234 let _output_descriptor = registry
235 .descriptor_pool()
236 .get_message_by_name(&method_info.output_type)
237 .ok_or_else(|| format!("Output type '{}' not found", method_info.output_type))?;
238
239 let converter =
241 super::converters::ProtobufJsonConverter::new(registry.descriptor_pool().clone());
242
243 let client_messages: Vec<Value> = if let Some(arr) = initial_request.as_array() {
245 arr.clone()
246 } else {
247 vec![initial_request]
248 };
249
250 let mut dynamic_messages = Vec::new();
252 for (i, json_msg) in client_messages.iter().enumerate() {
253 match converter.json_to_protobuf(&input_descriptor, json_msg) {
254 Ok(dynamic_msg) => dynamic_messages.push(dynamic_msg),
255 Err(e) => {
256 warn!("Failed to convert client message {} to protobuf: {}", i, e);
257 let error_msg = StreamingMessage {
259 event_type: "conversion_error".to_string(),
260 data: serde_json::json!({
261 "message": format!("Failed to convert client message {}: {}", i, e),
262 "sequence": i
263 }),
264 metadata: vec![
265 ("error_type".to_string(), "conversion".to_string()),
266 ("sequence".to_string(), i.to_string()),
267 ]
268 .into_iter()
269 .collect(),
270 };
271 if let Ok(json_str) = serde_json::to_string(&error_msg) {
272 let _ = tx
273 .send(Ok(axum::response::sse::Event::default()
274 .event("error")
275 .data(json_str)))
276 .await;
277 }
278 continue;
279 }
280 }
281 }
282
283 if dynamic_messages.is_empty() {
284 return Err("No valid client messages to send".into());
285 }
286
287 let start_msg = StreamingMessage {
289 event_type: "bidirectional_stream_start".to_string(),
290 data: serde_json::json!({
291 "service": service_name,
292 "method": method_name,
293 "client_messages_count": dynamic_messages.len()
294 }),
295 metadata: vec![
296 ("stream_type".to_string(), "bidirectional".to_string()),
297 ("protocol".to_string(), "grpc-web-over-sse".to_string()),
298 ]
299 .into_iter()
300 .collect(),
301 };
302
303 if let Ok(json_str) = serde_json::to_string(&start_msg) {
304 let _ = tx
305 .send(Ok(axum::response::sse::Event::default()
306 .event("stream_start")
307 .data(json_str)))
308 .await;
309 }
310
311 let (client_tx, client_rx) =
313 mpsc::channel::<Result<prost_reflect::DynamicMessage, tonic::Status>>(10);
314
315 let _request = Request::new(ReceiverStream::new(client_rx));
317
318 let client_tx_clone = client_tx.clone();
320 tokio::spawn(async move {
321 for (i, dynamic_msg) in dynamic_messages.into_iter().enumerate() {
322 if client_tx_clone.send(Ok(dynamic_msg)).await.is_err() {
323 warn!("Failed to send client message {} to gRPC stream", i);
324 break;
325 }
326 }
327 drop(client_tx_clone);
329 });
330
331 let method_descriptor = proxy.cache().get_method(service_name, method_name).await?;
333
334 let smart_generator = proxy.smart_generator().clone();
337 let output_descriptor = method_descriptor.output();
338
339 let mock_response = {
341 match smart_generator.lock() {
342 Ok(mut gen) => gen.generate_message(&output_descriptor),
343 Err(e) => {
344 let error_msg = StreamingMessage {
345 event_type: "error".to_string(),
346 data: serde_json::json!({
347 "message": format!("Failed to acquire smart generator lock: {}", e)
348 }),
349 metadata: vec![("error_type".to_string(), "lock".to_string())]
350 .into_iter()
351 .collect(),
352 };
353 if let Ok(json_str) = serde_json::to_string(&error_msg) {
354 let _ = tx
355 .send(Ok(axum::response::sse::Event::default()
356 .event("error")
357 .data(json_str)))
358 .await;
359 }
360 return Ok(());
361 }
362 }
363 };
364
365 match converter.protobuf_to_json(&output_descriptor, &mock_response) {
367 Ok(json_response) => {
368 let response_msg = StreamingMessage {
369 event_type: "grpc_response".to_string(),
370 data: json_response,
371 metadata: vec![
372 ("sequence".to_string(), "1".to_string()),
373 ("message_type".to_string(), "response".to_string()),
374 ]
375 .into_iter()
376 .collect(),
377 };
378
379 if let Ok(json_str) = serde_json::to_string(&response_msg) {
380 let _ = tx
381 .send(Ok(axum::response::sse::Event::default()
382 .event("grpc_response")
383 .data(json_str)))
384 .await;
385 }
386 }
387 Err(e) => {
388 let error_msg = StreamingMessage {
389 event_type: "conversion_error".to_string(),
390 data: serde_json::json!({
391 "message": format!("Failed to convert response to JSON: {}", e)
392 }),
393 metadata: vec![("error_type".to_string(), "conversion".to_string())]
394 .into_iter()
395 .collect(),
396 };
397 if let Ok(json_str) = serde_json::to_string(&error_msg) {
398 let _ = tx
399 .send(Ok(axum::response::sse::Event::default()
400 .event("error")
401 .data(json_str)))
402 .await;
403 }
404 }
405 }
406
407 let end_msg = StreamingMessage {
409 event_type: "bidirectional_stream_end".to_string(),
410 data: serde_json::json!({
411 "message": "Bidirectional streaming session completed",
412 "statistics": {
413 "responses_sent": 1
414 }
415 }),
416 metadata: vec![("session_status".to_string(), "completed".to_string())]
417 .into_iter()
418 .collect(),
419 };
420
421 if let Ok(json_str) = serde_json::to_string(&end_msg) {
422 let _ = tx
423 .send(Ok(axum::response::sse::Event::default().event("stream_end").data(json_str)))
424 .await;
425 }
426
427 Ok(())
428 }
429}
430
431pub struct ErrorHandler;
433
434impl ErrorHandler {
435 pub fn error_to_status_code(error: &str) -> http::StatusCode {
437 if error.contains("not found") || error.contains("Unknown") {
438 http::StatusCode::NOT_FOUND
439 } else if error.contains("unauthorized") || error.contains("forbidden") {
440 http::StatusCode::FORBIDDEN
441 } else if error.contains("invalid") || error.contains("malformed") {
442 http::StatusCode::BAD_REQUEST
443 } else {
444 http::StatusCode::INTERNAL_SERVER_ERROR
445 }
446 }
447
448 pub fn create_error_response(error: String) -> BridgeResponse<Value> {
450 BridgeResponse {
451 success: false,
452 data: None,
453 error: Some(error),
454 metadata: std::collections::HashMap::new(),
455 }
456 }
457}
458
459pub struct RequestProcessor;
461
462impl RequestProcessor {
463 pub fn validate_request(
465 service_name: &str,
466 method_name: &str,
467 body_size: usize,
468 max_body_size: usize,
469 ) -> Result<(), String> {
470 if service_name.trim().is_empty() {
471 return Err("Service name cannot be empty".to_string());
472 }
473
474 if method_name.trim().is_empty() {
475 return Err("Method name cannot be empty".to_string());
476 }
477
478 if body_size > max_body_size {
479 return Err(format!(
480 "Request body too large: {} bytes (max: {} bytes)",
481 body_size, max_body_size
482 ));
483 }
484
485 Ok(())
487 }
488
489 pub fn extract_metadata_from_headers(
491 headers: &http::HeaderMap,
492 ) -> std::collections::HashMap<String, String> {
493 let mut metadata = std::collections::HashMap::new();
494
495 for (key, value) in headers.iter() {
496 let key_str = key.as_str();
497 if !key_str.starts_with("host")
499 && !key_str.starts_with("content-type")
500 && !key_str.starts_with("content-length")
501 && !key_str.starts_with("user-agent")
502 && !key_str.starts_with("accept")
503 && !key_str.starts_with("authorization")
504 {
505 if let Ok(value_str) = value.to_str() {
506 metadata.insert(key_str.to_string(), value_str.to_string());
507 }
508 }
509 }
510
511 metadata
512 }
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518 use axum::http::HeaderMap;
519
520 #[test]
521 fn test_error_to_status_code() {
522 assert_eq!(
523 ErrorHandler::error_to_status_code("service not found"),
524 http::StatusCode::NOT_FOUND
525 );
526 assert_eq!(ErrorHandler::error_to_status_code("unauthorized"), http::StatusCode::FORBIDDEN);
527 assert_eq!(
528 ErrorHandler::error_to_status_code("invalid request"),
529 http::StatusCode::BAD_REQUEST
530 );
531 assert_eq!(
532 ErrorHandler::error_to_status_code("internal error"),
533 http::StatusCode::INTERNAL_SERVER_ERROR
534 );
535
536 assert_eq!(
538 ErrorHandler::error_to_status_code("Unknown service"),
539 http::StatusCode::NOT_FOUND
540 );
541 assert_eq!(
542 ErrorHandler::error_to_status_code("forbidden access"),
543 http::StatusCode::FORBIDDEN
544 );
545 assert_eq!(
546 ErrorHandler::error_to_status_code("malformed JSON"),
547 http::StatusCode::BAD_REQUEST
548 );
549 assert_eq!(
550 ErrorHandler::error_to_status_code("random error"),
551 http::StatusCode::INTERNAL_SERVER_ERROR
552 );
553 }
554
555 #[test]
556 fn test_validate_request() {
557 assert!(RequestProcessor::validate_request("test", "method", 100, 1000).is_ok());
558 assert!(RequestProcessor::validate_request("", "method", 100, 1000).is_err());
559 assert!(RequestProcessor::validate_request("test", "", 100, 1000).is_err());
560 assert!(RequestProcessor::validate_request("test", "method", 2000, 1000).is_err());
561
562 assert!(
564 RequestProcessor::validate_request("valid_service", "valid_method", 0, 1000).is_ok()
565 );
566 assert!(RequestProcessor::validate_request("test", "method", 1000, 1000).is_ok());
567 assert!(RequestProcessor::validate_request("test", "method", 1001, 1000).is_err());
568
569 let long_name = "a".repeat(1000);
571 assert!(RequestProcessor::validate_request(&long_name, &long_name, 100, 1000).is_ok());
572 }
573
574 #[test]
575 fn test_extract_metadata_from_headers() {
576 let mut headers = HeaderMap::new();
577
578 headers.insert("content-type", "application/json".parse().unwrap());
580 headers.insert("authorization", "Bearer token123".parse().unwrap());
581 headers.insert("x-custom-header", "custom-value".parse().unwrap());
582 headers.insert("x-api-key", "key123".parse().unwrap());
583 headers.insert("user-agent", "test-agent".parse().unwrap());
584
585 let metadata = RequestProcessor::extract_metadata_from_headers(&headers);
586
587 assert!(!metadata.contains_key("content-type"));
589 assert!(!metadata.contains_key("authorization")); assert!(!metadata.contains_key("user-agent"));
591
592 assert_eq!(metadata.get("x-custom-header"), Some(&"custom-value".to_string()));
594 assert_eq!(metadata.get("x-api-key"), Some(&"key123".to_string()));
595
596 let empty_headers = HeaderMap::new();
598 let empty_metadata = RequestProcessor::extract_metadata_from_headers(&empty_headers);
599 assert!(empty_metadata.is_empty());
600
601 let mut case_headers = HeaderMap::new();
603 case_headers.insert("X-CUSTOM-HEADER", "value".parse().unwrap());
604 let case_metadata = RequestProcessor::extract_metadata_from_headers(&case_headers);
605 assert_eq!(case_metadata.get("x-custom-header"), Some(&"value".to_string()));
606 }
607
608 #[test]
609 fn test_create_error_response() {
610 let error_message = "Test error message";
611 let response = ErrorHandler::create_error_response(error_message.to_string());
612
613 assert!(!response.success);
614 assert!(response.data.is_none());
615 assert_eq!(response.error, Some(error_message.to_string()));
616 assert!(response.metadata.is_empty());
617 }
618
619 #[tokio::test]
620 async fn test_streaming_message_serialization() {
621 let message = StreamingMessage {
622 event_type: "test_event".to_string(),
623 data: serde_json::json!({"key": "value"}),
624 metadata: vec![
625 ("sequence".to_string(), "1".to_string()),
626 ("type".to_string(), "test".to_string()),
627 ]
628 .into_iter()
629 .collect(),
630 };
631
632 let json_str = serde_json::to_string(&message).unwrap();
634 assert!(json_str.contains("test_event"));
635 assert!(json_str.contains("key"));
636 assert!(json_str.contains("value"));
637 assert!(json_str.contains("sequence"));
638 assert!(json_str.contains("1"));
639 assert!(json_str.contains("type"));
640 assert!(json_str.contains("test"));
641
642 let deserialized: StreamingMessage = serde_json::from_str(&json_str).unwrap();
644 assert_eq!(deserialized.event_type, message.event_type);
645 assert_eq!(deserialized.data, message.data);
646 assert_eq!(deserialized.metadata, message.metadata);
647 }
648
649 #[tokio::test]
650 async fn test_create_sse_stream_basic() {
651 let config = HttpBridgeConfig {
652 enabled: true,
653 base_path: "/api".to_string(),
654 enable_cors: false,
655 max_request_size: 1024,
656 timeout_seconds: 30,
657 route_pattern: "/{service}/{method}".to_string(),
658 };
659
660 let stream_response = StreamHandler::create_sse_stream(
661 config,
662 "test_service".to_string(),
663 "test_method".to_string(),
664 )
665 .await;
666
667 let sse_response = stream_response.into_response();
669 assert_eq!(sse_response.status(), http::StatusCode::OK);
670
671 let content_type = sse_response
673 .headers()
674 .get("content-type")
675 .and_then(|h| h.to_str().ok())
676 .unwrap_or("");
677 assert!(content_type.contains("text/event-stream"));
678 }
679
680 #[test]
681 fn test_bridge_response_serialization() {
682 let response = BridgeResponse::<Value> {
683 success: true,
684 data: Some(serde_json::json!({"result": "success"})),
685 error: None,
686 metadata: vec![
687 ("service".to_string(), "test".to_string()),
688 ("method".to_string(), "test".to_string()),
689 ]
690 .into_iter()
691 .collect(),
692 };
693
694 let json_str = serde_json::to_string(&response).unwrap();
695 assert!(json_str.contains("success"));
696 assert!(json_str.contains("true"));
697 assert!(json_str.contains("result"));
698 assert!(json_str.contains("success"));
699 assert!(json_str.contains("service"));
700 assert!(json_str.contains("method"));
701
702 let deserialized: BridgeResponse<Value> = serde_json::from_str(&json_str).unwrap();
703 assert_eq!(deserialized.success, response.success);
704 assert_eq!(deserialized.data, response.data);
705 assert_eq!(deserialized.error, response.error);
706 assert_eq!(deserialized.metadata, response.metadata);
707 }
708
709 #[test]
710 fn test_validate_request_edge_cases() {
711 assert!(RequestProcessor::validate_request("test", "method", 0, 0).is_ok());
713 assert!(RequestProcessor::validate_request("test", "method", 1, 0).is_err());
714
715 assert!(RequestProcessor::validate_request(" test ", " method ", 100, 1000).is_ok());
717 assert!(RequestProcessor::validate_request(" ", "method", 100, 1000).is_err());
718 assert!(RequestProcessor::validate_request("test", " ", 100, 1000).is_err());
719
720 let large_size = usize::MAX / 2;
722 assert!(
723 RequestProcessor::validate_request("test", "method", large_size, usize::MAX).is_ok()
724 );
725 assert!(RequestProcessor::validate_request("test", "method", large_size + 1, large_size)
726 .is_err());
727 }
728
729 #[test]
730 fn test_header_extraction_comprehensive() {
731 let mut headers = HeaderMap::new();
732
733 headers.insert("host", "localhost:9080".parse().unwrap());
735 headers.insert("content-length", "123".parse().unwrap());
736 headers.insert("accept", "application/json".parse().unwrap());
737 headers.insert("x-forwarded-for", "192.168.1.1".parse().unwrap());
738 headers.insert("x-custom-metadata", "custom-value".parse().unwrap());
739 headers.insert("x-trace-id", "trace-123".parse().unwrap());
740 headers.insert("x-request-id", "req-456".parse().unwrap());
741
742 let metadata = RequestProcessor::extract_metadata_from_headers(&headers);
743
744 assert!(!metadata.contains_key("host"));
746 assert!(!metadata.contains_key("content-length"));
747 assert!(!metadata.contains_key("accept"));
748
749 assert_eq!(metadata.get("x-forwarded-for"), Some(&"192.168.1.1".to_string()));
751 assert_eq!(metadata.get("x-custom-metadata"), Some(&"custom-value".to_string()));
752 assert_eq!(metadata.get("x-trace-id"), Some(&"trace-123".to_string()));
753 assert_eq!(metadata.get("x-request-id"), Some(&"req-456".to_string()));
754
755 assert_eq!(metadata.len(), 4);
757 }
758
759 #[test]
760 fn test_error_response_comprehensive() {
761 let test_errors = vec![
763 "Service not found",
764 "Method not found",
765 "Invalid request body",
766 "Authentication failed",
767 "Internal server error",
768 "Timeout exceeded",
769 "Rate limit exceeded",
770 "Database connection failed",
771 ];
772
773 for error_msg in test_errors {
774 let response = ErrorHandler::create_error_response(error_msg.to_string());
775 assert!(!response.success);
776 assert!(response.data.is_none());
777 assert_eq!(response.error, Some(error_msg.to_string()));
778 assert!(response.metadata.is_empty());
779 }
780 }
781}