1use std::collections::HashMap;
8use std::sync::Arc;
9
10use embacle::config::CliRunnerType;
11use embacle::types::{LlmProvider, RunnerError};
12use tokio::sync::Mutex;
13
14use crate::runner::factory;
15
16pub type SharedState = Arc<ServerState>;
18
19pub struct ServerState {
25 default_provider: CliRunnerType,
27 runners: Mutex<HashMap<CliRunnerType, Arc<dyn LlmProvider>>>,
29}
30
31impl ServerState {
32 pub fn new(default_provider: CliRunnerType) -> Self {
34 Self {
35 default_provider,
36 runners: Mutex::new(HashMap::new()),
37 }
38 }
39
40 pub const fn default_provider(&self) -> CliRunnerType {
42 self.default_provider
43 }
44
45 pub async fn get_runner(
50 &self,
51 provider: CliRunnerType,
52 ) -> Result<Arc<dyn LlmProvider>, RunnerError> {
53 {
55 let runners = self.runners.lock().await;
56 if let Some(runner) = runners.get(&provider) {
57 return Ok(Arc::clone(runner));
58 }
59 }
60
61 let runner = factory::create_runner(provider).await?;
63 let runner: Arc<dyn LlmProvider> = Arc::from(runner);
64
65 let mut runners = self.runners.lock().await;
66 let runner = runners.entry(provider).or_insert_with(|| runner).clone();
68 Ok(runner)
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75
76 #[test]
77 fn default_provider_is_stored() {
78 let state = ServerState::new(CliRunnerType::Copilot);
79 assert_eq!(state.default_provider(), CliRunnerType::Copilot);
80 }
81
82 #[test]
83 fn default_provider_claude() {
84 let state = ServerState::new(CliRunnerType::ClaudeCode);
85 assert_eq!(state.default_provider(), CliRunnerType::ClaudeCode);
86 }
87}