Skip to main content

offline_intelligence/
shared_state.rs

1//! Shared state management for thread-based architecture
2//!
3//! This module provides the core shared memory infrastructure that enables
4//! efficient communication between worker threads while maintaining thread safety.
5
6use std::sync::{Arc, RwLock, atomic::{AtomicUsize, AtomicBool, Ordering}};
7use dashmap::DashMap;
8use tracing::info;
9
10use crate::{
11    config::Config,
12    context_engine::ContextOrchestrator,
13    memory_db::MemoryDatabase,
14    cache_management::KVCacheManager,
15    worker_threads::LLMWorker,
16    model_management::ModelManager,
17    model_runtime::RuntimeManager,
18};
19use crate::engine_management::EngineManager;
20
21/// Cached result of pre-extracting a file attachment before the user hits Send.
22/// Populated by `POST /attachments/preprocess`, consumed by `/generate/stream`.
23#[derive(Clone)]
24pub struct PreExtracted {
25    pub text: String,
26    pub extracted_at: std::time::Instant,
27}
28
29impl PreExtracted {
30    /// Returns true when the entry has exceeded its time-to-live.
31    pub fn is_stale(&self, ttl_secs: u64) -> bool {
32        self.extracted_at.elapsed().as_secs() >= ttl_secs
33    }
34}
35
36/// Core shared system state container
37pub struct SharedSystemState {
38    /// Conversation data with hierarchical locking
39    pub conversations: Arc<ConversationHierarchy>,
40
41    /// LLM runtime for direct inference
42    pub llm_runtime: Arc<RwLock<Option<LLMRuntime>>>,
43
44    /// Cache management system
45    pub cache_manager: Arc<RwLock<Option<Arc<KVCacheManager>>>>,
46
47    /// Database connection pool
48    pub database_pool: Arc<MemoryDatabase>,
49
50    /// Configuration (read-only after initialization)
51    pub config: Arc<Config>,
52
53    /// Atomic counters for performance tracking
54    pub counters: Arc<AtomicCounters>,
55
56    /// Context orchestrator for memory management (tokio RwLock for async access)
57    pub context_orchestrator: Arc<tokio::sync::RwLock<Option<ContextOrchestrator>>>,
58
59    /// LLM worker for inference operations
60    pub llm_worker: Arc<LLMWorker>,
61
62    /// Model management system
63    pub model_manager: Option<Arc<ModelManager>>,
64
65    /// Runtime management system
66    pub runtime_manager: Arc<std::sync::RwLock<Option<Arc<RuntimeManager>>>>,
67
68    /// Engine management system
69    pub engine_manager: Option<Arc<EngineManager>>,
70
71    /// Engine availability flag - true when an engine binary is ready for use
72    pub engine_available: Arc<AtomicBool>,
73
74    /// Initialization completion flag - true when all components are ready
75    pub initialization_complete: Arc<AtomicBool>,
76    
77    /// HTTP server port - may differ from config if original port was in use
78    pub http_port: Arc<RwLock<u16>>,
79
80    /// Pre-extracted attachment text cache.
81    /// Key: `"inline:{path}"` or `"local_storage:{id}"`.
82    /// Populated by `POST /attachments/preprocess` while the user types;
83    /// consumed (and evicted) by `/generate/stream` at Send time.
84    pub attachment_cache: Arc<DashMap<String, PreExtracted>>,
85
86    /// Limits concurrent binary-file extractions to num_cpus/2 (min 1, max 8)
87    /// so the LLM server is never CPU-starved on low-spec hardware.
88    pub extraction_semaphore: Arc<tokio::sync::Semaphore>,
89}
90
91/// Hierarchical conversation storage for reduced lock contention
92pub struct ConversationHierarchy {
93    /// Coarse-grained session-level locks
94    pub sessions: DashMap<String, Arc<RwLock<SessionData>>>,
95
96    /// Fine-grained message-level queues for hot paths
97    pub message_queues: DashMap<String, Arc<crossbeam_queue::ArrayQueue<PendingMessage>>>,
98
99    /// Lock-free counters for performance metrics
100    pub counters: Arc<AtomicCounters>,
101}
102
103/// Session-level data structure
104#[derive(Debug, Clone)]
105pub struct SessionData {
106    pub session_id: String,
107    pub messages: Vec<crate::memory::Message>,
108    pub last_accessed: std::time::Instant,
109    pub pinned: bool,
110}
111
112/// Pending message for asynchronous processing
113#[derive(Debug, Clone)]
114pub struct PendingMessage {
115    pub message: crate::memory::Message,
116    pub timestamp: std::time::Instant,
117}
118
119/// Atomic counters for system metrics
120pub struct AtomicCounters {
121    pub total_requests: AtomicUsize,
122    pub active_sessions: AtomicUsize,
123    pub processed_messages: AtomicUsize,
124    pub cache_hits: AtomicUsize,
125    pub cache_misses: AtomicUsize,
126}
127
128impl AtomicCounters {
129    pub fn new() -> Self {
130        Self {
131            total_requests: AtomicUsize::new(0),
132            active_sessions: AtomicUsize::new(0),
133            processed_messages: AtomicUsize::new(0),
134            cache_hits: AtomicUsize::new(0),
135            cache_misses: AtomicUsize::new(0),
136        }
137    }
138
139    pub fn inc_total_requests(&self) -> usize {
140        self.total_requests.fetch_add(1, Ordering::Relaxed) + 1
141    }
142
143    pub fn inc_processed_messages(&self) -> usize {
144        self.processed_messages.fetch_add(1, Ordering::Relaxed) + 1
145    }
146
147    pub fn inc_cache_hit(&self) -> usize {
148        self.cache_hits.fetch_add(1, Ordering::Relaxed) + 1
149    }
150
151    pub fn inc_cache_miss(&self) -> usize {
152        self.cache_misses.fetch_add(1, Ordering::Relaxed) + 1
153    }
154}
155
156/// Direct LLM runtime integration
157pub struct LLMRuntime {
158    pub model_path: String,
159    pub context_size: u32,
160    pub batch_size: u32,
161    pub threads: u32,
162    pub gpu_layers: u32,
163    // Note: Actual llama.cpp integration would go here
164    // For now, we'll maintain the existing HTTP proxy approach
165    // but prepare the structure for direct integration
166}
167
168impl SharedSystemState {
169    pub fn new(config: Config, database: Arc<MemoryDatabase>) -> anyhow::Result<Self> {
170        info!("Initializing shared system state");
171
172        let conversations = Arc::new(ConversationHierarchy {
173            sessions: DashMap::new(),
174            message_queues: DashMap::new(),
175            counters: Arc::new(AtomicCounters::new()),
176        });
177
178        // Extract values before moving config into Arc
179        let api_port = config.api_port;
180        let backend_url = config.backend_url.clone();
181        
182        let config = Arc::new(config);
183        let counters = Arc::new(AtomicCounters::new());
184
185        // Create LLM worker with backend URL from config
186        let llm_worker = Arc::new(LLMWorker::new_with_backend(backend_url));
187
188        // Semaphore for concurrent binary-file extractions.
189        // Use half the logical cores so the LLM server is never CPU-starved.
190        let max_concurrent = (num_cpus::get() / 2).max(1).min(8);
191        info!("Attachment extraction semaphore: {} concurrent slots (num_cpus={})", max_concurrent, num_cpus::get());
192
193        Ok(Self {
194            conversations,
195            llm_runtime: Arc::new(RwLock::new(None)),
196            cache_manager: Arc::new(RwLock::new(None)),
197            database_pool: database,
198            config,
199            counters,
200            context_orchestrator: Arc::new(tokio::sync::RwLock::new(None)),
201            llm_worker,
202            model_manager: None,
203            runtime_manager: Arc::new(std::sync::RwLock::new(None)),
204            engine_manager: None,
205            engine_available: Arc::new(AtomicBool::new(false)),
206            initialization_complete: Arc::new(AtomicBool::new(false)),
207            http_port: Arc::new(RwLock::new(api_port)),
208            attachment_cache: Arc::new(DashMap::new()),
209            extraction_semaphore: Arc::new(tokio::sync::Semaphore::new(max_concurrent)),
210        })
211    }
212
213    /// Mark initialization as complete - call this after all components are initialized
214    pub fn mark_initialization_complete(&self) {
215        self.initialization_complete.store(true, Ordering::SeqCst);
216        info!("✅ Backend initialization marked as complete");
217    }
218
219    /// Check if initialization is complete
220    pub fn is_initialization_complete(&self) -> bool {
221        self.initialization_complete.load(Ordering::SeqCst)
222    }
223
224    /// Set LLM worker (replaces the default one created during initialization)
225    pub fn set_llm_worker(&self, _worker: Arc<LLMWorker>) {
226        // The llm_worker is now initialized in new() with the backend URL.
227        // This method is kept for backward compatibility but is a no-op since
228        // we initialize with the correct backend URL during construction.
229        info!("LLM worker already initialized with backend URL");
230    }
231
232    /// Set runtime manager (allows setting after initialization)
233    pub fn set_runtime_manager(&self, runtime_manager: Arc<RuntimeManager>) -> anyhow::Result<()> {
234        // Update the runtime manager in shared state
235        let mut guard = self.runtime_manager
236            .write()
237            .map_err(|e| anyhow::anyhow!("Failed to acquire runtime manager write lock: {}", e))?;
238        *guard = Some(runtime_manager);
239        Ok(())
240    }
241
242    /// Initialize LLM runtime with current configuration
243    pub fn initialize_llm_runtime(&self) -> anyhow::Result<()> {
244        let mut runtime_guard = self.llm_runtime.try_write()
245            .map_err(|_| anyhow::anyhow!("Failed to acquire LLM runtime write lock"))?;
246
247        let runtime = LLMRuntime {
248            model_path: self.config.model_path.clone(),
249            context_size: self.config.ctx_size,
250            batch_size: self.config.batch_size,
251            threads: self.config.threads,
252            gpu_layers: self.config.gpu_layers,
253        };
254
255        *runtime_guard = Some(runtime);
256        info!("LLM runtime initialized");
257        Ok(())
258    }
259
260    /// Get or create session data with proper locking
261    pub async fn get_or_create_session(&self, session_id: &str) -> Arc<RwLock<SessionData>> {
262        // Fast path: try to get existing session
263        if let Some(session) = self.conversations.sessions.get(session_id) {
264            return session.clone();
265        }
266
267        // Slow path: create new session
268        let new_session = Arc::new(RwLock::new(SessionData {
269            session_id: session_id.to_string(),
270            messages: Vec::new(),
271            last_accessed: std::time::Instant::now(),
272            pinned: false,
273        }));
274
275        self.conversations.sessions.insert(session_id.to_string(), new_session.clone());
276        self.counters.active_sessions.fetch_add(1, Ordering::Relaxed);
277
278        new_session
279    }
280
281    /// Queue message for asynchronous processing
282    pub fn queue_message(&self, session_id: &str, message: crate::memory::Message) -> bool {
283        let queue = self.conversations.message_queues
284            .entry(session_id.to_string())
285            .or_insert_with(|| Arc::new(crossbeam_queue::ArrayQueue::new(1000)));
286
287        queue.push(PendingMessage {
288            message,
289            timestamp: std::time::Instant::now(),
290        }).is_ok()
291    }
292
293    /// Process queued messages for a session
294    pub async fn process_queued_messages(&self, session_id: &str) -> Vec<PendingMessage> {
295        let mut messages = Vec::new();
296
297        if let Some(queue) = self.conversations.message_queues.get(session_id) {
298            while let Some(msg) = queue.pop() {
299                messages.push(msg);
300            }
301        }
302
303        messages
304    }
305}
306
307impl ConversationHierarchy {
308    pub fn new() -> Self {
309        Self {
310            sessions: DashMap::new(),
311            message_queues: DashMap::new(),
312            counters: Arc::new(AtomicCounters::new()),
313        }
314    }
315}
316
317/// Unified application state for all API handlers.
318/// This is the single state type used by the Axum router, providing access
319/// to all subsystems through shared memory (Arc) rather than network hops.
320#[derive(Clone)]
321pub struct UnifiedAppState {
322    pub shared_state: Arc<SharedSystemState>,
323    pub context_orchestrator: Arc<tokio::sync::RwLock<Option<ContextOrchestrator>>>,
324    pub llm_worker: Arc<LLMWorker>,
325    pub auth_state: Option<Arc<crate::api::auth_api::AuthState>>,
326    /// Shared HTTP client — one TLS pool reused across all outbound requests
327    /// (OpenRouter, HuggingFace, etc.) instead of creating a new client per call.
328    pub http_client: reqwest::Client,
329}
330
331impl UnifiedAppState {
332    pub fn new(shared_state: Arc<SharedSystemState>) -> Self {
333        let context_orchestrator = shared_state.context_orchestrator.clone();
334        let llm_worker = shared_state.llm_worker.clone();
335        // Build a single shared client with a generous timeout for LLM streaming.
336        // reqwest::Client is cheaply Clone (just bumps an Arc ref-count internally).
337        let http_client = reqwest::Client::builder()
338            .timeout(std::time::Duration::from_secs(300))
339            .build()
340            .unwrap_or_else(|_| reqwest::Client::new());
341        Self {
342            shared_state,
343            context_orchestrator,
344            llm_worker,
345            auth_state: None,
346            http_client,
347        }
348    }
349
350    /// Get API key from all sources in priority order:
351    /// 1. Database stored keys (persisted)
352    /// 2. Environment variables
353    /// 3. Config file
354    /// 
355    /// This ensures ultimate synchronicity - keys stored in DB are available
356    /// to both backend and frontend immediately after saving.
357    pub async fn get_openrouter_api_key(&self) -> Option<String> {
358        // 1. First check database (highest priority - persisted)
359        if let Ok(Some(key)) = self.shared_state.database_pool.api_keys.get_key_plaintext(&crate::memory_db::ApiKeyType::OpenRouter) {
360            if !key.is_empty() {
361                info!("Using OpenRouter API key from database");
362                return Some(key);
363            }
364        }
365        
366        // 2. Check environment variable
367        if let Ok(key) = std::env::var("OPENROUTER_API_KEY") {
368            if !key.is_empty() {
369                info!("Using OpenRouter API key from environment variable");
370                return Some(key);
371            }
372        }
373        
374        // 3. Check config
375        let config = &self.shared_state.config;
376        if !config.openrouter_api_key.is_empty() {
377            info!("Using OpenRouter API key from config");
378            return Some(config.openrouter_api_key.clone());
379        }
380        
381        None
382    }
383
384    /// Get HuggingFace token from all sources in priority order:
385    /// 1. Database stored keys (persisted)
386    /// 2. Environment variables
387    /// 3. Config file
388    pub async fn get_huggingface_token(&self) -> Option<String> {
389        // 1. First check database (highest priority - persisted)
390        if let Ok(Some(token)) = self.shared_state.database_pool.api_keys.get_key_plaintext(&crate::memory_db::ApiKeyType::HuggingFace) {
391            if !token.is_empty() {
392                info!("Using HuggingFace token from database");
393                return Some(token);
394            }
395        }
396        
397        // 2. Check environment variable
398        // HUGGINGFACE_TOKEN or HF_TOKEN are common env var names
399        if let Ok(token) = std::env::var("HUGGINGFACE_TOKEN").or_else(|_| std::env::var("HF_TOKEN")) {
400            if !token.is_empty() {
401                info!("Using HuggingFace token from environment variable");
402                return Some(token);
403            }
404        }
405        
406        None
407    }
408}
409
410// Re-exports for convenience
411pub use self::SharedSystemState as SharedState;