Skip to main content

diskann_benchmark_core/
tokio.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6/// Create a generic multi-threaded runtime with `num_threads`.
7///
8/// No guarantees are made about the returned [`tokio::runtime::Runtime`] except that it
9/// will have `num_threads` workers.
10pub fn runtime(num_threads: usize) -> anyhow::Result<tokio::runtime::Runtime> {
11    Ok(tokio::runtime::Builder::new_multi_thread()
12        .worker_threads(num_threads)
13        .build()?)
14}
15
16/// Create a generic multi-threaded runtime with `num_threads`.
17///
18/// After initial setup, the [`tokio::runtime::Builder`] will be passed to the closure `f`
19/// for customization. Note that the builder provided to the callback will already be
20/// initialized to contain `num_threads` threads.
21pub fn runtime_with<F>(num_threads: usize, f: F) -> anyhow::Result<tokio::runtime::Runtime>
22where
23    F: FnOnce(&mut tokio::runtime::Builder),
24{
25    let mut builder = tokio::runtime::Builder::new_multi_thread();
26    builder.worker_threads(num_threads);
27    f(&mut builder);
28    Ok(builder.build()?)
29}
30
31///////////
32// Tests //
33///////////
34
35#[cfg(test)]
36mod tests {
37    use super::*;
38
39    #[test]
40    fn test_runtimes() {
41        for num_threads in [1, 2, 4, 8] {
42            let rt = runtime(num_threads).unwrap();
43            let metrics = rt.metrics();
44            assert_eq!(metrics.num_workers(), num_threads);
45        }
46    }
47
48    #[test]
49    fn test_runtime_with_threads() {
50        for num_threads in [1, 2, 4, 8] {
51            let rt = runtime_with(num_threads, |_| {}).unwrap();
52            let metrics = rt.metrics();
53            assert_eq!(metrics.num_workers(), num_threads);
54        }
55    }
56
57    #[test]
58    fn test_runtime_with_customizes_builder() {
59        let rt = runtime_with(2, |builder| {
60            builder.thread_name("custom-worker");
61        })
62        .unwrap();
63
64        // Verify the runtime was created with the correct number of threads.
65        assert_eq!(rt.metrics().num_workers(), 2);
66
67        // Verify the thread name was applied by spawning work on the runtime
68        // and checking the thread name from within a worker.
69        let name = rt.block_on(async {
70            tokio::task::spawn(async { std::thread::current().name().unwrap_or("").to_string() })
71                .await
72                .unwrap()
73        });
74        assert!(
75            name.starts_with("custom-worker"),
76            "expected thread name starting with 'custom-worker', got '{name}'",
77        );
78    }
79}