Skip to main content

embacle_server/
state.rs

1// ABOUTME: Server state holding default provider and lazily-created runner cache
2// ABOUTME: Stateless per-request routing with no mutable active provider or model
3//
4// SPDX-License-Identifier: Apache-2.0
5// Copyright (c) 2026 dravr.ai
6
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use embacle::config::CliRunnerType;
11use embacle::types::{LlmProvider, RunnerError};
12use tokio::sync::Mutex;
13
14use crate::runner::factory;
15
16/// Shared server state handle
17pub type SharedState = Arc<ServerState>;
18
19/// Server state with immutable default provider and lazy runner cache
20///
21/// Unlike the MCP server, there is no mutable active provider or model.
22/// All provider routing is determined per-request via the model string.
23/// The runner cache avoids recreating providers on every request.
24pub struct ServerState {
25    /// Default provider used when model string has no prefix
26    default_provider: CliRunnerType,
27    /// Lazily-created runners keyed by provider type
28    runners: Mutex<HashMap<CliRunnerType, Arc<dyn LlmProvider>>>,
29}
30
31impl ServerState {
32    /// Create server state with the given default provider
33    pub fn new(default_provider: CliRunnerType) -> Self {
34        Self {
35            default_provider,
36            runners: Mutex::new(HashMap::new()),
37        }
38    }
39
40    /// Get the server's default provider type
41    pub const fn default_provider(&self) -> CliRunnerType {
42        self.default_provider
43    }
44
45    /// Get or lazily create a runner for the given provider type
46    ///
47    /// Created runners are cached for future calls. The lock is released
48    /// during runner creation to avoid blocking concurrent requests.
49    pub async fn get_runner(
50        &self,
51        provider: CliRunnerType,
52    ) -> Result<Arc<dyn LlmProvider>, RunnerError> {
53        // Fast path: check cache under lock
54        {
55            let runners = self.runners.lock().await;
56            if let Some(runner) = runners.get(&provider) {
57                return Ok(Arc::clone(runner));
58            }
59        }
60
61        // Slow path: create runner without holding the lock
62        let runner = factory::create_runner(provider).await?;
63        let runner: Arc<dyn LlmProvider> = Arc::from(runner);
64
65        let runner = self
66            .runners
67            .lock()
68            .await
69            .entry(provider)
70            .or_insert_with(|| runner)
71            .clone();
72        Ok(runner)
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79
80    #[test]
81    fn default_provider_is_stored() {
82        let state = ServerState::new(CliRunnerType::Copilot);
83        assert_eq!(state.default_provider(), CliRunnerType::Copilot);
84    }
85
86    #[test]
87    fn default_provider_claude() {
88        let state = ServerState::new(CliRunnerType::ClaudeCode);
89        assert_eq!(state.default_provider(), CliRunnerType::ClaudeCode);
90    }
91}