Skip to main content

offline_intelligence/
shared_state.rs

1//! Shared state management for thread-based architecture
2//!
3//! This module provides the core shared memory infrastructure that enables
4//! efficient communication between worker threads while maintaining thread safety.
5
6use std::sync::{Arc, RwLock, atomic::{AtomicUsize, Ordering}};
7use dashmap::DashMap;
8use tracing::info;
9
10use crate::{
11    config::Config,
12    context_engine::ContextOrchestrator,
13    memory_db::MemoryDatabase,
14    cache_management::KVCacheManager,
15    worker_threads::LLMWorker,
16};
17
18/// Core shared system state container
19pub struct SharedSystemState {
20    /// Conversation data with hierarchical locking
21    pub conversations: Arc<ConversationHierarchy>,
22
23    /// LLM runtime for direct inference
24    pub llm_runtime: Arc<RwLock<Option<LLMRuntime>>>,
25
26    /// Cache management system
27    pub cache_manager: Arc<RwLock<Option<Arc<KVCacheManager>>>>,
28
29    /// Database connection pool
30    pub database_pool: Arc<MemoryDatabase>,
31
32    /// Configuration (read-only after initialization)
33    pub config: Arc<Config>,
34
35    /// Atomic counters for performance tracking
36    pub counters: Arc<AtomicCounters>,
37
38    /// Context orchestrator for memory management (tokio RwLock for async access)
39    pub context_orchestrator: Arc<tokio::sync::RwLock<Option<ContextOrchestrator>>>,
40
41    /// LLM worker for inference operations
42    pub llm_worker: Arc<LLMWorker>,
43}
44
45/// Hierarchical conversation storage for reduced lock contention
46pub struct ConversationHierarchy {
47    /// Coarse-grained session-level locks
48    pub sessions: DashMap<String, Arc<RwLock<SessionData>>>,
49
50    /// Fine-grained message-level queues for hot paths
51    pub message_queues: DashMap<String, Arc<crossbeam_queue::ArrayQueue<PendingMessage>>>,
52
53    /// Lock-free counters for performance metrics
54    pub counters: Arc<AtomicCounters>,
55}
56
57/// Session-level data structure
58#[derive(Debug, Clone)]
59pub struct SessionData {
60    pub session_id: String,
61    pub messages: Vec<crate::memory::Message>,
62    pub last_accessed: std::time::Instant,
63    pub pinned: bool,
64}
65
66/// Pending message for asynchronous processing
67#[derive(Debug, Clone)]
68pub struct PendingMessage {
69    pub message: crate::memory::Message,
70    pub timestamp: std::time::Instant,
71}
72
73/// Atomic counters for system metrics
74pub struct AtomicCounters {
75    pub total_requests: AtomicUsize,
76    pub active_sessions: AtomicUsize,
77    pub processed_messages: AtomicUsize,
78    pub cache_hits: AtomicUsize,
79    pub cache_misses: AtomicUsize,
80}
81
82impl AtomicCounters {
83    pub fn new() -> Self {
84        Self {
85            total_requests: AtomicUsize::new(0),
86            active_sessions: AtomicUsize::new(0),
87            processed_messages: AtomicUsize::new(0),
88            cache_hits: AtomicUsize::new(0),
89            cache_misses: AtomicUsize::new(0),
90        }
91    }
92
93    pub fn inc_total_requests(&self) -> usize {
94        self.total_requests.fetch_add(1, Ordering::Relaxed) + 1
95    }
96
97    pub fn inc_processed_messages(&self) -> usize {
98        self.processed_messages.fetch_add(1, Ordering::Relaxed) + 1
99    }
100
101    pub fn inc_cache_hit(&self) -> usize {
102        self.cache_hits.fetch_add(1, Ordering::Relaxed) + 1
103    }
104
105    pub fn inc_cache_miss(&self) -> usize {
106        self.cache_misses.fetch_add(1, Ordering::Relaxed) + 1
107    }
108}
109
110/// Direct LLM runtime integration
111pub struct LLMRuntime {
112    pub model_path: String,
113    pub context_size: u32,
114    pub batch_size: u32,
115    pub threads: u32,
116    pub gpu_layers: u32,
117    // Note: Actual llama.cpp integration would go here
118    // For now, we'll maintain the existing HTTP proxy approach
119    // but prepare the structure for direct integration
120}
121
122impl SharedSystemState {
123    pub fn new(config: Config, database: Arc<MemoryDatabase>) -> anyhow::Result<Self> {
124        info!("Initializing shared system state");
125
126        let conversations = Arc::new(ConversationHierarchy {
127            sessions: DashMap::new(),
128            message_queues: DashMap::new(),
129            counters: Arc::new(AtomicCounters::new()),
130        });
131
132        let config = Arc::new(config);
133        let counters = Arc::new(AtomicCounters::new());
134
135        // Create LLM worker with backend URL from config
136        let backend_url = config.backend_url.clone();
137        let llm_worker = Arc::new(LLMWorker::new_with_backend(backend_url));
138
139        Ok(Self {
140            conversations,
141            llm_runtime: Arc::new(RwLock::new(None)),
142            cache_manager: Arc::new(RwLock::new(None)),
143            database_pool: database,
144            config,
145            counters,
146            context_orchestrator: Arc::new(tokio::sync::RwLock::new(None)),
147            llm_worker,
148        })
149    }
150
151    /// Set LLM worker (replaces the default one created during initialization)
152    pub fn set_llm_worker(&self, _worker: Arc<LLMWorker>) {
153        // The llm_worker is now initialized in new() with the backend URL.
154        // This method is kept for backward compatibility but is a no-op since
155        // we initialize with the correct backend URL during construction.
156        info!("LLM worker already initialized with backend URL");
157    }
158
159    /// Initialize LLM runtime with current configuration
160    pub fn initialize_llm_runtime(&self) -> anyhow::Result<()> {
161        let mut runtime_guard = self.llm_runtime.try_write()
162            .map_err(|_| anyhow::anyhow!("Failed to acquire LLM runtime write lock"))?;
163
164        let runtime = LLMRuntime {
165            model_path: self.config.model_path.clone(),
166            context_size: self.config.ctx_size,
167            batch_size: self.config.batch_size,
168            threads: self.config.threads,
169            gpu_layers: self.config.gpu_layers,
170        };
171
172        *runtime_guard = Some(runtime);
173        info!("LLM runtime initialized");
174        Ok(())
175    }
176
177    /// Get or create session data with proper locking
178    pub async fn get_or_create_session(&self, session_id: &str) -> Arc<RwLock<SessionData>> {
179        // Fast path: try to get existing session
180        if let Some(session) = self.conversations.sessions.get(session_id) {
181            return session.clone();
182        }
183
184        // Slow path: create new session
185        let new_session = Arc::new(RwLock::new(SessionData {
186            session_id: session_id.to_string(),
187            messages: Vec::new(),
188            last_accessed: std::time::Instant::now(),
189            pinned: false,
190        }));
191
192        self.conversations.sessions.insert(session_id.to_string(), new_session.clone());
193        self.counters.active_sessions.fetch_add(1, Ordering::Relaxed);
194
195        new_session
196    }
197
198    /// Queue message for asynchronous processing
199    pub fn queue_message(&self, session_id: &str, message: crate::memory::Message) -> bool {
200        let queue = self.conversations.message_queues
201            .entry(session_id.to_string())
202            .or_insert_with(|| Arc::new(crossbeam_queue::ArrayQueue::new(1000)));
203
204        queue.push(PendingMessage {
205            message,
206            timestamp: std::time::Instant::now(),
207        }).is_ok()
208    }
209
210    /// Process queued messages for a session
211    pub async fn process_queued_messages(&self, session_id: &str) -> Vec<PendingMessage> {
212        let mut messages = Vec::new();
213
214        if let Some(queue) = self.conversations.message_queues.get(session_id) {
215            while let Some(msg) = queue.pop() {
216                messages.push(msg);
217            }
218        }
219
220        messages
221    }
222}
223
224impl ConversationHierarchy {
225    pub fn new() -> Self {
226        Self {
227            sessions: DashMap::new(),
228            message_queues: DashMap::new(),
229            counters: Arc::new(AtomicCounters::new()),
230        }
231    }
232}
233
234/// Unified application state for all API handlers.
235/// This is the single state type used by the Axum router, providing access
236/// to all subsystems through shared memory (Arc) rather than network hops.
237#[derive(Clone)]
238pub struct UnifiedAppState {
239    pub shared_state: Arc<SharedSystemState>,
240    pub context_orchestrator: Arc<tokio::sync::RwLock<Option<ContextOrchestrator>>>,
241    pub llm_worker: Arc<LLMWorker>,
242}
243
244impl UnifiedAppState {
245    pub fn new(shared_state: Arc<SharedSystemState>) -> Self {
246        let context_orchestrator = shared_state.context_orchestrator.clone();
247        let llm_worker = shared_state.llm_worker.clone();
248        Self {
249            shared_state,
250            context_orchestrator,
251            llm_worker,
252        }
253    }
254}
255
256// Re-exports for convenience
257pub use self::SharedSystemState as SharedState;