Skip to main content

cognee_core/
thread_pool.rs

1use std::future::Future;
2use std::pin::Pin;
3
4use crate::error::CoreError;
5/// Dyn-compatible interface for a CPU-bound thread pool.
6///
7/// Use the blanket [`CpuPoolExt::spawn`] method for ergonomic, generic usage.
8/// Implement only [`CpuPool::spawn_raw`] in concrete types.
9pub trait CpuPool: Send + Sync {
10    /// Spawn a task on the pool and return a future that resolves once the task
11    /// finishes.  The future does **not** borrow `self`; the task is enqueued
12    /// immediately when `spawn_raw` is called.
13    fn spawn_raw(
14        &self,
15        task: Box<dyn FnOnce() + Send + 'static>,
16    ) -> Pin<Box<dyn Future<Output = Result<(), CoreError>> + Send + 'static>>;
17}
18/// Ergonomic extension for [`CpuPool`] that adds a generic `spawn` with a
19/// return value.  Auto-implemented for every `T: CpuPool`.
20pub trait CpuPoolExt: CpuPool {
21    /// Spawn a CPU-intensive closure on the thread pool and await its result
22    /// asynchronously.
23    ///
24    /// The closure is executed on a rayon (or other CPU) worker thread while
25    /// the caller's async task yields. Useful for blocking work that would
26    /// otherwise stall the Tokio executor.
27    ///
28    /// Returns `Err(CoreError::TaskAborted)` if the worker panicked or the pool
29    /// was shut down before the result could be delivered.
30    fn spawn<F, R>(
31        &self,
32        f: F,
33    ) -> Pin<Box<dyn Future<Output = Result<R, CoreError>> + Send + 'static>>
34    where
35        F: FnOnce() -> R + Send + 'static,
36        R: Send + 'static,
37    {
38        let (tx, rx) = tokio::sync::oneshot::channel::<R>();
39
40        let task: Box<dyn FnOnce() + Send + 'static> = Box::new(move || {
41            let result = f();
42            // If the receiver was dropped the caller gave up; that's fine.
43            let _ = tx.send(result);
44        });
45
46        let raw_fut = self.spawn_raw(task);
47
48        Box::pin(async move {
49            // Wait for the task to complete (raw_fut resolves after the closure
50            // returns), then retrieve the value from the oneshot channel.
51            raw_fut.await?;
52            rx.await.map_err(|_| CoreError::TaskAborted {
53                reason: "task result channel dropped (task panicked or pool shut down)".into(),
54            })
55        })
56    }
57}
58
59impl<T: CpuPool + ?Sized> CpuPoolExt for T {}
60/// A [`CpuPool`] backed by a dedicated [`rayon::ThreadPool`].
61///
62/// Provides direct access to the underlying pool via [`RayonThreadPool::rayon_pool`]
63/// for callers that want to use rayon's parallel iterators directly.
64pub struct RayonThreadPool {
65    pool: rayon::ThreadPool,
66}
67
68impl RayonThreadPool {
69    /// Create a pool with a specific number of threads.
70    pub fn new(num_threads: usize) -> Result<Self, CoreError> {
71        let pool = rayon::ThreadPoolBuilder::new()
72            .num_threads(num_threads)
73            .build()
74            .map_err(|e| CoreError::ThreadPoolBuild(e.to_string()))?;
75        Ok(Self { pool })
76    }
77
78    /// Create a pool with rayon's default thread count (one per logical CPU).
79    pub fn with_default_threads() -> Result<Self, CoreError> {
80        let pool = rayon::ThreadPoolBuilder::new()
81            .build()
82            .map_err(|e| CoreError::ThreadPoolBuild(e.to_string()))?;
83        Ok(Self { pool })
84    }
85
86    /// Direct access to the underlying [`rayon::ThreadPool`], e.g. for
87    /// `pool.install(|| { ... })` or parallel iterators scoped to this pool.
88    pub fn rayon_pool(&self) -> &rayon::ThreadPool {
89        &self.pool
90    }
91}
92
93impl CpuPool for RayonThreadPool {
94    fn spawn_raw(
95        &self,
96        task: Box<dyn FnOnce() + Send + 'static>,
97    ) -> Pin<Box<dyn Future<Output = Result<(), CoreError>> + Send + 'static>> {
98        let (tx, rx) = tokio::sync::oneshot::channel::<Result<(), String>>();
99
100        self.pool.spawn(move || {
101            let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(task));
102            match result {
103                Ok(()) => {
104                    let _ = tx.send(Ok(()));
105                }
106                Err(panic_payload) => {
107                    let msg = panic_payload
108                        .downcast_ref::<String>()
109                        .map(|s| s.as_str())
110                        .or_else(|| panic_payload.downcast_ref::<&str>().copied())
111                        .unwrap_or("unknown panic")
112                        .to_string();
113                    let _ = tx.send(Err(msg));
114                }
115            }
116        });
117
118        Box::pin(async move {
119            match rx.await {
120                Ok(Ok(())) => Ok(()),
121                Ok(Err(panic_msg)) => Err(CoreError::TaskAborted {
122                    reason: format!("task panicked: {panic_msg}"),
123                }),
124                Err(_) => Err(CoreError::TaskAborted {
125                    reason: "pool shut down before task completed".into(),
126                }),
127            }
128        })
129    }
130}