1use std::sync::{Arc, atomic::{AtomicBool, Ordering}};
7use std::thread::{self, JoinHandle};
8use tokio::sync::{mpsc, oneshot};
9use tracing::{info, error};
10
11use crate::{
12 shared_state::SharedState,
13 config::Config,
14};
15
16#[derive(Debug, Clone)]
18pub struct ThreadPoolConfig {
19 pub context_engine_threads: usize,
20 pub cache_manager_threads: usize,
21 pub database_threads: usize,
22 pub llm_threads: usize,
23 pub io_threads: usize,
24}
25
26impl ThreadPoolConfig {
27 pub fn new(config: &Config) -> Self {
28 let cpu_cores = num_cpus::get();
30
31 Self {
32 context_engine_threads: (cpu_cores / 4).max(2).min(4),
33 cache_manager_threads: 1.max(cpu_cores / 8).min(2),
34 database_threads: config.max_concurrent_streams as usize,
35 llm_threads: 1, io_threads: (cpu_cores / 2).max(2).min(4),
37 }
38 }
39}
40
41pub enum SystemCommand {
43 ProcessMessage {
45 session_id: String,
46 message: crate::memory::Message,
47 sender: Box<dyn FnOnce(anyhow::Result<crate::memory::Message>) + Send>,
48 },
49
50 GenerateResponse {
52 session_id: String,
53 context: Vec<crate::memory::Message>,
54 sender: Box<dyn FnOnce(anyhow::Result<String>) + Send>,
55 },
56
57 UpdateCache {
59 session_id: String,
60 entries: Vec<crate::cache_management::cache_extractor::KVEntry>,
61 sender: Box<dyn FnOnce(anyhow::Result<()>) + Send>,
62 },
63
64 PersistConversation {
66 session_id: String,
67 messages: Vec<crate::memory::Message>,
68 sender: Box<dyn FnOnce(anyhow::Result<()>) + Send>,
69 },
70
71 Shutdown,
73}
74
75pub struct WorkerThread {
77 command_receiver: mpsc::UnboundedReceiver<SystemCommand>,
78 shared_state: Arc<SharedState>,
79 thread_handle: Option<JoinHandle<()>>,
80 running: Arc<AtomicBool>,
81}
82
83impl WorkerThread {
84 pub fn new(
85 name: String,
86 command_receiver: mpsc::UnboundedReceiver<SystemCommand>,
87 shared_state: Arc<SharedState>,
88 ) -> Self {
89 let running = Arc::new(AtomicBool::new(true));
90 let running_clone = running.clone();
91 let shared_state_clone = shared_state.clone();
92
93 let thread_handle = thread::Builder::new()
94 .name(name.clone())
95 .spawn({
96 let receiver = command_receiver; move || {
98 let rt = tokio::runtime::Builder::new_current_thread()
99 .enable_all()
100 .build()
101 .expect("Failed to create worker thread runtime");
102
103 rt.block_on(async move {
104 Self::run_worker_loop(receiver, shared_state_clone, running_clone).await;
105 });
106 }
107 })
108 .expect("Failed to spawn worker thread");
109
110 info!("Spawned worker thread: {}", name);
111
112 Self {
113 command_receiver: mpsc::unbounded_channel().1, shared_state,
115 thread_handle: Some(thread_handle),
116 running,
117 }
118 }
119
120 async fn run_worker_loop(
121 mut receiver: mpsc::UnboundedReceiver<SystemCommand>,
122 shared_state: Arc<SharedState>,
123 running: Arc<AtomicBool>,
124 ) {
125 while running.load(Ordering::Relaxed) {
126 tokio::select! {
127 command = receiver.recv() => {
128 match command {
129 Some(cmd) => {
130 if let Err(e) = Self::handle_command(cmd, &shared_state).await {
131 error!("Worker thread command failed: {}", e);
132 }
133 }
134 None => {
135 info!("Worker thread command channel closed");
136 break;
137 }
138 }
139 }
140 _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => {
141 }
143 }
144 }
145
146 info!("Worker thread shutting down");
147 }
148
149 async fn handle_command(
150 command: SystemCommand,
151 shared_state: &Arc<SharedState>,
152 ) -> anyhow::Result<()> {
153 match command {
154 SystemCommand::ProcessMessage { session_id, message, sender } => {
155 let result = Self::process_message(shared_state, session_id, message).await;
156 sender(result);
157 }
158 SystemCommand::GenerateResponse { session_id, context, sender } => {
159 let result = Self::generate_response(shared_state, session_id, context).await;
160 sender(result);
161 }
162 SystemCommand::UpdateCache { session_id, entries, sender } => {
163 let result = Self::update_cache(shared_state, session_id, entries).await;
164 sender(result);
165 }
166 SystemCommand::PersistConversation { session_id, messages, sender } => {
167 let result = Self::persist_conversation(shared_state, session_id, messages).await;
168 sender(result);
169 }
170 SystemCommand::Shutdown => {
171 }
173 }
174 Ok(())
175 }
176
177 async fn process_message(
178 shared_state: &Arc<SharedState>,
179 session_id: String,
180 message: crate::memory::Message,
181 ) -> anyhow::Result<crate::memory::Message> {
182 let session = shared_state.get_or_create_session(&session_id).await;
184 let mut session_guard = session.write()
185 .map_err(|_| anyhow::anyhow!("Failed to acquire session write lock"))?;
186
187 session_guard.messages.push(message.clone());
189 session_guard.last_accessed = std::time::Instant::now();
190
191 shared_state.counters.inc_processed_messages();
193
194 Ok(message)
195 }
196
197 async fn generate_response(
198 _shared_state: &Arc<SharedState>,
199 _session_id: String,
200 _context: Vec<crate::memory::Message>,
201 ) -> anyhow::Result<String> {
202 Ok("Generated response placeholder".to_string())
205 }
206
207 async fn update_cache(
208 shared_state: &Arc<SharedState>,
209 session_id: String,
210 entries: Vec<crate::cache_management::cache_extractor::KVEntry>,
211 ) -> anyhow::Result<()> {
212 let cache_guard = shared_state.cache_manager.read()
213 .map_err(|_| anyhow::anyhow!("Failed to acquire cache manager read lock"))?;
214 if let Some(cache_manager) = &*cache_guard {
215 info!("Updating cache for session {} with {} entries", session_id, entries.len());
218 }
219 Ok(())
220 }
221
222 async fn persist_conversation(
223 shared_state: &Arc<SharedState>,
224 session_id: String,
225 messages: Vec<crate::memory::Message>,
226 ) -> anyhow::Result<()> {
227 info!("Persisting conversation {} with {} messages", session_id, messages.len());
229 Ok(())
231 }
232}
233
234impl Drop for WorkerThread {
235 fn drop(&mut self) {
236 self.running.store(false, Ordering::Relaxed);
237 if let Some(handle) = self.thread_handle.take() {
238 let _ = handle.join();
239 }
240 }
241}
242
243pub struct ThreadPool {
245 config: ThreadPoolConfig,
246 shared_state: Arc<SharedState>,
247 workers: Vec<WorkerThread>,
248 command_senders: Vec<mpsc::UnboundedSender<SystemCommand>>,
249}
250
251impl ThreadPool {
252 pub fn new(config: ThreadPoolConfig, shared_state: Arc<SharedState>) -> Self {
253 Self {
254 config,
255 shared_state,
256 workers: Vec::new(),
257 command_senders: Vec::new(),
258 }
259 }
260
261 pub async fn start(&mut self) -> anyhow::Result<()> {
262 info!("Starting thread pool with config: {:?}", self.config);
263
264 let mut channels = Vec::new();
266 for i in 0..self.config.context_engine_threads {
267 let (tx, rx) = mpsc::unbounded_channel();
268 channels.push((format!("context-worker-{}", i), tx, rx));
269 }
270
271 for (name, tx, rx) in channels {
273 let worker = WorkerThread::new(
274 name,
275 rx,
276 self.shared_state.clone(),
277 );
278 self.workers.push(worker);
279 self.command_senders.push(tx);
280 }
281
282 info!("Thread pool started with {} workers", self.workers.len());
283 Ok(())
284 }
285
286 pub async fn send_command(&self, command: SystemCommand) -> anyhow::Result<()> {
287 static NEXT_WORKER: AtomicBool = AtomicBool::new(false);
289 let worker_index = if NEXT_WORKER.fetch_xor(true, Ordering::Relaxed) { 0 } else { 1 };
290 let sender_index = worker_index % self.command_senders.len();
291
292 self.command_senders[sender_index]
293 .send(command)
294 .map_err(|_| anyhow::anyhow!("Failed to send command to worker thread"))
295 }
296
297 pub async fn shutdown(&mut self) -> anyhow::Result<()> {
298 info!("Shutting down thread pool");
299
300 for sender in &self.command_senders {
302 let _ = sender.send(SystemCommand::Shutdown);
303 }
304
305 self.workers.clear();
307 self.command_senders.clear();
308
309 info!("Thread pool shutdown complete");
310 Ok(())
311 }
312}
313
314pub use self::SystemCommand as Command;