1use std::collections::HashMap;
8use std::sync::Arc;
9
10use embacle::config::CliRunnerType;
11use embacle::types::{LlmProvider, RunnerError};
12use tokio::sync::Mutex;
13
14use crate::runner::factory;
15
16pub type SharedState = Arc<ServerState>;
18
19pub struct ServerState {
25 default_provider: CliRunnerType,
27 runners: Mutex<HashMap<CliRunnerType, Arc<dyn LlmProvider>>>,
29}
30
31impl ServerState {
32 pub fn new(default_provider: CliRunnerType) -> Self {
34 Self {
35 default_provider,
36 runners: Mutex::new(HashMap::new()),
37 }
38 }
39
40 pub const fn default_provider(&self) -> CliRunnerType {
42 self.default_provider
43 }
44
45 pub async fn get_runner(
50 &self,
51 provider: CliRunnerType,
52 ) -> Result<Arc<dyn LlmProvider>, RunnerError> {
53 {
55 let runners = self.runners.lock().await;
56 if let Some(runner) = runners.get(&provider) {
57 return Ok(Arc::clone(runner));
58 }
59 }
60
61 let runner = factory::create_runner(provider).await?;
63 let runner: Arc<dyn LlmProvider> = Arc::from(runner);
64
65 let runner = self
66 .runners
67 .lock()
68 .await
69 .entry(provider)
70 .or_insert_with(|| runner)
71 .clone();
72 Ok(runner)
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79
80 #[test]
81 fn default_provider_is_stored() {
82 let state = ServerState::new(CliRunnerType::Copilot);
83 assert_eq!(state.default_provider(), CliRunnerType::Copilot);
84 }
85
86 #[test]
87 fn default_provider_claude() {
88 let state = ServerState::new(CliRunnerType::ClaudeCode);
89 assert_eq!(state.default_provider(), CliRunnerType::ClaudeCode);
90 }
91}