1#![allow(dead_code)]
6
7use serde::{Deserialize, Serialize};
8use std::path::PathBuf;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12#[serde(rename_all = "kebab-case")]
13pub enum ProviderType {
14 Copilot,
19 Cursor,
21 #[serde(rename = "continuedev")]
23 ContinueDev,
24
25 Ollama,
30 Vllm,
32 Foundry,
34 LmStudio,
36 #[serde(rename = "localai")]
38 LocalAI,
39 #[serde(rename = "text-gen-webui")]
41 TextGenWebUI,
42 Jan,
44 #[serde(rename = "gpt4all")]
46 Gpt4All,
47 Llamafile,
49
50 #[serde(rename = "m365copilot")]
55 M365Copilot,
56 #[serde(rename = "chatgpt")]
58 ChatGPT,
59 #[serde(rename = "openai")]
61 OpenAI,
62 #[serde(rename = "anthropic")]
64 Anthropic,
65 #[serde(rename = "perplexity")]
67 Perplexity,
68 #[serde(rename = "deepseek")]
70 DeepSeek,
71 #[serde(rename = "qwen")]
73 Qwen,
74 #[serde(rename = "gemini")]
76 Gemini,
77 #[serde(rename = "mistral")]
79 Mistral,
80 #[serde(rename = "cohere")]
82 Cohere,
83 #[serde(rename = "grok")]
85 Grok,
86 #[serde(rename = "groq")]
88 Groq,
89 #[serde(rename = "together")]
91 Together,
92 #[serde(rename = "fireworks")]
94 Fireworks,
95 #[serde(rename = "replicate")]
97 Replicate,
98 #[serde(rename = "huggingface")]
100 HuggingFace,
101
102 Custom,
104}
105
106impl ProviderType {
107 pub fn display_name(&self) -> &'static str {
109 match self {
110 Self::Copilot => "GitHub Copilot",
112 Self::Cursor => "Cursor",
113 Self::ContinueDev => "Continue.dev",
114 Self::Ollama => "Ollama",
116 Self::Vllm => "vLLM",
117 Self::Foundry => "Azure AI Foundry",
118 Self::LmStudio => "LM Studio",
119 Self::LocalAI => "LocalAI",
120 Self::TextGenWebUI => "Text Generation WebUI",
121 Self::Jan => "Jan.ai",
122 Self::Gpt4All => "GPT4All",
123 Self::Llamafile => "Llamafile",
124 Self::M365Copilot => "Microsoft 365 Copilot",
126 Self::ChatGPT => "ChatGPT",
127 Self::OpenAI => "OpenAI API",
128 Self::Anthropic => "Anthropic Claude",
129 Self::Perplexity => "Perplexity AI",
130 Self::DeepSeek => "DeepSeek",
131 Self::Qwen => "Qwen (Alibaba)",
132 Self::Gemini => "Google Gemini",
133 Self::Mistral => "Mistral AI",
134 Self::Cohere => "Cohere",
135 Self::Grok => "xAI Grok",
136 Self::Groq => "Groq",
137 Self::Together => "Together AI",
138 Self::Fireworks => "Fireworks AI",
139 Self::Replicate => "Replicate",
140 Self::HuggingFace => "HuggingFace",
141 Self::Custom => "Custom",
142 }
143 }
144
145 pub fn default_endpoint(&self) -> Option<&'static str> {
147 match self {
148 Self::Copilot => None,
150 Self::Cursor => None,
151 Self::ContinueDev => None,
152 Self::Ollama => Some("http://localhost:11434"),
154 Self::Vllm => Some("http://localhost:8000"),
155 Self::Foundry => Some("http://localhost:5272"),
156 Self::LmStudio => Some("http://localhost:1234/v1"),
157 Self::LocalAI => Some("http://localhost:8080/v1"),
158 Self::TextGenWebUI => Some("http://localhost:5000/v1"),
159 Self::Jan => Some("http://localhost:1337/v1"),
160 Self::Gpt4All => Some("http://localhost:4891/v1"),
161 Self::Llamafile => Some("http://localhost:8080/v1"),
162 Self::M365Copilot => Some("https://graph.microsoft.com/v1.0"),
164 Self::ChatGPT => Some("https://chat.openai.com"),
165 Self::OpenAI => Some("https://api.openai.com/v1"),
166 Self::Anthropic => Some("https://api.anthropic.com/v1"),
167 Self::Perplexity => Some("https://api.perplexity.ai"),
168 Self::DeepSeek => Some("https://api.deepseek.com/v1"),
169 Self::Qwen => Some("https://dashscope.aliyuncs.com/api/v1"),
170 Self::Gemini => Some("https://generativelanguage.googleapis.com/v1beta"),
171 Self::Mistral => Some("https://api.mistral.ai/v1"),
172 Self::Cohere => Some("https://api.cohere.ai/v1"),
173 Self::Grok => Some("https://api.x.ai/v1"),
174 Self::Groq => Some("https://api.groq.com/openai/v1"),
175 Self::Together => Some("https://api.together.xyz/v1"),
176 Self::Fireworks => Some("https://api.fireworks.ai/inference/v1"),
177 Self::Replicate => Some("https://api.replicate.com/v1"),
178 Self::HuggingFace => Some("https://api-inference.huggingface.co"),
179 Self::Custom => None,
180 }
181 }
182
183 pub fn uses_file_storage(&self) -> bool {
185 matches!(self, Self::Copilot | Self::Cursor | Self::ContinueDev)
186 }
187
188 pub fn is_cloud_provider(&self) -> bool {
190 matches!(
191 self,
192 Self::M365Copilot
193 | Self::ChatGPT
194 | Self::OpenAI
195 | Self::Anthropic
196 | Self::Perplexity
197 | Self::DeepSeek
198 | Self::Qwen
199 | Self::Gemini
200 | Self::Mistral
201 | Self::Cohere
202 | Self::Grok
203 | Self::Groq
204 | Self::Together
205 | Self::Fireworks
206 | Self::Replicate
207 | Self::HuggingFace
208 )
209 }
210
211 pub fn is_openai_compatible(&self) -> bool {
213 matches!(
214 self,
215 Self::Ollama
216 | Self::Vllm
217 | Self::Foundry
218 | Self::OpenAI
219 | Self::LmStudio
220 | Self::LocalAI
221 | Self::TextGenWebUI
222 | Self::Jan
223 | Self::Gpt4All
224 | Self::Llamafile
225 | Self::DeepSeek | Self::Groq | Self::Together | Self::Fireworks | Self::Custom
230 )
231 }
232
233 pub fn requires_api_key(&self) -> bool {
235 self.is_cloud_provider()
236 }
237}
238
239impl std::fmt::Display for ProviderType {
240 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241 write!(f, "{}", self.display_name())
242 }
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize)]
247pub struct ProviderConfig {
248 pub provider_type: ProviderType,
250
251 #[serde(default = "default_true")]
253 pub enabled: bool,
254
255 pub endpoint: Option<String>,
257
258 pub api_key: Option<String>,
260
261 pub model: Option<String>,
263
264 pub name: Option<String>,
266
267 pub storage_path: Option<PathBuf>,
269
270 #[serde(default)]
272 pub extra: std::collections::HashMap<String, serde_json::Value>,
273}
274
275fn default_true() -> bool {
276 true
277}
278
279impl ProviderConfig {
280 pub fn new(provider_type: ProviderType) -> Self {
282 Self {
283 provider_type,
284 enabled: true,
285 endpoint: provider_type.default_endpoint().map(String::from),
286 api_key: None,
287 model: None,
288 name: None,
289 storage_path: None,
290 extra: std::collections::HashMap::new(),
291 }
292 }
293
294 pub fn display_name(&self) -> String {
296 self.name
297 .clone()
298 .unwrap_or_else(|| self.provider_type.display_name().to_string())
299 }
300}
301
302#[derive(Debug, Clone, Serialize, Deserialize)]
304pub struct CsmConfig {
305 #[serde(default)]
307 pub providers: Vec<ProviderConfig>,
308
309 pub default_provider: Option<ProviderType>,
311
312 #[serde(default = "default_true")]
314 pub auto_discover: bool,
315}
316
317impl Default for CsmConfig {
318 fn default() -> Self {
319 Self {
320 providers: Vec::new(),
321 default_provider: None,
322 auto_discover: true, }
324 }
325}
326
327impl CsmConfig {
328 pub fn load() -> anyhow::Result<Self> {
330 let config_path = Self::config_path()?;
331
332 if config_path.exists() {
333 let content = std::fs::read_to_string(&config_path)?;
334 let config: Self = serde_json::from_str(&content)?;
335 Ok(config)
336 } else {
337 Ok(Self::default())
338 }
339 }
340
341 pub fn save(&self) -> anyhow::Result<()> {
343 let config_path = Self::config_path()?;
344
345 if let Some(parent) = config_path.parent() {
346 std::fs::create_dir_all(parent)?;
347 }
348
349 let content = serde_json::to_string_pretty(self)?;
350 std::fs::write(&config_path, content)?;
351 Ok(())
352 }
353
354 pub fn config_path() -> anyhow::Result<PathBuf> {
356 let config_dir =
357 dirs::config_dir().ok_or_else(|| anyhow::anyhow!("Could not find config directory"))?;
358 Ok(config_dir.join("csm").join("config.json"))
359 }
360
361 pub fn get_provider(&self, provider_type: ProviderType) -> Option<&ProviderConfig> {
363 self.providers
364 .iter()
365 .find(|p| p.provider_type == provider_type)
366 }
367
368 pub fn set_provider(&mut self, config: ProviderConfig) {
370 if let Some(existing) = self
371 .providers
372 .iter_mut()
373 .find(|p| p.provider_type == config.provider_type)
374 {
375 *existing = config;
376 } else {
377 self.providers.push(config);
378 }
379 }
380}