1use axum::{
4 extract::{Query, State},
5 response::{sse::Event, Sse},
6 routing::get,
7 Router,
8};
9use futures::stream::{self, Stream};
10use serde::{Deserialize, Serialize};
11use std::convert::Infallible;
12use std::time::Duration;
13
14use mockforge_core::templating;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct SSEConfig {
19 pub event_type: Option<String>,
21 pub data_template: String,
23 pub interval_ms: u64,
25 pub max_events: usize,
27 pub initial_delay_ms: u64,
29}
30
31#[derive(Debug, Deserialize)]
33pub struct SSEQueryParams {
34 pub event: Option<String>,
36 pub data: Option<String>,
38 pub interval: Option<u64>,
40 pub max_events: Option<usize>,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct SSEEventData {
47 pub id: Option<String>,
49 pub event: Option<String>,
51 pub data: serde_json::Value,
53 pub retry: Option<u64>,
55 pub timestamp: String,
57}
58
59#[derive(Clone)]
61pub struct SSEStreamManager {
62 config: SSEConfig,
63}
64
65impl SSEStreamManager {
66 pub fn new(config: SSEConfig) -> Self {
68 Self { config }
69 }
70
71 pub fn default_config() -> SSEConfig {
73 SSEConfig {
74 event_type: Some("message".to_string()),
75 data_template: r#"{"message": "{{faker.sentence}}", "timestamp": "{{now}}"}"#
76 .to_string(),
77 interval_ms: 1000,
78 max_events: 0, initial_delay_ms: 0,
80 }
81 }
82
83 pub fn create_stream(
85 &self,
86 query_params: SSEQueryParams,
87 ) -> impl Stream<Item = Result<Event, Infallible>> {
88 let config = self.merge_config_with_params(query_params);
89
90 let event_type = config.event_type.clone();
91 let data_template = config.data_template.clone();
92 let max_events = config.max_events;
93 let interval_duration = Duration::from_millis(config.interval_ms);
94 let initial_delay = config.initial_delay_ms;
95
96 let event_type = event_type.clone();
98 let data_template = data_template.clone();
99
100 stream::unfold(0usize, move |count| {
101 let event_type = event_type.clone();
102 let data_template = data_template.clone();
103 let max_events = max_events;
104 let interval_duration = interval_duration;
105 let initial_delay = initial_delay;
106
107 Box::pin(async move {
108 if max_events > 0 && count >= max_events {
110 return None;
111 }
112
113 if count > 0 || initial_delay > 0 {
115 tokio::time::sleep(interval_duration).await;
116 }
117
118 let event_data = Self::generate_event_data(&data_template, count);
120
121 let mut event = Event::default();
123
124 if let Some(event_type) = &event_type {
125 event = event.event(event_type);
126 }
127
128 let data_json =
130 serde_json::to_string(&event_data).unwrap_or_else(|_| "{}".to_string());
131 event = event.data(data_json);
132
133 event = event.id(count.to_string());
135
136 Some((Ok(event), count + 1))
137 })
138 })
139 }
140
141 fn generate_event_data(template: &str, count: usize) -> SSEEventData {
143 let expanded_data = templating::expand_str(template);
145
146 let data_value = serde_json::from_str(&expanded_data)
148 .unwrap_or(serde_json::Value::String(expanded_data));
149
150 SSEEventData {
151 id: Some(count.to_string()),
152 event: None, data: data_value,
154 retry: None,
155 timestamp: templating::expand_str("{{now}}"),
156 }
157 }
158
159 fn merge_config_with_params(&self, params: SSEQueryParams) -> SSEConfig {
161 let mut config = self.config.clone();
162
163 if let Some(event) = params.event {
164 config.event_type = Some(event);
165 }
166
167 if let Some(data) = params.data {
168 config.data_template = data;
169 }
170
171 if let Some(interval) = params.interval {
172 config.interval_ms = interval;
173 }
174
175 if let Some(max_events) = params.max_events {
176 config.max_events = max_events;
177 }
178
179 config
180 }
181}
182
183pub fn sse_router() -> Router {
185 sse_router_with_config(SSEStreamManager::default_config())
186}
187
188pub fn sse_router_with_config(config: SSEConfig) -> Router {
190 let manager = SSEStreamManager::new(config);
191 Router::new().route("/sse", get(sse_handler)).with_state(manager)
192}
193
194async fn sse_handler(
196 State(manager): State<SSEStreamManager>,
197 Query(params): Query<SSEQueryParams>,
198) -> Sse<impl Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>>> {
199 let stream = manager.create_stream(params);
200
201 Sse::new(stream).keep_alive(
202 axum::response::sse::KeepAlive::new()
203 .interval(Duration::from_secs(1))
204 .text("keepalive"),
205 )
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use tokio_stream::StreamExt;
212
213 #[tokio::test]
214 async fn test_sse_stream_generation() {
215 let config = SSEConfig {
216 event_type: Some("test".to_string()),
217 data_template: r#"{"count": {{count}}}"#.to_string(),
218 interval_ms: 10,
219 max_events: 3,
220 initial_delay_ms: 0,
221 };
222
223 let manager = SSEStreamManager::new(config);
224 let params = SSEQueryParams {
225 event: None,
226 data: None,
227 interval: None,
228 max_events: None,
229 };
230
231 let mut stream = manager.create_stream(params);
232 let mut events = Vec::new();
233
234 while let Some(result) = stream.next().await {
235 match result {
236 Ok(event) => events.push(event),
237 Err(_) => break,
238 }
239 if events.len() >= 3 {
240 break;
241 }
242 }
243
244 assert_eq!(events.len(), 3);
245 }
246
247 #[tokio::test]
248 async fn test_event_data_generation() {
249 let template = r#"{"message": "test", "timestamp": "{{now}}"}"#;
250 let event_data = SSEStreamManager::generate_event_data(template, 1);
251
252 assert_eq!(event_data.id, Some("1".to_string()));
253 assert!(!event_data.timestamp.is_empty());
254 }
255}