1use std::collections::HashMap;
8use std::sync::Arc;
9
10use embacle::config::CliRunnerType;
11use embacle::types::{LlmProvider, RunnerError};
12use tokio::sync::Mutex;
13
14use crate::runner::factory;
15
16pub type SharedState = Arc<ServerState>;
18
19pub struct ServerState {
25 default_provider: CliRunnerType,
27 runners: Mutex<HashMap<CliRunnerType, Arc<dyn LlmProvider>>>,
29}
30
31impl ServerState {
32 pub fn new(default_provider: CliRunnerType) -> Self {
34 Self {
35 default_provider,
36 runners: Mutex::new(HashMap::new()),
37 }
38 }
39
40 pub const fn default_provider(&self) -> CliRunnerType {
42 self.default_provider
43 }
44
45 pub async fn get_runner(
49 &self,
50 provider: CliRunnerType,
51 ) -> Result<Arc<dyn LlmProvider>, RunnerError> {
52 let mut runners = self.runners.lock().await;
53
54 if let Some(runner) = runners.get(&provider) {
55 return Ok(Arc::clone(runner));
56 }
57
58 let runner = factory::create_runner(provider)?;
59 let runner: Arc<dyn LlmProvider> = Arc::from(runner);
60 runners.insert(provider, Arc::clone(&runner));
61 drop(runners);
62 Ok(runner)
63 }
64}
65
66#[cfg(test)]
67mod tests {
68 use super::*;
69
70 #[test]
71 fn default_provider_is_stored() {
72 let state = ServerState::new(CliRunnerType::Copilot);
73 assert_eq!(state.default_provider(), CliRunnerType::Copilot);
74 }
75
76 #[test]
77 fn default_provider_claude() {
78 let state = ServerState::new(CliRunnerType::ClaudeCode);
79 assert_eq!(state.default_provider(), CliRunnerType::ClaudeCode);
80 }
81}