1use self::ffi::{VTable, SELF_VTABLE};
4use crate::{Benchmark, ErasedSampler, Error};
5use anyhow::Context;
6use libloading::{Library, Symbol};
7use std::{
8 ffi::{c_char, c_ulonglong},
9 path::Path,
10 ptr::{addr_of, null},
11 slice, str,
12 sync::mpsc::{channel, Receiver, Sender},
13 thread::{self, JoinHandle},
14};
15
16pub type FunctionIdx = usize;
17
18#[derive(Debug, Clone)]
19pub struct NamedFunction {
20 pub name: String,
21
22 pub idx: FunctionIdx,
24}
25
26pub(crate) struct Spi {
27 tests: Vec<NamedFunction>,
28 selected_function: Option<FunctionIdx>,
29 mode: SpiMode,
30}
31
32#[derive(PartialEq, Eq, Clone, Copy)]
33pub enum SpiModeKind {
34 Synchronous,
38
39 Asynchronous,
43}
44
45enum SpiMode {
46 Synchronous {
47 vt: Box<dyn VTable>,
48 last_measurement: u64,
49 },
50 Asynchronous {
51 worker: Option<JoinHandle<()>>,
52 tx: Sender<SpiRequest>,
53 rx: Receiver<SpiReply>,
54 },
55}
56
57impl Spi {
58 pub(crate) fn for_library(path: impl AsRef<Path>, mode: SpiModeKind) -> Spi {
59 let lib = unsafe { Library::new(path.as_ref()) }
60 .with_context(|| format!("Unable to open library: {}", path.as_ref().display()))
61 .unwrap();
62 spi_handle_for_vtable(ffi::LibraryVTable::new(lib).unwrap(), mode)
63 }
64
65 pub(crate) fn for_self(mode: SpiModeKind) -> Option<Spi> {
66 unsafe { SELF_VTABLE.take() }.map(|vt| spi_handle_for_vtable(vt, mode))
67 }
68
69 pub(crate) fn tests(&self) -> &[NamedFunction] {
70 &self.tests
71 }
72
73 pub(crate) fn lookup(&self, name: &str) -> Option<&NamedFunction> {
74 self.tests.iter().find(|f| f.name == name)
75 }
76
77 pub(crate) fn run(&mut self, iterations: usize) -> u64 {
78 match &self.mode {
79 SpiMode::Synchronous { vt, .. } => vt.run(iterations as c_ulonglong),
80 SpiMode::Asynchronous { worker: _, tx, rx } => {
81 tx.send(SpiRequest::Run { iterations }).unwrap();
82 match rx.recv().unwrap() {
83 SpiReply::Run(time) => time,
84 r => panic!("Unexpected response: {:?}", r),
85 }
86 }
87 }
88 }
89
90 pub(crate) fn measure(&mut self, iterations: usize) {
91 match &mut self.mode {
92 SpiMode::Synchronous {
93 vt,
94 last_measurement,
95 } => {
96 *last_measurement = vt.run(iterations as c_ulonglong);
97 }
98 SpiMode::Asynchronous { tx, .. } => {
99 tx.send(SpiRequest::Measure { iterations }).unwrap();
100 }
101 }
102 }
103
104 pub(crate) fn read_sample(&mut self) -> u64 {
105 match &self.mode {
106 SpiMode::Synchronous {
107 last_measurement, ..
108 } => *last_measurement,
109 SpiMode::Asynchronous { rx, .. } => match rx.recv().unwrap() {
110 SpiReply::Measure(time) => time,
111 r => panic!("Unexpected response: {:?}", r),
112 },
113 }
114 }
115
116 pub(crate) fn estimate_iterations(&mut self, time_ms: u32) -> usize {
117 match &self.mode {
118 SpiMode::Synchronous { vt, .. } => vt.estimate_iterations(time_ms) as usize,
119 SpiMode::Asynchronous { tx, rx, .. } => {
120 tx.send(SpiRequest::EstimateIterations { time_ms }).unwrap();
121 match rx.recv().unwrap() {
122 SpiReply::EstimateIterations(iters) => iters,
123 r => panic!("Unexpected response: {:?}", r),
124 }
125 }
126 }
127 }
128
129 pub(crate) fn prepare_state(&mut self, seed: u64) {
130 match &self.mode {
131 SpiMode::Synchronous { vt, .. } => vt.prepare_state(seed),
132 SpiMode::Asynchronous { tx, rx, .. } => {
133 tx.send(SpiRequest::PrepareState { seed }).unwrap();
134 match rx.recv().unwrap() {
135 SpiReply::PrepareState => {}
136 r => panic!("Unexpected response: {:?}", r),
137 }
138 }
139 }
140 }
141
142 pub(crate) fn select(&mut self, idx: usize) {
143 match &self.mode {
144 SpiMode::Synchronous { vt, .. } => vt.select(idx as c_ulonglong),
145 SpiMode::Asynchronous { tx, rx, .. } => {
146 tx.send(SpiRequest::Select { idx }).unwrap();
147 match rx.recv().unwrap() {
148 SpiReply::Select => self.selected_function = Some(idx),
149 r => panic!("Unexpected response: {:?}", r),
150 }
151 }
152 }
153 }
154}
155
156impl Drop for Spi {
157 fn drop(&mut self) {
158 if let SpiMode::Asynchronous { worker, tx, .. } = &mut self.mode {
159 if let Some(worker) = worker.take() {
160 tx.send(SpiRequest::Shutdown).unwrap();
161 worker.join().unwrap();
162 }
163 }
164 }
165}
166
167fn spi_worker(vt: &dyn VTable, rx: Receiver<SpiRequest>, tx: Sender<SpiReply>) {
168 use SpiReply as Rp;
169 use SpiRequest as Rq;
170
171 while let Ok(req) = rx.recv() {
172 let reply = match req {
173 Rq::EstimateIterations { time_ms } => {
174 Rp::EstimateIterations(vt.estimate_iterations(time_ms) as usize)
175 }
176 Rq::PrepareState { seed } => {
177 vt.prepare_state(seed);
178 Rp::PrepareState
179 }
180 Rq::Select { idx } => {
181 vt.select(idx as c_ulonglong);
182 Rp::Select
183 }
184 Rq::Run { iterations } => Rp::Run(vt.run(iterations as c_ulonglong)),
185 Rq::Measure { iterations } => Rp::Measure(vt.run(iterations as c_ulonglong)),
186 Rq::Shutdown => break,
187 };
188 tx.send(reply).unwrap();
189 }
190}
191
192fn spi_handle_for_vtable(vtable: impl VTable + Send + 'static, mode: SpiModeKind) -> Spi {
193 vtable.init();
194 let tests = enumerate_tests(&vtable).unwrap();
195
196 match mode {
197 SpiModeKind::Asynchronous => {
198 let (request_tx, request_rx) = channel();
199 let (reply_tx, reply_rx) = channel();
200 let worker = thread::spawn(move || {
201 spi_worker(&vtable, request_rx, reply_tx);
202 });
203
204 Spi {
205 tests,
206 selected_function: None,
207 mode: SpiMode::Asynchronous {
208 worker: Some(worker),
209 tx: request_tx,
210 rx: reply_rx,
211 },
212 }
213 }
214 SpiModeKind::Synchronous => Spi {
215 tests,
216 selected_function: None,
217 mode: SpiMode::Synchronous {
218 vt: Box::new(vtable),
219 last_measurement: 0,
220 },
221 },
222 }
223}
224
225fn enumerate_tests(vt: &dyn VTable) -> Result<Vec<NamedFunction>, Error> {
226 let mut tests = vec![];
227 for idx in 0..vt.count() {
228 vt.select(idx);
229
230 let mut length = 0;
231 let name_ptr: *const c_char = null();
232 vt.get_test_name(addr_of!(name_ptr) as _, &mut length);
233 if length == 0 {
234 continue;
235 }
236 let slice = unsafe { slice::from_raw_parts(name_ptr as *const u8, length as usize) };
237 let name = str::from_utf8(slice)
238 .map_err(Error::InvalidFFIString)?
239 .to_string();
240 let idx = idx as usize;
241 tests.push(NamedFunction { name, idx });
242 }
243 Ok(tests)
244}
245
246enum SpiRequest {
247 EstimateIterations { time_ms: u32 },
248 PrepareState { seed: u64 },
249 Select { idx: usize },
250 Run { iterations: usize },
251 Measure { iterations: usize },
252 Shutdown,
253}
254
255#[derive(Debug)]
256enum SpiReply {
257 EstimateIterations(usize),
258 PrepareState,
259 Select,
260 Run(u64),
261 Measure(u64),
262}
263
264struct State {
267 benchmarks: Vec<Benchmark>,
268 selected_function: Option<(usize, Option<Box<dyn ErasedSampler>>)>,
269}
270
271impl State {
272 fn selected(&self) -> &Benchmark {
273 &self.benchmarks[self.ensure_selected()]
274 }
275
276 fn ensure_selected(&self) -> usize {
277 self.selected_function
278 .as_ref()
279 .map(|(idx, _)| *idx)
280 .expect("No function was selected. Call tango_select() first")
281 }
282
283 fn selected_state_mut(&mut self) -> Option<&mut Box<dyn ErasedSampler>> {
284 self.selected_function
285 .as_mut()
286 .and_then(|(_, state)| state.as_mut())
287 }
288}
289
290static mut STATE: Option<State> = None;
292
293pub fn __tango_init(benchmarks: Vec<Benchmark>) {
298 unsafe {
299 if STATE.is_none() {
300 STATE = Some(State {
301 benchmarks,
302 selected_function: None,
303 });
304 }
305 }
306}
307
308pub mod ffi {
315 use super::*;
316 use std::{
317 ffi::{c_uint, c_ulonglong},
318 mem,
319 os::raw::c_char,
320 ptr::null,
321 };
322
323 pub type InitFn = unsafe extern "C" fn();
325 type CountFn = unsafe extern "C" fn() -> c_ulonglong;
326 type GetTestNameFn = unsafe extern "C" fn(*mut *const c_char, *mut c_ulonglong);
327 type SelectFn = unsafe extern "C" fn(c_ulonglong);
328 type RunFn = unsafe extern "C" fn(c_ulonglong) -> u64;
329 type EstimateIterationsFn = unsafe extern "C" fn(c_uint) -> c_ulonglong;
330 type PrepareStateFn = unsafe extern "C" fn(c_ulonglong);
331 type FreeFn = unsafe extern "C" fn();
332
333 #[allow(unused)]
336 mod type_check {
337 use super::*;
338
339 const TANGO_COUNT: CountFn = tango_count;
340 const TANGO_SELECT: SelectFn = tango_select;
341 const TANGO_GET_TEST_NAME: GetTestNameFn = tango_get_test_name;
342 const TANGO_RUN: RunFn = tango_run;
343 const TANGO_ESTIMATE_ITERATIONS: EstimateIterationsFn = tango_estimate_iterations;
344 const TANGO_FREE: FreeFn = tango_free;
345 }
346
347 #[no_mangle]
348 unsafe extern "C" fn tango_count() -> c_ulonglong {
349 STATE
350 .as_ref()
351 .map(|s| s.benchmarks.len() as c_ulonglong)
352 .unwrap_or(0)
353 }
354
355 #[no_mangle]
356 unsafe extern "C" fn tango_select(idx: c_ulonglong) {
357 if let Some(s) = STATE.as_mut() {
358 let idx = idx as usize;
359 assert!(idx < s.benchmarks.len());
360
361 s.selected_function = Some(match s.selected_function.take() {
362 Some((selected, state)) if selected == idx => (selected, state),
364 _ => (idx, None),
365 });
366 }
367 }
368
369 #[no_mangle]
370 unsafe extern "C" fn tango_get_test_name(name: *mut *const c_char, length: *mut c_ulonglong) {
371 if let Some(s) = STATE.as_ref() {
372 let n = s.selected().name();
373 *name = n.as_ptr() as _;
374 *length = n.len() as c_ulonglong;
375 } else {
376 *name = null();
377 *length = 0;
378 }
379 }
380
381 #[no_mangle]
382 unsafe extern "C" fn tango_run(iterations: c_ulonglong) -> u64 {
383 if let Some(s) = STATE.as_mut() {
384 s.selected_state_mut()
385 .expect("no tango_prepare_state() was called")
386 .measure(iterations as usize)
387 } else {
388 0
389 }
390 }
391
392 #[no_mangle]
393 unsafe extern "C" fn tango_estimate_iterations(time_ms: c_uint) -> c_ulonglong {
394 if let Some(s) = STATE.as_mut() {
395 s.selected_state_mut()
396 .expect("no tango_prepare_state() was called")
397 .as_mut()
398 .estimate_iterations(time_ms) as c_ulonglong
399 } else {
400 0
401 }
402 }
403
404 #[no_mangle]
405 unsafe extern "C" fn tango_prepare_state(seed: c_ulonglong) {
406 if let Some(s) = STATE.as_mut() {
407 let Some((idx, state)) = &mut s.selected_function else {
408 panic!("No tango_select() was called")
409 };
410 *state = Some(s.benchmarks[*idx].prepare_state(seed));
411 }
412 }
413
414 #[no_mangle]
415 unsafe extern "C" fn tango_free() {
416 STATE.take();
417 }
418
419 pub(super) trait VTable {
420 fn init(&self);
421 fn count(&self) -> c_ulonglong;
422 fn select(&self, func_idx: c_ulonglong);
423 fn get_test_name(&self, ptr: *mut *const c_char, len: *mut c_ulonglong);
424 fn run(&self, iterations: c_ulonglong) -> c_ulonglong;
425 fn estimate_iterations(&self, time_ms: c_uint) -> c_ulonglong;
426 fn prepare_state(&self, seed: c_ulonglong);
427 }
428
429 pub(super) static mut SELF_VTABLE: Option<SelfVTable> = Some(SelfVTable);
430
431 pub(super) struct SelfVTable;
437
438 impl VTable for SelfVTable {
439 fn init(&self) {
440 }
442
443 fn count(&self) -> c_ulonglong {
444 unsafe { tango_count() }
445 }
446
447 fn select(&self, func_idx: c_ulonglong) {
448 unsafe { tango_select(func_idx) }
449 }
450
451 fn get_test_name(&self, ptr: *mut *const c_char, len: *mut c_ulonglong) {
452 unsafe { tango_get_test_name(ptr, len) }
453 }
454
455 fn run(&self, iterations: c_ulonglong) -> u64 {
456 unsafe { tango_run(iterations) }
457 }
458
459 fn estimate_iterations(&self, time_ms: c_uint) -> c_ulonglong {
460 unsafe { tango_estimate_iterations(time_ms) }
461 }
462
463 fn prepare_state(&self, seed: u64) {
464 unsafe { tango_prepare_state(seed) }
465 }
466 }
467
468 impl Drop for SelfVTable {
469 fn drop(&mut self) {
470 unsafe {
471 tango_free();
472 }
473 }
474 }
475
476 pub(super) struct LibraryVTable {
477 init_fn: Symbol<'static, InitFn>,
481 count_fn: Symbol<'static, CountFn>,
482 select_fn: Symbol<'static, SelectFn>,
483 get_test_name_fn: Symbol<'static, GetTestNameFn>,
484 run_fn: Symbol<'static, RunFn>,
485 estimate_iterations_fn: Symbol<'static, EstimateIterationsFn>,
486 prepare_state_fn: Symbol<'static, PrepareStateFn>,
487 free_fn: Symbol<'static, FreeFn>,
488
489 _library: Box<Library>,
491 }
492
493 impl LibraryVTable {
494 pub(super) fn new(library: Library) -> Result<Self, Error> {
495 let library = Box::new(library);
498 let init_fn = lookup_symbol::<InitFn>(&library, "tango_init")?;
499 let count_fn = lookup_symbol::<CountFn>(&library, "tango_count")?;
500 let select_fn = lookup_symbol::<SelectFn>(&library, "tango_select")?;
501 let get_test_name_fn = lookup_symbol::<GetTestNameFn>(&library, "tango_get_test_name")?;
502 let run_fn = lookup_symbol::<RunFn>(&library, "tango_run")?;
503 let estimate_iterations_fn =
504 lookup_symbol::<EstimateIterationsFn>(&library, "tango_estimate_iterations")?;
505 let prepare_state_fn =
506 lookup_symbol::<PrepareStateFn>(&library, "tango_prepare_state")?;
507 let free_fn = lookup_symbol::<FreeFn>(&library, "tango_free")?;
508 Ok(Self {
509 _library: library,
510 init_fn,
511 count_fn,
512 select_fn,
513 get_test_name_fn,
514 run_fn,
515 estimate_iterations_fn,
516 prepare_state_fn,
517 free_fn,
518 })
519 }
520 }
521
522 impl VTable for LibraryVTable {
523 fn init(&self) {
524 unsafe { (self.init_fn)() }
525 }
526
527 fn count(&self) -> c_ulonglong {
528 unsafe { (self.count_fn)() }
529 }
530
531 fn select(&self, func_idx: c_ulonglong) {
532 unsafe { (self.select_fn)(func_idx) }
533 }
534
535 fn get_test_name(&self, ptr: *mut *const c_char, len: *mut c_ulonglong) {
536 unsafe { (self.get_test_name_fn)(ptr, len) }
537 }
538
539 fn run(&self, iterations: c_ulonglong) -> u64 {
540 unsafe { (self.run_fn)(iterations) }
541 }
542
543 fn estimate_iterations(&self, time_ms: c_uint) -> c_ulonglong {
544 unsafe { (self.estimate_iterations_fn)(time_ms) }
545 }
546
547 fn prepare_state(&self, seed: c_ulonglong) {
548 unsafe { (self.prepare_state_fn)(seed) }
549 }
550 }
551
552 impl Drop for LibraryVTable {
553 fn drop(&mut self) {
554 unsafe { (self.free_fn)() }
555 }
556 }
557
558 fn lookup_symbol<'l, T>(
559 library: &'l Library,
560 name: &'static str,
561 ) -> Result<Symbol<'static, T>, Error> {
562 unsafe {
563 let symbol = library
564 .get(name.as_bytes())
565 .map_err(Error::UnableToLoadSymbol)?;
566 Ok(mem::transmute::<Symbol<'l, T>, Symbol<'static, T>>(symbol))
567 }
568 }
569}