greppy/ai/
claude.rs

1use crate::core::error::{Error, Result};
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use std::sync::Mutex;
5use std::time::{Duration, Instant};
6
7// Claude OAuth constants
8const CLIENT_ID: &str = "9d1c250a-e61b-44d9-88ed-5944d1962f5e";
9const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages";
10const ANTHROPIC_VERSION: &str = "2023-06-01";
11
12// Cache access token for 50 minutes (tokens typically expire in 1 hour)
13#[allow(dead_code)]
14const TOKEN_CACHE_DURATION: Duration = Duration::from_secs(50 * 60);
15
16#[derive(Debug, Serialize)]
17struct MessageRequest {
18    model: String,
19    max_tokens: u32,
20    system: String,
21    messages: Vec<Message>,
22}
23
24#[derive(Debug, Serialize)]
25struct Message {
26    role: String,
27    content: String,
28}
29
30#[derive(Debug, Deserialize)]
31struct MessageResponse {
32    content: Option<Vec<ContentBlock>>,
33    error: Option<ApiError>,
34}
35
36#[derive(Debug, Deserialize)]
37struct ContentBlock {
38    text: Option<String>,
39}
40
41#[derive(Debug, Deserialize)]
42struct ApiError {
43    message: String,
44}
45
46#[derive(Debug, Deserialize)]
47struct TokenResponse {
48    access_token: String,
49    expires_in: u64,
50}
51
52/// Cached access token with expiry
53struct CachedToken {
54    token: String,
55    expires_at: Instant,
56}
57
58pub struct ClaudeClient {
59    client: Client,
60    refresh_token: String,
61    cached_token: Mutex<Option<CachedToken>>,
62}
63
64impl ClaudeClient {
65    /// Create a new Claude client with OAuth refresh token
66    pub fn new(refresh_token: String) -> Self {
67        Self {
68            client: Client::new(),
69            refresh_token,
70            cached_token: Mutex::new(None),
71        }
72    }
73
74    /// Get access token, using cache if valid
75    async fn get_access_token(&self) -> Result<String> {
76        // Check cache first
77        if let Ok(guard) = self.cached_token.lock() {
78            if let Some(ref cached) = *guard {
79                if Instant::now() < cached.expires_at {
80                    return Ok(cached.token.clone());
81                }
82            }
83        }
84
85        // Refresh token
86        let params = serde_json::json!({
87            "grant_type": "refresh_token",
88            "refresh_token": self.refresh_token,
89            "client_id": CLIENT_ID,
90        });
91
92        let res = self
93            .client
94            .post("https://console.anthropic.com/v1/oauth/token")
95            .header("Content-Type", "application/json")
96            .json(&params)
97            .send()
98            .await
99            .map_err(|e| Error::DaemonError {
100                message: format!("Token refresh failed: {}", e),
101            })?;
102
103        if !res.status().is_success() {
104            let text = res.text().await.unwrap_or_default();
105            return Err(Error::DaemonError {
106                message: format!("Token refresh error: {}", text),
107            });
108        }
109
110        let token_response: TokenResponse = res.json().await.map_err(|e| Error::DaemonError {
111            message: format!("Failed to parse token response: {}", e),
112        })?;
113
114        // Cache the token
115        let expires_at = Instant::now()
116            + Duration::from_secs(token_response.expires_in.saturating_sub(600).max(60));
117        if let Ok(mut guard) = self.cached_token.lock() {
118            *guard = Some(CachedToken {
119                token: token_response.access_token.clone(),
120                expires_at,
121            });
122        }
123
124        Ok(token_response.access_token)
125    }
126
127    /// Rerank search results by relevance to query
128    /// Returns JSON array of indices in order of relevance: [2, 0, 5, 1, ...]
129    pub async fn rerank(&self, query: &str, chunks: &[String]) -> Result<Vec<usize>> {
130        let access_token = self.get_access_token().await?;
131
132        let system_prompt =
133            "You are a code search reranker. Given a query and numbered code chunks, \
134            return ONLY a JSON array of chunk indices ordered by relevance to the query. \
135            Most relevant first. Example response: [2, 0, 5, 1, 3, 4]";
136
137        let mut user_prompt = format!("Query: {}\n\nCode chunks:\n", query);
138        for (i, chunk) in chunks.iter().enumerate() {
139            user_prompt.push_str(&format!("\n--- Chunk {} ---\n{}\n", i, chunk));
140        }
141        user_prompt.push_str("\nReturn ONLY the JSON array of indices, nothing else.");
142
143        let request_body = MessageRequest {
144            model: "claude-3-5-haiku-latest".to_string(),
145            max_tokens: 256,
146            system: system_prompt.to_string(),
147            messages: vec![Message {
148                role: "user".to_string(),
149                content: user_prompt,
150            }],
151        };
152
153        let res = self
154            .client
155            .post(ANTHROPIC_API_URL)
156            .query(&[("beta", "true")])
157            .header("Authorization", format!("Bearer {}", access_token))
158            .header("anthropic-version", ANTHROPIC_VERSION)
159            .header("anthropic-beta", "oauth-2025-04-20")
160            .header("User-Agent", "greppy/0.9.0")
161            .header("Content-Type", "application/json")
162            .json(&request_body)
163            .send()
164            .await
165            .map_err(|e| Error::DaemonError {
166                message: format!("API request failed: {}", e),
167            })?;
168
169        if !res.status().is_success() {
170            let text = res.text().await.unwrap_or_default();
171            return Err(Error::DaemonError {
172                message: format!("Claude API Error: {}", text),
173            });
174        }
175
176        let response: MessageResponse = res.json().await.map_err(|e| Error::DaemonError {
177            message: format!("Failed to parse response: {}", e),
178        })?;
179
180        if let Some(error) = response.error {
181            return Err(Error::DaemonError {
182                message: format!("Claude API Error: {}", error.message),
183            });
184        }
185
186        // Parse the JSON array from response
187        if let Some(content) = response.content {
188            if let Some(block) = content.first() {
189                if let Some(text) = &block.text {
190                    // Extract JSON array from response
191                    let text = text.trim();
192                    if let Ok(indices) = serde_json::from_str::<Vec<usize>>(text) {
193                        return Ok(indices);
194                    }
195                    // Try to find JSON array in the text
196                    if let Some(start) = text.find('[') {
197                        if let Some(end) = text.rfind(']') {
198                            let json_str = &text[start..=end];
199                            if let Ok(indices) = serde_json::from_str::<Vec<usize>>(json_str) {
200                                return Ok(indices);
201                            }
202                        }
203                    }
204                }
205            }
206        }
207
208        // Fallback: return original order
209        Ok((0..chunks.len()).collect())
210    }
211
212    /// Expand a query into related symbol names for trace operations
213    /// Input: "auth" -> Output: ["auth", "login", "authenticate", "session", ...]
214    pub async fn expand_query(&self, query: &str) -> Result<Vec<String>> {
215        use crate::ai::trace_prompts::{
216            build_expansion_prompt, parse_expansion_response, QUERY_EXPANSION_SYSTEM,
217        };
218
219        let access_token = self.get_access_token().await?;
220
221        let request_body = MessageRequest {
222            model: "claude-3-5-haiku-latest".to_string(),
223            max_tokens: 256,
224            system: QUERY_EXPANSION_SYSTEM.to_string(),
225            messages: vec![Message {
226                role: "user".to_string(),
227                content: build_expansion_prompt(query),
228            }],
229        };
230
231        let res = self
232            .client
233            .post(ANTHROPIC_API_URL)
234            .query(&[("beta", "true")])
235            .header("Authorization", format!("Bearer {}", access_token))
236            .header("anthropic-version", ANTHROPIC_VERSION)
237            .header("anthropic-beta", "oauth-2025-04-20")
238            .header("User-Agent", "greppy/0.9.0")
239            .header("Content-Type", "application/json")
240            .json(&request_body)
241            .send()
242            .await
243            .map_err(|e| Error::DaemonError {
244                message: format!("API request failed: {}", e),
245            })?;
246
247        if !res.status().is_success() {
248            let text = res.text().await.unwrap_or_default();
249            return Err(Error::DaemonError {
250                message: format!("Claude API Error: {}", text),
251            });
252        }
253
254        let response: MessageResponse = res.json().await.map_err(|e| Error::DaemonError {
255            message: format!("Failed to parse response: {}", e),
256        })?;
257
258        if let Some(error) = response.error {
259            return Err(Error::DaemonError {
260                message: format!("Claude API Error: {}", error.message),
261            });
262        }
263
264        // Parse the expanded symbols from response
265        if let Some(content) = response.content {
266            if let Some(block) = content.first() {
267                if let Some(text) = &block.text {
268                    let symbols = parse_expansion_response(text);
269                    if !symbols.is_empty() {
270                        return Ok(symbols);
271                    }
272                }
273            }
274        }
275
276        // Fallback: return the original query as a single symbol
277        Ok(vec![query.to_string()])
278    }
279
280    /// Rerank trace invocation paths by relevance to query
281    /// Returns indices in order of relevance: [2, 0, 5, 1, ...]
282    pub async fn rerank_trace(&self, query: &str, paths: &[String]) -> Result<Vec<usize>> {
283        use crate::ai::trace_prompts::{
284            build_trace_rerank_prompt, parse_rerank_response, TRACE_RERANK_SYSTEM,
285        };
286
287        let access_token = self.get_access_token().await?;
288
289        let request_body = MessageRequest {
290            model: "claude-3-5-haiku-latest".to_string(),
291            max_tokens: 256,
292            system: TRACE_RERANK_SYSTEM.to_string(),
293            messages: vec![Message {
294                role: "user".to_string(),
295                content: build_trace_rerank_prompt(query, paths),
296            }],
297        };
298
299        let res = self
300            .client
301            .post(ANTHROPIC_API_URL)
302            .query(&[("beta", "true")])
303            .header("Authorization", format!("Bearer {}", access_token))
304            .header("anthropic-version", ANTHROPIC_VERSION)
305            .header("anthropic-beta", "oauth-2025-04-20")
306            .header("User-Agent", "greppy/0.9.0")
307            .header("Content-Type", "application/json")
308            .json(&request_body)
309            .send()
310            .await
311            .map_err(|e| Error::DaemonError {
312                message: format!("API request failed: {}", e),
313            })?;
314
315        if !res.status().is_success() {
316            let text = res.text().await.unwrap_or_default();
317            return Err(Error::DaemonError {
318                message: format!("Claude API Error: {}", text),
319            });
320        }
321
322        let response: MessageResponse = res.json().await.map_err(|e| Error::DaemonError {
323            message: format!("Failed to parse response: {}", e),
324        })?;
325
326        if let Some(error) = response.error {
327            return Err(Error::DaemonError {
328                message: format!("Claude API Error: {}", error.message),
329            });
330        }
331
332        // Parse the reranked indices from response
333        if let Some(content) = response.content {
334            if let Some(block) = content.first() {
335                if let Some(text) = &block.text {
336                    let indices = parse_rerank_response(text, paths.len());
337                    if !indices.is_empty() {
338                        return Ok(indices);
339                    }
340                }
341            }
342        }
343
344        // Fallback: return original order
345        Ok((0..paths.len()).collect())
346    }
347}