1use std::sync::Arc;
2
3#[cfg(feature = "quickwit")]
4use futures_util::{future::Either, FutureExt};
5
6use crate::TantivyError;
7
8#[derive(Clone)]
11pub enum Executor {
12 SingleThread,
14 ThreadPool(Arc<rayon::ThreadPool>),
16}
17
18#[cfg(feature = "quickwit")]
19impl From<Arc<rayon::ThreadPool>> for Executor {
20 fn from(thread_pool: Arc<rayon::ThreadPool>) -> Self {
21 Executor::ThreadPool(thread_pool)
22 }
23}
24
25impl Executor {
26 pub fn single_thread() -> Executor {
28 Executor::SingleThread
29 }
30
31 pub fn multi_thread(num_threads: usize, prefix: &'static str) -> crate::Result<Executor> {
33 let pool = rayon::ThreadPoolBuilder::new()
34 .num_threads(num_threads)
35 .thread_name(move |num| format!("{prefix}{num}"))
36 .build()?;
37 Ok(Executor::ThreadPool(Arc::new(pool)))
38 }
39
40 pub fn map<A, R, F>(&self, f: F, args: impl Iterator<Item = A>) -> crate::Result<Vec<R>>
45 where
46 A: Send,
47 R: Send,
48 F: Sized + Sync + Fn(A) -> crate::Result<R>,
49 {
50 match self {
51 Executor::SingleThread => args.map(f).collect::<crate::Result<_>>(),
52 Executor::ThreadPool(pool) => {
53 let args: Vec<A> = args.collect();
54 let num_fruits = args.len();
55 let fruit_receiver = {
56 let (fruit_sender, fruit_receiver) = crossbeam_channel::unbounded();
57 pool.scope(|scope| {
58 for (idx, arg) in args.into_iter().enumerate() {
59 let f_ref = &f;
62 let fruit_sender_ref = &fruit_sender;
63 scope.spawn(move |_| {
64 let fruit = f_ref(arg);
65 if let Err(err) = fruit_sender_ref.send((idx, fruit)) {
66 error!(
67 "Failed to send search task. It probably means all search \
68 threads have panicked. {err:?}"
69 );
70 }
71 });
72 }
73 });
74 fruit_receiver
75 };
79 let mut result_placeholders: Vec<Option<R>> =
80 std::iter::repeat_with(|| None).take(num_fruits).collect();
81 for (pos, fruit_res) in fruit_receiver {
82 let fruit = fruit_res?;
83 result_placeholders[pos] = Some(fruit);
84 }
85 let results: Vec<R> = result_placeholders.into_iter().flatten().collect();
86 if results.len() != num_fruits {
87 return Err(TantivyError::InternalError(
88 "One of the mapped execution failed.".to_string(),
89 ));
90 }
91 Ok(results)
92 }
93 }
94 }
95
96 #[cfg(feature = "quickwit")]
100 pub fn spawn_blocking<T: Send + 'static>(
101 &self,
102 cpu_intensive_task: impl FnOnce() -> T + Send + 'static,
103 ) -> impl std::future::Future<Output = Result<T, ()>> {
104 match self {
105 Executor::SingleThread => Either::Left(std::future::ready(Ok(cpu_intensive_task()))),
106 Executor::ThreadPool(pool) => {
107 let (sender, receiver) = oneshot::channel();
108 pool.spawn(|| {
109 if sender.is_closed() {
110 return;
111 }
112 let task_result = cpu_intensive_task();
113 let _ = sender.send(task_result);
114 });
115
116 let res = receiver.map(|res| res.map_err(|_| ()));
117 Either::Right(res)
118 }
119 }
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::Executor;
126
127 #[test]
128 #[should_panic(expected = "panic should propagate")]
129 fn test_panic_propagates_single_thread() {
130 let _result: Vec<usize> = Executor::single_thread()
131 .map(
132 |_| {
133 panic!("panic should propagate");
134 },
135 vec![0].into_iter(),
136 )
137 .unwrap();
138 }
139
140 #[test]
141 #[should_panic] fn test_panic_propagates_multi_thread() {
143 let _result: Vec<usize> = Executor::multi_thread(1, "search-test")
144 .unwrap()
145 .map(
146 |_| {
147 panic!("panic should propagate");
148 },
149 vec![0].into_iter(),
150 )
151 .unwrap();
152 }
153
154 #[test]
155 fn test_map_singlethread() {
156 let result: Vec<usize> = Executor::single_thread()
157 .map(|i| Ok(i * 2), 0..1_000)
158 .unwrap();
159 assert_eq!(result.len(), 1_000);
160 for i in 0..1_000 {
161 assert_eq!(result[i], i * 2);
162 }
163 }
164
165 #[test]
166 fn test_map_multithread() {
167 let result: Vec<usize> = Executor::multi_thread(3, "search-test")
168 .unwrap()
169 .map(|i| Ok(i * 2), 0..10)
170 .unwrap();
171 assert_eq!(result.len(), 10);
172 for i in 0..10 {
173 assert_eq!(result[i], i * 2);
174 }
175 }
176
177 #[cfg(feature = "quickwit")]
178 #[test]
179 fn test_cancel_cpu_intensive_tasks() {
180 use std::sync::atomic::{AtomicU64, Ordering};
181 use std::sync::Arc;
182
183 let counter: Arc<AtomicU64> = Default::default();
184
185 let other_counter: Arc<AtomicU64> = Default::default();
186
187 let mut futures = Vec::new();
188 let mut other_futures = Vec::new();
189
190 let (tx, rx) = crossbeam_channel::bounded::<()>(0);
191 let rx = Arc::new(rx);
192 let executor = Executor::multi_thread(3, "search-test").unwrap();
193 for _ in 0..1000 {
194 let counter_clone: Arc<AtomicU64> = counter.clone();
195 let other_counter_clone: Arc<AtomicU64> = other_counter.clone();
196
197 let rx_clone = rx.clone();
198 let rx_clone2 = rx.clone();
199 let fut = executor.spawn_blocking(move || {
200 counter_clone.fetch_add(1, Ordering::SeqCst);
201 let _ = rx_clone.recv();
202 });
203 futures.push(fut);
204 let other_fut = executor.spawn_blocking(move || {
205 other_counter_clone.fetch_add(1, Ordering::SeqCst);
206 let _ = rx_clone2.recv();
207 });
208 other_futures.push(other_fut);
209 }
210
211 for _ in 0..100 {
213 tx.send(()).unwrap();
214 }
215
216 let counter_val = counter.load(Ordering::SeqCst);
217 let other_counter_val = other_counter.load(Ordering::SeqCst);
218 assert!(counter_val >= 30);
219 assert!(other_counter_val >= 30);
220
221 drop(other_futures);
222
223 for _ in 0..100 {
225 tx.send(()).unwrap();
226 }
227
228 let counter_val2 = counter.load(Ordering::SeqCst);
229 assert!(counter_val2 >= counter_val + 100 - 6);
230
231 let other_counter_val2 = other_counter.load(Ordering::SeqCst);
232 assert!(other_counter_val2 <= other_counter_val + 6);
233 }
234}