1use std::collections::HashMap;
8use std::sync::Arc;
9
10use embacle::config::CliRunnerType;
11use embacle::types::{LlmProvider, RunnerError};
12use tokio::sync::{Mutex, RwLock};
13
14use crate::runner::factory;
15
16pub type SharedState = Arc<RwLock<ServerState>>;
18
19pub struct ServerState {
24 active_provider: CliRunnerType,
25 active_model: Option<String>,
26 multiplex_providers: Vec<CliRunnerType>,
27 runners: Mutex<HashMap<CliRunnerType, Arc<dyn LlmProvider>>>,
28}
29
30impl ServerState {
31 pub fn new(default_provider: CliRunnerType) -> Self {
33 Self {
34 active_provider: default_provider,
35 active_model: None,
36 multiplex_providers: Vec::new(),
37 runners: Mutex::new(HashMap::new()),
38 }
39 }
40
41 pub const fn active_provider(&self) -> CliRunnerType {
43 self.active_provider
44 }
45
46 pub fn set_active_provider(&mut self, provider: CliRunnerType) {
48 self.active_provider = provider;
49 self.active_model = None;
50 }
51
52 pub fn active_model(&self) -> Option<&str> {
54 self.active_model.as_deref()
55 }
56
57 pub fn set_active_model(&mut self, model: Option<String>) {
59 self.active_model = model;
60 }
61
62 pub fn multiplex_providers(&self) -> &[CliRunnerType] {
64 &self.multiplex_providers
65 }
66
67 pub fn set_multiplex_providers(&mut self, providers: Vec<CliRunnerType>) {
69 self.multiplex_providers = providers;
70 }
71
72 pub async fn get_runner(
77 &self,
78 provider: CliRunnerType,
79 ) -> Result<Arc<dyn LlmProvider>, RunnerError> {
80 {
82 let runners = self.runners.lock().await;
83 if let Some(runner) = runners.get(&provider) {
84 return Ok(Arc::clone(runner));
85 }
86 }
87
88 let runner = factory::create_runner(provider).await?;
90 let runner: Arc<dyn LlmProvider> = Arc::from(runner);
91
92 let runner = self
93 .runners
94 .lock()
95 .await
96 .entry(provider)
97 .or_insert_with(|| runner)
98 .clone();
99 Ok(runner)
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106
107 #[test]
108 fn default_state_uses_provided_provider() {
109 let state = ServerState::new(CliRunnerType::Copilot);
110 assert_eq!(state.active_provider(), CliRunnerType::Copilot);
111 assert!(state.active_model().is_none());
112 assert!(state.multiplex_providers().is_empty());
113 }
114
115 #[test]
116 fn set_provider_resets_model() {
117 let mut state = ServerState::new(CliRunnerType::Copilot);
118 state.set_active_model(Some("gpt-4o".to_owned()));
119 assert_eq!(state.active_model(), Some("gpt-4o"));
120
121 state.set_active_provider(CliRunnerType::ClaudeCode);
122 assert_eq!(state.active_provider(), CliRunnerType::ClaudeCode);
123 assert!(state.active_model().is_none());
124 }
125
126 #[test]
127 fn multiplex_providers_round_trip() {
128 let mut state = ServerState::new(CliRunnerType::Copilot);
129 let providers = vec![CliRunnerType::ClaudeCode, CliRunnerType::OpenCode];
130 state.set_multiplex_providers(providers.clone());
131 assert_eq!(state.multiplex_providers(), &providers);
132 }
133}