1use crate::measure::pin_to_cpu;
10use crate::{Bencher, BenchmarkDef, run_benchmark_loop};
11use fluxbench_ipc::{
12 BenchmarkConfig, FailureKind, FrameReader, FrameWriter, SampleRingBuffer, SupervisorCommand,
13 WorkerCapabilities, WorkerMessage,
14};
15use std::sync::atomic::{AtomicBool, Ordering};
16
17#[cfg(unix)]
18use std::os::unix::io::FromRawFd;
19
20static SHUTDOWN_REQUESTED: AtomicBool = AtomicBool::new(false);
22
23pub fn shutdown_requested() -> bool {
25 SHUTDOWN_REQUESTED.load(Ordering::Relaxed)
26}
27
28#[cfg(unix)]
31fn install_sigterm_handler() {
32 unsafe {
33 let mut sa: libc::sigaction = std::mem::zeroed();
34 sa.sa_sigaction = sigterm_handler as *const () as usize;
35 sa.sa_flags = libc::SA_RESTART;
36 libc::sigemptyset(&mut sa.sa_mask);
37 libc::sigaction(libc::SIGTERM, &sa, std::ptr::null_mut());
38 }
39}
40
41#[cfg(unix)]
42extern "C" fn sigterm_handler(_sig: libc::c_int) {
43 SHUTDOWN_REQUESTED.store(true, Ordering::Relaxed);
44}
45
46#[cfg(not(unix))]
48fn install_sigterm_handler() {}
49
50enum IpcTransport {
52 #[cfg(unix)]
53 Fds {
54 read_fd: i32,
55 write_fd: i32,
56 },
57 Stdio,
58}
59
60fn detect_transport() -> IpcTransport {
61 #[cfg(unix)]
62 if let Ok(val) = std::env::var("FLUX_IPC_FD") {
63 let parts: Vec<&str> = val.split(',').collect();
64 if parts.len() == 2 {
65 if let (Ok(r), Ok(w)) = (parts[0].parse::<i32>(), parts[1].parse::<i32>()) {
66 return IpcTransport::Fds {
67 read_fd: r,
68 write_fd: w,
69 };
70 }
71 }
72 eprintln!(
73 "fluxbench: warning: invalid FLUX_IPC_FD={val:?} (expected format: <read_fd>,<write_fd>), falling back to stdio"
74 );
75 }
76 IpcTransport::Stdio
77}
78
79pub struct WorkerMain {
81 reader: FrameReader<Box<dyn std::io::Read>>,
82 writer: FrameWriter<Box<dyn std::io::Write>>,
83}
84
85impl WorkerMain {
86 pub fn new() -> Self {
88 match detect_transport() {
89 #[cfg(unix)]
90 IpcTransport::Fds { read_fd, write_fd } => {
91 let read_file = unsafe { std::fs::File::from_raw_fd(read_fd) };
92 let write_file = unsafe { std::fs::File::from_raw_fd(write_fd) };
93 Self {
94 reader: FrameReader::new(Box::new(read_file) as Box<dyn std::io::Read>),
95 writer: FrameWriter::new(Box::new(write_file) as Box<dyn std::io::Write>),
96 }
97 }
98 IpcTransport::Stdio => Self {
99 reader: FrameReader::new(Box::new(std::io::stdin()) as Box<dyn std::io::Read>),
100 writer: FrameWriter::new(Box::new(std::io::stdout()) as Box<dyn std::io::Write>),
101 },
102 }
103 }
104
105 pub fn run(&mut self) -> Result<(), Box<dyn std::error::Error>> {
107 install_sigterm_handler();
108
109 self.writer
111 .write(&WorkerMessage::Hello(WorkerCapabilities::default()))?;
112
113 let _ = pin_to_cpu(0);
115
116 loop {
118 if shutdown_requested() {
119 break;
120 }
121
122 let command: SupervisorCommand = self.reader.read()?;
123
124 match command {
125 SupervisorCommand::Run { bench_id, config } => {
126 self.run_benchmark(&bench_id, &config)?;
127 if shutdown_requested() {
128 break;
129 }
130 }
131 SupervisorCommand::Abort => {
132 break;
133 }
134 SupervisorCommand::Shutdown => {
135 break;
136 }
137 SupervisorCommand::Ping => {}
138 }
139 }
140
141 Ok(())
142 }
143
144 fn run_benchmark(
146 &mut self,
147 bench_id: &str,
148 config: &BenchmarkConfig,
149 ) -> Result<(), Box<dyn std::error::Error>> {
150 let bench = inventory::iter::<BenchmarkDef>
152 .into_iter()
153 .find(|b| b.id == bench_id);
154
155 let bench = match bench {
156 Some(b) => b,
157 None => {
158 self.writer.write(&WorkerMessage::Failure {
159 kind: FailureKind::Unknown,
160 message: format!("Benchmark not found: {}", bench_id),
161 backtrace: None,
162 })?;
163 return Ok(());
164 }
165 };
166
167 let mut ring_buffer = SampleRingBuffer::new(bench_id);
169
170 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
172 let bencher = Bencher::new(config.track_allocations);
173
174 run_benchmark_loop(
175 bencher,
176 |b| (bench.runner_fn)(b),
177 config.warmup_time_ns,
178 config.measurement_time_ns,
179 config.min_iterations,
180 config.max_iterations,
181 )
182 }));
183
184 match result {
185 Ok(bench_result) => {
186 for sample in bench_result.samples {
188 if let Some(batch) = ring_buffer.push(sample) {
189 self.writer.write(&WorkerMessage::SampleBatch(batch))?;
190 }
191 }
192
193 if let Some(batch) = ring_buffer.flush_final() {
195 self.writer.write(&WorkerMessage::SampleBatch(batch))?;
196 }
197
198 self.writer.write(&WorkerMessage::Complete {
200 total_iterations: bench_result.iterations,
201 total_duration_nanos: bench_result.total_time_ns,
202 })?;
203 }
204 Err(panic) => {
205 let message = if let Some(s) = panic.downcast_ref::<&str>() {
206 s.to_string()
207 } else if let Some(s) = panic.downcast_ref::<String>() {
208 s.clone()
209 } else {
210 "Unknown panic".to_string()
211 };
212
213 if let Some(batch) = ring_buffer.flush_final() {
215 let _ = self.writer.write(&WorkerMessage::SampleBatch(batch));
216 }
217
218 let backtrace = std::backtrace::Backtrace::capture();
219 let backtrace_str = match backtrace.status() {
220 std::backtrace::BacktraceStatus::Captured => Some(backtrace.to_string()),
221 _ => None,
222 };
223
224 self.writer.write(&WorkerMessage::Failure {
225 kind: FailureKind::Panic,
226 message,
227 backtrace: backtrace_str,
228 })?;
229 }
230 }
231
232 Ok(())
233 }
234}
235
236impl Default for WorkerMain {
237 fn default() -> Self {
238 Self::new()
239 }
240}