1use std::fmt;
8use std::future::Future;
9use std::pin::Pin;
10use std::time::{Duration, Instant};
11
12use crate::{SimulationResult, TokioNetworkProvider, TokioTaskProvider, TokioTimeProvider};
13
14use super::report::SimulationMetrics;
15use super::topology::WorkloadTopology;
16
17type TokioWorkloadFn = Box<
19 dyn Fn(
20 TokioNetworkProvider,
21 TokioTimeProvider,
22 TokioTaskProvider,
23 WorkloadTopology,
24 ) -> Pin<Box<dyn Future<Output = SimulationResult<SimulationMetrics>>>>,
25>;
26
27pub struct TokioWorkload {
29 name: String,
30 ip_address: String,
31 workload: TokioWorkloadFn,
32}
33
34impl fmt::Debug for TokioWorkload {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 f.debug_struct("TokioWorkload")
37 .field("name", &self.name)
38 .field("ip_address", &self.ip_address)
39 .field("workload", &"<closure>")
40 .finish()
41 }
42}
43
44#[derive(Debug, Clone)]
46pub struct TokioReport {
47 pub workload_results: Vec<(String, SimulationResult<SimulationMetrics>)>,
49 pub total_wall_time: Duration,
51 pub successful: usize,
53 pub failed: usize,
55}
56
57impl TokioReport {
58 pub fn success_rate(&self) -> f64 {
60 let total = self.successful + self.failed;
61 if total == 0 {
62 0.0
63 } else {
64 (self.successful as f64 / total as f64) * 100.0
65 }
66 }
67}
68
69impl fmt::Display for TokioReport {
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71 writeln!(f, "=== Tokio Execution Report ===")?;
72 writeln!(f, "Total Workloads: {}", self.successful + self.failed)?;
73 writeln!(f, "Successful: {}", self.successful)?;
74 writeln!(f, "Failed: {}", self.failed)?;
75 writeln!(f, "Success Rate: {:.2}%", self.success_rate())?;
76 writeln!(f, "Total Wall Time: {:?}", self.total_wall_time)?;
77 writeln!(f)?;
78
79 for (name, result) in &self.workload_results {
80 match result {
81 Ok(_) => writeln!(f, "✅ {}: SUCCESS", name)?,
82 Err(e) => writeln!(f, "❌ {}: FAILED - {:?}", name, e)?,
83 }
84 }
85
86 Ok(())
87 }
88}
89
90#[derive(Debug)]
92pub struct TokioRunner {
93 workloads: Vec<TokioWorkload>,
94 next_port: u16, }
96
97impl Default for TokioRunner {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103impl TokioRunner {
104 pub fn new() -> Self {
106 Self {
107 workloads: Vec::new(),
108 next_port: 9001, }
110 }
111
112 pub fn register_workload<S, F, Fut>(mut self, name: S, workload: F) -> Self
118 where
119 S: Into<String>,
120 F: Fn(TokioNetworkProvider, TokioTimeProvider, TokioTaskProvider, WorkloadTopology) -> Fut
121 + 'static,
122 Fut: Future<Output = SimulationResult<SimulationMetrics>> + 'static,
123 {
124 let ip_address = format!("127.0.0.1:{}", self.next_port);
126 self.next_port += 1;
127
128 let boxed_workload = Box::new(move |provider, time_provider, task_provider, topology| {
129 let fut = workload(provider, time_provider, task_provider, topology);
130 Box::pin(fut) as Pin<Box<dyn Future<Output = SimulationResult<SimulationMetrics>>>>
131 });
132
133 self.workloads.push(TokioWorkload {
134 name: name.into(),
135 ip_address,
136 workload: boxed_workload,
137 });
138 self
139 }
140
141 pub async fn run(self) -> TokioReport {
143 if self.workloads.is_empty() {
144 return TokioReport {
145 workload_results: Vec::new(),
146 total_wall_time: Duration::ZERO,
147 successful: 0,
148 failed: 0,
149 };
150 }
151
152 let start_time = Instant::now();
153
154 let shutdown_signal = tokio_util::sync::CancellationToken::new();
156
157 let all_ips: Vec<String> = self
159 .workloads
160 .iter()
161 .map(|w| w.ip_address.clone())
162 .collect();
163
164 let mut workload_results = Vec::new();
165 let mut successful = 0;
166 let mut failed = 0;
167
168 if self.workloads.len() == 1 {
170 let workload = &self.workloads[0];
172 let my_ip = workload.ip_address.clone();
173 let peer_ips = all_ips.iter().filter(|ip| *ip != &my_ip).cloned().collect();
174 let peer_names = self
175 .workloads
176 .iter()
177 .filter(|w| w.ip_address != my_ip)
178 .map(|w| w.name.clone())
179 .collect();
180 let topology = WorkloadTopology {
181 my_ip,
182 peer_ips,
183 peer_names,
184 shutdown_signal: shutdown_signal.clone(),
185 state_registry: crate::StateRegistry::new(),
186 };
187
188 let provider = TokioNetworkProvider::new();
189 let time_provider = TokioTimeProvider::new();
190 let task_provider = TokioTaskProvider;
191
192 let result =
193 (workload.workload)(provider, time_provider, task_provider, topology).await;
194
195 if result.is_ok() {
197 shutdown_signal.cancel();
198 }
199
200 match result {
201 Ok(_) => successful += 1,
202 Err(_) => failed += 1,
203 }
204 workload_results.push((workload.name.clone(), result));
205 } else {
206 let mut handles = Vec::new();
208
209 for workload in &self.workloads {
210 let my_ip = workload.ip_address.clone();
211 let peer_ips = all_ips.iter().filter(|ip| *ip != &my_ip).cloned().collect();
212 let peer_names = self
213 .workloads
214 .iter()
215 .filter(|w| w.ip_address != my_ip)
216 .map(|w| w.name.clone())
217 .collect();
218 let topology = WorkloadTopology {
219 my_ip,
220 peer_ips,
221 peer_names,
222 shutdown_signal: shutdown_signal.clone(),
223 state_registry: crate::StateRegistry::new(),
224 };
225
226 let provider = TokioNetworkProvider::new();
227 let time_provider = TokioTimeProvider::new();
228 let task_provider = TokioTaskProvider;
229
230 let handle = tokio::task::spawn_local((workload.workload)(
231 provider,
232 time_provider,
233 task_provider,
234 topology,
235 ));
236 handles.push((workload.name.clone(), handle));
237 }
238
239 let mut pending_futures: Vec<_> = handles
241 .into_iter()
242 .map(|(name, handle)| {
243 Box::pin(async move {
244 let result = handle.await;
245 (name, result)
246 })
247 })
248 .collect();
249
250 let mut first_success_triggered = false;
251
252 while !pending_futures.is_empty() {
254 let (completed_result, _index, remaining_futures) =
255 futures::future::select_all(pending_futures).await;
256
257 pending_futures = remaining_futures;
258
259 let (name, handle_result) = completed_result;
260 let result = match handle_result {
261 Ok(workload_result) => workload_result,
262 Err(_) => Err(crate::SimulationError::InvalidState(
263 "Task panicked".to_string(),
264 )),
265 };
266
267 if !first_success_triggered && result.is_ok() {
269 tracing::debug!(
270 "TokioRunner: Workload '{}' completed successfully, triggering shutdown",
271 name
272 );
273 shutdown_signal.cancel();
274 first_success_triggered = true;
275 }
276
277 match result {
278 Ok(_) => successful += 1,
279 Err(_) => failed += 1,
280 }
281 workload_results.push((name, result));
282 }
283 }
284
285 let total_wall_time = start_time.elapsed();
286
287 TokioReport {
288 workload_results,
289 total_wall_time,
290 successful,
291 failed,
292 }
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[test]
301 fn test_tokio_runner_empty() {
302 let local_runtime = tokio::runtime::Builder::new_current_thread()
303 .enable_io()
304 .enable_time()
305 .build_local(Default::default())
306 .expect("Failed to build local runtime");
307
308 let report = local_runtime.block_on(async { TokioRunner::new().run().await });
309
310 assert_eq!(report.successful, 0);
311 assert_eq!(report.failed, 0);
312 assert_eq!(report.success_rate(), 0.0);
313 }
314
315 #[test]
316 fn test_tokio_runner_single_workload() {
317 let local_runtime = tokio::runtime::Builder::new_current_thread()
318 .enable_io()
319 .enable_time()
320 .build_local(Default::default())
321 .expect("Failed to build local runtime");
322
323 let report = local_runtime.block_on(async {
324 TokioRunner::new()
325 .register_workload(
326 "test_workload",
327 |_provider, _time_provider, _task_provider, _topology| async {
328 Ok(SimulationMetrics::default())
329 },
330 )
331 .run()
332 .await
333 });
334
335 assert_eq!(report.successful, 1);
336 assert_eq!(report.failed, 0);
337 assert_eq!(report.success_rate(), 100.0);
338 assert!(report.total_wall_time > Duration::ZERO);
339 }
340
341 #[test]
342 fn test_tokio_runner_multiple_workloads() {
343 let local_runtime = tokio::runtime::Builder::new_current_thread()
344 .enable_io()
345 .enable_time()
346 .build_local(Default::default())
347 .expect("Failed to build local runtime");
348
349 let report = local_runtime.block_on(async {
350 TokioRunner::new()
351 .register_workload(
352 "workload1",
353 |_provider, _time_provider, _task_provider, _topology| async {
354 Ok(SimulationMetrics::default())
355 },
356 )
357 .register_workload(
358 "workload2",
359 |_provider, _time_provider, _task_provider, _topology| async {
360 Ok(SimulationMetrics::default())
361 },
362 )
363 .run()
364 .await
365 });
366
367 assert_eq!(report.successful, 2);
368 assert_eq!(report.failed, 0);
369 assert_eq!(report.success_rate(), 100.0);
370 }
371}