1use axum::{
4 extract::{Query, State},
5 response::{sse::Event, Sse},
6 routing::get,
7 Router,
8};
9use futures::stream::{self, Stream};
10use serde::{Deserialize, Serialize};
11use std::convert::Infallible;
12use std::time::Duration;
13
14use mockforge_core::templating;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct SSEConfig {
19 pub event_type: Option<String>,
21 pub data_template: String,
23 pub interval_ms: u64,
25 pub max_events: usize,
27 pub initial_delay_ms: u64,
29}
30
31#[derive(Debug, Deserialize)]
33pub struct SSEQueryParams {
34 pub event: Option<String>,
36 pub data: Option<String>,
38 pub interval: Option<u64>,
40 pub max_events: Option<usize>,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct SSEEventData {
47 pub id: Option<String>,
49 pub event: Option<String>,
51 pub data: serde_json::Value,
53 pub retry: Option<u64>,
55 pub timestamp: String,
57}
58
59#[derive(Clone)]
61pub struct SSEStreamManager {
62 config: SSEConfig,
63}
64
65impl SSEStreamManager {
66 pub fn new(config: SSEConfig) -> Self {
68 Self { config }
69 }
70
71 pub fn default_config() -> SSEConfig {
73 SSEConfig {
74 event_type: Some("message".to_string()),
75 data_template: r#"{"message": "{{faker.sentence}}", "timestamp": "{{now}}"}"#
76 .to_string(),
77 interval_ms: 1000,
78 max_events: 0, initial_delay_ms: 0,
80 }
81 }
82
83 pub fn create_stream(
85 &self,
86 query_params: SSEQueryParams,
87 ) -> impl Stream<Item = Result<Event, Infallible>> {
88 let config = self.merge_config_with_params(query_params);
89
90 let event_type = config.event_type.clone();
91 let data_template = config.data_template.clone();
92 let max_events = config.max_events;
93 let interval_duration = Duration::from_millis(config.interval_ms);
94 let initial_delay = config.initial_delay_ms;
95
96 let event_type = event_type.clone();
98 let data_template = data_template.clone();
99
100 stream::unfold(0usize, move |count| {
101 let event_type = event_type.clone();
102 let data_template = data_template.clone();
103 let max_events = max_events;
104 let interval_duration = interval_duration;
105 let initial_delay = initial_delay;
106
107 Box::pin(async move {
108 if max_events > 0 && count >= max_events {
110 return None;
111 }
112
113 if count > 0 || initial_delay > 0 {
115 tokio::time::sleep(interval_duration).await;
116 }
117
118 let event_data = Self::generate_event_data(&data_template, count);
120
121 let mut event = Event::default();
123
124 if let Some(event_type) = &event_type {
125 event = event.event(event_type);
126 }
127
128 let data_json =
130 serde_json::to_string(&event_data).unwrap_or_else(|_| "{}".to_string());
131 event = event.data(data_json);
132
133 event = event.id(count.to_string());
135
136 Some((Ok(event), count + 1))
137 })
138 })
139 }
140
141 fn generate_event_data(template: &str, count: usize) -> SSEEventData {
143 let expanded_data = templating::expand_str(template);
145
146 let data_value = serde_json::from_str(&expanded_data)
148 .unwrap_or(serde_json::Value::String(expanded_data));
149
150 SSEEventData {
151 id: Some(count.to_string()),
152 event: None, data: data_value,
154 retry: None,
155 timestamp: templating::expand_str("{{now}}"),
156 }
157 }
158
159 fn merge_config_with_params(&self, params: SSEQueryParams) -> SSEConfig {
161 let mut config = self.config.clone();
162
163 if let Some(event) = params.event {
164 config.event_type = Some(event);
165 }
166
167 if let Some(data) = params.data {
168 config.data_template = data;
169 }
170
171 if let Some(interval) = params.interval {
172 config.interval_ms = interval;
173 }
174
175 if let Some(max_events) = params.max_events {
176 config.max_events = max_events;
177 }
178
179 config
180 }
181}
182
183pub fn sse_router() -> Router {
185 sse_router_with_config(SSEStreamManager::default_config())
186}
187
188pub fn sse_router_with_config(config: SSEConfig) -> Router {
190 let manager = SSEStreamManager::new(config);
191 Router::new().route("/sse", get(sse_handler)).with_state(manager)
192}
193
194async fn sse_handler(
196 State(manager): State<SSEStreamManager>,
197 Query(params): Query<SSEQueryParams>,
198) -> Sse<impl Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>>> {
199 let stream = manager.create_stream(params);
200
201 Sse::new(stream).keep_alive(
202 axum::response::sse::KeepAlive::new()
203 .interval(Duration::from_secs(1))
204 .text("keepalive"),
205 )
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use tokio_stream::StreamExt;
212
213 #[test]
216 fn test_sse_config_default_via_manager() {
217 let config = SSEStreamManager::default_config();
218
219 assert_eq!(config.event_type, Some("message".to_string()));
220 assert!(config.data_template.contains("faker"));
221 assert_eq!(config.interval_ms, 1000);
222 assert_eq!(config.max_events, 0);
223 assert_eq!(config.initial_delay_ms, 0);
224 }
225
226 #[test]
227 fn test_sse_config_custom() {
228 let config = SSEConfig {
229 event_type: Some("custom".to_string()),
230 data_template: "test data".to_string(),
231 interval_ms: 500,
232 max_events: 10,
233 initial_delay_ms: 100,
234 };
235
236 assert_eq!(config.event_type, Some("custom".to_string()));
237 assert_eq!(config.data_template, "test data");
238 assert_eq!(config.interval_ms, 500);
239 assert_eq!(config.max_events, 10);
240 assert_eq!(config.initial_delay_ms, 100);
241 }
242
243 #[test]
244 fn test_sse_config_clone() {
245 let config = SSEConfig {
246 event_type: Some("clone_test".to_string()),
247 data_template: "clone data".to_string(),
248 interval_ms: 250,
249 max_events: 5,
250 initial_delay_ms: 50,
251 };
252
253 let cloned = config.clone();
254
255 assert_eq!(cloned.event_type, config.event_type);
256 assert_eq!(cloned.data_template, config.data_template);
257 assert_eq!(cloned.interval_ms, config.interval_ms);
258 assert_eq!(cloned.max_events, config.max_events);
259 assert_eq!(cloned.initial_delay_ms, config.initial_delay_ms);
260 }
261
262 #[test]
263 fn test_sse_config_debug() {
264 let config = SSEConfig {
265 event_type: Some("debug".to_string()),
266 data_template: "debug data".to_string(),
267 interval_ms: 100,
268 max_events: 1,
269 initial_delay_ms: 0,
270 };
271
272 let debug_str = format!("{:?}", config);
273
274 assert!(debug_str.contains("event_type"));
275 assert!(debug_str.contains("data_template"));
276 assert!(debug_str.contains("interval_ms"));
277 }
278
279 #[test]
280 fn test_sse_config_serialization() {
281 let config = SSEConfig {
282 event_type: Some("serialize".to_string()),
283 data_template: "{\"key\": \"value\"}".to_string(),
284 interval_ms: 200,
285 max_events: 3,
286 initial_delay_ms: 10,
287 };
288
289 let json = serde_json::to_string(&config).unwrap();
290 let deserialized: SSEConfig = serde_json::from_str(&json).unwrap();
291
292 assert_eq!(deserialized.event_type, config.event_type);
293 assert_eq!(deserialized.interval_ms, config.interval_ms);
294 }
295
296 #[test]
297 fn test_sse_config_no_event_type() {
298 let config = SSEConfig {
299 event_type: None,
300 data_template: "data".to_string(),
301 interval_ms: 100,
302 max_events: 1,
303 initial_delay_ms: 0,
304 };
305
306 assert!(config.event_type.is_none());
307 }
308
309 #[test]
312 fn test_sse_query_params_all_none() {
313 let params = SSEQueryParams {
314 event: None,
315 data: None,
316 interval: None,
317 max_events: None,
318 };
319
320 assert!(params.event.is_none());
321 assert!(params.data.is_none());
322 assert!(params.interval.is_none());
323 assert!(params.max_events.is_none());
324 }
325
326 #[test]
327 fn test_sse_query_params_with_values() {
328 let params = SSEQueryParams {
329 event: Some("custom_event".to_string()),
330 data: Some("{\"custom\": true}".to_string()),
331 interval: Some(500),
332 max_events: Some(10),
333 };
334
335 assert_eq!(params.event, Some("custom_event".to_string()));
336 assert_eq!(params.data, Some("{\"custom\": true}".to_string()));
337 assert_eq!(params.interval, Some(500));
338 assert_eq!(params.max_events, Some(10));
339 }
340
341 #[test]
342 fn test_sse_query_params_debug() {
343 let params = SSEQueryParams {
344 event: Some("test".to_string()),
345 data: None,
346 interval: Some(100),
347 max_events: None,
348 };
349
350 let debug_str = format!("{:?}", params);
351
352 assert!(debug_str.contains("event"));
353 assert!(debug_str.contains("interval"));
354 }
355
356 #[test]
359 fn test_sse_event_data_creation() {
360 let event_data = SSEEventData {
361 id: Some("1".to_string()),
362 event: Some("test_event".to_string()),
363 data: serde_json::json!({"key": "value"}),
364 retry: Some(5000),
365 timestamp: "2024-01-01T00:00:00Z".to_string(),
366 };
367
368 assert_eq!(event_data.id, Some("1".to_string()));
369 assert_eq!(event_data.event, Some("test_event".to_string()));
370 assert_eq!(event_data.data["key"], "value");
371 assert_eq!(event_data.retry, Some(5000));
372 }
373
374 #[test]
375 fn test_sse_event_data_clone() {
376 let event_data = SSEEventData {
377 id: Some("2".to_string()),
378 event: None,
379 data: serde_json::json!({"number": 42}),
380 retry: None,
381 timestamp: "now".to_string(),
382 };
383
384 let cloned = event_data.clone();
385
386 assert_eq!(cloned.id, event_data.id);
387 assert_eq!(cloned.event, event_data.event);
388 assert_eq!(cloned.data, event_data.data);
389 }
390
391 #[test]
392 fn test_sse_event_data_serialization() {
393 let event_data = SSEEventData {
394 id: Some("3".to_string()),
395 event: Some("message".to_string()),
396 data: serde_json::json!({"text": "hello"}),
397 retry: Some(1000),
398 timestamp: "test".to_string(),
399 };
400
401 let json = serde_json::to_string(&event_data).unwrap();
402
403 assert!(json.contains("\"id\":\"3\""));
404 assert!(json.contains("\"event\":\"message\""));
405 assert!(json.contains("\"text\":\"hello\""));
406 }
407
408 #[test]
411 fn test_sse_stream_manager_creation() {
412 let config = SSEConfig {
413 event_type: Some("test".to_string()),
414 data_template: "data".to_string(),
415 interval_ms: 100,
416 max_events: 5,
417 initial_delay_ms: 0,
418 };
419
420 let manager = SSEStreamManager::new(config.clone());
421
422 assert_eq!(manager.config.event_type, config.event_type);
424 assert_eq!(manager.config.max_events, config.max_events);
425 }
426
427 #[test]
428 fn test_merge_config_with_empty_params() {
429 let config = SSEConfig {
430 event_type: Some("original".to_string()),
431 data_template: "original data".to_string(),
432 interval_ms: 1000,
433 max_events: 10,
434 initial_delay_ms: 0,
435 };
436
437 let manager = SSEStreamManager::new(config);
438 let params = SSEQueryParams {
439 event: None,
440 data: None,
441 interval: None,
442 max_events: None,
443 };
444
445 let merged = manager.merge_config_with_params(params);
446
447 assert_eq!(merged.event_type, Some("original".to_string()));
448 assert_eq!(merged.data_template, "original data");
449 assert_eq!(merged.interval_ms, 1000);
450 assert_eq!(merged.max_events, 10);
451 }
452
453 #[test]
454 fn test_merge_config_with_all_params() {
455 let config = SSEConfig {
456 event_type: Some("original".to_string()),
457 data_template: "original data".to_string(),
458 interval_ms: 1000,
459 max_events: 10,
460 initial_delay_ms: 0,
461 };
462
463 let manager = SSEStreamManager::new(config);
464 let params = SSEQueryParams {
465 event: Some("overridden".to_string()),
466 data: Some("overridden data".to_string()),
467 interval: Some(500),
468 max_events: Some(5),
469 };
470
471 let merged = manager.merge_config_with_params(params);
472
473 assert_eq!(merged.event_type, Some("overridden".to_string()));
474 assert_eq!(merged.data_template, "overridden data");
475 assert_eq!(merged.interval_ms, 500);
476 assert_eq!(merged.max_events, 5);
477 }
478
479 #[test]
480 fn test_merge_config_partial_override() {
481 let config = SSEConfig {
482 event_type: Some("original".to_string()),
483 data_template: "original data".to_string(),
484 interval_ms: 1000,
485 max_events: 10,
486 initial_delay_ms: 0,
487 };
488
489 let manager = SSEStreamManager::new(config);
490 let params = SSEQueryParams {
491 event: Some("new_event".to_string()),
492 data: None,
493 interval: Some(2000),
494 max_events: None,
495 };
496
497 let merged = manager.merge_config_with_params(params);
498
499 assert_eq!(merged.event_type, Some("new_event".to_string()));
500 assert_eq!(merged.data_template, "original data"); assert_eq!(merged.interval_ms, 2000);
502 assert_eq!(merged.max_events, 10); }
504
505 #[test]
508 fn test_generate_event_data_simple_template() {
509 let template = r#"{"message": "hello"}"#;
510 let event_data = SSEStreamManager::generate_event_data(template, 0);
511
512 assert_eq!(event_data.id, Some("0".to_string()));
513 assert_eq!(event_data.event, None);
514 assert!(event_data.data.is_object());
515 }
516
517 #[test]
518 fn test_generate_event_data_string_fallback() {
519 let template = "not json at all";
520 let event_data = SSEStreamManager::generate_event_data(template, 5);
521
522 assert_eq!(event_data.id, Some("5".to_string()));
523 assert!(event_data.data.is_string());
525 }
526
527 #[test]
528 fn test_generate_event_data_incremental_count() {
529 for count in 0..5 {
530 let template = r#"{"test": true}"#;
531 let event_data = SSEStreamManager::generate_event_data(template, count);
532
533 assert_eq!(event_data.id, Some(count.to_string()));
534 }
535 }
536
537 #[test]
538 fn test_generate_event_data_timestamp_populated() {
539 let template = "{}";
540 let event_data = SSEStreamManager::generate_event_data(template, 0);
541
542 assert!(!event_data.timestamp.is_empty());
543 }
544
545 #[tokio::test]
548 async fn test_sse_stream_generation() {
549 let config = SSEConfig {
550 event_type: Some("test".to_string()),
551 data_template: r#"{"count": {{count}}}"#.to_string(),
552 interval_ms: 10,
553 max_events: 3,
554 initial_delay_ms: 0,
555 };
556
557 let manager = SSEStreamManager::new(config);
558 let params = SSEQueryParams {
559 event: None,
560 data: None,
561 interval: None,
562 max_events: None,
563 };
564
565 let mut stream = manager.create_stream(params);
566 let mut events = Vec::new();
567
568 while let Some(result) = stream.next().await {
569 match result {
570 Ok(event) => events.push(event),
571 Err(_) => break,
572 }
573 if events.len() >= 3 {
574 break;
575 }
576 }
577
578 assert_eq!(events.len(), 3);
579 }
580
581 #[tokio::test]
582 async fn test_event_data_generation() {
583 let template = r#"{"message": "test", "timestamp": "{{now}}"}"#;
584 let event_data = SSEStreamManager::generate_event_data(template, 1);
585
586 assert_eq!(event_data.id, Some("1".to_string()));
587 assert!(!event_data.timestamp.is_empty());
588 }
589
590 #[tokio::test]
591 async fn test_sse_stream_with_max_events_1() {
592 let config = SSEConfig {
593 event_type: Some("single".to_string()),
594 data_template: r#"{"single": true}"#.to_string(),
595 interval_ms: 1,
596 max_events: 1,
597 initial_delay_ms: 0,
598 };
599
600 let manager = SSEStreamManager::new(config);
601 let params = SSEQueryParams {
602 event: None,
603 data: None,
604 interval: None,
605 max_events: None,
606 };
607
608 let mut stream = manager.create_stream(params);
609 let mut count = 0;
610
611 while let Some(Ok(_)) = stream.next().await {
612 count += 1;
613 if count > 5 {
614 break; }
616 }
617
618 assert_eq!(count, 1);
619 }
620
621 #[tokio::test]
622 async fn test_sse_stream_with_query_param_override() {
623 let config = SSEConfig {
624 event_type: Some("original".to_string()),
625 data_template: r#"{"original": true}"#.to_string(),
626 interval_ms: 1000,
627 max_events: 100,
628 initial_delay_ms: 0,
629 };
630
631 let manager = SSEStreamManager::new(config);
632 let params = SSEQueryParams {
633 event: None,
634 data: None,
635 interval: Some(1),
636 max_events: Some(2),
637 };
638
639 let mut stream = manager.create_stream(params);
640 let mut count = 0;
641
642 while let Some(Ok(_)) = stream.next().await {
643 count += 1;
644 if count > 10 {
645 break; }
647 }
648
649 assert_eq!(count, 2);
650 }
651}