use reqwest::{
header::{HeaderMap, HeaderValue, ACCEPT, AUTHORIZATION, USER_AGENT},
Client as HttpClient,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{env, error::Error, fmt, fs, path::Path};
#[derive(Debug)]
pub enum CopilotError {
InvalidModel(String),
TokenError(String),
HttpError(String),
Other(String),
}
impl fmt::Display for CopilotError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CopilotError::InvalidModel(model) => {
write!(f, "Invalid model specified: {model}")
}
CopilotError::TokenError(msg) => write!(f, "Token error: {msg}"),
CopilotError::HttpError(msg) => write!(f, "HTTP error: {msg}"),
CopilotError::Other(msg) => write!(f, "{msg}"),
}
}
}
impl Error for CopilotError {}
#[derive(Debug, Serialize, Deserialize)]
pub struct CopilotTokenResponse {
pub token: String,
pub expires_at: u64,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Agent {
pub id: String,
pub name: String,
pub description: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct AgentsResponse {
pub agents: Vec<Agent>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Model {
pub id: String,
pub name: String,
pub version: Option<String>,
pub tokenizer: Option<String>,
pub max_input_tokens: Option<u32>,
pub max_output_tokens: Option<u32>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ModelsResponse {
pub data: Vec<Model>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<Message>,
pub n: u32,
pub top_p: f64,
pub stream: bool,
pub temperature: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatChoice {
pub message: Message,
pub finish_reason: Option<String>,
pub usage: Option<TokenUsage>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TokenUsage {
pub total_tokens: u32,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatResponse {
pub choices: Vec<ChatChoice>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct EmbeddingRequest {
pub dimensions: u32,
pub input: Vec<String>,
pub model: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Embedding {
pub index: usize,
pub embedding: Vec<f64>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct EmbeddingResponse {
pub data: Vec<Embedding>,
}
pub struct CopilotClient {
http_client: HttpClient,
github_token: String,
editor_version: String,
models: Vec<Model>,
}
impl CopilotClient {
pub async fn from_env_with_models(editor_version: String) -> Result<Self, CopilotError> {
let github_token =
get_github_token().map_err(|e| CopilotError::TokenError(e.to_string()))?;
Self::new_with_models(github_token, editor_version).await
}
pub async fn new_with_models(
github_token: String,
editor_version: String,
) -> Result<Self, CopilotError> {
let http_client = HttpClient::new();
let mut client = CopilotClient {
http_client,
github_token,
editor_version,
models: Vec::new(),
};
let models = client.get_models().await?;
client.models = models;
Ok(client)
}
async fn get_headers(&self) -> Result<HeaderMap, CopilotError> {
let token = self.get_copilot_token().await?;
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {token}"))
.map_err(|e| CopilotError::Other(e.to_string()))?,
);
headers.insert(
"Editor-Version",
HeaderValue::from_str(&self.editor_version)
.map_err(|e| CopilotError::Other(e.to_string()))?,
);
headers.insert(
"Editor-Plugin-Version",
HeaderValue::from_static("CopilotChat.nvim/*"),
);
headers.insert(
"Copilot-Integration-Id",
HeaderValue::from_static("vscode-chat"),
);
headers.insert(USER_AGENT, HeaderValue::from_static("CopilotChat.nvim"));
headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
Ok(headers)
}
async fn get_copilot_token(&self) -> Result<String, CopilotError> {
let url = "https://api.github.com/copilot_internal/v2/token";
let mut headers = HeaderMap::new();
headers.insert(USER_AGENT, HeaderValue::from_static("CopilotChat.nvim"));
headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
headers.insert(
"Authorization",
HeaderValue::from_str(&format!("Token {}", self.github_token))
.map_err(|e| CopilotError::Other(e.to_string()))?,
);
let res = self
.http_client
.get(url)
.headers(headers)
.send()
.await
.map_err(|e| CopilotError::HttpError(e.to_string()))?
.error_for_status()
.map_err(|e| CopilotError::HttpError(e.to_string()))?;
let token_response: CopilotTokenResponse = res
.json()
.await
.map_err(|e| CopilotError::Other(e.to_string()))?;
Ok(token_response.token)
}
pub async fn get_agents(&self) -> Result<Vec<Agent>, CopilotError> {
let url = "https://api.githubcopilot.com/agents";
let headers = self.get_headers().await?;
let res = self
.http_client
.get(url)
.headers(headers)
.send()
.await
.map_err(|e| CopilotError::HttpError(e.to_string()))?
.error_for_status()
.map_err(|e| CopilotError::HttpError(e.to_string()))?;
let agents_response: AgentsResponse = res
.json()
.await
.map_err(|e| CopilotError::Other(e.to_string()))?;
Ok(agents_response.agents)
}
pub async fn get_models(&self) -> Result<Vec<Model>, CopilotError> {
let url = "https://api.githubcopilot.com/models";
let headers = self.get_headers().await?;
let res = self
.http_client
.get(url)
.headers(headers)
.send()
.await
.map_err(|e| CopilotError::HttpError(e.to_string()))?
.error_for_status()
.map_err(|e| CopilotError::HttpError(e.to_string()))?;
let models_response: ModelsResponse = res
.json()
.await
.map_err(|e| CopilotError::Other(e.to_string()))?;
Ok(models_response.data)
}
pub async fn chat_completion(
&self,
messages: Vec<Message>,
model_id: String,
) -> Result<ChatResponse, CopilotError> {
if !self.models.iter().any(|m| m.id == model_id) {
return Err(CopilotError::InvalidModel(model_id));
}
let url = "https://api.githubcopilot.com/chat/completions";
let headers = self.get_headers().await?;
let request_body = ChatRequest {
model: model_id,
messages,
n: 1,
top_p: 1.0,
stream: false,
temperature: 0.5,
max_tokens: None,
};
let res = self
.http_client
.post(url)
.headers(headers)
.json(&request_body)
.send()
.await
.map_err(|e| CopilotError::HttpError(e.to_string()))?
.error_for_status()
.map_err(|e| CopilotError::HttpError(e.to_string()))?;
let chat_response: ChatResponse = res
.json()
.await
.map_err(|e| CopilotError::Other(e.to_string()))?;
Ok(chat_response)
}
pub async fn get_embeddings(
&self,
inputs: Vec<String>,
) -> Result<Vec<Embedding>, CopilotError> {
let url = "https://api.githubcopilot.com/embeddings";
let headers = self.get_headers().await?;
let request_body = EmbeddingRequest {
dimensions: 512,
input: inputs,
model: "text-embedding-3-small".to_string(),
};
let res = self
.http_client
.post(url)
.headers(headers)
.json(&request_body)
.send()
.await
.map_err(|e| CopilotError::HttpError(e.to_string()))?
.error_for_status()
.map_err(|e| CopilotError::HttpError(e.to_string()))?;
let embedding_response: EmbeddingResponse = res
.json()
.await
.map_err(|e| CopilotError::Other(e.to_string()))?;
Ok(embedding_response.data)
}
}
pub fn get_github_token() -> Result<String, Box<dyn Error>> {
if let Ok(token) = env::var("GITHUB_TOKEN") {
if env::var("CODESPACES").is_ok() {
return Ok(token);
}
}
let config_dir = get_config_path()?;
let file_paths = vec![
format!("{config_dir}/github-copilot/hosts.json"),
format!("{config_dir}/github-copilot/apps.json"),
];
for file_path in file_paths {
if Path::new(&file_path).exists() {
let content = fs::read_to_string(&file_path)?;
let json_value: Value = serde_json::from_str(&content)?;
if let Some(obj) = json_value.as_object() {
for (key, value) in obj {
if key.contains("github.com") {
if let Some(oauth_token) = value.get("oauth_token") {
if let Some(token_str) = oauth_token.as_str() {
return Ok(token_str.to_string());
}
}
}
}
}
}
}
Err("Failed to find GitHub token".into())
}
pub fn get_config_path() -> Result<String, Box<dyn Error>> {
if let Ok(xdg) = env::var("XDG_CONFIG_HOME") {
if !xdg.is_empty() {
return Ok(xdg);
}
}
if cfg!(target_os = "windows") {
if let Ok(local) = env::var("LOCALAPPDATA") {
if !local.is_empty() {
return Ok(local);
}
}
} else if let Ok(home) = env::var("HOME") {
return Ok(format!("{home}/.config"));
}
Err("Failed to find config directory".into())
}