1use std::fmt;
8use std::future::Future;
9use std::pin::Pin;
10use std::time::{Duration, Instant};
11
12use crate::{SimulationResult, TokioNetworkProvider, TokioTaskProvider, TokioTimeProvider};
13
14use super::report::SimulationMetrics;
15use super::topology::WorkloadTopology;
16
17type TokioWorkloadFn = Box<
19 dyn Fn(
20 TokioNetworkProvider,
21 TokioTimeProvider,
22 TokioTaskProvider,
23 WorkloadTopology,
24 ) -> Pin<Box<dyn Future<Output = SimulationResult<SimulationMetrics>>>>,
25>;
26
27pub struct TokioWorkload {
29 name: String,
30 ip_address: String,
31 workload: TokioWorkloadFn,
32}
33
34impl fmt::Debug for TokioWorkload {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 f.debug_struct("TokioWorkload")
37 .field("name", &self.name)
38 .field("ip_address", &self.ip_address)
39 .field("workload", &"<closure>")
40 .finish()
41 }
42}
43
44#[derive(Debug, Clone)]
46pub struct TokioReport {
47 pub workload_results: Vec<(String, SimulationResult<SimulationMetrics>)>,
49 pub total_wall_time: Duration,
51 pub successful: usize,
53 pub failed: usize,
55}
56
57impl TokioReport {
58 pub fn success_rate(&self) -> f64 {
60 let total = self.successful + self.failed;
61 if total == 0 {
62 0.0
63 } else {
64 (self.successful as f64 / total as f64) * 100.0
65 }
66 }
67}
68
69impl fmt::Display for TokioReport {
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71 writeln!(f, "=== Tokio Execution Report ===")?;
72 writeln!(f, "Total Workloads: {}", self.successful + self.failed)?;
73 writeln!(f, "Successful: {}", self.successful)?;
74 writeln!(f, "Failed: {}", self.failed)?;
75 writeln!(f, "Success Rate: {:.2}%", self.success_rate())?;
76 writeln!(f, "Total Wall Time: {:?}", self.total_wall_time)?;
77 writeln!(f)?;
78
79 for (name, result) in &self.workload_results {
80 match result {
81 Ok(_) => writeln!(f, "✅ {}: SUCCESS", name)?,
82 Err(e) => writeln!(f, "❌ {}: FAILED - {:?}", name, e)?,
83 }
84 }
85
86 Ok(())
87 }
88}
89
90#[derive(Debug)]
92pub struct TokioRunner {
93 workloads: Vec<TokioWorkload>,
94 next_port: u16, }
96
97impl Default for TokioRunner {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103impl TokioRunner {
104 pub fn new() -> Self {
106 Self {
107 workloads: Vec::new(),
108 next_port: 9001, }
110 }
111
112 pub fn register_workload<S, F, Fut>(mut self, name: S, workload: F) -> Self
118 where
119 S: Into<String>,
120 F: Fn(TokioNetworkProvider, TokioTimeProvider, TokioTaskProvider, WorkloadTopology) -> Fut
121 + 'static,
122 Fut: Future<Output = SimulationResult<SimulationMetrics>> + 'static,
123 {
124 let ip_address = format!("127.0.0.1:{}", self.next_port);
126 self.next_port += 1;
127
128 let boxed_workload = Box::new(move |provider, time_provider, task_provider, topology| {
129 let fut = workload(provider, time_provider, task_provider, topology);
130 Box::pin(fut) as Pin<Box<dyn Future<Output = SimulationResult<SimulationMetrics>>>>
131 });
132
133 self.workloads.push(TokioWorkload {
134 name: name.into(),
135 ip_address,
136 workload: boxed_workload,
137 });
138 self
139 }
140
141 pub async fn run(self) -> TokioReport {
143 if self.workloads.is_empty() {
144 return TokioReport {
145 workload_results: Vec::new(),
146 total_wall_time: Duration::ZERO,
147 successful: 0,
148 failed: 0,
149 };
150 }
151
152 let start_time = Instant::now();
153
154 let shutdown_signal = tokio_util::sync::CancellationToken::new();
156
157 let all_ips: Vec<String> = self
159 .workloads
160 .iter()
161 .map(|w| w.ip_address.clone())
162 .collect();
163
164 let mut workload_results = Vec::new();
165 let mut successful = 0;
166 let mut failed = 0;
167
168 if self.workloads.len() == 1 {
170 let workload = &self.workloads[0];
172 let my_ip = workload.ip_address.clone();
173 let peer_ips = all_ips.iter().filter(|ip| *ip != &my_ip).cloned().collect();
174 let peer_names = self
175 .workloads
176 .iter()
177 .filter(|w| w.ip_address != my_ip)
178 .map(|w| w.name.clone())
179 .collect();
180 let topology = WorkloadTopology {
181 my_ip,
182 peer_ips,
183 peer_names,
184 shutdown_signal: shutdown_signal.clone(),
185 };
186
187 let provider = TokioNetworkProvider::new();
188 let time_provider = TokioTimeProvider::new();
189 let task_provider = TokioTaskProvider;
190
191 let result =
192 (workload.workload)(provider, time_provider, task_provider, topology).await;
193
194 if result.is_ok() {
196 shutdown_signal.cancel();
197 }
198
199 match result {
200 Ok(_) => successful += 1,
201 Err(_) => failed += 1,
202 }
203 workload_results.push((workload.name.clone(), result));
204 } else {
205 let mut handles = Vec::new();
207
208 for workload in &self.workloads {
209 let my_ip = workload.ip_address.clone();
210 let peer_ips = all_ips.iter().filter(|ip| *ip != &my_ip).cloned().collect();
211 let peer_names = self
212 .workloads
213 .iter()
214 .filter(|w| w.ip_address != my_ip)
215 .map(|w| w.name.clone())
216 .collect();
217 let topology = WorkloadTopology {
218 my_ip,
219 peer_ips,
220 peer_names,
221 shutdown_signal: shutdown_signal.clone(),
222 };
223
224 let provider = TokioNetworkProvider::new();
225 let time_provider = TokioTimeProvider::new();
226 let task_provider = TokioTaskProvider;
227
228 let handle = tokio::task::spawn_local((workload.workload)(
229 provider,
230 time_provider,
231 task_provider,
232 topology,
233 ));
234 handles.push((workload.name.clone(), handle));
235 }
236
237 let mut pending_futures: Vec<_> = handles
239 .into_iter()
240 .map(|(name, handle)| {
241 Box::pin(async move {
242 let result = handle.await;
243 (name, result)
244 })
245 })
246 .collect();
247
248 let mut first_success_triggered = false;
249
250 while !pending_futures.is_empty() {
252 let (completed_result, _index, remaining_futures) =
253 futures::future::select_all(pending_futures).await;
254
255 pending_futures = remaining_futures;
256
257 let (name, handle_result) = completed_result;
258 let result = match handle_result {
259 Ok(workload_result) => workload_result,
260 Err(_) => Err(crate::SimulationError::InvalidState(
261 "Task panicked".to_string(),
262 )),
263 };
264
265 if !first_success_triggered && result.is_ok() {
267 tracing::debug!(
268 "TokioRunner: Workload '{}' completed successfully, triggering shutdown",
269 name
270 );
271 shutdown_signal.cancel();
272 first_success_triggered = true;
273 }
274
275 match result {
276 Ok(_) => successful += 1,
277 Err(_) => failed += 1,
278 }
279 workload_results.push((name, result));
280 }
281 }
282
283 let total_wall_time = start_time.elapsed();
284
285 TokioReport {
286 workload_results,
287 total_wall_time,
288 successful,
289 failed,
290 }
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_tokio_runner_empty() {
300 let local_runtime = tokio::runtime::Builder::new_current_thread()
301 .enable_io()
302 .enable_time()
303 .build_local(Default::default())
304 .expect("Failed to build local runtime");
305
306 let report = local_runtime.block_on(async { TokioRunner::new().run().await });
307
308 assert_eq!(report.successful, 0);
309 assert_eq!(report.failed, 0);
310 assert_eq!(report.success_rate(), 0.0);
311 }
312
313 #[test]
314 fn test_tokio_runner_single_workload() {
315 let local_runtime = tokio::runtime::Builder::new_current_thread()
316 .enable_io()
317 .enable_time()
318 .build_local(Default::default())
319 .expect("Failed to build local runtime");
320
321 let report = local_runtime.block_on(async {
322 TokioRunner::new()
323 .register_workload(
324 "test_workload",
325 |_provider, _time_provider, _task_provider, _topology| async {
326 Ok(SimulationMetrics::default())
327 },
328 )
329 .run()
330 .await
331 });
332
333 assert_eq!(report.successful, 1);
334 assert_eq!(report.failed, 0);
335 assert_eq!(report.success_rate(), 100.0);
336 assert!(report.total_wall_time > Duration::ZERO);
337 }
338
339 #[test]
340 fn test_tokio_runner_multiple_workloads() {
341 let local_runtime = tokio::runtime::Builder::new_current_thread()
342 .enable_io()
343 .enable_time()
344 .build_local(Default::default())
345 .expect("Failed to build local runtime");
346
347 let report = local_runtime.block_on(async {
348 TokioRunner::new()
349 .register_workload(
350 "workload1",
351 |_provider, _time_provider, _task_provider, _topology| async {
352 Ok(SimulationMetrics::default())
353 },
354 )
355 .register_workload(
356 "workload2",
357 |_provider, _time_provider, _task_provider, _topology| async {
358 Ok(SimulationMetrics::default())
359 },
360 )
361 .run()
362 .await
363 });
364
365 assert_eq!(report.successful, 2);
366 assert_eq!(report.failed, 0);
367 assert_eq!(report.success_rate(), 100.0);
368 }
369}