1use reqwest::{
16 header::{HeaderMap, HeaderValue, ACCEPT, AUTHORIZATION, USER_AGENT},
17 Client as HttpClient,
18};
19use serde::{Deserialize, Serialize};
20use serde_json::Value;
21use std::{env, error::Error, fmt, fs, path::Path};
22
23#[derive(Debug)]
25pub enum CopilotError {
26 InvalidModel(String),
28 TokenError(String),
30 HttpError(String),
32 Other(String),
34}
35
36impl fmt::Display for CopilotError {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 match self {
39 CopilotError::InvalidModel(model) => {
40 write!(f, "Invalid model specified: {model}")
41 }
42 CopilotError::TokenError(msg) => write!(f, "Token error: {msg}"),
43 CopilotError::HttpError(msg) => write!(f, "HTTP error: {msg}"),
44 CopilotError::Other(msg) => write!(f, "{msg}"),
45 }
46 }
47}
48
49impl Error for CopilotError {}
50
51#[derive(Debug, Serialize, Deserialize)]
55pub struct CopilotTokenResponse {
56 pub token: String,
58 pub expires_at: u64,
60}
61
62#[derive(Debug, Serialize, Deserialize)]
64pub struct Agent {
65 pub id: String,
67 pub name: String,
69 pub description: Option<String>,
71}
72
73#[derive(Debug, Serialize, Deserialize)]
75pub struct AgentsResponse {
76 pub agents: Vec<Agent>,
78}
79
80#[derive(Debug, Serialize, Deserialize)]
82pub struct Model {
83 pub id: String,
85 pub name: String,
87 pub version: Option<String>,
89 pub tokenizer: Option<String>,
91 pub max_input_tokens: Option<u32>,
93 pub max_output_tokens: Option<u32>,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
99pub struct ModelsResponse {
100 pub data: Vec<Model>,
102}
103
104#[derive(Debug, Serialize, Deserialize)]
108pub struct Message {
109 pub role: String,
111 pub content: String,
113}
114
115#[derive(Debug, Serialize, Deserialize)]
117pub struct ChatRequest {
118 pub model: String,
120 pub messages: Vec<Message>,
122 pub n: u32,
124 pub top_p: f64,
126 pub stream: bool,
128 pub temperature: f64,
130 #[serde(skip_serializing_if = "Option::is_none")]
132 pub max_tokens: Option<u32>,
133}
134
135#[derive(Debug, Serialize, Deserialize)]
137pub struct ChatChoice {
138 pub message: Message,
140 pub finish_reason: Option<String>,
142 pub usage: Option<TokenUsage>,
144}
145
146#[derive(Debug, Serialize, Deserialize)]
148pub struct TokenUsage {
149 pub total_tokens: u32,
151}
152
153#[derive(Debug, Serialize, Deserialize)]
155pub struct ChatResponse {
156 pub choices: Vec<ChatChoice>,
158}
159
160#[derive(Debug, Serialize, Deserialize)]
162pub struct EmbeddingRequest {
163 pub dimensions: u32,
165 pub input: Vec<String>,
167 pub model: String,
169}
170
171#[derive(Debug, Serialize, Deserialize)]
173pub struct Embedding {
174 pub index: usize,
176 pub embedding: Vec<f64>,
178}
179
180#[derive(Debug, Serialize, Deserialize)]
182pub struct EmbeddingResponse {
183 pub data: Vec<Embedding>,
185}
186
187pub struct CopilotClient {
192 http_client: HttpClient,
193 github_token: String,
194 editor_version: String,
195 models: Vec<Model>,
197}
198
199impl CopilotClient {
200 pub async fn from_env_with_models(editor_version: String) -> Result<Self, CopilotError> {
211 let github_token =
212 get_github_token().map_err(|e| CopilotError::TokenError(e.to_string()))?;
213 Self::new_with_models(github_token, editor_version).await
214 }
215
216 pub async fn new_with_models(
228 github_token: String,
229 editor_version: String,
230 ) -> Result<Self, CopilotError> {
231 let http_client = HttpClient::new();
232 let mut client = CopilotClient {
233 http_client,
234 github_token,
235 editor_version,
236 models: Vec::new(),
237 };
238 let models = client.get_models().await?;
240 client.models = models;
241 Ok(client)
242 }
243
244 async fn get_headers(&self) -> Result<HeaderMap, CopilotError> {
249 let token = self.get_copilot_token().await?;
250 let mut headers = HeaderMap::new();
251 headers.insert(
252 AUTHORIZATION,
253 HeaderValue::from_str(&format!("Bearer {token}"))
254 .map_err(|e| CopilotError::Other(e.to_string()))?,
255 );
256 headers.insert(
257 "Editor-Version",
258 HeaderValue::from_str(&self.editor_version)
259 .map_err(|e| CopilotError::Other(e.to_string()))?,
260 );
261 headers.insert(
262 "Editor-Plugin-Version",
263 HeaderValue::from_static("CopilotChat.nvim/*"),
264 );
265 headers.insert(
266 "Copilot-Integration-Id",
267 HeaderValue::from_static("vscode-chat"),
268 );
269 headers.insert(USER_AGENT, HeaderValue::from_static("CopilotChat.nvim"));
270 headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
271 Ok(headers)
272 }
273
274 async fn get_copilot_token(&self) -> Result<String, CopilotError> {
280 let url = "https://api.github.com/copilot_internal/v2/token";
281 let mut headers = HeaderMap::new();
282 headers.insert(USER_AGENT, HeaderValue::from_static("CopilotChat.nvim"));
283 headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
284 headers.insert(
285 "Authorization",
286 HeaderValue::from_str(&format!("Token {}", self.github_token))
287 .map_err(|e| CopilotError::Other(e.to_string()))?,
288 );
289 let res = self
290 .http_client
291 .get(url)
292 .headers(headers)
293 .send()
294 .await
295 .map_err(|e| CopilotError::HttpError(e.to_string()))?
296 .error_for_status()
297 .map_err(|e| CopilotError::HttpError(e.to_string()))?;
298 let token_response: CopilotTokenResponse = res
299 .json()
300 .await
301 .map_err(|e| CopilotError::Other(e.to_string()))?;
302 Ok(token_response.token)
303 }
304
305 pub async fn get_agents(&self) -> Result<Vec<Agent>, CopilotError> {
311 let url = "https://api.githubcopilot.com/agents";
312 let headers = self.get_headers().await?;
313 let res = self
314 .http_client
315 .get(url)
316 .headers(headers)
317 .send()
318 .await
319 .map_err(|e| CopilotError::HttpError(e.to_string()))?
320 .error_for_status()
321 .map_err(|e| CopilotError::HttpError(e.to_string()))?;
322 let agents_response: AgentsResponse = res
323 .json()
324 .await
325 .map_err(|e| CopilotError::Other(e.to_string()))?;
326 Ok(agents_response.agents)
327 }
328
329 pub async fn get_models(&self) -> Result<Vec<Model>, CopilotError> {
335 let url = "https://api.githubcopilot.com/models";
336 let headers = self.get_headers().await?;
337 let res = self
338 .http_client
339 .get(url)
340 .headers(headers)
341 .send()
342 .await
343 .map_err(|e| CopilotError::HttpError(e.to_string()))?
344 .error_for_status()
345 .map_err(|e| CopilotError::HttpError(e.to_string()))?;
346 let models_response: ModelsResponse = res
347 .json()
348 .await
349 .map_err(|e| CopilotError::Other(e.to_string()))?;
350 Ok(models_response.data)
351 }
352
353 pub async fn chat_completion(
365 &self,
366 messages: Vec<Message>,
367 model_id: String,
368 ) -> Result<ChatResponse, CopilotError> {
369 if !self.models.iter().any(|m| m.id == model_id) {
371 return Err(CopilotError::InvalidModel(model_id));
372 }
373 let url = "https://api.githubcopilot.com/chat/completions";
374 let headers = self.get_headers().await?;
375 let request_body = ChatRequest {
376 model: model_id,
377 messages,
378 n: 1,
379 top_p: 1.0,
380 stream: false,
381 temperature: 0.5,
382 max_tokens: None,
383 };
384 let res = self
385 .http_client
386 .post(url)
387 .headers(headers)
388 .json(&request_body)
389 .send()
390 .await
391 .map_err(|e| CopilotError::HttpError(e.to_string()))?
392 .error_for_status()
393 .map_err(|e| CopilotError::HttpError(e.to_string()))?;
394 let chat_response: ChatResponse = res
395 .json()
396 .await
397 .map_err(|e| CopilotError::Other(e.to_string()))?;
398 Ok(chat_response)
399 }
400
401 pub async fn get_embeddings(
411 &self,
412 inputs: Vec<String>,
413 ) -> Result<Vec<Embedding>, CopilotError> {
414 let url = "https://api.githubcopilot.com/embeddings";
415 let headers = self.get_headers().await?;
416 let request_body = EmbeddingRequest {
417 dimensions: 512,
418 input: inputs,
419 model: "text-embedding-3-small".to_string(),
420 };
421 let res = self
422 .http_client
423 .post(url)
424 .headers(headers)
425 .json(&request_body)
426 .send()
427 .await
428 .map_err(|e| CopilotError::HttpError(e.to_string()))?
429 .error_for_status()
430 .map_err(|e| CopilotError::HttpError(e.to_string()))?;
431 let embedding_response: EmbeddingResponse = res
432 .json()
433 .await
434 .map_err(|e| CopilotError::Other(e.to_string()))?;
435 Ok(embedding_response.data)
436 }
437}
438
439pub fn get_github_token() -> Result<String, Box<dyn Error>> {
445 if let Ok(token) = env::var("GITHUB_TOKEN") {
446 if env::var("CODESPACES").is_ok() {
447 return Ok(token);
448 }
449 }
450 let config_dir = get_config_path()?;
451 let file_paths = vec![
452 format!("{config_dir}/github-copilot/hosts.json"),
453 format!("{config_dir}/github-copilot/apps.json"),
454 ];
455 for file_path in file_paths {
456 if Path::new(&file_path).exists() {
457 let content = fs::read_to_string(&file_path)?;
458 let json_value: Value = serde_json::from_str(&content)?;
459 if let Some(obj) = json_value.as_object() {
460 for (key, value) in obj {
461 if key.contains("github.com") {
462 if let Some(oauth_token) = value.get("oauth_token") {
463 if let Some(token_str) = oauth_token.as_str() {
464 return Ok(token_str.to_string());
465 }
466 }
467 }
468 }
469 }
470 }
471 }
472 Err("Failed to find GitHub token".into())
473}
474
475pub fn get_config_path() -> Result<String, Box<dyn Error>> {
484 if let Ok(xdg) = env::var("XDG_CONFIG_HOME") {
485 if !xdg.is_empty() {
486 return Ok(xdg);
487 }
488 }
489 if cfg!(target_os = "windows") {
490 if let Ok(local) = env::var("LOCALAPPDATA") {
491 if !local.is_empty() {
492 return Ok(local);
493 }
494 }
495 } else if let Ok(home) = env::var("HOME") {
496 return Ok(format!("{home}/.config"));
497 }
498 Err("Failed to find config directory".into())
499}