1#![allow(dead_code)]
6
7use serde::{Deserialize, Serialize};
8use std::path::PathBuf;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12#[serde(rename_all = "kebab-case")]
13pub enum ProviderType {
14 Copilot,
19 Cursor,
21
22 Ollama,
27 Vllm,
29 Foundry,
31 LmStudio,
33 #[serde(rename = "localai")]
35 LocalAI,
36 #[serde(rename = "text-gen-webui")]
38 TextGenWebUI,
39 Jan,
41 #[serde(rename = "gpt4all")]
43 Gpt4All,
44 Llamafile,
46
47 #[serde(rename = "m365copilot")]
52 M365Copilot,
53 #[serde(rename = "chatgpt")]
55 ChatGPT,
56 #[serde(rename = "openai")]
58 OpenAI,
59 #[serde(rename = "anthropic")]
61 Anthropic,
62 #[serde(rename = "perplexity")]
64 Perplexity,
65 #[serde(rename = "deepseek")]
67 DeepSeek,
68 #[serde(rename = "qwen")]
70 Qwen,
71 #[serde(rename = "gemini")]
73 Gemini,
74 #[serde(rename = "mistral")]
76 Mistral,
77 #[serde(rename = "cohere")]
79 Cohere,
80 #[serde(rename = "grok")]
82 Grok,
83 #[serde(rename = "groq")]
85 Groq,
86 #[serde(rename = "together")]
88 Together,
89 #[serde(rename = "fireworks")]
91 Fireworks,
92 #[serde(rename = "replicate")]
94 Replicate,
95 #[serde(rename = "huggingface")]
97 HuggingFace,
98
99 Custom,
101}
102
103impl ProviderType {
104 pub fn display_name(&self) -> &'static str {
106 match self {
107 Self::Copilot => "GitHub Copilot",
109 Self::Cursor => "Cursor",
110 Self::Ollama => "Ollama",
112 Self::Vllm => "vLLM",
113 Self::Foundry => "Azure AI Foundry",
114 Self::LmStudio => "LM Studio",
115 Self::LocalAI => "LocalAI",
116 Self::TextGenWebUI => "Text Generation WebUI",
117 Self::Jan => "Jan.ai",
118 Self::Gpt4All => "GPT4All",
119 Self::Llamafile => "Llamafile",
120 Self::M365Copilot => "Microsoft 365 Copilot",
122 Self::ChatGPT => "ChatGPT",
123 Self::OpenAI => "OpenAI API",
124 Self::Anthropic => "Anthropic Claude",
125 Self::Perplexity => "Perplexity AI",
126 Self::DeepSeek => "DeepSeek",
127 Self::Qwen => "Qwen (Alibaba)",
128 Self::Gemini => "Google Gemini",
129 Self::Mistral => "Mistral AI",
130 Self::Cohere => "Cohere",
131 Self::Grok => "xAI Grok",
132 Self::Groq => "Groq",
133 Self::Together => "Together AI",
134 Self::Fireworks => "Fireworks AI",
135 Self::Replicate => "Replicate",
136 Self::HuggingFace => "HuggingFace",
137 Self::Custom => "Custom",
138 }
139 }
140
141 pub fn default_endpoint(&self) -> Option<&'static str> {
143 match self {
144 Self::Copilot => None,
146 Self::Cursor => None,
147 Self::Ollama => Some("http://localhost:11434"),
149 Self::Vllm => Some("http://localhost:8000"),
150 Self::Foundry => Some("http://localhost:5272"),
151 Self::LmStudio => Some("http://localhost:1234/v1"),
152 Self::LocalAI => Some("http://localhost:8080/v1"),
153 Self::TextGenWebUI => Some("http://localhost:5000/v1"),
154 Self::Jan => Some("http://localhost:1337/v1"),
155 Self::Gpt4All => Some("http://localhost:4891/v1"),
156 Self::Llamafile => Some("http://localhost:8080/v1"),
157 Self::M365Copilot => Some("https://graph.microsoft.com/v1.0"),
159 Self::ChatGPT => Some("https://chat.openai.com"),
160 Self::OpenAI => Some("https://api.openai.com/v1"),
161 Self::Anthropic => Some("https://api.anthropic.com/v1"),
162 Self::Perplexity => Some("https://api.perplexity.ai"),
163 Self::DeepSeek => Some("https://api.deepseek.com/v1"),
164 Self::Qwen => Some("https://dashscope.aliyuncs.com/api/v1"),
165 Self::Gemini => Some("https://generativelanguage.googleapis.com/v1beta"),
166 Self::Mistral => Some("https://api.mistral.ai/v1"),
167 Self::Cohere => Some("https://api.cohere.ai/v1"),
168 Self::Grok => Some("https://api.x.ai/v1"),
169 Self::Groq => Some("https://api.groq.com/openai/v1"),
170 Self::Together => Some("https://api.together.xyz/v1"),
171 Self::Fireworks => Some("https://api.fireworks.ai/inference/v1"),
172 Self::Replicate => Some("https://api.replicate.com/v1"),
173 Self::HuggingFace => Some("https://api-inference.huggingface.co"),
174 Self::Custom => None,
175 }
176 }
177
178 pub fn uses_file_storage(&self) -> bool {
180 matches!(self, Self::Copilot | Self::Cursor)
181 }
182
183 pub fn is_cloud_provider(&self) -> bool {
185 matches!(
186 self,
187 Self::M365Copilot
188 | Self::ChatGPT
189 | Self::OpenAI
190 | Self::Anthropic
191 | Self::Perplexity
192 | Self::DeepSeek
193 | Self::Qwen
194 | Self::Gemini
195 | Self::Mistral
196 | Self::Cohere
197 | Self::Grok
198 | Self::Groq
199 | Self::Together
200 | Self::Fireworks
201 | Self::Replicate
202 | Self::HuggingFace
203 )
204 }
205
206 pub fn is_openai_compatible(&self) -> bool {
208 matches!(
209 self,
210 Self::Ollama
211 | Self::Vllm
212 | Self::Foundry
213 | Self::OpenAI
214 | Self::LmStudio
215 | Self::LocalAI
216 | Self::TextGenWebUI
217 | Self::Jan
218 | Self::Gpt4All
219 | Self::Llamafile
220 | Self::DeepSeek | Self::Groq | Self::Together | Self::Fireworks | Self::Custom
225 )
226 }
227
228 pub fn requires_api_key(&self) -> bool {
230 self.is_cloud_provider()
231 }
232}
233
234impl std::fmt::Display for ProviderType {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 write!(f, "{}", self.display_name())
237 }
238}
239
240#[derive(Debug, Clone, Serialize, Deserialize)]
242pub struct ProviderConfig {
243 pub provider_type: ProviderType,
245
246 #[serde(default = "default_true")]
248 pub enabled: bool,
249
250 pub endpoint: Option<String>,
252
253 pub api_key: Option<String>,
255
256 pub model: Option<String>,
258
259 pub name: Option<String>,
261
262 pub storage_path: Option<PathBuf>,
264
265 #[serde(default)]
267 pub extra: std::collections::HashMap<String, serde_json::Value>,
268}
269
270fn default_true() -> bool {
271 true
272}
273
274impl ProviderConfig {
275 pub fn new(provider_type: ProviderType) -> Self {
277 Self {
278 provider_type,
279 enabled: true,
280 endpoint: provider_type.default_endpoint().map(String::from),
281 api_key: None,
282 model: None,
283 name: None,
284 storage_path: None,
285 extra: std::collections::HashMap::new(),
286 }
287 }
288
289 pub fn display_name(&self) -> String {
291 self.name
292 .clone()
293 .unwrap_or_else(|| self.provider_type.display_name().to_string())
294 }
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize)]
299pub struct CsmConfig {
300 #[serde(default)]
302 pub providers: Vec<ProviderConfig>,
303
304 pub default_provider: Option<ProviderType>,
306
307 #[serde(default = "default_true")]
309 pub auto_discover: bool,
310}
311
312impl Default for CsmConfig {
313 fn default() -> Self {
314 Self {
315 providers: Vec::new(),
316 default_provider: None,
317 auto_discover: true, }
319 }
320}
321
322impl CsmConfig {
323 pub fn load() -> anyhow::Result<Self> {
325 let config_path = Self::config_path()?;
326
327 if config_path.exists() {
328 let content = std::fs::read_to_string(&config_path)?;
329 let config: Self = serde_json::from_str(&content)?;
330 Ok(config)
331 } else {
332 Ok(Self::default())
333 }
334 }
335
336 pub fn save(&self) -> anyhow::Result<()> {
338 let config_path = Self::config_path()?;
339
340 if let Some(parent) = config_path.parent() {
341 std::fs::create_dir_all(parent)?;
342 }
343
344 let content = serde_json::to_string_pretty(self)?;
345 std::fs::write(&config_path, content)?;
346 Ok(())
347 }
348
349 pub fn config_path() -> anyhow::Result<PathBuf> {
351 let config_dir =
352 dirs::config_dir().ok_or_else(|| anyhow::anyhow!("Could not find config directory"))?;
353 Ok(config_dir.join("csm").join("config.json"))
354 }
355
356 pub fn get_provider(&self, provider_type: ProviderType) -> Option<&ProviderConfig> {
358 self.providers
359 .iter()
360 .find(|p| p.provider_type == provider_type)
361 }
362
363 pub fn set_provider(&mut self, config: ProviderConfig) {
365 if let Some(existing) = self
366 .providers
367 .iter_mut()
368 .find(|p| p.provider_type == config.provider_type)
369 {
370 *existing = config;
371 } else {
372 self.providers.push(config);
373 }
374 }
375}