Skip to main content

offline_intelligence/
thread_pool.rs

1//! Thread pool management for worker threads
2//!
3//! This module provides the infrastructure for managing dedicated worker threads
4//! for different system components, enabling efficient parallel processing.
5
6use std::sync::{Arc, atomic::{AtomicBool, Ordering}};
7use std::thread::{self, JoinHandle};
8use tokio::sync::{mpsc, oneshot};
9use tracing::{info, error};
10
11use crate::{
12    shared_state::SharedState,
13    config::Config,
14};
15
16/// Configuration for thread pool sizing
17#[derive(Debug, Clone)]
18pub struct ThreadPoolConfig {
19    pub context_engine_threads: usize,
20    pub cache_manager_threads: usize,
21    pub database_threads: usize,
22    pub llm_threads: usize,
23    pub io_threads: usize,
24}
25
26impl ThreadPoolConfig {
27    pub fn new(config: &Config) -> Self {
28        // Scale thread counts based on system resources
29        let cpu_cores = num_cpus::get();
30        
31        Self {
32            context_engine_threads: (cpu_cores / 4).max(2).min(4),
33            cache_manager_threads: 1.max(cpu_cores / 8).min(2),
34            database_threads: config.max_concurrent_streams as usize,
35            llm_threads: 1, // LLM inference is typically single-threaded per model
36            io_threads: (cpu_cores / 2).max(2).min(4),
37        }
38    }
39}
40
41/// System-wide command types for thread communication
42pub enum SystemCommand {
43    // Conversation operations
44    ProcessMessage {
45        session_id: String,
46        message: crate::memory::Message,
47        sender: Box<dyn FnOnce(anyhow::Result<crate::memory::Message>) + Send>,
48    },
49    
50    // LLM operations
51    GenerateResponse {
52        session_id: String,
53        context: Vec<crate::memory::Message>,
54        sender: Box<dyn FnOnce(anyhow::Result<String>) + Send>,
55    },
56    
57    // Cache operations
58    UpdateCache {
59        session_id: String,
60        entries: Vec<crate::cache_management::cache_extractor::KVEntry>,
61        sender: Box<dyn FnOnce(anyhow::Result<()>) + Send>,
62    },
63    
64    // Database operations
65    PersistConversation {
66        session_id: String,
67        messages: Vec<crate::memory::Message>,
68        sender: Box<dyn FnOnce(anyhow::Result<()>) + Send>,
69    },
70    
71    // Administrative operations
72    Shutdown,
73}
74
75/// Worker thread implementation
76pub struct WorkerThread {
77    command_receiver: mpsc::UnboundedReceiver<SystemCommand>,
78    shared_state: Arc<SharedState>,
79    thread_handle: Option<JoinHandle<()>>,
80    running: Arc<AtomicBool>,
81}
82
83impl WorkerThread {
84    pub fn new(
85        name: String,
86        command_receiver: mpsc::UnboundedReceiver<SystemCommand>,
87        shared_state: Arc<SharedState>,
88    ) -> Self {
89        let running = Arc::new(AtomicBool::new(true));
90        let running_clone = running.clone();
91        let shared_state_clone = shared_state.clone();
92        
93        let thread_handle = thread::Builder::new()
94            .name(name.clone())
95            .spawn({
96                let receiver = command_receiver; // Move receiver into closure
97                move || {
98                    let rt = tokio::runtime::Builder::new_current_thread()
99                        .enable_all()
100                        .build()
101                        .expect("Failed to create worker thread runtime");
102                    
103                    rt.block_on(async move {
104                        Self::run_worker_loop(receiver, shared_state_clone, running_clone).await;
105                    });
106                }
107            })
108            .expect("Failed to spawn worker thread");
109        
110        info!("Spawned worker thread: {}", name);
111        
112        Self {
113            command_receiver: mpsc::unbounded_channel().1, // Create dummy receiver
114            shared_state,
115            thread_handle: Some(thread_handle),
116            running,
117        }
118    }
119    
120    async fn run_worker_loop(
121        mut receiver: mpsc::UnboundedReceiver<SystemCommand>,
122        shared_state: Arc<SharedState>,
123        running: Arc<AtomicBool>,
124    ) {
125        while running.load(Ordering::Relaxed) {
126            tokio::select! {
127                command = receiver.recv() => {
128                    match command {
129                        Some(cmd) => {
130                            if let Err(e) = Self::handle_command(cmd, &shared_state).await {
131                                error!("Worker thread command failed: {}", e);
132                            }
133                        }
134                        None => {
135                            info!("Worker thread command channel closed");
136                            break;
137                        }
138                    }
139                }
140                _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => {
141                    // Periodic maintenance tasks could go here
142                }
143            }
144        }
145        
146        info!("Worker thread shutting down");
147    }
148    
149    async fn handle_command(
150        command: SystemCommand,
151        shared_state: &Arc<SharedState>,
152    ) -> anyhow::Result<()> {
153        match command {
154            SystemCommand::ProcessMessage { session_id, message, sender } => {
155                let result = Self::process_message(shared_state, session_id, message).await;
156                sender(result);
157            }
158            SystemCommand::GenerateResponse { session_id, context, sender } => {
159                let result = Self::generate_response(shared_state, session_id, context).await;
160                sender(result);
161            }
162            SystemCommand::UpdateCache { session_id, entries, sender } => {
163                let result = Self::update_cache(shared_state, session_id, entries).await;
164                sender(result);
165            }
166            SystemCommand::PersistConversation { session_id, messages, sender } => {
167                let result = Self::persist_conversation(shared_state, session_id, messages).await;
168                sender(result);
169            }
170            SystemCommand::Shutdown => {
171                // Graceful shutdown handled by running flag
172            }
173        }
174        Ok(())
175    }
176    
177    async fn process_message(
178        shared_state: &Arc<SharedState>,
179        session_id: String,
180        message: crate::memory::Message,
181    ) -> anyhow::Result<crate::memory::Message> {
182        // Get or create session
183        let session = shared_state.get_or_create_session(&session_id).await;
184        let mut session_guard = session.write()
185            .map_err(|_| anyhow::anyhow!("Failed to acquire session write lock"))?;
186        
187        // Add message to session
188        session_guard.messages.push(message.clone());
189        session_guard.last_accessed = std::time::Instant::now();
190        
191        // Update metrics
192        shared_state.counters.inc_processed_messages();
193        
194        Ok(message)
195    }
196    
197    async fn generate_response(
198        _shared_state: &Arc<SharedState>,
199        _session_id: String,
200        _context: Vec<crate::memory::Message>,
201    ) -> anyhow::Result<String> {
202        // Placeholder for LLM integration
203        // In full implementation, this would call the direct llama.cpp interface
204        Ok("Generated response placeholder".to_string())
205    }
206    
207    async fn update_cache(
208        shared_state: &Arc<SharedState>,
209        session_id: String,
210        entries: Vec<crate::cache_management::cache_extractor::KVEntry>,
211    ) -> anyhow::Result<()> {
212        let cache_guard = shared_state.cache_manager.read()
213            .map_err(|_| anyhow::anyhow!("Failed to acquire cache manager read lock"))?;
214        if let Some(cache_manager) = &*cache_guard {
215            // Update cache with new entries
216            // Implementation would depend on the specific cache manager API
217            info!("Updating cache for session {} with {} entries", session_id, entries.len());
218        }
219        Ok(())
220    }
221    
222    async fn persist_conversation(
223        shared_state: &Arc<SharedState>,
224        session_id: String,
225        messages: Vec<crate::memory::Message>,
226    ) -> anyhow::Result<()> {
227        // Persist to database using shared connection pool
228        info!("Persisting conversation {} with {} messages", session_id, messages.len());
229        // Actual implementation would use shared_state.database_pool
230        Ok(())
231    }
232}
233
234impl Drop for WorkerThread {
235    fn drop(&mut self) {
236        self.running.store(false, Ordering::Relaxed);
237        if let Some(handle) = self.thread_handle.take() {
238            let _ = handle.join();
239        }
240    }
241}
242
243/// Thread pool manager for coordinating worker threads
244pub struct ThreadPool {
245    config: ThreadPoolConfig,
246    shared_state: Arc<SharedState>,
247    workers: Vec<WorkerThread>,
248    command_senders: Vec<mpsc::UnboundedSender<SystemCommand>>,
249}
250
251impl ThreadPool {
252    pub fn new(config: ThreadPoolConfig, shared_state: Arc<SharedState>) -> Self {
253        Self {
254            config,
255            shared_state,
256            workers: Vec::new(),
257            command_senders: Vec::new(),
258        }
259    }
260    
261    pub async fn start(&mut self) -> anyhow::Result<()> {
262        info!("Starting thread pool with config: {:?}", self.config);
263        
264        // Create command channels
265        let mut channels = Vec::new();
266        for i in 0..self.config.context_engine_threads {
267            let (tx, rx) = mpsc::unbounded_channel();
268            channels.push((format!("context-worker-{}", i), tx, rx));
269        }
270        
271        // Spawn worker threads
272        for (name, tx, rx) in channels {
273            let worker = WorkerThread::new(
274                name,
275                rx,
276                self.shared_state.clone(),
277            );
278            self.workers.push(worker);
279            self.command_senders.push(tx);
280        }
281        
282        info!("Thread pool started with {} workers", self.workers.len());
283        Ok(())
284    }
285    
286    pub async fn send_command(&self, command: SystemCommand) -> anyhow::Result<()> {
287        // Simple round-robin distribution for now
288        static NEXT_WORKER: AtomicBool = AtomicBool::new(false);
289        let worker_index = if NEXT_WORKER.fetch_xor(true, Ordering::Relaxed) { 0 } else { 1 };
290        let sender_index = worker_index % self.command_senders.len();
291        
292        self.command_senders[sender_index]
293            .send(command)
294            .map_err(|_| anyhow::anyhow!("Failed to send command to worker thread"))
295    }
296    
297    pub async fn shutdown(&mut self) -> anyhow::Result<()> {
298        info!("Shutting down thread pool");
299        
300        // Send shutdown commands
301        for sender in &self.command_senders {
302            let _ = sender.send(SystemCommand::Shutdown);
303        }
304        
305        // Drop workers to trigger cleanup
306        self.workers.clear();
307        self.command_senders.clear();
308        
309        info!("Thread pool shutdown complete");
310        Ok(())
311    }
312}
313
314// Convenience re-exports
315pub use self::SystemCommand as Command;