Skip to main content

offline_intelligence/
thread_pool.rs

1//! Thread pool management for worker threads
2//!
3//! This module provides the infrastructure for managing dedicated worker threads
4//! for different system components, enabling efficient parallel processing.
5
6use std::sync::{Arc, atomic::{AtomicBool, Ordering}};
7use std::thread::{self, JoinHandle};
8use tokio::sync::mpsc;
9use tracing::{info, error};
10
11use crate::{
12    shared_state::SharedState,
13    config::Config,
14};
15
16/// Configuration for thread pool sizing
17#[derive(Debug, Clone)]
18pub struct ThreadPoolConfig {
19    pub context_engine_threads: usize,
20    pub cache_manager_threads: usize,
21    pub database_threads: usize,
22    pub llm_threads: usize,
23    pub io_threads: usize,
24}
25
26impl ThreadPoolConfig {
27    pub fn new(config: &Config) -> Self {
28        // Scale thread counts based on system resources
29        let cpu_cores = num_cpus::get();
30        
31        Self {
32            context_engine_threads: (cpu_cores / 4).max(2).min(4),
33            cache_manager_threads: 1.max(cpu_cores / 8).min(2),
34            database_threads: config.max_concurrent_streams as usize,
35            llm_threads: 1, // LLM inference is typically single-threaded per model
36            io_threads: (cpu_cores / 2).max(2).min(4),
37        }
38    }
39}
40
41/// System-wide command types for thread communication
42pub enum SystemCommand {
43    // Conversation operations
44    ProcessMessage {
45        session_id: String,
46        message: crate::memory::Message,
47        sender: Box<dyn FnOnce(anyhow::Result<crate::memory::Message>) + Send>,
48    },
49    
50    // LLM operations
51    GenerateResponse {
52        session_id: String,
53        context: Vec<crate::memory::Message>,
54        sender: Box<dyn FnOnce(anyhow::Result<String>) + Send>,
55    },
56    
57    // Cache operations
58    UpdateCache {
59        session_id: String,
60        entries: Vec<crate::cache_management::cache_extractor::KVEntry>,
61        sender: Box<dyn FnOnce(anyhow::Result<()>) + Send>,
62    },
63    
64    // Database operations
65    PersistConversation {
66        session_id: String,
67        messages: Vec<crate::memory::Message>,
68        sender: Box<dyn FnOnce(anyhow::Result<()>) + Send>,
69    },
70    
71    // Administrative operations
72    Shutdown,
73}
74
75/// Worker thread implementation
76pub struct WorkerThread {
77    thread_handle: Option<JoinHandle<()>>,
78    running: Arc<AtomicBool>,
79}
80
81impl WorkerThread {
82    pub fn new(
83        name: String,
84        command_receiver: mpsc::UnboundedReceiver<SystemCommand>,
85        shared_state: Arc<SharedState>,
86    ) -> Self {
87        let running = Arc::new(AtomicBool::new(true));
88        let running_clone = running.clone();
89
90        let thread_handle = thread::Builder::new()
91            .name(name.clone())
92            .spawn(move || {
93                let rt = tokio::runtime::Builder::new_current_thread()
94                    .enable_all()
95                    .build()
96                    .expect("Failed to create worker thread runtime");
97
98                rt.block_on(async move {
99                    Self::run_worker_loop(command_receiver, shared_state, running_clone).await;
100                });
101            })
102            .expect("Failed to spawn worker thread");
103
104        info!("Spawned worker thread: {}", name);
105
106        Self {
107            thread_handle: Some(thread_handle),
108            running,
109        }
110    }
111    
112    async fn run_worker_loop(
113        mut receiver: mpsc::UnboundedReceiver<SystemCommand>,
114        shared_state: Arc<SharedState>,
115        running: Arc<AtomicBool>,
116    ) {
117        while running.load(Ordering::Relaxed) {
118            tokio::select! {
119                command = receiver.recv() => {
120                    match command {
121                        Some(cmd) => {
122                            if let Err(e) = Self::handle_command(cmd, &shared_state).await {
123                                error!("Worker thread command failed: {}", e);
124                            }
125                        }
126                        None => {
127                            info!("Worker thread command channel closed");
128                            break;
129                        }
130                    }
131                }
132                _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => {
133                    // Periodic maintenance tasks could go here
134                }
135            }
136        }
137        
138        info!("Worker thread shutting down");
139    }
140    
141    async fn handle_command(
142        command: SystemCommand,
143        shared_state: &Arc<SharedState>,
144    ) -> anyhow::Result<()> {
145        match command {
146            SystemCommand::ProcessMessage { session_id, message, sender } => {
147                let result = Self::process_message(shared_state, session_id, message).await;
148                sender(result);
149            }
150            SystemCommand::GenerateResponse { session_id, context, sender } => {
151                let result = Self::generate_response(shared_state, session_id, context).await;
152                sender(result);
153            }
154            SystemCommand::UpdateCache { session_id, entries, sender } => {
155                let result = Self::update_cache(shared_state, session_id, entries).await;
156                sender(result);
157            }
158            SystemCommand::PersistConversation { session_id, messages, sender } => {
159                let result = Self::persist_conversation(shared_state, session_id, messages).await;
160                sender(result);
161            }
162            SystemCommand::Shutdown => {
163                // Graceful shutdown handled by running flag
164            }
165        }
166        Ok(())
167    }
168    
169    async fn process_message(
170        shared_state: &Arc<SharedState>,
171        session_id: String,
172        message: crate::memory::Message,
173    ) -> anyhow::Result<crate::memory::Message> {
174        // Get or create session
175        let session = shared_state.get_or_create_session(&session_id).await;
176        let mut session_guard = session.write()
177            .map_err(|_| anyhow::anyhow!("Failed to acquire session write lock"))?;
178        
179        // Add message to session
180        session_guard.messages.push(message.clone());
181        session_guard.last_accessed = std::time::Instant::now();
182        
183        // Update metrics
184        shared_state.counters.inc_processed_messages();
185        
186        Ok(message)
187    }
188    
189    async fn generate_response(
190        shared_state: &Arc<SharedState>,
191        session_id: String,
192        context: Vec<crate::memory::Message>,
193    ) -> anyhow::Result<String> {
194        // Use the LLM worker for actual response generation
195        let response = shared_state
196            .llm_worker
197            .generate_response(session_id, context.clone())
198            .await?;
199        
200        Ok(response)
201    }
202    
203    async fn update_cache(
204        shared_state: &Arc<SharedState>,
205        session_id: String,
206        entries: Vec<crate::cache_management::cache_extractor::KVEntry>,
207    ) -> anyhow::Result<()> {
208        let cache_guard = shared_state.cache_manager.read()
209            .map_err(|_| anyhow::anyhow!("Failed to acquire cache manager read lock"))?;
210        if let Some(_cache_manager) = &*cache_guard {
211            // Update cache with new entries
212            // Implementation would depend on the specific cache manager API
213            info!("Updating cache for session {} with {} entries", session_id, entries.len());
214        }
215        Ok(())
216    }
217    
218    async fn persist_conversation(
219        _shared_state: &Arc<SharedState>,
220        session_id: String,
221        messages: Vec<crate::memory::Message>,
222    ) -> anyhow::Result<()> {
223        // Persist to database using shared connection pool
224        info!("Persisting conversation {} with {} messages", session_id, messages.len());
225        // Actual implementation would use shared_state.database_pool
226        Ok(())
227    }
228}
229
230impl Drop for WorkerThread {
231    fn drop(&mut self) {
232        self.running.store(false, Ordering::Relaxed);
233        if let Some(handle) = self.thread_handle.take() {
234            let _ = handle.join();
235        }
236    }
237}
238
239/// Thread pool manager for coordinating worker threads
240pub struct ThreadPool {
241    config: ThreadPoolConfig,
242    shared_state: Arc<SharedState>,
243    workers: Vec<WorkerThread>,
244    command_senders: Vec<mpsc::UnboundedSender<SystemCommand>>,
245}
246
247impl ThreadPool {
248    pub fn new(config: ThreadPoolConfig, shared_state: Arc<SharedState>) -> Self {
249        Self {
250            config,
251            shared_state,
252            workers: Vec::new(),
253            command_senders: Vec::new(),
254        }
255    }
256    
257    pub async fn start(&mut self) -> anyhow::Result<()> {
258        info!("Starting thread pool with config: {:?}", self.config);
259        
260        // Create command channels
261        let mut channels = Vec::new();
262        for i in 0..self.config.context_engine_threads {
263            let (tx, rx) = mpsc::unbounded_channel();
264            channels.push((format!("context-worker-{}", i), tx, rx));
265        }
266        
267        // Spawn worker threads
268        for (name, tx, rx) in channels {
269            let worker = WorkerThread::new(
270                name,
271                rx,
272                self.shared_state.clone(),
273            );
274            self.workers.push(worker);
275            self.command_senders.push(tx);
276        }
277        
278        info!("Thread pool started with {} workers", self.workers.len());
279        Ok(())
280    }
281    
282    pub async fn send_command(&self, command: SystemCommand) -> anyhow::Result<()> {
283        // Simple round-robin distribution for now
284        static NEXT_WORKER: AtomicBool = AtomicBool::new(false);
285        let worker_index = if NEXT_WORKER.fetch_xor(true, Ordering::Relaxed) { 0 } else { 1 };
286        let sender_index = worker_index % self.command_senders.len();
287        
288        self.command_senders[sender_index]
289            .send(command)
290            .map_err(|_| anyhow::anyhow!("Failed to send command to worker thread"))
291    }
292    
293    pub async fn shutdown(&mut self) -> anyhow::Result<()> {
294        info!("Shutting down thread pool");
295        
296        // Send shutdown commands
297        for sender in &self.command_senders {
298            let _ = sender.send(SystemCommand::Shutdown);
299        }
300        
301        // Drop workers to trigger cleanup
302        self.workers.clear();
303        self.command_senders.clear();
304        
305        info!("Thread pool shutdown complete");
306        Ok(())
307    }
308}
309
310// Convenience re-exports
311pub use self::SystemCommand as Command;