1use super::{BridgeResponse, HttpBridgeConfig};
7use axum::response::{IntoResponse, Sse};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::mpsc;
13use tokio_stream::{wrappers::ReceiverStream, StreamExt};
14use tonic::Request;
15use tracing::warn;
16
17pub struct StreamHandler;
19
20#[derive(Debug, Serialize, Deserialize)]
22pub struct StreamingMessage {
23 pub event_type: String,
25 pub data: Value,
27 pub metadata: std::collections::HashMap<String, String>,
29}
30
31impl StreamHandler {
32 pub async fn create_sse_stream(
34 _config: HttpBridgeConfig,
35 service_name: String,
36 method_name: String,
37 ) -> impl IntoResponse {
38 let (tx, rx) =
39 tokio::sync::mpsc::channel::<Result<axum::response::sse::Event, axum::BoxError>>(32);
40
41 tokio::spawn(async move {
43 let init_msg = StreamingMessage {
45 event_type: "stream_init".to_string(),
46 data: serde_json::json!({
47 "service": service_name,
48 "method": method_name,
49 "message": "Stream initialized for bidirectional communication"
50 }),
51 metadata: std::collections::HashMap::new(),
52 };
53
54 if let Ok(json_str) = serde_json::to_string(&init_msg) {
55 let _ = tx
56 .send(Ok(axum::response::sse::Event::default().event("message").data(json_str)))
57 .await;
58 }
59
60 let mut counter = 0;
62 while counter < 10 {
63 tokio::time::sleep(Duration::from_millis(500)).await;
64
65 let stream_msg = StreamingMessage {
66 event_type: "data".to_string(),
67 data: serde_json::json!({
68 "counter": counter,
69 "message": format!("Streaming message #{}", counter),
70 "timestamp": chrono::Utc::now().to_rfc3339()
71 }),
72 metadata: vec![("sequence".to_string(), counter.to_string())]
73 .into_iter()
74 .collect(),
75 };
76
77 if let Ok(json_str) = serde_json::to_string(&stream_msg) {
78 let event_type = if counter % 3 == 0 {
79 "heartbeat"
80 } else {
81 "data"
82 };
83 let _ = tx
84 .send(Ok(axum::response::sse::Event::default()
85 .event(event_type)
86 .data(json_str)))
87 .await;
88 }
89
90 counter += 1;
91
92 if counter == 7 {
94 let error_msg = StreamingMessage {
95 event_type: "error".to_string(),
96 data: serde_json::json!({
97 "error": "Simulated network error",
98 "code": "NETWORK_ERROR"
99 }),
100 metadata: vec![("error_code".to_string(), "123".to_string())]
101 .into_iter()
102 .collect(),
103 };
104
105 if let Ok(json_str) = serde_json::to_string(&error_msg) {
106 let _ = tx
107 .send(Ok(axum::response::sse::Event::default()
108 .event("error")
109 .data(json_str)))
110 .await;
111 }
112 }
113 }
114
115 let complete_msg = StreamingMessage {
117 event_type: "stream_complete".to_string(),
118 data: serde_json::json!({
119 "message": "Streaming session completed",
120 "total_messages": counter
121 }),
122 metadata: vec![("session_id".to_string(), "demo-123".to_string())]
123 .into_iter()
124 .collect(),
125 };
126
127 if let Ok(json_str) = serde_json::to_string(&complete_msg) {
128 let _ = tx
129 .send(Ok(axum::response::sse::Event::default()
130 .event("complete")
131 .data(json_str)))
132 .await;
133 }
134 });
135
136 let stream = ReceiverStream::new(rx).map(|result: Result<axum::response::sse::Event, axum::BoxError>| -> Result<axum::response::sse::Event, axum::BoxError> {
137 match result {
138 Ok(event) => Ok(event),
139 Err(e) => Ok(axum::response::sse::Event::default()
140 .event("error")
141 .data(format!("Stream error: {}", e))),
142 }
143 });
144
145 Sse::new(stream).keep_alive(
146 axum::response::sse::KeepAlive::new()
147 .interval(Duration::from_secs(30))
148 .text("keep-alive"),
149 )
150 }
151
152 pub async fn create_grpc_stream_stream(
154 proxy: Arc<super::MockReflectionProxy>,
155 service_name: &str,
156 method_name: &str,
157 initial_request: Value,
158 ) -> impl IntoResponse {
159 let (tx, rx) = tokio::sync::mpsc::channel(32);
160
161 let service_name = service_name.to_string();
163 let method_name = method_name.to_string();
164
165 let result = Self::handle_grpc_bidirectional_streaming(
166 proxy,
167 &service_name,
168 &method_name,
169 initial_request,
170 tx.clone(),
171 )
172 .await;
173
174 tokio::spawn(async move {
175 match result {
176 Ok(_) => {
177 let _ = tx
178 .send(Ok(axum::response::sse::Event::default()
179 .event("complete")
180 .data("Stream completed successfully")))
181 .await;
182 }
183 Err(e) => {
184 let _ = tx
185 .send(Ok(axum::response::sse::Event::default()
186 .event("error")
187 .data(format!("Stream error: {}", e))))
188 .await;
189 }
190 }
191 });
192
193 let stream = ReceiverStream::new(rx).map(|result: Result<axum::response::sse::Event, axum::BoxError>| -> Result<axum::response::sse::Event, axum::BoxError> {
194 match result {
195 Ok(event) => Ok(event),
196 Err(e) => Ok(axum::response::sse::Event::default()
197 .event("error")
198 .data(format!("Stream error: {}", e))),
199 }
200 });
201
202 Sse::new(stream).keep_alive(
203 axum::response::sse::KeepAlive::new()
204 .interval(Duration::from_secs(30))
205 .text("keep-alive"),
206 )
207 }
208
209 async fn handle_grpc_bidirectional_streaming(
211 proxy: Arc<super::MockReflectionProxy>,
212 service_name: &str,
213 method_name: &str,
214 initial_request: Value,
215 tx: tokio::sync::mpsc::Sender<Result<axum::response::sse::Event, axum::BoxError>>,
216 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
217 let registry = proxy.service_registry();
219 let service = registry
220 .get(service_name)
221 .ok_or_else(|| format!("Service '{}' not found", service_name))?;
222
223 let method_info = service
224 .service()
225 .methods
226 .iter()
227 .find(|m| m.name == method_name)
228 .ok_or_else(|| {
229 format!("Method '{}' not found in service '{}'", method_name, service_name)
230 })?;
231 let input_descriptor = registry
232 .descriptor_pool()
233 .get_message_by_name(&method_info.input_type)
234 .ok_or_else(|| format!("Input type '{}' not found", method_info.input_type))?;
235 let _output_descriptor = registry
236 .descriptor_pool()
237 .get_message_by_name(&method_info.output_type)
238 .ok_or_else(|| format!("Output type '{}' not found", method_info.output_type))?;
239
240 let converter =
242 super::converters::ProtobufJsonConverter::new(registry.descriptor_pool().clone());
243
244 let client_messages: Vec<Value> = if let Some(arr) = initial_request.as_array() {
246 arr.clone()
247 } else {
248 vec![initial_request]
249 };
250
251 let mut dynamic_messages = Vec::new();
253 for (i, json_msg) in client_messages.iter().enumerate() {
254 match converter.json_to_protobuf(&input_descriptor, json_msg) {
255 Ok(dynamic_msg) => dynamic_messages.push(dynamic_msg),
256 Err(e) => {
257 warn!("Failed to convert client message {} to protobuf: {}", i, e);
258 let error_msg = StreamingMessage {
260 event_type: "conversion_error".to_string(),
261 data: serde_json::json!({
262 "message": format!("Failed to convert client message {}: {}", i, e),
263 "sequence": i
264 }),
265 metadata: vec![
266 ("error_type".to_string(), "conversion".to_string()),
267 ("sequence".to_string(), i.to_string()),
268 ]
269 .into_iter()
270 .collect(),
271 };
272 if let Ok(json_str) = serde_json::to_string(&error_msg) {
273 let _ = tx
274 .send(Ok(axum::response::sse::Event::default()
275 .event("error")
276 .data(json_str)))
277 .await;
278 }
279 continue;
280 }
281 }
282 }
283
284 if dynamic_messages.is_empty() {
285 return Err("No valid client messages to send".into());
286 }
287
288 let start_msg = StreamingMessage {
290 event_type: "bidirectional_stream_start".to_string(),
291 data: serde_json::json!({
292 "service": service_name,
293 "method": method_name,
294 "client_messages_count": dynamic_messages.len()
295 }),
296 metadata: vec![
297 ("stream_type".to_string(), "bidirectional".to_string()),
298 ("protocol".to_string(), "grpc-web-over-sse".to_string()),
299 ]
300 .into_iter()
301 .collect(),
302 };
303
304 if let Ok(json_str) = serde_json::to_string(&start_msg) {
305 let _ = tx
306 .send(Ok(axum::response::sse::Event::default()
307 .event("stream_start")
308 .data(json_str)))
309 .await;
310 }
311
312 let (client_tx, client_rx) =
314 mpsc::channel::<Result<prost_reflect::DynamicMessage, tonic::Status>>(10);
315
316 let _request = Request::new(ReceiverStream::new(client_rx));
318
319 let client_tx_clone = client_tx.clone();
321 tokio::spawn(async move {
322 for (i, dynamic_msg) in dynamic_messages.into_iter().enumerate() {
323 if client_tx_clone.send(Ok(dynamic_msg)).await.is_err() {
324 warn!("Failed to send client message {} to gRPC stream", i);
325 break;
326 }
327 }
328 drop(client_tx_clone);
330 });
331
332 let method_descriptor = proxy.cache().get_method(service_name, method_name).await?;
334
335 let smart_generator = proxy.smart_generator().clone();
338 let output_descriptor = method_descriptor.output();
339
340 let mock_response = {
342 match smart_generator.lock() {
343 Ok(mut gen) => gen.generate_message(&output_descriptor),
344 Err(e) => {
345 let error_msg = StreamingMessage {
346 event_type: "error".to_string(),
347 data: serde_json::json!({
348 "message": format!("Failed to acquire smart generator lock: {}", e)
349 }),
350 metadata: vec![("error_type".to_string(), "lock".to_string())]
351 .into_iter()
352 .collect(),
353 };
354 if let Ok(json_str) = serde_json::to_string(&error_msg) {
355 let _ = tx
356 .send(Ok(axum::response::sse::Event::default()
357 .event("error")
358 .data(json_str)))
359 .await;
360 }
361 return Ok(());
362 }
363 }
364 };
365
366 match converter.protobuf_to_json(&output_descriptor, &mock_response) {
368 Ok(json_response) => {
369 let response_msg = StreamingMessage {
370 event_type: "grpc_response".to_string(),
371 data: json_response,
372 metadata: vec![
373 ("sequence".to_string(), "1".to_string()),
374 ("message_type".to_string(), "response".to_string()),
375 ]
376 .into_iter()
377 .collect(),
378 };
379
380 if let Ok(json_str) = serde_json::to_string(&response_msg) {
381 let _ = tx
382 .send(Ok(axum::response::sse::Event::default()
383 .event("grpc_response")
384 .data(json_str)))
385 .await;
386 }
387 }
388 Err(e) => {
389 let error_msg = StreamingMessage {
390 event_type: "conversion_error".to_string(),
391 data: serde_json::json!({
392 "message": format!("Failed to convert response to JSON: {}", e)
393 }),
394 metadata: vec![("error_type".to_string(), "conversion".to_string())]
395 .into_iter()
396 .collect(),
397 };
398 if let Ok(json_str) = serde_json::to_string(&error_msg) {
399 let _ = tx
400 .send(Ok(axum::response::sse::Event::default()
401 .event("error")
402 .data(json_str)))
403 .await;
404 }
405 }
406 }
407
408 let end_msg = StreamingMessage {
410 event_type: "bidirectional_stream_end".to_string(),
411 data: serde_json::json!({
412 "message": "Bidirectional streaming session completed",
413 "statistics": {
414 "responses_sent": 1
415 }
416 }),
417 metadata: vec![("session_status".to_string(), "completed".to_string())]
418 .into_iter()
419 .collect(),
420 };
421
422 if let Ok(json_str) = serde_json::to_string(&end_msg) {
423 let _ = tx
424 .send(Ok(axum::response::sse::Event::default().event("stream_end").data(json_str)))
425 .await;
426 }
427
428 Ok(())
429 }
430}
431
432pub struct ErrorHandler;
434
435impl ErrorHandler {
436 pub fn error_to_status_code(error: &str) -> axum::http::StatusCode {
438 if error.contains("not found") || error.contains("Unknown") {
439 axum::http::StatusCode::NOT_FOUND
440 } else if error.contains("unauthorized") || error.contains("forbidden") {
441 axum::http::StatusCode::FORBIDDEN
442 } else if error.contains("invalid") || error.contains("malformed") {
443 axum::http::StatusCode::BAD_REQUEST
444 } else {
445 axum::http::StatusCode::INTERNAL_SERVER_ERROR
446 }
447 }
448
449 pub fn create_error_response(error: String) -> BridgeResponse<Value> {
451 BridgeResponse {
452 success: false,
453 data: None,
454 error: Some(error),
455 metadata: std::collections::HashMap::new(),
456 }
457 }
458}
459
460pub struct RequestProcessor;
462
463impl RequestProcessor {
464 pub fn validate_request(
466 service_name: &str,
467 method_name: &str,
468 body_size: usize,
469 max_body_size: usize,
470 ) -> Result<(), String> {
471 if service_name.trim().is_empty() {
472 return Err("Service name cannot be empty".to_string());
473 }
474
475 if method_name.trim().is_empty() {
476 return Err("Method name cannot be empty".to_string());
477 }
478
479 if body_size > max_body_size {
480 return Err(format!(
481 "Request body too large: {} bytes (max: {} bytes)",
482 body_size, max_body_size
483 ));
484 }
485
486 Ok(())
488 }
489
490 pub fn extract_metadata_from_headers(
492 headers: &axum::http::HeaderMap,
493 ) -> std::collections::HashMap<String, String> {
494 let mut metadata = std::collections::HashMap::new();
495
496 for (key, value) in headers.iter() {
497 let key_str = key.as_str();
498 if !key_str.starts_with("host")
500 && !key_str.starts_with("content-type")
501 && !key_str.starts_with("content-length")
502 && !key_str.starts_with("user-agent")
503 && !key_str.starts_with("accept")
504 && !key_str.starts_with("authorization")
505 {
506 if let Ok(value_str) = value.to_str() {
507 metadata.insert(key_str.to_string(), value_str.to_string());
508 }
509 }
510 }
511
512 metadata
513 }
514}
515
516#[cfg(test)]
517mod tests {
518 use super::*;
519 use axum::http::HeaderMap;
520
521 #[test]
522 fn test_error_to_status_code() {
523 assert_eq!(
524 ErrorHandler::error_to_status_code("service not found"),
525 axum::http::StatusCode::NOT_FOUND
526 );
527 assert_eq!(
528 ErrorHandler::error_to_status_code("unauthorized"),
529 axum::http::StatusCode::FORBIDDEN
530 );
531 assert_eq!(
532 ErrorHandler::error_to_status_code("invalid request"),
533 axum::http::StatusCode::BAD_REQUEST
534 );
535 assert_eq!(
536 ErrorHandler::error_to_status_code("internal error"),
537 axum::http::StatusCode::INTERNAL_SERVER_ERROR
538 );
539
540 assert_eq!(
542 ErrorHandler::error_to_status_code("Unknown service"),
543 axum::http::StatusCode::NOT_FOUND
544 );
545 assert_eq!(
546 ErrorHandler::error_to_status_code("forbidden access"),
547 axum::http::StatusCode::FORBIDDEN
548 );
549 assert_eq!(
550 ErrorHandler::error_to_status_code("malformed JSON"),
551 axum::http::StatusCode::BAD_REQUEST
552 );
553 assert_eq!(
554 ErrorHandler::error_to_status_code("random error"),
555 axum::http::StatusCode::INTERNAL_SERVER_ERROR
556 );
557 }
558
559 #[test]
560 fn test_validate_request() {
561 assert!(RequestProcessor::validate_request("test", "method", 100, 1000).is_ok());
562 assert!(RequestProcessor::validate_request("", "method", 100, 1000).is_err());
563 assert!(RequestProcessor::validate_request("test", "", 100, 1000).is_err());
564 assert!(RequestProcessor::validate_request("test", "method", 2000, 1000).is_err());
565
566 assert!(
568 RequestProcessor::validate_request("valid_service", "valid_method", 0, 1000).is_ok()
569 );
570 assert!(RequestProcessor::validate_request("test", "method", 1000, 1000).is_ok());
571 assert!(RequestProcessor::validate_request("test", "method", 1001, 1000).is_err());
572
573 let long_name = "a".repeat(1000);
575 assert!(RequestProcessor::validate_request(&long_name, &long_name, 100, 1000).is_ok());
576 }
577
578 #[test]
579 fn test_extract_metadata_from_headers() {
580 let mut headers = HeaderMap::new();
581
582 headers.insert("content-type", "application/json".parse().unwrap());
584 headers.insert("authorization", "Bearer token123".parse().unwrap());
585 headers.insert("x-custom-header", "custom-value".parse().unwrap());
586 headers.insert("x-api-key", "key123".parse().unwrap());
587 headers.insert("user-agent", "test-agent".parse().unwrap());
588
589 let metadata = RequestProcessor::extract_metadata_from_headers(&headers);
590
591 assert!(!metadata.contains_key("content-type"));
593 assert!(!metadata.contains_key("authorization")); assert!(!metadata.contains_key("user-agent"));
595
596 assert_eq!(metadata.get("x-custom-header"), Some(&"custom-value".to_string()));
598 assert_eq!(metadata.get("x-api-key"), Some(&"key123".to_string()));
599
600 let empty_headers = HeaderMap::new();
602 let empty_metadata = RequestProcessor::extract_metadata_from_headers(&empty_headers);
603 assert!(empty_metadata.is_empty());
604
605 let mut case_headers = HeaderMap::new();
607 case_headers.insert("X-CUSTOM-HEADER", "value".parse().unwrap());
608 let case_metadata = RequestProcessor::extract_metadata_from_headers(&case_headers);
609 assert_eq!(case_metadata.get("x-custom-header"), Some(&"value".to_string()));
610 }
611
612 #[test]
613 fn test_create_error_response() {
614 let error_message = "Test error message";
615 let response = ErrorHandler::create_error_response(error_message.to_string());
616
617 assert!(!response.success);
618 assert!(response.data.is_none());
619 assert_eq!(response.error, Some(error_message.to_string()));
620 assert!(response.metadata.is_empty());
621 }
622
623 #[tokio::test]
624 async fn test_streaming_message_serialization() {
625 let message = StreamingMessage {
626 event_type: "test_event".to_string(),
627 data: serde_json::json!({"key": "value"}),
628 metadata: vec![
629 ("sequence".to_string(), "1".to_string()),
630 ("type".to_string(), "test".to_string()),
631 ]
632 .into_iter()
633 .collect(),
634 };
635
636 let json_str = serde_json::to_string(&message).unwrap();
638 assert!(json_str.contains("test_event"));
639 assert!(json_str.contains("key"));
640 assert!(json_str.contains("value"));
641 assert!(json_str.contains("sequence"));
642 assert!(json_str.contains("1"));
643 assert!(json_str.contains("type"));
644 assert!(json_str.contains("test"));
645
646 let deserialized: StreamingMessage = serde_json::from_str(&json_str).unwrap();
648 assert_eq!(deserialized.event_type, message.event_type);
649 assert_eq!(deserialized.data, message.data);
650 assert_eq!(deserialized.metadata, message.metadata);
651 }
652
653 #[tokio::test]
654 async fn test_create_sse_stream_basic() {
655 let config = HttpBridgeConfig {
656 enabled: true,
657 base_path: "/api".to_string(),
658 enable_cors: false,
659 max_request_size: 1024,
660 timeout_seconds: 30,
661 route_pattern: "/{service}/{method}".to_string(),
662 };
663
664 let stream_response = StreamHandler::create_sse_stream(
665 config,
666 "test_service".to_string(),
667 "test_method".to_string(),
668 )
669 .await;
670
671 let sse_response = stream_response.into_response();
673 assert_eq!(sse_response.status(), axum::http::StatusCode::OK);
674
675 let content_type = sse_response
677 .headers()
678 .get("content-type")
679 .and_then(|h| h.to_str().ok())
680 .unwrap_or("");
681 assert!(content_type.contains("text/event-stream"));
682 }
683
684 #[test]
685 fn test_bridge_response_serialization() {
686 let response = BridgeResponse::<serde_json::Value> {
687 success: true,
688 data: Some(serde_json::json!({"result": "success"})),
689 error: None,
690 metadata: vec![
691 ("service".to_string(), "test".to_string()),
692 ("method".to_string(), "test".to_string()),
693 ]
694 .into_iter()
695 .collect(),
696 };
697
698 let json_str = serde_json::to_string(&response).unwrap();
699 assert!(json_str.contains("success"));
700 assert!(json_str.contains("true"));
701 assert!(json_str.contains("result"));
702 assert!(json_str.contains("success"));
703 assert!(json_str.contains("service"));
704 assert!(json_str.contains("method"));
705
706 let deserialized: BridgeResponse<serde_json::Value> =
707 serde_json::from_str(&json_str).unwrap();
708 assert_eq!(deserialized.success, response.success);
709 assert_eq!(deserialized.data, response.data);
710 assert_eq!(deserialized.error, response.error);
711 assert_eq!(deserialized.metadata, response.metadata);
712 }
713
714 #[test]
715 fn test_validate_request_edge_cases() {
716 assert!(RequestProcessor::validate_request("test", "method", 0, 0).is_ok());
718 assert!(RequestProcessor::validate_request("test", "method", 1, 0).is_err());
719
720 assert!(RequestProcessor::validate_request(" test ", " method ", 100, 1000).is_ok());
722 assert!(RequestProcessor::validate_request(" ", "method", 100, 1000).is_err());
723 assert!(RequestProcessor::validate_request("test", " ", 100, 1000).is_err());
724
725 let large_size = usize::MAX / 2;
727 assert!(
728 RequestProcessor::validate_request("test", "method", large_size, usize::MAX).is_ok()
729 );
730 assert!(RequestProcessor::validate_request("test", "method", large_size + 1, large_size)
731 .is_err());
732 }
733
734 #[test]
735 fn test_header_extraction_comprehensive() {
736 let mut headers = HeaderMap::new();
737
738 headers.insert("host", "localhost:9080".parse().unwrap());
740 headers.insert("content-length", "123".parse().unwrap());
741 headers.insert("accept", "application/json".parse().unwrap());
742 headers.insert("x-forwarded-for", "192.168.1.1".parse().unwrap());
743 headers.insert("x-custom-metadata", "custom-value".parse().unwrap());
744 headers.insert("x-trace-id", "trace-123".parse().unwrap());
745 headers.insert("x-request-id", "req-456".parse().unwrap());
746
747 let metadata = RequestProcessor::extract_metadata_from_headers(&headers);
748
749 assert!(!metadata.contains_key("host"));
751 assert!(!metadata.contains_key("content-length"));
752 assert!(!metadata.contains_key("accept"));
753
754 assert_eq!(metadata.get("x-forwarded-for"), Some(&"192.168.1.1".to_string()));
756 assert_eq!(metadata.get("x-custom-metadata"), Some(&"custom-value".to_string()));
757 assert_eq!(metadata.get("x-trace-id"), Some(&"trace-123".to_string()));
758 assert_eq!(metadata.get("x-request-id"), Some(&"req-456".to_string()));
759
760 assert_eq!(metadata.len(), 4);
762 }
763
764 #[test]
765 fn test_error_response_comprehensive() {
766 let test_errors = vec![
768 "Service not found",
769 "Method not found",
770 "Invalid request body",
771 "Authentication failed",
772 "Internal server error",
773 "Timeout exceeded",
774 "Rate limit exceeded",
775 "Database connection failed",
776 ];
777
778 for error_msg in test_errors {
779 let response = ErrorHandler::create_error_response(error_msg.to_string());
780 assert!(!response.success);
781 assert!(response.data.is_none());
782 assert_eq!(response.error, Some(error_msg.to_string()));
783 assert!(response.metadata.is_empty());
784 }
785 }
786}