cognee_core/
thread_pool.rs1use std::future::Future;
2use std::pin::Pin;
3
4use crate::error::CoreError;
5pub trait CpuPool: Send + Sync {
10 fn spawn_raw(
14 &self,
15 task: Box<dyn FnOnce() + Send + 'static>,
16 ) -> Pin<Box<dyn Future<Output = Result<(), CoreError>> + Send + 'static>>;
17}
18pub trait CpuPoolExt: CpuPool {
21 fn spawn<F, R>(
31 &self,
32 f: F,
33 ) -> Pin<Box<dyn Future<Output = Result<R, CoreError>> + Send + 'static>>
34 where
35 F: FnOnce() -> R + Send + 'static,
36 R: Send + 'static,
37 {
38 let (tx, rx) = tokio::sync::oneshot::channel::<R>();
39
40 let task: Box<dyn FnOnce() + Send + 'static> = Box::new(move || {
41 let result = f();
42 let _ = tx.send(result);
44 });
45
46 let raw_fut = self.spawn_raw(task);
47
48 Box::pin(async move {
49 raw_fut.await?;
52 rx.await.map_err(|_| CoreError::TaskAborted {
53 reason: "task result channel dropped (task panicked or pool shut down)".into(),
54 })
55 })
56 }
57}
58
59impl<T: CpuPool + ?Sized> CpuPoolExt for T {}
60pub struct RayonThreadPool {
65 pool: rayon::ThreadPool,
66}
67
68impl RayonThreadPool {
69 pub fn new(num_threads: usize) -> Result<Self, CoreError> {
71 let pool = rayon::ThreadPoolBuilder::new()
72 .num_threads(num_threads)
73 .build()
74 .map_err(|e| CoreError::ThreadPoolBuild(e.to_string()))?;
75 Ok(Self { pool })
76 }
77
78 pub fn with_default_threads() -> Result<Self, CoreError> {
80 let pool = rayon::ThreadPoolBuilder::new()
81 .build()
82 .map_err(|e| CoreError::ThreadPoolBuild(e.to_string()))?;
83 Ok(Self { pool })
84 }
85
86 pub fn rayon_pool(&self) -> &rayon::ThreadPool {
89 &self.pool
90 }
91}
92
93impl CpuPool for RayonThreadPool {
94 fn spawn_raw(
95 &self,
96 task: Box<dyn FnOnce() + Send + 'static>,
97 ) -> Pin<Box<dyn Future<Output = Result<(), CoreError>> + Send + 'static>> {
98 let (tx, rx) = tokio::sync::oneshot::channel::<Result<(), String>>();
99
100 self.pool.spawn(move || {
101 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(task));
102 match result {
103 Ok(()) => {
104 let _ = tx.send(Ok(()));
105 }
106 Err(panic_payload) => {
107 let msg = panic_payload
108 .downcast_ref::<String>()
109 .map(|s| s.as_str())
110 .or_else(|| panic_payload.downcast_ref::<&str>().copied())
111 .unwrap_or("unknown panic")
112 .to_string();
113 let _ = tx.send(Err(msg));
114 }
115 }
116 });
117
118 Box::pin(async move {
119 match rx.await {
120 Ok(Ok(())) => Ok(()),
121 Ok(Err(panic_msg)) => Err(CoreError::TaskAborted {
122 reason: format!("task panicked: {panic_msg}"),
123 }),
124 Err(_) => Err(CoreError::TaskAborted {
125 reason: "pool shut down before task completed".into(),
126 }),
127 }
128 })
129 }
130}