copilot_client/
lib.rs

1//! # GitHub Copilot Client Library
2//!
3//! This library provides a client for interacting with the GitHub Copilot API. It handles token
4//! management, model retrieval, chat completions, and embeddings requests. The client uses
5//! [reqwest](https://crates.io/crates/reqwest) for HTTP requests and [serde](https://crates.io/crates/serde)
6//! for JSON serialization/deserialization.
7//!
8//! ## Features
9//!
10//! - Retrieve a GitHub token from the environment or configuration files.
11//! - Fetch available Copilot models and agents.
12//! - Send chat completion requests and receive responses.
13//! - Request embeddings for provided input strings.
14
15use reqwest::{
16    header::{HeaderMap, HeaderValue, ACCEPT, AUTHORIZATION, USER_AGENT},
17    Client as HttpClient,
18};
19use serde::{Deserialize, Serialize};
20use serde_json::Value;
21use std::{env, error::Error, fmt, fs, path::Path};
22
23/// Represents errors that can occur when interacting with the GitHub Copilot API.
24#[derive(Debug)]
25pub enum CopilotError {
26    /// An invalid model was specified.
27    InvalidModel(String),
28    /// An error occurred while retrieving or parsing the GitHub token.
29    TokenError(String),
30    /// An HTTP error occurred during the API call.
31    HttpError(String),
32    /// Other errors.
33    Other(String),
34}
35
36impl fmt::Display for CopilotError {
37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38        match self {
39            CopilotError::InvalidModel(model) => {
40                write!(f, "Invalid model specified: {model}")
41            }
42            CopilotError::TokenError(msg) => write!(f, "Token error: {msg}"),
43            CopilotError::HttpError(msg) => write!(f, "HTTP error: {msg}"),
44            CopilotError::Other(msg) => write!(f, "{msg}"),
45        }
46    }
47}
48
49impl Error for CopilotError {}
50
51/// Response from the GitHub Copilot token endpoint.
52///
53/// The `expires_at` field is a Unix timestamp.
54#[derive(Debug, Serialize, Deserialize)]
55pub struct CopilotTokenResponse {
56    /// The Copilot token.
57    pub token: String,
58    /// Expiration time as a Unix timestamp.
59    pub expires_at: u64,
60}
61
62/// Represents an agent returned by the GitHub Copilot API.
63#[derive(Debug, Serialize, Deserialize)]
64pub struct Agent {
65    /// The agent's identifier.
66    pub id: String,
67    /// The agent's name.
68    pub name: String,
69    /// An optional description for the agent.
70    pub description: Option<String>,
71}
72
73/// Response payload for retrieving agents.
74#[derive(Debug, Serialize, Deserialize)]
75pub struct AgentsResponse {
76    /// List of agents.
77    pub agents: Vec<Agent>,
78}
79
80/// Represents a model available for GitHub Copilot.
81#[derive(Debug, Serialize, Deserialize)]
82pub struct Model {
83    /// The model identifier.
84    pub id: String,
85    /// The model name.
86    pub name: String,
87    /// The version of the model, if available.
88    pub version: Option<String>,
89    /// The tokenizer used by the model, if available.
90    pub tokenizer: Option<String>,
91    /// Maximum number of input tokens allowed.
92    pub max_input_tokens: Option<u32>,
93    /// Maximum number of output tokens allowed.
94    pub max_output_tokens: Option<u32>,
95}
96
97/// Response payload for retrieving models.
98#[derive(Debug, Serialize, Deserialize)]
99pub struct ModelsResponse {
100    /// List of models.
101    pub data: Vec<Model>,
102}
103
104/// Represents a chat message.
105///
106/// The `role` field typically contains values such as `"system"`, `"user"`, or `"assistant"`.
107#[derive(Debug, Serialize, Deserialize)]
108pub struct Message {
109    /// The role of the message sender.
110    pub role: String,
111    /// The content of the message.
112    pub content: String,
113}
114
115/// Request payload for a chat completion.
116#[derive(Debug, Serialize, Deserialize)]
117pub struct ChatRequest {
118    /// The model identifier to use.
119    pub model: String,
120    /// The list of messages to send.
121    pub messages: Vec<Message>,
122    /// Number of chat completions to generate.
123    pub n: u32,
124    /// Nucleus sampling probability.
125    pub top_p: f64,
126    /// Whether to stream the response.
127    pub stream: bool,
128    /// Sampling temperature.
129    pub temperature: f64,
130    /// Optional maximum number of tokens to generate.
131    #[serde(skip_serializing_if = "Option::is_none")]
132    pub max_tokens: Option<u32>,
133}
134
135/// Represents a single choice in a chat completion response.
136#[derive(Debug, Serialize, Deserialize)]
137pub struct ChatChoice {
138    /// The message generated by the model.
139    pub message: Message,
140    /// The reason why the generation finished.
141    pub finish_reason: Option<String>,
142    /// Optional token usage information.
143    pub usage: Option<TokenUsage>,
144}
145
146/// Information about token usage in a chat response.
147#[derive(Debug, Serialize, Deserialize)]
148pub struct TokenUsage {
149    /// Total tokens used.
150    pub total_tokens: u32,
151}
152
153/// Response payload for a chat completion request.
154#[derive(Debug, Serialize, Deserialize)]
155pub struct ChatResponse {
156    /// List of generated chat choices.
157    pub choices: Vec<ChatChoice>,
158}
159
160/// Request payload for an embeddings request.
161#[derive(Debug, Serialize, Deserialize)]
162pub struct EmbeddingRequest {
163    /// The dimensions of the embedding vector.
164    pub dimensions: u32,
165    /// List of input strings to embed.
166    pub input: Vec<String>,
167    /// The model identifier to use for embeddings.
168    pub model: String,
169}
170
171/// Represents an individual embedding.
172#[derive(Debug, Serialize, Deserialize)]
173pub struct Embedding {
174    /// The index corresponding to the input.
175    pub index: usize,
176    /// The embedding vector.
177    pub embedding: Vec<f64>,
178}
179
180/// Response payload for an embeddings request.
181#[derive(Debug, Serialize, Deserialize)]
182pub struct EmbeddingResponse {
183    /// List of embeddings.
184    pub data: Vec<Embedding>,
185}
186
187/// Client for interacting with the GitHub Copilot API.
188///
189/// This client handles GitHub token retrieval, fetching available models,
190/// and sending API requests for chat completions and embeddings.
191pub struct CopilotClient {
192    http_client: HttpClient,
193    github_token: String,
194    editor_version: String,
195    /// List of available models.
196    models: Vec<Model>,
197}
198
199impl CopilotClient {
200    /// Creates a new `CopilotClient` by retrieving the GitHub token from environment variables
201    /// or configuration files, and then fetching the list of available models.
202    ///
203    /// # Arguments
204    ///
205    /// * `editor_version` - The version of the editor (e.g., "1.0.0").
206    ///
207    /// # Errors
208    ///
209    /// Returns a `CopilotError` if the token retrieval or model fetching fails.
210    pub async fn from_env_with_models(editor_version: String) -> Result<Self, CopilotError> {
211        let github_token =
212            get_github_token().map_err(|e| CopilotError::TokenError(e.to_string()))?;
213        Self::new_with_models(github_token, editor_version).await
214    }
215
216    /// Creates a new `CopilotClient` with the provided GitHub token and editor version,
217    /// and fetches the list of available models.
218    ///
219    /// # Arguments
220    ///
221    /// * `github_token` - The GitHub token for authentication.
222    /// * `editor_version` - The version of the editor.
223    ///
224    /// # Errors
225    ///
226    /// Returns a `CopilotError` if the model fetching fails.
227    pub async fn new_with_models(
228        github_token: String,
229        editor_version: String,
230    ) -> Result<Self, CopilotError> {
231        let http_client = HttpClient::new();
232        let mut client = CopilotClient {
233            http_client,
234            github_token,
235            editor_version,
236            models: Vec::new(),
237        };
238        // Fetch and store the available models.
239        let models = client.get_models().await?;
240        client.models = models;
241        Ok(client)
242    }
243
244    /// Constructs the HTTP headers required for GitHub Copilot API requests.
245    ///
246    /// This includes the authentication token, editor version information,
247    /// and other necessary headers.
248    async fn get_headers(&self) -> Result<HeaderMap, CopilotError> {
249        let token = self.get_copilot_token().await?;
250        let mut headers = HeaderMap::new();
251        headers.insert(
252            AUTHORIZATION,
253            HeaderValue::from_str(&format!("Bearer {token}"))
254                .map_err(|e| CopilotError::Other(e.to_string()))?,
255        );
256        headers.insert(
257            "Editor-Version",
258            HeaderValue::from_str(&self.editor_version)
259                .map_err(|e| CopilotError::Other(e.to_string()))?,
260        );
261        headers.insert(
262            "Editor-Plugin-Version",
263            HeaderValue::from_static("CopilotChat.nvim/*"),
264        );
265        headers.insert(
266            "Copilot-Integration-Id",
267            HeaderValue::from_static("vscode-chat"),
268        );
269        headers.insert(USER_AGENT, HeaderValue::from_static("CopilotChat.nvim"));
270        headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
271        Ok(headers)
272    }
273
274    /// Retrieves a GitHub Copilot token using the stored GitHub token.
275    ///
276    /// # Errors
277    ///
278    /// Returns a `CopilotError` if the HTTP request fails or the response cannot be parsed.
279    async fn get_copilot_token(&self) -> Result<String, CopilotError> {
280        let url = "https://api.github.com/copilot_internal/v2/token";
281        let mut headers = HeaderMap::new();
282        headers.insert(USER_AGENT, HeaderValue::from_static("CopilotChat.nvim"));
283        headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
284        headers.insert(
285            "Authorization",
286            HeaderValue::from_str(&format!("Token {}", self.github_token))
287                .map_err(|e| CopilotError::Other(e.to_string()))?,
288        );
289        let res = self
290            .http_client
291            .get(url)
292            .headers(headers)
293            .send()
294            .await
295            .map_err(|e| CopilotError::HttpError(e.to_string()))?
296            .error_for_status()
297            .map_err(|e| CopilotError::HttpError(e.to_string()))?;
298        let token_response: CopilotTokenResponse = res
299            .json()
300            .await
301            .map_err(|e| CopilotError::Other(e.to_string()))?;
302        Ok(token_response.token)
303    }
304
305    /// Fetches the list of agents from the GitHub Copilot API.
306    ///
307    /// # Errors
308    ///
309    /// Returns a `CopilotError` if the HTTP request fails or the response cannot be parsed.
310    pub async fn get_agents(&self) -> Result<Vec<Agent>, CopilotError> {
311        let url = "https://api.githubcopilot.com/agents";
312        let headers = self.get_headers().await?;
313        let res = self
314            .http_client
315            .get(url)
316            .headers(headers)
317            .send()
318            .await
319            .map_err(|e| CopilotError::HttpError(e.to_string()))?
320            .error_for_status()
321            .map_err(|e| CopilotError::HttpError(e.to_string()))?;
322        let agents_response: AgentsResponse = res
323            .json()
324            .await
325            .map_err(|e| CopilotError::Other(e.to_string()))?;
326        Ok(agents_response.agents)
327    }
328
329    /// Fetches the list of available models from the GitHub Copilot API.
330    ///
331    /// # Errors
332    ///
333    /// Returns a `CopilotError` if the HTTP request fails or the response cannot be parsed.
334    pub async fn get_models(&self) -> Result<Vec<Model>, CopilotError> {
335        let url = "https://api.githubcopilot.com/models";
336        let headers = self.get_headers().await?;
337        let res = self
338            .http_client
339            .get(url)
340            .headers(headers)
341            .send()
342            .await
343            .map_err(|e| CopilotError::HttpError(e.to_string()))?
344            .error_for_status()
345            .map_err(|e| CopilotError::HttpError(e.to_string()))?;
346        let models_response: ModelsResponse = res
347            .json()
348            .await
349            .map_err(|e| CopilotError::Other(e.to_string()))?;
350        Ok(models_response.data)
351    }
352
353    /// Sends a chat completion request to the GitHub Copilot API.
354    ///
355    /// # Arguments
356    ///
357    /// * `messages` - A vector of chat messages to send.
358    /// * `model_id` - The identifier of the model to use.
359    ///
360    /// # Errors
361    ///
362    /// Returns a `CopilotError::InvalidModel` error if the specified model is not available,
363    /// or another `CopilotError` if the HTTP request or response parsing fails.
364    pub async fn chat_completion(
365        &self,
366        messages: Vec<Message>,
367        model_id: String,
368    ) -> Result<ChatResponse, CopilotError> {
369        // Check if the specified model is available.
370        if !self.models.iter().any(|m| m.id == model_id) {
371            return Err(CopilotError::InvalidModel(model_id));
372        }
373        let url = "https://api.githubcopilot.com/chat/completions";
374        let headers = self.get_headers().await?;
375        let request_body = ChatRequest {
376            model: model_id,
377            messages,
378            n: 1,
379            top_p: 1.0,
380            stream: false,
381            temperature: 0.5,
382            max_tokens: None,
383        };
384        let res = self
385            .http_client
386            .post(url)
387            .headers(headers)
388            .json(&request_body)
389            .send()
390            .await
391            .map_err(|e| CopilotError::HttpError(e.to_string()))?
392            .error_for_status()
393            .map_err(|e| CopilotError::HttpError(e.to_string()))?;
394        let chat_response: ChatResponse = res
395            .json()
396            .await
397            .map_err(|e| CopilotError::Other(e.to_string()))?;
398        Ok(chat_response)
399    }
400
401    /// Sends an embeddings request to the GitHub Copilot API.
402    ///
403    /// # Arguments
404    ///
405    /// * `inputs` - A vector of input strings to generate embeddings for.
406    ///
407    /// # Errors
408    ///
409    /// Returns a `CopilotError` if the HTTP request fails or the response cannot be parsed.
410    pub async fn get_embeddings(
411        &self,
412        inputs: Vec<String>,
413    ) -> Result<Vec<Embedding>, CopilotError> {
414        let url = "https://api.githubcopilot.com/embeddings";
415        let headers = self.get_headers().await?;
416        let request_body = EmbeddingRequest {
417            dimensions: 512,
418            input: inputs,
419            model: "text-embedding-3-small".to_string(),
420        };
421        let res = self
422            .http_client
423            .post(url)
424            .headers(headers)
425            .json(&request_body)
426            .send()
427            .await
428            .map_err(|e| CopilotError::HttpError(e.to_string()))?
429            .error_for_status()
430            .map_err(|e| CopilotError::HttpError(e.to_string()))?;
431        let embedding_response: EmbeddingResponse = res
432            .json()
433            .await
434            .map_err(|e| CopilotError::Other(e.to_string()))?;
435        Ok(embedding_response.data)
436    }
437}
438
439/// Retrieves the GitHub token from the `GITHUB_TOKEN` environment variable or from a configuration file.
440///
441/// # Errors
442///
443/// Returns an error if the token is not found in the environment or configuration files.
444pub fn get_github_token() -> Result<String, Box<dyn Error>> {
445    if let Ok(token) = env::var("GITHUB_TOKEN") {
446        if env::var("CODESPACES").is_ok() {
447            return Ok(token);
448        }
449    }
450    let config_dir = get_config_path()?;
451    let file_paths = vec![
452        format!("{config_dir}/github-copilot/hosts.json"),
453        format!("{config_dir}/github-copilot/apps.json"),
454    ];
455    for file_path in file_paths {
456        if Path::new(&file_path).exists() {
457            let content = fs::read_to_string(&file_path)?;
458            let json_value: Value = serde_json::from_str(&content)?;
459            if let Some(obj) = json_value.as_object() {
460                for (key, value) in obj {
461                    if key.contains("github.com") {
462                        if let Some(oauth_token) = value.get("oauth_token") {
463                            if let Some(token_str) = oauth_token.as_str() {
464                                return Ok(token_str.to_string());
465                            }
466                        }
467                    }
468                }
469            }
470        }
471    }
472    Err("Failed to find GitHub token".into())
473}
474
475/// Returns the user's configuration directory.
476///
477/// On Unix systems, this is determined by the `XDG_CONFIG_HOME` environment variable or defaults
478/// to `$HOME/.config`. On Windows, it uses `LOCALAPPDATA`.
479///
480/// # Errors
481///
482/// Returns an error if the configuration directory cannot be determined.
483pub fn get_config_path() -> Result<String, Box<dyn Error>> {
484    if let Ok(xdg) = env::var("XDG_CONFIG_HOME") {
485        if !xdg.is_empty() {
486            return Ok(xdg);
487        }
488    }
489    if cfg!(target_os = "windows") {
490        if let Ok(local) = env::var("LOCALAPPDATA") {
491            if !local.is_empty() {
492                return Ok(local);
493            }
494        }
495    } else if let Ok(home) = env::var("HOME") {
496        return Ok(format!("{home}/.config"));
497    }
498    Err("Failed to find config directory".into())
499}