Skip to main content

embacle_server/
state.rs

1// ABOUTME: Server state holding default provider and lazily-created runner cache
2// ABOUTME: Stateless per-request routing with no mutable active provider or model
3//
4// SPDX-License-Identifier: Apache-2.0
5// Copyright (c) 2026 dravr.ai
6
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use embacle::config::CliRunnerType;
11use embacle::types::{LlmProvider, RunnerError};
12use tokio::sync::Mutex;
13
14use crate::runner::factory;
15
16/// Shared server state handle
17pub type SharedState = Arc<ServerState>;
18
19/// Server state with immutable default provider and lazy runner cache
20///
21/// Unlike the MCP server, there is no mutable active provider or model.
22/// All provider routing is determined per-request via the model string.
23/// The runner cache avoids recreating providers on every request.
24pub struct ServerState {
25    /// Default provider used when model string has no prefix
26    default_provider: CliRunnerType,
27    /// Lazily-created runners keyed by provider type
28    runners: Mutex<HashMap<CliRunnerType, Arc<dyn LlmProvider>>>,
29}
30
31impl ServerState {
32    /// Create server state with the given default provider
33    pub fn new(default_provider: CliRunnerType) -> Self {
34        Self {
35            default_provider,
36            runners: Mutex::new(HashMap::new()),
37        }
38    }
39
40    /// Get the server's default provider type
41    pub const fn default_provider(&self) -> CliRunnerType {
42        self.default_provider
43    }
44
45    /// Get or lazily create a runner for the given provider type
46    ///
47    /// Created runners are cached for future calls. The lock is released
48    /// during runner creation to avoid blocking concurrent requests.
49    pub async fn get_runner(
50        &self,
51        provider: CliRunnerType,
52    ) -> Result<Arc<dyn LlmProvider>, RunnerError> {
53        // Fast path: check cache under lock
54        {
55            let runners = self.runners.lock().await;
56            if let Some(runner) = runners.get(&provider) {
57                return Ok(Arc::clone(runner));
58            }
59        }
60
61        // Slow path: create runner without holding the lock
62        let runner = factory::create_runner(provider).await?;
63        let runner: Arc<dyn LlmProvider> = Arc::from(runner);
64
65        let mut runners = self.runners.lock().await;
66        // Another request may have created the runner while we were waiting
67        let runner = runners.entry(provider).or_insert_with(|| runner).clone();
68        Ok(runner)
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75
76    #[test]
77    fn default_provider_is_stored() {
78        let state = ServerState::new(CliRunnerType::Copilot);
79        assert_eq!(state.default_provider(), CliRunnerType::Copilot);
80    }
81
82    #[test]
83    fn default_provider_claude() {
84        let state = ServerState::new(CliRunnerType::ClaudeCode);
85        assert_eq!(state.default_provider(), CliRunnerType::ClaudeCode);
86    }
87}