1use super::{ProviderConfig, SourceItem, SourceProvider};
7use crate::db::hash_content;
8use crate::error::{AgentRootError, Result};
9use crate::index::extract_title;
10use base64::Engine;
11use serde::Deserialize;
12
13pub struct GitHubProvider {
15 client: reqwest::Client,
16}
17
18const MAX_RETRIES: u32 = 3;
19const INITIAL_BACKOFF_MS: u64 = 1000;
20
21impl GitHubProvider {
22 pub fn new() -> Self {
24 let client = reqwest::Client::builder()
25 .user_agent("agentroot/1.0")
26 .build()
27 .unwrap_or_else(|_| reqwest::Client::new());
28
29 Self { client }
30 }
31
32 fn parse_github_url(&self, url: &str) -> Result<GitHubUrl> {
34 let url = url.trim();
35
36 if url.starts_with("https://github.com/") || url.starts_with("http://github.com/") {
37 let parts: Vec<&str> = url
38 .trim_start_matches("https://github.com/")
39 .trim_start_matches("http://github.com/")
40 .split('/')
41 .collect();
42
43 if parts.len() >= 2 {
44 let owner = parts[0].to_string();
45 let repo = parts[1].to_string();
46
47 if parts.len() == 2 {
48 return Ok(GitHubUrl::Repository { owner, repo });
49 }
50
51 if parts.len() >= 5 && parts[2] == "blob" {
52 let branch = parts[3].to_string();
53 let path = parts[4..].join("/");
54 return Ok(GitHubUrl::File {
55 owner,
56 repo,
57 branch,
58 path,
59 });
60 }
61 }
62 }
63
64 Err(AgentRootError::InvalidInput(format!(
65 "Invalid GitHub URL: {}. \
66 Expected format: https://github.com/owner/repo or https://github.com/owner/repo/blob/branch/path",
67 url
68 )))
69 }
70
71 fn get_token(&self, config: &ProviderConfig) -> Option<String> {
73 config
74 .get_option("github_token")
75 .cloned()
76 .or_else(|| std::env::var("GITHUB_TOKEN").ok())
77 }
78
79 fn check_rate_limit(&self, response: &reqwest::Response) {
81 if let Some(remaining) = response.headers().get("x-ratelimit-remaining") {
82 if let Ok(remaining_str) = remaining.to_str() {
83 if let Ok(remaining_count) = remaining_str.parse::<i32>() {
84 if remaining_count < 10 {
85 eprintln!(
86 "Warning: GitHub API rate limit low ({} requests remaining). \
87 Set GITHUB_TOKEN to increase limits.",
88 remaining_count
89 );
90 }
91 }
92 }
93 }
94 }
95
96 async fn send_with_retry(&self, request: reqwest::RequestBuilder) -> Result<reqwest::Response> {
98 let mut retries = 0;
99 let mut backoff_ms = INITIAL_BACKOFF_MS;
100
101 loop {
102 let req = request.try_clone().ok_or_else(|| {
103 AgentRootError::ExternalError("Failed to clone request".to_string())
104 })?;
105
106 match req.send().await {
107 Ok(response) => {
108 self.check_rate_limit(&response);
109
110 if response.status() == 429 && retries < MAX_RETRIES {
111 let retry_after = response
112 .headers()
113 .get("retry-after")
114 .and_then(|v| v.to_str().ok())
115 .and_then(|v| v.parse::<u64>().ok())
116 .unwrap_or(backoff_ms / 1000);
117
118 eprintln!(
119 "Rate limit exceeded. Retrying after {} seconds (attempt {}/{})",
120 retry_after,
121 retries + 1,
122 MAX_RETRIES
123 );
124
125 tokio::time::sleep(tokio::time::Duration::from_secs(retry_after)).await;
126 retries += 1;
127 backoff_ms *= 2;
128 continue;
129 }
130
131 return Ok(response);
132 }
133 Err(e) if retries < MAX_RETRIES && e.is_timeout() => {
134 eprintln!(
135 "Request timeout. Retrying in {} seconds (attempt {}/{})",
136 backoff_ms / 1000,
137 retries + 1,
138 MAX_RETRIES
139 );
140 tokio::time::sleep(tokio::time::Duration::from_millis(backoff_ms)).await;
141 retries += 1;
142 backoff_ms *= 2;
143 }
144 Err(e) => return Err(e.into()),
145 }
146 }
147 }
148
149 async fn fetch_file(
151 &self,
152 owner: &str,
153 repo: &str,
154 branch: &str,
155 path: &str,
156 token: Option<&str>,
157 ) -> Result<String> {
158 let raw_url = format!(
159 "https://raw.githubusercontent.com/{}/{}/{}/{}",
160 owner, repo, branch, path
161 );
162
163 let mut request = self.client.get(&raw_url);
164
165 if let Some(token) = token {
166 request = request.header("Authorization", format!("token {}", token));
167 }
168
169 let response = self.send_with_retry(request).await.map_err(|e| {
170 AgentRootError::ExternalError(format!(
171 "Failed to fetch file from GitHub: {}. Check your internet connection.",
172 e
173 ))
174 })?;
175
176 let status = response.status();
177 if !status.is_success() {
178 let error_msg = match status.as_u16() {
179 404 => format!(
180 "File not found: {}/{}/{}/{}. Verify the repository, branch, and file path are correct.",
181 owner, repo, branch, path
182 ),
183 403 => {
184 "GitHub API rate limit exceeded or access forbidden. \
185 Set GITHUB_TOKEN environment variable with a personal access token to increase rate limits. \
186 Get token from: https://github.com/settings/tokens".to_string()
187 }
188 401 => {
189 "Authentication failed. Your GITHUB_TOKEN may be invalid or expired. \
190 Generate a new token at: https://github.com/settings/tokens".to_string()
191 }
192 _ => format!("GitHub API error {}: {}", status.as_u16(), status.canonical_reason().unwrap_or("Unknown error")),
193 };
194 return Err(AgentRootError::ExternalError(error_msg));
195 }
196
197 response.text().await.map_err(|e| {
198 AgentRootError::ExternalError(format!("Failed to read file content: {}", e))
199 })
200 }
201
202 async fn fetch_readme(
204 &self,
205 owner: &str,
206 repo: &str,
207 token: Option<&str>,
208 ) -> Result<(String, String)> {
209 let api_url = format!("https://api.github.com/repos/{}/{}/readme", owner, repo);
210
211 let mut request = self.client.get(&api_url);
212
213 if let Some(token) = token {
214 request = request.header("Authorization", format!("token {}", token));
215 }
216
217 request = request.header("Accept", "application/vnd.github.v3+json");
218
219 let response = self.send_with_retry(request).await.map_err(|e| {
220 AgentRootError::ExternalError(format!(
221 "Failed to fetch README from GitHub: {}. Check your internet connection.",
222 e
223 ))
224 })?;
225
226 let status = response.status();
227 if !status.is_success() {
228 let error_msg = match status.as_u16() {
229 404 => format!(
230 "README not found for repository {}/{}. The repository may not have a README file, or it may not exist.",
231 owner, repo
232 ),
233 403 => {
234 "GitHub API rate limit exceeded or repository access forbidden. \
235 For public repositories, set GITHUB_TOKEN environment variable to increase rate limits. \
236 For private repositories, ensure your token has 'repo' scope. \
237 Get token from: https://github.com/settings/tokens".to_string()
238 }
239 401 => {
240 "Authentication failed. Your GITHUB_TOKEN may be invalid or expired. \
241 Generate a new token at: https://github.com/settings/tokens".to_string()
242 }
243 _ => format!("GitHub API error {}: {}", status.as_u16(), status.canonical_reason().unwrap_or("Unknown error")),
244 };
245 return Err(AgentRootError::ExternalError(error_msg));
246 }
247
248 let readme: ReadmeResponse = response.json().await.map_err(|e| {
249 AgentRootError::ExternalError(format!("Failed to parse README response: {}", e))
250 })?;
251 let content = String::from_utf8(
252 base64::engine::general_purpose::STANDARD
253 .decode(readme.content.replace('\n', ""))
254 .map_err(|e| {
255 AgentRootError::ExternalError(format!("Base64 decode error: {}", e))
256 })?,
257 )
258 .map_err(|e| AgentRootError::ExternalError(format!("UTF-8 decode error: {}", e)))?;
259
260 Ok((readme.name, content))
261 }
262
263 async fn list_repo_files(
265 &self,
266 owner: &str,
267 repo: &str,
268 token: Option<&str>,
269 ) -> Result<Vec<RepoFile>> {
270 let api_url = format!(
271 "https://api.github.com/repos/{}/{}/git/trees/HEAD?recursive=1",
272 owner, repo
273 );
274
275 let mut request = self.client.get(&api_url);
276
277 if let Some(token) = token {
278 request = request.header("Authorization", format!("token {}", token));
279 }
280
281 request = request.header("Accept", "application/vnd.github.v3+json");
282
283 let response = self.send_with_retry(request).await.map_err(|e| {
284 AgentRootError::ExternalError(format!(
285 "Failed to list files from GitHub repository: {}. Check your internet connection.",
286 e
287 ))
288 })?;
289
290 let status = response.status();
291 if !status.is_success() {
292 let error_msg = match status.as_u16() {
293 404 => format!(
294 "Repository not found: {}/{}. Verify the repository owner and name are correct.",
295 owner, repo
296 ),
297 403 => {
298 "GitHub API rate limit exceeded or repository access forbidden. \
299 For public repositories, set GITHUB_TOKEN environment variable to increase rate limits. \
300 For private repositories, ensure your token has 'repo' scope. \
301 Get token from: https://github.com/settings/tokens".to_string()
302 }
303 401 => {
304 "Authentication failed. Your GITHUB_TOKEN may be invalid or expired. \
305 Generate a new token at: https://github.com/settings/tokens".to_string()
306 }
307 409 => format!(
308 "Repository {}/{} is empty or has no commits yet.",
309 owner, repo
310 ),
311 _ => format!("GitHub API error {}: {}", status.as_u16(), status.canonical_reason().unwrap_or("Unknown error")),
312 };
313 return Err(AgentRootError::ExternalError(error_msg));
314 }
315
316 let tree: TreeResponse = response.json().await.map_err(|e| {
317 AgentRootError::ExternalError(format!("Failed to parse repository file tree: {}", e))
318 })?;
319 Ok(tree.tree)
320 }
321}
322
323impl Default for GitHubProvider {
324 fn default() -> Self {
325 Self::new()
326 }
327}
328
329#[async_trait::async_trait]
330impl SourceProvider for GitHubProvider {
331 fn provider_type(&self) -> &'static str {
332 "github"
333 }
334
335 async fn list_items(&self, config: &ProviderConfig) -> Result<Vec<SourceItem>> {
336 let github_url = self.parse_github_url(&config.base_path)?;
337 let token = self.get_token(config);
338
339 match github_url {
340 GitHubUrl::Repository { owner, repo } => {
341 let files = self
342 .list_repo_files(&owner, &repo, token.as_deref())
343 .await?;
344 let pattern = glob::Pattern::new(&config.pattern)?;
345
346 let mut items = Vec::new();
347
348 for file in files {
349 if file.file_type == "blob" && pattern.matches(&file.path) {
350 let url = format!(
351 "https://github.com/{}/{}/blob/HEAD/{}",
352 owner, repo, file.path
353 );
354 match self.fetch_item(&url).await {
355 Ok(item) => items.push(item),
356 Err(_) => continue,
357 }
358 }
359 }
360
361 Ok(items)
362 }
363 GitHubUrl::File { .. } => {
364 let item = self.fetch_item(&config.base_path).await?;
365 Ok(vec![item])
366 }
367 }
368 }
369
370 async fn fetch_item(&self, uri: &str) -> Result<SourceItem> {
371 let github_url = self.parse_github_url(uri)?;
372 let token = std::env::var("GITHUB_TOKEN").ok();
373
374 match github_url {
375 GitHubUrl::Repository { owner, repo } => {
376 let (filename, content) =
377 self.fetch_readme(&owner, &repo, token.as_deref()).await?;
378 let title = extract_title(&content, &filename);
379 let hash = hash_content(&content);
380 let uri = format!("{}/{}/{}", owner, repo, filename);
381
382 Ok(
383 SourceItem::new(uri, title, content, hash, "github".to_string())
384 .with_metadata("owner".to_string(), owner)
385 .with_metadata("repo".to_string(), repo)
386 .with_metadata("file".to_string(), filename),
387 )
388 }
389 GitHubUrl::File {
390 owner,
391 repo,
392 branch,
393 path,
394 } => {
395 let content = self
396 .fetch_file(&owner, &repo, &branch, &path, token.as_deref())
397 .await?;
398 let title = extract_title(&content, &path);
399 let hash = hash_content(&content);
400 let uri = format!("{}/{}/{}", owner, repo, path);
401
402 Ok(
403 SourceItem::new(uri, title, content, hash, "github".to_string())
404 .with_metadata("owner".to_string(), owner)
405 .with_metadata("repo".to_string(), repo)
406 .with_metadata("branch".to_string(), branch)
407 .with_metadata("path".to_string(), path),
408 )
409 }
410 }
411 }
412}
413
414#[derive(Debug, Clone)]
416enum GitHubUrl {
417 Repository {
418 owner: String,
419 repo: String,
420 },
421 File {
422 owner: String,
423 repo: String,
424 branch: String,
425 path: String,
426 },
427}
428
429#[derive(Debug, Deserialize)]
431struct ReadmeResponse {
432 name: String,
433 content: String,
434}
435
436#[derive(Debug, Deserialize)]
438struct TreeResponse {
439 tree: Vec<RepoFile>,
440}
441
442#[derive(Debug, Deserialize)]
444struct RepoFile {
445 path: String,
446 #[serde(rename = "type")]
447 file_type: String,
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453
454 #[test]
455 fn test_github_provider_type() {
456 let provider = GitHubProvider::new();
457 assert_eq!(provider.provider_type(), "github");
458 }
459
460 #[test]
461 fn test_parse_github_repo_url() {
462 let provider = GitHubProvider::new();
463 let url = "https://github.com/rust-lang/rust";
464 let parsed = provider.parse_github_url(url).unwrap();
465
466 match parsed {
467 GitHubUrl::Repository { owner, repo } => {
468 assert_eq!(owner, "rust-lang");
469 assert_eq!(repo, "rust");
470 }
471 _ => panic!("Expected Repository variant"),
472 }
473 }
474
475 #[test]
476 fn test_parse_github_file_url() {
477 let provider = GitHubProvider::new();
478 let url = "https://github.com/rust-lang/rust/blob/master/README.md";
479 let parsed = provider.parse_github_url(url).unwrap();
480
481 match parsed {
482 GitHubUrl::File {
483 owner,
484 repo,
485 branch,
486 path,
487 } => {
488 assert_eq!(owner, "rust-lang");
489 assert_eq!(repo, "rust");
490 assert_eq!(branch, "master");
491 assert_eq!(path, "README.md");
492 }
493 _ => panic!("Expected File variant"),
494 }
495 }
496
497 #[test]
498 fn test_parse_invalid_url() {
499 let provider = GitHubProvider::new();
500 let url = "https://example.com/not-github";
501 let result = provider.parse_github_url(url);
502 assert!(result.is_err());
503 }
504
505 #[test]
506 fn test_parse_github_url_variants() {
507 let provider = GitHubProvider::new();
508
509 let test_cases = vec![
510 ("https://github.com/rust-lang/rust", true),
511 ("http://github.com/rust-lang/rust", true),
512 ("https://github.com/user/repo/blob/main/README.md", true),
513 (
514 "https://github.com/user/repo/blob/feature-branch/src/main.rs",
515 true,
516 ),
517 ("https://gitlab.com/user/repo", false),
518 ("github.com/user/repo", false),
519 ("https://github.com/", false),
520 ("https://github.com/user", false),
521 ];
522
523 for (url, should_succeed) in test_cases {
524 let result = provider.parse_github_url(url);
525 assert_eq!(
526 result.is_ok(),
527 should_succeed,
528 "URL: {} - Expected success: {}, Got: {:?}",
529 url,
530 should_succeed,
531 result
532 );
533 }
534 }
535
536 #[test]
537 fn test_parse_github_file_url_components() {
538 let provider = GitHubProvider::new();
539 let url = "https://github.com/rust-lang/rust/blob/master/src/main.rs";
540 let result = provider.parse_github_url(url).unwrap();
541
542 match result {
543 GitHubUrl::File {
544 owner,
545 repo,
546 branch,
547 path,
548 } => {
549 assert_eq!(owner, "rust-lang");
550 assert_eq!(repo, "rust");
551 assert_eq!(branch, "master");
552 assert_eq!(path, "src/main.rs");
553 }
554 _ => panic!("Expected File variant"),
555 }
556 }
557
558 #[test]
559 fn test_parse_github_file_url_nested_path() {
560 let provider = GitHubProvider::new();
561 let url = "https://github.com/owner/repo/blob/main/deep/nested/path/file.md";
562 let result = provider.parse_github_url(url).unwrap();
563
564 match result {
565 GitHubUrl::File { path, .. } => {
566 assert_eq!(path, "deep/nested/path/file.md");
567 }
568 _ => panic!("Expected File variant"),
569 }
570 }
571
572 #[test]
573 fn test_get_token_from_config() {
574 let provider = GitHubProvider::new();
575
576 let config = ProviderConfig::new(
577 "https://github.com/user/repo".to_string(),
578 "*.md".to_string(),
579 )
580 .with_option("github_token".to_string(), "ghp_test123".to_string());
581
582 let token = provider.get_token(&config);
583 assert_eq!(token, Some("ghp_test123".to_string()));
584 }
585
586 #[test]
587 fn test_get_token_priority() {
588 let provider = GitHubProvider::new();
589
590 let config_with_token = ProviderConfig::new(
591 "https://github.com/user/repo".to_string(),
592 "*.md".to_string(),
593 )
594 .with_option("github_token".to_string(), "ghp_config".to_string());
595
596 let token = provider.get_token(&config_with_token);
597 assert_eq!(token, Some("ghp_config".to_string()));
598 }
599
600 #[test]
601 fn test_provider_type() {
602 let provider = GitHubProvider::new();
603 assert_eq!(provider.provider_type(), "github");
604 }
605
606 #[test]
607 fn test_parse_github_url_edge_cases() {
608 let provider = GitHubProvider::new();
609
610 let edge_cases = vec![
611 "https://github.com/user/repo-with-dashes",
612 "https://github.com/user/repo_with_underscores",
613 "https://github.com/user/repo.with.dots",
614 "https://github.com/user-with-dash/repo",
615 "https://github.com/user_with_underscore/repo",
616 ];
617
618 for url in edge_cases {
619 let result = provider.parse_github_url(url);
620 assert!(result.is_ok(), "Failed to parse valid URL: {}", url);
621 }
622 }
623}