Skip to main content

aiclient_api/providers/copilot/
mod.rs

1pub mod client;
2pub mod headers;
3pub mod models;
4
5use anyhow::{Context, Result};
6use async_trait::async_trait;
7use futures::StreamExt;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::RwLock;
12use tokio::time::sleep;
13
14use crate::auth::copilot::fetch_copilot_token;
15use crate::config::types::AccountType;
16use crate::providers::{Model, OutputFormat, Provider, ProviderRequest, ProviderResponse};
17use client::CopilotClient;
18use headers::CopilotHeaders;
19
20pub struct CopilotToken {
21    pub copilot_token: String,
22    pub expires_at: i64,
23    pub refresh_in: u64,
24}
25
26pub struct CopilotProvider {
27    client: CopilotClient,
28    headers: Arc<headers::CopilotHeaders>,
29    token: Arc<RwLock<Option<CopilotToken>>>,
30    github_token: String,
31    #[allow(dead_code)]
32    account_type: AccountType,
33    healthy: AtomicBool,
34}
35
36impl CopilotProvider {
37    pub fn new(
38        github_token: String,
39        account_type: AccountType,
40        vscode_version: &str,
41    ) -> Arc<Self> {
42        let client = CopilotClient::new(&account_type);
43        let headers = Arc::new(CopilotHeaders::new(vscode_version));
44
45        Arc::new(Self {
46            client,
47            headers,
48            token: Arc::new(RwLock::new(None)),
49            github_token,
50            account_type,
51            healthy: AtomicBool::new(false),
52        })
53    }
54
55    pub fn start(self: &Arc<Self>) {
56        self.headers.start_session_rotation();
57        self.start_token_refresh();
58    }
59
60    fn start_token_refresh(self: &Arc<Self>) {
61        let provider = self.clone();
62        tokio::spawn(async move {
63            let mut consecutive_failures: u32 = 0;
64            loop {
65                match fetch_copilot_token(
66                    provider.client.http_client(),
67                    &provider.github_token,
68                )
69                .await
70                {
71                    Ok(resp) => {
72                        consecutive_failures = 0;
73                        let refresh_in = resp.refresh_in;
74                        {
75                            let mut token = provider.token.write().await;
76                            *token = Some(CopilotToken {
77                                copilot_token: resp.token,
78                                expires_at: resp.expires_at,
79                                refresh_in: resp.refresh_in,
80                            });
81                        }
82                        provider.healthy.store(true, Ordering::Relaxed);
83                        tracing::info!("Copilot token refreshed successfully");
84
85                        let sleep_secs = if refresh_in > 60 {
86                            refresh_in - 60
87                        } else {
88                            1
89                        };
90                        sleep(Duration::from_secs(sleep_secs)).await;
91                    }
92                    Err(e) => {
93                        consecutive_failures += 1;
94                        tracing::warn!(
95                            "Failed to fetch Copilot token ({} consecutive): {:#}",
96                            consecutive_failures,
97                            e
98                        );
99                        if consecutive_failures >= 3 {
100                            provider.healthy.store(false, Ordering::Relaxed);
101                        }
102                        sleep(Duration::from_secs(15)).await;
103                    }
104                }
105            }
106        });
107    }
108
109    async fn get_copilot_token(&self) -> Result<String> {
110        let token = self.token.read().await;
111        token
112            .as_ref()
113            .map(|t| t.copilot_token.clone())
114            .context("Copilot token not yet available")
115    }
116}
117
118#[async_trait]
119impl Provider for CopilotProvider {
120    fn name(&self) -> &str {
121        "copilot"
122    }
123
124    fn is_healthy(&self) -> bool {
125        self.healthy.load(Ordering::Relaxed)
126    }
127
128    async fn list_models(&self) -> Result<Vec<Model>> {
129        let copilot_token = self.get_copilot_token().await?;
130        models::fetch_models(&self.client, &self.headers, &copilot_token).await
131    }
132
133    async fn chat(&self, request: ProviderRequest) -> Result<ProviderResponse> {
134        let copilot_token = self.get_copilot_token().await?;
135        let headers = self.headers.build(&copilot_token);
136
137        // Strip provider prefix from model id if present
138        let model_id = if let Some(stripped) = request.model.strip_prefix("copilot/") {
139            stripped.to_string()
140        } else {
141            request.model.clone()
142        };
143
144        let mut body = serde_json::json!({
145            "model": model_id,
146            "messages": request.messages,
147            "stream": request.stream,
148        });
149
150        if let Some(temp) = request.temperature {
151            body["temperature"] = serde_json::json!(temp);
152        }
153        if let Some(max_tok) = request.max_tokens {
154            body["max_tokens"] = serde_json::json!(max_tok);
155        }
156        if let Some(tools) = request.tools {
157            body["tools"] = serde_json::json!(tools);
158        }
159        if let Some(tc) = request.tool_choice {
160            body["tool_choice"] = tc;
161        }
162        if let Some(system) = request.system {
163            // Prepend system as a system message
164            if let Some(messages) = body["messages"].as_array_mut() {
165                messages.insert(0, serde_json::json!({"role": "system", "content": system}));
166            }
167        }
168
169        if request.stream {
170            let resp = self
171                .client
172                .chat_completions(headers, body, true)
173                .await?;
174
175            let byte_stream = resp
176                .bytes_stream()
177                .map(|r| r.map_err(|e| anyhow::anyhow!(e)));
178
179            Ok(ProviderResponse::Stream(Box::pin(byte_stream)))
180        } else {
181            let resp = self
182                .client
183                .chat_completions(headers, body, false)
184                .await?;
185
186            let json: serde_json::Value = resp.json().await.context("Failed to parse chat response")?;
187            Ok(ProviderResponse::Complete(json))
188        }
189    }
190
191    fn supports_passthrough(&self, _format: OutputFormat) -> bool {
192        true
193    }
194
195    async fn passthrough(
196        &self,
197        _model: &str,
198        body: serde_json::Value,
199        format: OutputFormat,
200        stream: bool,
201    ) -> Result<ProviderResponse> {
202        let copilot_token = self.get_copilot_token().await?;
203        let headers = self.headers.build(&copilot_token);
204
205        let resp = match format {
206            OutputFormat::OpenAI => {
207                self.client.chat_completions(headers, body, stream).await?
208            }
209            OutputFormat::Anthropic => {
210                self.client.messages(headers, body, stream).await?
211            }
212        };
213
214        if stream {
215            let byte_stream = resp
216                .bytes_stream()
217                .map(|r| r.map_err(|e| anyhow::anyhow!(e)));
218
219            Ok(ProviderResponse::Stream(Box::pin(byte_stream)))
220        } else {
221            let json: serde_json::Value =
222                resp.json().await.context("Failed to parse passthrough response")?;
223            Ok(ProviderResponse::Complete(json))
224        }
225    }
226}