Skip to main content

embacle_server/
state.rs

1// ABOUTME: Server state holding default provider and lazily-created runner cache
2// ABOUTME: Stateless per-request routing with no mutable active provider or model
3//
4// SPDX-License-Identifier: Apache-2.0
5// Copyright (c) 2026 dravr.ai
6
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use embacle::config::CliRunnerType;
11use embacle::types::{LlmProvider, RunnerError};
12use tokio::sync::Mutex;
13
14use crate::runner::factory;
15
16/// Shared server state handle
17pub type SharedState = Arc<ServerState>;
18
19/// Server state with immutable default provider and lazy runner cache
20///
21/// Unlike the MCP server, there is no mutable active provider or model.
22/// All provider routing is determined per-request via the model string.
23/// The runner cache avoids recreating providers on every request.
24pub struct ServerState {
25    /// Default provider used when model string has no prefix
26    default_provider: CliRunnerType,
27    /// Lazily-created runners keyed by provider type
28    runners: Mutex<HashMap<CliRunnerType, Arc<dyn LlmProvider>>>,
29}
30
31impl ServerState {
32    /// Create server state with the given default provider
33    pub fn new(default_provider: CliRunnerType) -> Self {
34        Self {
35            default_provider,
36            runners: Mutex::new(HashMap::new()),
37        }
38    }
39
40    /// Get the server's default provider type
41    pub const fn default_provider(&self) -> CliRunnerType {
42        self.default_provider
43    }
44
45    /// Get or lazily create a runner for the given provider type
46    ///
47    /// Created runners are cached for future calls.
48    pub async fn get_runner(
49        &self,
50        provider: CliRunnerType,
51    ) -> Result<Arc<dyn LlmProvider>, RunnerError> {
52        let mut runners = self.runners.lock().await;
53
54        if let Some(runner) = runners.get(&provider) {
55            return Ok(Arc::clone(runner));
56        }
57
58        let runner = factory::create_runner(provider)?;
59        let runner: Arc<dyn LlmProvider> = Arc::from(runner);
60        runners.insert(provider, Arc::clone(&runner));
61        drop(runners);
62        Ok(runner)
63    }
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69
70    #[test]
71    fn default_provider_is_stored() {
72        let state = ServerState::new(CliRunnerType::Copilot);
73        assert_eq!(state.default_provider(), CliRunnerType::Copilot);
74    }
75
76    #[test]
77    fn default_provider_claude() {
78        let state = ServerState::new(CliRunnerType::ClaudeCode);
79        assert_eq!(state.default_provider(), CliRunnerType::ClaudeCode);
80    }
81}