aiclient_api/providers/copilot/
mod.rs1pub mod client;
2pub mod headers;
3pub mod models;
4
5use anyhow::{Context, Result};
6use async_trait::async_trait;
7use futures::StreamExt;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::RwLock;
12use tokio::time::sleep;
13
14use crate::auth::copilot::fetch_copilot_token;
15use crate::config::types::AccountType;
16use crate::providers::{Model, OutputFormat, Provider, ProviderRequest, ProviderResponse};
17use client::CopilotClient;
18use headers::CopilotHeaders;
19
20pub struct CopilotToken {
21 pub copilot_token: String,
22 pub expires_at: i64,
23 pub refresh_in: u64,
24}
25
26pub struct CopilotProvider {
27 client: CopilotClient,
28 headers: Arc<headers::CopilotHeaders>,
29 token: Arc<RwLock<Option<CopilotToken>>>,
30 github_token: String,
31 #[allow(dead_code)]
32 account_type: AccountType,
33 healthy: AtomicBool,
34}
35
36impl CopilotProvider {
37 pub fn new(
38 github_token: String,
39 account_type: AccountType,
40 vscode_version: &str,
41 ) -> Arc<Self> {
42 let client = CopilotClient::new(&account_type);
43 let headers = Arc::new(CopilotHeaders::new(vscode_version));
44
45 Arc::new(Self {
46 client,
47 headers,
48 token: Arc::new(RwLock::new(None)),
49 github_token,
50 account_type,
51 healthy: AtomicBool::new(false),
52 })
53 }
54
55 pub fn start(self: &Arc<Self>) {
56 self.headers.start_session_rotation();
57 self.start_token_refresh();
58 }
59
60 fn start_token_refresh(self: &Arc<Self>) {
61 let provider = self.clone();
62 tokio::spawn(async move {
63 let mut consecutive_failures: u32 = 0;
64 loop {
65 match fetch_copilot_token(
66 provider.client.http_client(),
67 &provider.github_token,
68 )
69 .await
70 {
71 Ok(resp) => {
72 consecutive_failures = 0;
73 let refresh_in = resp.refresh_in;
74 {
75 let mut token = provider.token.write().await;
76 *token = Some(CopilotToken {
77 copilot_token: resp.token,
78 expires_at: resp.expires_at,
79 refresh_in: resp.refresh_in,
80 });
81 }
82 provider.healthy.store(true, Ordering::Relaxed);
83 tracing::info!("Copilot token refreshed successfully");
84
85 let sleep_secs = if refresh_in > 60 {
86 refresh_in - 60
87 } else {
88 1
89 };
90 sleep(Duration::from_secs(sleep_secs)).await;
91 }
92 Err(e) => {
93 consecutive_failures += 1;
94 tracing::warn!(
95 "Failed to fetch Copilot token ({} consecutive): {:#}",
96 consecutive_failures,
97 e
98 );
99 if consecutive_failures >= 3 {
100 provider.healthy.store(false, Ordering::Relaxed);
101 }
102 sleep(Duration::from_secs(15)).await;
103 }
104 }
105 }
106 });
107 }
108
109 async fn get_copilot_token(&self) -> Result<String> {
110 let token = self.token.read().await;
111 token
112 .as_ref()
113 .map(|t| t.copilot_token.clone())
114 .context("Copilot token not yet available")
115 }
116}
117
118#[async_trait]
119impl Provider for CopilotProvider {
120 fn name(&self) -> &str {
121 "copilot"
122 }
123
124 fn is_healthy(&self) -> bool {
125 self.healthy.load(Ordering::Relaxed)
126 }
127
128 async fn list_models(&self) -> Result<Vec<Model>> {
129 let copilot_token = self.get_copilot_token().await?;
130 models::fetch_models(&self.client, &self.headers, &copilot_token).await
131 }
132
133 async fn chat(&self, request: ProviderRequest) -> Result<ProviderResponse> {
134 let copilot_token = self.get_copilot_token().await?;
135 let headers = self.headers.build(&copilot_token);
136
137 let model_id = if let Some(stripped) = request.model.strip_prefix("copilot/") {
139 stripped.to_string()
140 } else {
141 request.model.clone()
142 };
143
144 let mut body = serde_json::json!({
145 "model": model_id,
146 "messages": request.messages,
147 "stream": request.stream,
148 });
149
150 if let Some(temp) = request.temperature {
151 body["temperature"] = serde_json::json!(temp);
152 }
153 if let Some(max_tok) = request.max_tokens {
154 body["max_tokens"] = serde_json::json!(max_tok);
155 }
156 if let Some(tools) = request.tools {
157 body["tools"] = serde_json::json!(tools);
158 }
159 if let Some(tc) = request.tool_choice {
160 body["tool_choice"] = tc;
161 }
162 if let Some(system) = request.system {
163 if let Some(messages) = body["messages"].as_array_mut() {
165 messages.insert(0, serde_json::json!({"role": "system", "content": system}));
166 }
167 }
168
169 if request.stream {
170 let resp = self
171 .client
172 .chat_completions(headers, body, true)
173 .await?;
174
175 let byte_stream = resp
176 .bytes_stream()
177 .map(|r| r.map_err(|e| anyhow::anyhow!(e)));
178
179 Ok(ProviderResponse::Stream(Box::pin(byte_stream)))
180 } else {
181 let resp = self
182 .client
183 .chat_completions(headers, body, false)
184 .await?;
185
186 let json: serde_json::Value = resp.json().await.context("Failed to parse chat response")?;
187 Ok(ProviderResponse::Complete(json))
188 }
189 }
190
191 fn supports_passthrough(&self, _format: OutputFormat) -> bool {
192 true
193 }
194
195 async fn passthrough(
196 &self,
197 _model: &str,
198 body: serde_json::Value,
199 format: OutputFormat,
200 stream: bool,
201 ) -> Result<ProviderResponse> {
202 let copilot_token = self.get_copilot_token().await?;
203 let headers = self.headers.build(&copilot_token);
204
205 let resp = match format {
206 OutputFormat::OpenAI => {
207 self.client.chat_completions(headers, body, stream).await?
208 }
209 OutputFormat::Anthropic => {
210 self.client.messages(headers, body, stream).await?
211 }
212 };
213
214 if stream {
215 let byte_stream = resp
216 .bytes_stream()
217 .map(|r| r.map_err(|e| anyhow::anyhow!(e)));
218
219 Ok(ProviderResponse::Stream(Box::pin(byte_stream)))
220 } else {
221 let json: serde_json::Value =
222 resp.json().await.context("Failed to parse passthrough response")?;
223 Ok(ProviderResponse::Complete(json))
224 }
225 }
226}