1use std::sync::{Arc, atomic::{AtomicBool, Ordering}};
7use std::thread::{self, JoinHandle};
8use tokio::sync::mpsc;
9use tracing::{info, error};
10
11use crate::{
12 shared_state::SharedState,
13 config::Config,
14};
15
16#[derive(Debug, Clone)]
18pub struct ThreadPoolConfig {
19 pub context_engine_threads: usize,
20 pub cache_manager_threads: usize,
21 pub database_threads: usize,
22 pub llm_threads: usize,
23 pub io_threads: usize,
24}
25
26impl ThreadPoolConfig {
27 pub fn new(config: &Config) -> Self {
28 let cpu_cores = num_cpus::get();
30
31 Self {
32 context_engine_threads: (cpu_cores / 4).max(2).min(4),
33 cache_manager_threads: 1.max(cpu_cores / 8).min(2),
34 database_threads: config.max_concurrent_streams as usize,
35 llm_threads: 1, io_threads: (cpu_cores / 2).max(2).min(4),
37 }
38 }
39}
40
41pub enum SystemCommand {
43 ProcessMessage {
45 session_id: String,
46 message: crate::memory::Message,
47 sender: Box<dyn FnOnce(anyhow::Result<crate::memory::Message>) + Send>,
48 },
49
50 GenerateResponse {
52 session_id: String,
53 context: Vec<crate::memory::Message>,
54 sender: Box<dyn FnOnce(anyhow::Result<String>) + Send>,
55 },
56
57 UpdateCache {
59 session_id: String,
60 entries: Vec<crate::cache_management::cache_extractor::KVEntry>,
61 sender: Box<dyn FnOnce(anyhow::Result<()>) + Send>,
62 },
63
64 PersistConversation {
66 session_id: String,
67 messages: Vec<crate::memory::Message>,
68 sender: Box<dyn FnOnce(anyhow::Result<()>) + Send>,
69 },
70
71 Shutdown,
73}
74
75pub struct WorkerThread {
77 thread_handle: Option<JoinHandle<()>>,
78 running: Arc<AtomicBool>,
79}
80
81impl WorkerThread {
82 pub fn new(
83 name: String,
84 command_receiver: mpsc::UnboundedReceiver<SystemCommand>,
85 shared_state: Arc<SharedState>,
86 ) -> Self {
87 let running = Arc::new(AtomicBool::new(true));
88 let running_clone = running.clone();
89
90 let thread_handle = thread::Builder::new()
91 .name(name.clone())
92 .spawn(move || {
93 let rt = tokio::runtime::Builder::new_current_thread()
94 .enable_all()
95 .build()
96 .expect("Failed to create worker thread runtime");
97
98 rt.block_on(async move {
99 Self::run_worker_loop(command_receiver, shared_state, running_clone).await;
100 });
101 })
102 .expect("Failed to spawn worker thread");
103
104 info!("Spawned worker thread: {}", name);
105
106 Self {
107 thread_handle: Some(thread_handle),
108 running,
109 }
110 }
111
112 async fn run_worker_loop(
113 mut receiver: mpsc::UnboundedReceiver<SystemCommand>,
114 shared_state: Arc<SharedState>,
115 running: Arc<AtomicBool>,
116 ) {
117 while running.load(Ordering::Relaxed) {
118 tokio::select! {
119 command = receiver.recv() => {
120 match command {
121 Some(cmd) => {
122 if let Err(e) = Self::handle_command(cmd, &shared_state).await {
123 error!("Worker thread command failed: {}", e);
124 }
125 }
126 None => {
127 info!("Worker thread command channel closed");
128 break;
129 }
130 }
131 }
132 _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => {
133 }
135 }
136 }
137
138 info!("Worker thread shutting down");
139 }
140
141 async fn handle_command(
142 command: SystemCommand,
143 shared_state: &Arc<SharedState>,
144 ) -> anyhow::Result<()> {
145 match command {
146 SystemCommand::ProcessMessage { session_id, message, sender } => {
147 let result = Self::process_message(shared_state, session_id, message).await;
148 sender(result);
149 }
150 SystemCommand::GenerateResponse { session_id, context, sender } => {
151 let result = Self::generate_response(shared_state, session_id, context).await;
152 sender(result);
153 }
154 SystemCommand::UpdateCache { session_id, entries, sender } => {
155 let result = Self::update_cache(shared_state, session_id, entries).await;
156 sender(result);
157 }
158 SystemCommand::PersistConversation { session_id, messages, sender } => {
159 let result = Self::persist_conversation(shared_state, session_id, messages).await;
160 sender(result);
161 }
162 SystemCommand::Shutdown => {
163 }
165 }
166 Ok(())
167 }
168
169 async fn process_message(
170 shared_state: &Arc<SharedState>,
171 session_id: String,
172 message: crate::memory::Message,
173 ) -> anyhow::Result<crate::memory::Message> {
174 let session = shared_state.get_or_create_session(&session_id).await;
176 let mut session_guard = session.write()
177 .map_err(|_| anyhow::anyhow!("Failed to acquire session write lock"))?;
178
179 session_guard.messages.push(message.clone());
181 session_guard.last_accessed = std::time::Instant::now();
182
183 shared_state.counters.inc_processed_messages();
185
186 Ok(message)
187 }
188
189 async fn generate_response(
190 shared_state: &Arc<SharedState>,
191 session_id: String,
192 context: Vec<crate::memory::Message>,
193 ) -> anyhow::Result<String> {
194 let response = shared_state
196 .llm_worker
197 .generate_response(session_id, context.clone())
198 .await?;
199
200 Ok(response)
201 }
202
203 async fn update_cache(
204 shared_state: &Arc<SharedState>,
205 session_id: String,
206 entries: Vec<crate::cache_management::cache_extractor::KVEntry>,
207 ) -> anyhow::Result<()> {
208 let cache_guard = shared_state.cache_manager.read()
209 .map_err(|_| anyhow::anyhow!("Failed to acquire cache manager read lock"))?;
210 if let Some(_cache_manager) = &*cache_guard {
211 info!("Updating cache for session {} with {} entries", session_id, entries.len());
214 }
215 Ok(())
216 }
217
218 async fn persist_conversation(
219 _shared_state: &Arc<SharedState>,
220 session_id: String,
221 messages: Vec<crate::memory::Message>,
222 ) -> anyhow::Result<()> {
223 info!("Persisting conversation {} with {} messages", session_id, messages.len());
225 Ok(())
227 }
228}
229
230impl Drop for WorkerThread {
231 fn drop(&mut self) {
232 self.running.store(false, Ordering::Relaxed);
233 if let Some(handle) = self.thread_handle.take() {
234 let _ = handle.join();
235 }
236 }
237}
238
239pub struct ThreadPool {
241 config: ThreadPoolConfig,
242 shared_state: Arc<SharedState>,
243 workers: Vec<WorkerThread>,
244 command_senders: Vec<mpsc::UnboundedSender<SystemCommand>>,
245}
246
247impl ThreadPool {
248 pub fn new(config: ThreadPoolConfig, shared_state: Arc<SharedState>) -> Self {
249 Self {
250 config,
251 shared_state,
252 workers: Vec::new(),
253 command_senders: Vec::new(),
254 }
255 }
256
257 pub async fn start(&mut self) -> anyhow::Result<()> {
258 info!("Starting thread pool with config: {:?}", self.config);
259
260 let mut channels = Vec::new();
262 for i in 0..self.config.context_engine_threads {
263 let (tx, rx) = mpsc::unbounded_channel();
264 channels.push((format!("context-worker-{}", i), tx, rx));
265 }
266
267 for (name, tx, rx) in channels {
269 let worker = WorkerThread::new(
270 name,
271 rx,
272 self.shared_state.clone(),
273 );
274 self.workers.push(worker);
275 self.command_senders.push(tx);
276 }
277
278 info!("Thread pool started with {} workers", self.workers.len());
279 Ok(())
280 }
281
282 pub async fn send_command(&self, command: SystemCommand) -> anyhow::Result<()> {
283 static NEXT_WORKER: AtomicBool = AtomicBool::new(false);
285 let worker_index = if NEXT_WORKER.fetch_xor(true, Ordering::Relaxed) { 0 } else { 1 };
286 let sender_index = worker_index % self.command_senders.len();
287
288 self.command_senders[sender_index]
289 .send(command)
290 .map_err(|_| anyhow::anyhow!("Failed to send command to worker thread"))
291 }
292
293 pub async fn shutdown(&mut self) -> anyhow::Result<()> {
294 info!("Shutting down thread pool");
295
296 for sender in &self.command_senders {
298 let _ = sender.send(SystemCommand::Shutdown);
299 }
300
301 self.workers.clear();
303 self.command_senders.clear();
304
305 info!("Thread pool shutdown complete");
306 Ok(())
307 }
308}
309
310pub use self::SystemCommand as Command;