1use super::{BridgeResponse, HttpBridgeConfig};
7use axum::response::{IntoResponse, Sse};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::mpsc;
13use tokio_stream::{wrappers::ReceiverStream, StreamExt};
14use tonic::Request;
15use tracing::warn;
16
17pub struct StreamHandler;
19
20#[derive(Debug, Serialize, Deserialize)]
22pub struct StreamingMessage {
23 pub event_type: String,
25 pub data: Value,
27 pub metadata: std::collections::HashMap<String, String>,
29}
30
31impl StreamHandler {
32 pub async fn create_sse_stream(
34 _config: HttpBridgeConfig,
35 service_name: String,
36 method_name: String,
37 ) -> impl IntoResponse {
38 let (tx, rx) = mpsc::channel::<Result<axum::response::sse::Event, axum::BoxError>>(32);
39
40 tokio::spawn(async move {
42 let init_msg = StreamingMessage {
44 event_type: "stream_init".to_string(),
45 data: serde_json::json!({
46 "service": service_name,
47 "method": method_name,
48 "message": "Stream initialized for bidirectional communication"
49 }),
50 metadata: std::collections::HashMap::new(),
51 };
52
53 if let Ok(json_str) = serde_json::to_string(&init_msg) {
54 let _ = tx
55 .send(Ok(axum::response::sse::Event::default().event("message").data(json_str)))
56 .await;
57 }
58
59 let mut counter = 0;
61 while counter < 10 {
62 tokio::time::sleep(Duration::from_millis(500)).await;
63
64 let stream_msg = StreamingMessage {
65 event_type: "data".to_string(),
66 data: serde_json::json!({
67 "counter": counter,
68 "message": format!("Streaming message #{}", counter),
69 "timestamp": chrono::Utc::now().to_rfc3339()
70 }),
71 metadata: vec![("sequence".to_string(), counter.to_string())]
72 .into_iter()
73 .collect(),
74 };
75
76 if let Ok(json_str) = serde_json::to_string(&stream_msg) {
77 let event_type = if counter % 3 == 0 {
78 "heartbeat"
79 } else {
80 "data"
81 };
82 let _ = tx
83 .send(Ok(axum::response::sse::Event::default()
84 .event(event_type)
85 .data(json_str)))
86 .await;
87 }
88
89 counter += 1;
90
91 if counter == 7 {
93 let error_msg = StreamingMessage {
94 event_type: "error".to_string(),
95 data: serde_json::json!({
96 "error": "Simulated network error",
97 "code": "NETWORK_ERROR"
98 }),
99 metadata: vec![("error_code".to_string(), "123".to_string())]
100 .into_iter()
101 .collect(),
102 };
103
104 if let Ok(json_str) = serde_json::to_string(&error_msg) {
105 let _ = tx
106 .send(Ok(axum::response::sse::Event::default()
107 .event("error")
108 .data(json_str)))
109 .await;
110 }
111 }
112 }
113
114 let complete_msg = StreamingMessage {
116 event_type: "stream_complete".to_string(),
117 data: serde_json::json!({
118 "message": "Streaming session completed",
119 "total_messages": counter
120 }),
121 metadata: vec![("session_id".to_string(), "demo-123".to_string())]
122 .into_iter()
123 .collect(),
124 };
125
126 if let Ok(json_str) = serde_json::to_string(&complete_msg) {
127 let _ = tx
128 .send(Ok(axum::response::sse::Event::default()
129 .event("complete")
130 .data(json_str)))
131 .await;
132 }
133 });
134
135 let stream = ReceiverStream::new(rx).map(|result: Result<axum::response::sse::Event, axum::BoxError>| -> Result<axum::response::sse::Event, axum::BoxError> {
136 match result {
137 Ok(event) => Ok(event),
138 Err(e) => Ok(axum::response::sse::Event::default()
139 .event("error")
140 .data(format!("Stream error: {}", e))),
141 }
142 });
143
144 Sse::new(stream).keep_alive(
145 axum::response::sse::KeepAlive::new()
146 .interval(Duration::from_secs(30))
147 .text("keep-alive"),
148 )
149 }
150
151 pub async fn create_grpc_stream_stream(
153 proxy: Arc<super::MockReflectionProxy>,
154 service_name: &str,
155 method_name: &str,
156 initial_request: Value,
157 ) -> impl IntoResponse {
158 let (tx, rx) = mpsc::channel(32);
159
160 let service_name = service_name.to_string();
162 let method_name = method_name.to_string();
163
164 let result = Self::handle_grpc_bidirectional_streaming(
165 proxy,
166 &service_name,
167 &method_name,
168 initial_request,
169 tx.clone(),
170 )
171 .await;
172
173 tokio::spawn(async move {
174 match result {
175 Ok(_) => {
176 let _ = tx
177 .send(Ok(axum::response::sse::Event::default()
178 .event("complete")
179 .data("Stream completed successfully")))
180 .await;
181 }
182 Err(e) => {
183 let _ = tx
184 .send(Ok(axum::response::sse::Event::default()
185 .event("error")
186 .data(format!("Stream error: {}", e))))
187 .await;
188 }
189 }
190 });
191
192 let stream = ReceiverStream::new(rx).map(|result: Result<axum::response::sse::Event, axum::BoxError>| -> Result<axum::response::sse::Event, axum::BoxError> {
193 match result {
194 Ok(event) => Ok(event),
195 Err(e) => Ok(axum::response::sse::Event::default()
196 .event("error")
197 .data(format!("Stream error: {}", e))),
198 }
199 });
200
201 Sse::new(stream).keep_alive(
202 axum::response::sse::KeepAlive::new()
203 .interval(Duration::from_secs(30))
204 .text("keep-alive"),
205 )
206 }
207
208 async fn handle_grpc_bidirectional_streaming(
210 proxy: Arc<super::MockReflectionProxy>,
211 service_name: &str,
212 method_name: &str,
213 initial_request: Value,
214 tx: mpsc::Sender<Result<axum::response::sse::Event, axum::BoxError>>,
215 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
216 let registry = proxy.service_registry();
218 let service = registry
219 .get(service_name)
220 .ok_or_else(|| format!("Service '{}' not found", service_name))?;
221
222 let method_info = service
223 .service()
224 .methods
225 .iter()
226 .find(|m| m.name == method_name)
227 .ok_or_else(|| {
228 format!("Method '{}' not found in service '{}'", method_name, service_name)
229 })?;
230 let input_descriptor = registry
231 .descriptor_pool()
232 .get_message_by_name(&method_info.input_type)
233 .ok_or_else(|| format!("Input type '{}' not found", method_info.input_type))?;
234 let _output_descriptor = registry
235 .descriptor_pool()
236 .get_message_by_name(&method_info.output_type)
237 .ok_or_else(|| format!("Output type '{}' not found", method_info.output_type))?;
238
239 let converter =
241 super::converters::ProtobufJsonConverter::new(registry.descriptor_pool().clone());
242
243 let client_messages: Vec<Value> = if let Some(arr) = initial_request.as_array() {
245 arr.clone()
246 } else {
247 vec![initial_request]
248 };
249
250 let mut dynamic_messages = Vec::new();
252 for (i, json_msg) in client_messages.iter().enumerate() {
253 match converter.json_to_protobuf(&input_descriptor, json_msg) {
254 Ok(dynamic_msg) => dynamic_messages.push(dynamic_msg),
255 Err(e) => {
256 warn!("Failed to convert client message {} to protobuf: {}", i, e);
257 let error_msg = StreamingMessage {
259 event_type: "conversion_error".to_string(),
260 data: serde_json::json!({
261 "message": format!("Failed to convert client message {}: {}", i, e),
262 "sequence": i
263 }),
264 metadata: vec![
265 ("error_type".to_string(), "conversion".to_string()),
266 ("sequence".to_string(), i.to_string()),
267 ]
268 .into_iter()
269 .collect(),
270 };
271 if let Ok(json_str) = serde_json::to_string(&error_msg) {
272 let _ = tx
273 .send(Ok(axum::response::sse::Event::default()
274 .event("error")
275 .data(json_str)))
276 .await;
277 }
278 continue;
279 }
280 }
281 }
282
283 if dynamic_messages.is_empty() {
284 return Err("No valid client messages to send".into());
285 }
286
287 let start_msg = StreamingMessage {
289 event_type: "bidirectional_stream_start".to_string(),
290 data: serde_json::json!({
291 "service": service_name,
292 "method": method_name,
293 "client_messages_count": dynamic_messages.len()
294 }),
295 metadata: vec![
296 ("stream_type".to_string(), "bidirectional".to_string()),
297 ("protocol".to_string(), "grpc-web-over-sse".to_string()),
298 ]
299 .into_iter()
300 .collect(),
301 };
302
303 if let Ok(json_str) = serde_json::to_string(&start_msg) {
304 let _ = tx
305 .send(Ok(axum::response::sse::Event::default()
306 .event("stream_start")
307 .data(json_str)))
308 .await;
309 }
310
311 let (client_tx, client_rx) =
313 mpsc::channel::<Result<prost_reflect::DynamicMessage, tonic::Status>>(10);
314
315 let _request = Request::new(ReceiverStream::new(client_rx));
317
318 let client_tx_clone = client_tx.clone();
320 tokio::spawn(async move {
321 for (i, dynamic_msg) in dynamic_messages.into_iter().enumerate() {
322 if client_tx_clone.send(Ok(dynamic_msg)).await.is_err() {
323 warn!("Failed to send client message {} to gRPC stream", i);
324 break;
325 }
326 }
327 drop(client_tx_clone);
329 });
330
331 let method_descriptor = proxy.cache().get_method(service_name, method_name).await?;
333
334 let smart_generator = proxy.smart_generator().clone();
337 let output_descriptor = method_descriptor.output();
338
339 let mock_response = {
341 match smart_generator.lock() {
342 Ok(mut gen) => gen.generate_message(&output_descriptor),
343 Err(e) => {
344 let error_msg = StreamingMessage {
345 event_type: "error".to_string(),
346 data: serde_json::json!({
347 "message": format!("Failed to acquire smart generator lock: {}", e)
348 }),
349 metadata: vec![("error_type".to_string(), "lock".to_string())]
350 .into_iter()
351 .collect(),
352 };
353 if let Ok(json_str) = serde_json::to_string(&error_msg) {
354 let _ = tx
355 .send(Ok(axum::response::sse::Event::default()
356 .event("error")
357 .data(json_str)))
358 .await;
359 }
360 return Ok(());
361 }
362 }
363 };
364
365 match converter.protobuf_to_json(&output_descriptor, &mock_response) {
367 Ok(json_response) => {
368 let response_msg = StreamingMessage {
369 event_type: "grpc_response".to_string(),
370 data: json_response,
371 metadata: vec![
372 ("sequence".to_string(), "1".to_string()),
373 ("message_type".to_string(), "response".to_string()),
374 ]
375 .into_iter()
376 .collect(),
377 };
378
379 if let Ok(json_str) = serde_json::to_string(&response_msg) {
380 let _ = tx
381 .send(Ok(axum::response::sse::Event::default()
382 .event("grpc_response")
383 .data(json_str)))
384 .await;
385 }
386 }
387 Err(e) => {
388 let error_msg = StreamingMessage {
389 event_type: "conversion_error".to_string(),
390 data: serde_json::json!({
391 "message": format!("Failed to convert response to JSON: {}", e)
392 }),
393 metadata: vec![("error_type".to_string(), "conversion".to_string())]
394 .into_iter()
395 .collect(),
396 };
397 if let Ok(json_str) = serde_json::to_string(&error_msg) {
398 let _ = tx
399 .send(Ok(axum::response::sse::Event::default()
400 .event("error")
401 .data(json_str)))
402 .await;
403 }
404 }
405 }
406
407 let end_msg = StreamingMessage {
409 event_type: "bidirectional_stream_end".to_string(),
410 data: serde_json::json!({
411 "message": "Bidirectional streaming session completed",
412 "statistics": {
413 "responses_sent": 1
414 }
415 }),
416 metadata: vec![("session_status".to_string(), "completed".to_string())]
417 .into_iter()
418 .collect(),
419 };
420
421 if let Ok(json_str) = serde_json::to_string(&end_msg) {
422 let _ = tx
423 .send(Ok(axum::response::sse::Event::default().event("stream_end").data(json_str)))
424 .await;
425 }
426
427 Ok(())
428 }
429}
430
431pub struct ErrorHandler;
433
434impl ErrorHandler {
435 pub fn error_to_status_code(error: &str) -> http::StatusCode {
437 if error.contains("not found") || error.contains("Unknown") {
438 http::StatusCode::NOT_FOUND
439 } else if error.contains("unauthorized") || error.contains("forbidden") {
440 http::StatusCode::FORBIDDEN
441 } else if error.contains("invalid") || error.contains("malformed") {
442 http::StatusCode::BAD_REQUEST
443 } else {
444 http::StatusCode::INTERNAL_SERVER_ERROR
445 }
446 }
447
448 pub fn create_error_response(error: String) -> BridgeResponse<Value> {
450 BridgeResponse {
451 success: false,
452 data: None,
453 error: Some(error),
454 metadata: std::collections::HashMap::new(),
455 }
456 }
457}
458
459pub struct RequestProcessor;
461
462impl RequestProcessor {
463 pub fn validate_request(
465 service_name: &str,
466 method_name: &str,
467 body_size: usize,
468 max_body_size: usize,
469 ) -> Result<(), String> {
470 if service_name.trim().is_empty() {
471 return Err("Service name cannot be empty".to_string());
472 }
473
474 if method_name.trim().is_empty() {
475 return Err("Method name cannot be empty".to_string());
476 }
477
478 if body_size > max_body_size {
479 return Err(format!(
480 "Request body too large: {} bytes (max: {} bytes)",
481 body_size, max_body_size
482 ));
483 }
484
485 Ok(())
487 }
488
489 pub fn extract_metadata_from_headers(
491 headers: &http::HeaderMap,
492 ) -> std::collections::HashMap<String, String> {
493 let mut metadata = std::collections::HashMap::new();
494
495 for (key, value) in headers.iter() {
496 let key_str = key.as_str();
497 if !key_str.starts_with("host")
499 && !key_str.starts_with("content-type")
500 && !key_str.starts_with("content-length")
501 && !key_str.starts_with("user-agent")
502 && !key_str.starts_with("accept")
503 && !key_str.starts_with("authorization")
504 {
505 if let Ok(value_str) = value.to_str() {
506 metadata.insert(key_str.to_string(), value_str.to_string());
507 }
508 }
509 }
510
511 metadata
512 }
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518 use axum::http::HeaderMap;
519
520 #[test]
521 fn test_error_to_status_code() {
522 assert_eq!(
523 ErrorHandler::error_to_status_code("service not found"),
524 axum::http::StatusCode::NOT_FOUND
525 );
526 assert_eq!(
527 ErrorHandler::error_to_status_code("unauthorized"),
528 axum::http::StatusCode::FORBIDDEN
529 );
530 assert_eq!(
531 ErrorHandler::error_to_status_code("invalid request"),
532 axum::http::StatusCode::BAD_REQUEST
533 );
534 assert_eq!(
535 ErrorHandler::error_to_status_code("internal error"),
536 axum::http::StatusCode::INTERNAL_SERVER_ERROR
537 );
538
539 assert_eq!(
541 ErrorHandler::error_to_status_code("Unknown service"),
542 axum::http::StatusCode::NOT_FOUND
543 );
544 assert_eq!(
545 ErrorHandler::error_to_status_code("forbidden access"),
546 axum::http::StatusCode::FORBIDDEN
547 );
548 assert_eq!(
549 ErrorHandler::error_to_status_code("malformed JSON"),
550 axum::http::StatusCode::BAD_REQUEST
551 );
552 assert_eq!(
553 ErrorHandler::error_to_status_code("random error"),
554 axum::http::StatusCode::INTERNAL_SERVER_ERROR
555 );
556 }
557
558 #[test]
559 fn test_validate_request() {
560 assert!(RequestProcessor::validate_request("test", "method", 100, 1000).is_ok());
561 assert!(RequestProcessor::validate_request("", "method", 100, 1000).is_err());
562 assert!(RequestProcessor::validate_request("test", "", 100, 1000).is_err());
563 assert!(RequestProcessor::validate_request("test", "method", 2000, 1000).is_err());
564
565 assert!(
567 RequestProcessor::validate_request("valid_service", "valid_method", 0, 1000).is_ok()
568 );
569 assert!(RequestProcessor::validate_request("test", "method", 1000, 1000).is_ok());
570 assert!(RequestProcessor::validate_request("test", "method", 1001, 1000).is_err());
571
572 let long_name = "a".repeat(1000);
574 assert!(RequestProcessor::validate_request(&long_name, &long_name, 100, 1000).is_ok());
575 }
576
577 #[test]
578 fn test_extract_metadata_from_headers() {
579 let mut headers = HeaderMap::new();
580
581 headers.insert("content-type", "application/json".parse().unwrap());
583 headers.insert("authorization", "Bearer token123".parse().unwrap());
584 headers.insert("x-custom-header", "custom-value".parse().unwrap());
585 headers.insert("x-api-key", "key123".parse().unwrap());
586 headers.insert("user-agent", "test-agent".parse().unwrap());
587
588 let metadata = RequestProcessor::extract_metadata_from_headers(&headers);
589
590 assert!(!metadata.contains_key("content-type"));
592 assert!(!metadata.contains_key("authorization")); assert!(!metadata.contains_key("user-agent"));
594
595 assert_eq!(metadata.get("x-custom-header"), Some(&"custom-value".to_string()));
597 assert_eq!(metadata.get("x-api-key"), Some(&"key123".to_string()));
598
599 let empty_headers = HeaderMap::new();
601 let empty_metadata = RequestProcessor::extract_metadata_from_headers(&empty_headers);
602 assert!(empty_metadata.is_empty());
603
604 let mut case_headers = HeaderMap::new();
606 case_headers.insert("X-CUSTOM-HEADER", "value".parse().unwrap());
607 let case_metadata = RequestProcessor::extract_metadata_from_headers(&case_headers);
608 assert_eq!(case_metadata.get("x-custom-header"), Some(&"value".to_string()));
609 }
610
611 #[test]
612 fn test_create_error_response() {
613 let error_message = "Test error message";
614 let response = ErrorHandler::create_error_response(error_message.to_string());
615
616 assert!(!response.success);
617 assert!(response.data.is_none());
618 assert_eq!(response.error, Some(error_message.to_string()));
619 assert!(response.metadata.is_empty());
620 }
621
622 #[tokio::test]
623 async fn test_streaming_message_serialization() {
624 let message = StreamingMessage {
625 event_type: "test_event".to_string(),
626 data: serde_json::json!({"key": "value"}),
627 metadata: vec![
628 ("sequence".to_string(), "1".to_string()),
629 ("type".to_string(), "test".to_string()),
630 ]
631 .into_iter()
632 .collect(),
633 };
634
635 let json_str = serde_json::to_string(&message).unwrap();
637 assert!(json_str.contains("test_event"));
638 assert!(json_str.contains("key"));
639 assert!(json_str.contains("value"));
640 assert!(json_str.contains("sequence"));
641 assert!(json_str.contains("1"));
642 assert!(json_str.contains("type"));
643 assert!(json_str.contains("test"));
644
645 let deserialized: StreamingMessage = serde_json::from_str(&json_str).unwrap();
647 assert_eq!(deserialized.event_type, message.event_type);
648 assert_eq!(deserialized.data, message.data);
649 assert_eq!(deserialized.metadata, message.metadata);
650 }
651
652 #[tokio::test]
653 async fn test_create_sse_stream_basic() {
654 let config = HttpBridgeConfig {
655 enabled: true,
656 base_path: "/api".to_string(),
657 enable_cors: false,
658 max_request_size: 1024,
659 timeout_seconds: 30,
660 route_pattern: "/{service}/{method}".to_string(),
661 };
662
663 let stream_response = StreamHandler::create_sse_stream(
664 config,
665 "test_service".to_string(),
666 "test_method".to_string(),
667 )
668 .await;
669
670 let sse_response = stream_response.into_response();
672 assert_eq!(sse_response.status(), axum::http::StatusCode::OK);
673
674 let content_type = sse_response
676 .headers()
677 .get("content-type")
678 .and_then(|h| h.to_str().ok())
679 .unwrap_or("");
680 assert!(content_type.contains("text/event-stream"));
681 }
682
683 #[test]
684 fn test_bridge_response_serialization() {
685 let response = BridgeResponse::<serde_json::Value> {
686 success: true,
687 data: Some(serde_json::json!({"result": "success"})),
688 error: None,
689 metadata: vec![
690 ("service".to_string(), "test".to_string()),
691 ("method".to_string(), "test".to_string()),
692 ]
693 .into_iter()
694 .collect(),
695 };
696
697 let json_str = serde_json::to_string(&response).unwrap();
698 assert!(json_str.contains("success"));
699 assert!(json_str.contains("true"));
700 assert!(json_str.contains("result"));
701 assert!(json_str.contains("success"));
702 assert!(json_str.contains("service"));
703 assert!(json_str.contains("method"));
704
705 let deserialized: BridgeResponse<serde_json::Value> =
706 serde_json::from_str(&json_str).unwrap();
707 assert_eq!(deserialized.success, response.success);
708 assert_eq!(deserialized.data, response.data);
709 assert_eq!(deserialized.error, response.error);
710 assert_eq!(deserialized.metadata, response.metadata);
711 }
712
713 #[test]
714 fn test_validate_request_edge_cases() {
715 assert!(RequestProcessor::validate_request("test", "method", 0, 0).is_ok());
717 assert!(RequestProcessor::validate_request("test", "method", 1, 0).is_err());
718
719 assert!(RequestProcessor::validate_request(" test ", " method ", 100, 1000).is_ok());
721 assert!(RequestProcessor::validate_request(" ", "method", 100, 1000).is_err());
722 assert!(RequestProcessor::validate_request("test", " ", 100, 1000).is_err());
723
724 let large_size = usize::MAX / 2;
726 assert!(
727 RequestProcessor::validate_request("test", "method", large_size, usize::MAX).is_ok()
728 );
729 assert!(RequestProcessor::validate_request("test", "method", large_size + 1, large_size)
730 .is_err());
731 }
732
733 #[test]
734 fn test_header_extraction_comprehensive() {
735 let mut headers = HeaderMap::new();
736
737 headers.insert("host", "localhost:9080".parse().unwrap());
739 headers.insert("content-length", "123".parse().unwrap());
740 headers.insert("accept", "application/json".parse().unwrap());
741 headers.insert("x-forwarded-for", "192.168.1.1".parse().unwrap());
742 headers.insert("x-custom-metadata", "custom-value".parse().unwrap());
743 headers.insert("x-trace-id", "trace-123".parse().unwrap());
744 headers.insert("x-request-id", "req-456".parse().unwrap());
745
746 let metadata = RequestProcessor::extract_metadata_from_headers(&headers);
747
748 assert!(!metadata.contains_key("host"));
750 assert!(!metadata.contains_key("content-length"));
751 assert!(!metadata.contains_key("accept"));
752
753 assert_eq!(metadata.get("x-forwarded-for"), Some(&"192.168.1.1".to_string()));
755 assert_eq!(metadata.get("x-custom-metadata"), Some(&"custom-value".to_string()));
756 assert_eq!(metadata.get("x-trace-id"), Some(&"trace-123".to_string()));
757 assert_eq!(metadata.get("x-request-id"), Some(&"req-456".to_string()));
758
759 assert_eq!(metadata.len(), 4);
761 }
762
763 #[test]
764 fn test_error_response_comprehensive() {
765 let test_errors = vec![
767 "Service not found",
768 "Method not found",
769 "Invalid request body",
770 "Authentication failed",
771 "Internal server error",
772 "Timeout exceeded",
773 "Rate limit exceeded",
774 "Database connection failed",
775 ];
776
777 for error_msg in test_errors {
778 let response = ErrorHandler::create_error_response(error_msg.to_string());
779 assert!(!response.success);
780 assert!(response.data.is_none());
781 assert_eq!(response.error, Some(error_msg.to_string()));
782 assert!(response.metadata.is_empty());
783 }
784 }
785}