Skip to main content

embacle_mcp/
state.rs

1// ABOUTME: Shared server state holding active provider, model, and multiplex configuration
2// ABOUTME: Thread-safe via Arc<RwLock> with lazy runner creation on first use
3//
4// SPDX-License-Identifier: Apache-2.0
5// Copyright (c) 2026 dravr.ai
6
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use embacle::config::CliRunnerType;
11use embacle::types::{LlmProvider, RunnerError};
12use tokio::sync::{Mutex, RwLock};
13
14use crate::runner::factory;
15
16/// Type alias for the shared state handle used across the server
17pub type SharedState = Arc<RwLock<ServerState>>;
18
19/// Central server state tracking provider configuration and cached runners
20///
21/// Runners are created lazily on first access and cached for reuse.
22/// The active provider and model determine how prompt dispatch behaves.
23pub struct ServerState {
24    active_provider: CliRunnerType,
25    active_model: Option<String>,
26    multiplex_providers: Vec<CliRunnerType>,
27    runners: Mutex<HashMap<CliRunnerType, Arc<dyn LlmProvider>>>,
28}
29
30impl ServerState {
31    /// Create server state with the given default provider
32    pub fn new(default_provider: CliRunnerType) -> Self {
33        Self {
34            active_provider: default_provider,
35            active_model: None,
36            multiplex_providers: Vec::new(),
37            runners: Mutex::new(HashMap::new()),
38        }
39    }
40
41    /// Get the currently active provider type
42    pub const fn active_provider(&self) -> CliRunnerType {
43        self.active_provider
44    }
45
46    /// Switch the active provider (resets the active model)
47    pub fn set_active_provider(&mut self, provider: CliRunnerType) {
48        self.active_provider = provider;
49        self.active_model = None;
50    }
51
52    /// Get the currently selected model (None means use provider default)
53    pub fn active_model(&self) -> Option<&str> {
54        self.active_model.as_deref()
55    }
56
57    /// Set the model to use for subsequent requests
58    pub fn set_active_model(&mut self, model: Option<String>) {
59        self.active_model = model;
60    }
61
62    /// Get the list of providers configured for multiplex dispatch
63    pub fn multiplex_providers(&self) -> &[CliRunnerType] {
64        &self.multiplex_providers
65    }
66
67    /// Set the providers used when multiplexing prompts
68    pub fn set_multiplex_providers(&mut self, providers: Vec<CliRunnerType>) {
69        self.multiplex_providers = providers;
70    }
71
72    /// Get or lazily create a runner for the given provider type
73    ///
74    /// Created runners are cached for future calls. The runner cache uses
75    /// interior mutability so callers only need `&self`.
76    pub async fn get_runner(
77        &self,
78        provider: CliRunnerType,
79    ) -> Result<Arc<dyn LlmProvider>, RunnerError> {
80        // Fast path: check cache under lock
81        {
82            let runners = self.runners.lock().await;
83            if let Some(runner) = runners.get(&provider) {
84                return Ok(Arc::clone(runner));
85            }
86        }
87
88        // Slow path: create runner without holding the lock
89        let runner = factory::create_runner(provider).await?;
90        let runner: Arc<dyn LlmProvider> = Arc::from(runner);
91
92        let runner = self
93            .runners
94            .lock()
95            .await
96            .entry(provider)
97            .or_insert_with(|| runner)
98            .clone();
99        Ok(runner)
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn default_state_uses_provided_provider() {
109        let state = ServerState::new(CliRunnerType::Copilot);
110        assert_eq!(state.active_provider(), CliRunnerType::Copilot);
111        assert!(state.active_model().is_none());
112        assert!(state.multiplex_providers().is_empty());
113    }
114
115    #[test]
116    fn set_provider_resets_model() {
117        let mut state = ServerState::new(CliRunnerType::Copilot);
118        state.set_active_model(Some("gpt-4o".to_owned()));
119        assert_eq!(state.active_model(), Some("gpt-4o"));
120
121        state.set_active_provider(CliRunnerType::ClaudeCode);
122        assert_eq!(state.active_provider(), CliRunnerType::ClaudeCode);
123        assert!(state.active_model().is_none());
124    }
125
126    #[test]
127    fn multiplex_providers_round_trip() {
128        let mut state = ServerState::new(CliRunnerType::Copilot);
129        let providers = vec![CliRunnerType::ClaudeCode, CliRunnerType::OpenCode];
130        state.set_multiplex_providers(providers.clone());
131        assert_eq!(state.multiplex_providers(), &providers);
132    }
133}