agentroot_core/providers/
github.rs

1//! GitHub provider
2//!
3//! Provides content from GitHub repositories, files, and gists.
4//! Supports both public and private repositories with authentication.
5
6use super::{ProviderConfig, SourceItem, SourceProvider};
7use crate::db::hash_content;
8use crate::error::{AgentRootError, Result};
9use crate::index::extract_title;
10use base64::Engine;
11use serde::Deserialize;
12
13/// GitHub provider
14pub struct GitHubProvider {
15    client: reqwest::Client,
16}
17
18const MAX_RETRIES: u32 = 3;
19const INITIAL_BACKOFF_MS: u64 = 1000;
20
21impl GitHubProvider {
22    /// Create new GitHub provider
23    pub fn new() -> Self {
24        let client = reqwest::Client::builder()
25            .user_agent("agentroot/1.0")
26            .build()
27            .unwrap_or_else(|_| reqwest::Client::new());
28
29        Self { client }
30    }
31
32    /// Parse GitHub URL into components
33    fn parse_github_url(&self, url: &str) -> Result<GitHubUrl> {
34        let url = url.trim();
35
36        if url.starts_with("https://github.com/") || url.starts_with("http://github.com/") {
37            let parts: Vec<&str> = url
38                .trim_start_matches("https://github.com/")
39                .trim_start_matches("http://github.com/")
40                .split('/')
41                .collect();
42
43            if parts.len() >= 2 {
44                let owner = parts[0].to_string();
45                let repo = parts[1].to_string();
46
47                if parts.len() == 2 {
48                    return Ok(GitHubUrl::Repository { owner, repo });
49                }
50
51                if parts.len() >= 5 && parts[2] == "blob" {
52                    let branch = parts[3].to_string();
53                    let path = parts[4..].join("/");
54                    return Ok(GitHubUrl::File {
55                        owner,
56                        repo,
57                        branch,
58                        path,
59                    });
60                }
61            }
62        }
63
64        Err(AgentRootError::InvalidInput(format!(
65            "Invalid GitHub URL: {}. \
66             Expected format: https://github.com/owner/repo or https://github.com/owner/repo/blob/branch/path",
67            url
68        )))
69    }
70
71    /// Get GitHub API token from environment
72    fn get_token(&self, config: &ProviderConfig) -> Option<String> {
73        config
74            .get_option("github_token")
75            .cloned()
76            .or_else(|| std::env::var("GITHUB_TOKEN").ok())
77    }
78
79    /// Check rate limit from response headers and log warnings
80    fn check_rate_limit(&self, response: &reqwest::Response) {
81        if let Some(remaining) = response.headers().get("x-ratelimit-remaining") {
82            if let Ok(remaining_str) = remaining.to_str() {
83                if let Ok(remaining_count) = remaining_str.parse::<i32>() {
84                    if remaining_count < 10 {
85                        eprintln!(
86                            "Warning: GitHub API rate limit low ({} requests remaining). \
87                             Set GITHUB_TOKEN to increase limits.",
88                            remaining_count
89                        );
90                    }
91                }
92            }
93        }
94    }
95
96    /// Send request with retry logic for rate limits
97    async fn send_with_retry(&self, request: reqwest::RequestBuilder) -> Result<reqwest::Response> {
98        let mut retries = 0;
99        let mut backoff_ms = INITIAL_BACKOFF_MS;
100
101        loop {
102            let req = request.try_clone().ok_or_else(|| {
103                AgentRootError::ExternalError("Failed to clone request".to_string())
104            })?;
105
106            match req.send().await {
107                Ok(response) => {
108                    self.check_rate_limit(&response);
109
110                    if response.status() == 429 && retries < MAX_RETRIES {
111                        let retry_after = response
112                            .headers()
113                            .get("retry-after")
114                            .and_then(|v| v.to_str().ok())
115                            .and_then(|v| v.parse::<u64>().ok())
116                            .unwrap_or(backoff_ms / 1000);
117
118                        eprintln!(
119                            "Rate limit exceeded. Retrying after {} seconds (attempt {}/{})",
120                            retry_after,
121                            retries + 1,
122                            MAX_RETRIES
123                        );
124
125                        tokio::time::sleep(tokio::time::Duration::from_secs(retry_after)).await;
126                        retries += 1;
127                        backoff_ms *= 2;
128                        continue;
129                    }
130
131                    return Ok(response);
132                }
133                Err(e) if retries < MAX_RETRIES && e.is_timeout() => {
134                    eprintln!(
135                        "Request timeout. Retrying in {} seconds (attempt {}/{})",
136                        backoff_ms / 1000,
137                        retries + 1,
138                        MAX_RETRIES
139                    );
140                    tokio::time::sleep(tokio::time::Duration::from_millis(backoff_ms)).await;
141                    retries += 1;
142                    backoff_ms *= 2;
143                }
144                Err(e) => return Err(e.into()),
145            }
146        }
147    }
148
149    /// Fetch file from GitHub
150    async fn fetch_file(
151        &self,
152        owner: &str,
153        repo: &str,
154        branch: &str,
155        path: &str,
156        token: Option<&str>,
157    ) -> Result<String> {
158        let raw_url = format!(
159            "https://raw.githubusercontent.com/{}/{}/{}/{}",
160            owner, repo, branch, path
161        );
162
163        let mut request = self.client.get(&raw_url);
164
165        if let Some(token) = token {
166            request = request.header("Authorization", format!("token {}", token));
167        }
168
169        let response = self.send_with_retry(request).await.map_err(|e| {
170            AgentRootError::ExternalError(format!(
171                "Failed to fetch file from GitHub: {}. Check your internet connection.",
172                e
173            ))
174        })?;
175
176        let status = response.status();
177        if !status.is_success() {
178            let error_msg = match status.as_u16() {
179                404 => format!(
180                    "File not found: {}/{}/{}/{}. Verify the repository, branch, and file path are correct.",
181                    owner, repo, branch, path
182                ),
183                403 => {
184                    "GitHub API rate limit exceeded or access forbidden. \
185                     Set GITHUB_TOKEN environment variable with a personal access token to increase rate limits. \
186                     Get token from: https://github.com/settings/tokens".to_string()
187                }
188                401 => {
189                    "Authentication failed. Your GITHUB_TOKEN may be invalid or expired. \
190                     Generate a new token at: https://github.com/settings/tokens".to_string()
191                }
192                _ => format!("GitHub API error {}: {}", status.as_u16(), status.canonical_reason().unwrap_or("Unknown error")),
193            };
194            return Err(AgentRootError::ExternalError(error_msg));
195        }
196
197        response.text().await.map_err(|e| {
198            AgentRootError::ExternalError(format!("Failed to read file content: {}", e))
199        })
200    }
201
202    /// Fetch README from repository
203    async fn fetch_readme(
204        &self,
205        owner: &str,
206        repo: &str,
207        token: Option<&str>,
208    ) -> Result<(String, String)> {
209        let api_url = format!("https://api.github.com/repos/{}/{}/readme", owner, repo);
210
211        let mut request = self.client.get(&api_url);
212
213        if let Some(token) = token {
214            request = request.header("Authorization", format!("token {}", token));
215        }
216
217        request = request.header("Accept", "application/vnd.github.v3+json");
218
219        let response = self.send_with_retry(request).await.map_err(|e| {
220            AgentRootError::ExternalError(format!(
221                "Failed to fetch README from GitHub: {}. Check your internet connection.",
222                e
223            ))
224        })?;
225
226        let status = response.status();
227        if !status.is_success() {
228            let error_msg = match status.as_u16() {
229                404 => format!(
230                    "README not found for repository {}/{}. The repository may not have a README file, or it may not exist.",
231                    owner, repo
232                ),
233                403 => {
234                    "GitHub API rate limit exceeded or repository access forbidden. \
235                     For public repositories, set GITHUB_TOKEN environment variable to increase rate limits. \
236                     For private repositories, ensure your token has 'repo' scope. \
237                     Get token from: https://github.com/settings/tokens".to_string()
238                }
239                401 => {
240                    "Authentication failed. Your GITHUB_TOKEN may be invalid or expired. \
241                     Generate a new token at: https://github.com/settings/tokens".to_string()
242                }
243                _ => format!("GitHub API error {}: {}", status.as_u16(), status.canonical_reason().unwrap_or("Unknown error")),
244            };
245            return Err(AgentRootError::ExternalError(error_msg));
246        }
247
248        let readme: ReadmeResponse = response.json().await.map_err(|e| {
249            AgentRootError::ExternalError(format!("Failed to parse README response: {}", e))
250        })?;
251        let content = String::from_utf8(
252            base64::engine::general_purpose::STANDARD
253                .decode(readme.content.replace('\n', ""))
254                .map_err(|e| {
255                    AgentRootError::ExternalError(format!("Base64 decode error: {}", e))
256                })?,
257        )
258        .map_err(|e| AgentRootError::ExternalError(format!("UTF-8 decode error: {}", e)))?;
259
260        Ok((readme.name, content))
261    }
262
263    /// List files in repository
264    async fn list_repo_files(
265        &self,
266        owner: &str,
267        repo: &str,
268        token: Option<&str>,
269    ) -> Result<Vec<RepoFile>> {
270        let api_url = format!(
271            "https://api.github.com/repos/{}/{}/git/trees/HEAD?recursive=1",
272            owner, repo
273        );
274
275        let mut request = self.client.get(&api_url);
276
277        if let Some(token) = token {
278            request = request.header("Authorization", format!("token {}", token));
279        }
280
281        request = request.header("Accept", "application/vnd.github.v3+json");
282
283        let response = self.send_with_retry(request).await.map_err(|e| {
284            AgentRootError::ExternalError(format!(
285                "Failed to list files from GitHub repository: {}. Check your internet connection.",
286                e
287            ))
288        })?;
289
290        let status = response.status();
291        if !status.is_success() {
292            let error_msg = match status.as_u16() {
293                404 => format!(
294                    "Repository not found: {}/{}. Verify the repository owner and name are correct.",
295                    owner, repo
296                ),
297                403 => {
298                    "GitHub API rate limit exceeded or repository access forbidden. \
299                     For public repositories, set GITHUB_TOKEN environment variable to increase rate limits. \
300                     For private repositories, ensure your token has 'repo' scope. \
301                     Get token from: https://github.com/settings/tokens".to_string()
302                }
303                401 => {
304                    "Authentication failed. Your GITHUB_TOKEN may be invalid or expired. \
305                     Generate a new token at: https://github.com/settings/tokens".to_string()
306                }
307                409 => format!(
308                    "Repository {}/{} is empty or has no commits yet.",
309                    owner, repo
310                ),
311                _ => format!("GitHub API error {}: {}", status.as_u16(), status.canonical_reason().unwrap_or("Unknown error")),
312            };
313            return Err(AgentRootError::ExternalError(error_msg));
314        }
315
316        let tree: TreeResponse = response.json().await.map_err(|e| {
317            AgentRootError::ExternalError(format!("Failed to parse repository file tree: {}", e))
318        })?;
319        Ok(tree.tree)
320    }
321}
322
323impl Default for GitHubProvider {
324    fn default() -> Self {
325        Self::new()
326    }
327}
328
329#[async_trait::async_trait]
330impl SourceProvider for GitHubProvider {
331    fn provider_type(&self) -> &'static str {
332        "github"
333    }
334
335    async fn list_items(&self, config: &ProviderConfig) -> Result<Vec<SourceItem>> {
336        let github_url = self.parse_github_url(&config.base_path)?;
337        let token = self.get_token(config);
338
339        match github_url {
340            GitHubUrl::Repository { owner, repo } => {
341                let files = self
342                    .list_repo_files(&owner, &repo, token.as_deref())
343                    .await?;
344                let pattern = glob::Pattern::new(&config.pattern)?;
345
346                let mut items = Vec::new();
347
348                for file in files {
349                    if file.file_type == "blob" && pattern.matches(&file.path) {
350                        let url = format!(
351                            "https://github.com/{}/{}/blob/HEAD/{}",
352                            owner, repo, file.path
353                        );
354                        match self.fetch_item(&url).await {
355                            Ok(item) => items.push(item),
356                            Err(_) => continue,
357                        }
358                    }
359                }
360
361                Ok(items)
362            }
363            GitHubUrl::File { .. } => {
364                let item = self.fetch_item(&config.base_path).await?;
365                Ok(vec![item])
366            }
367        }
368    }
369
370    async fn fetch_item(&self, uri: &str) -> Result<SourceItem> {
371        let github_url = self.parse_github_url(uri)?;
372        let token = std::env::var("GITHUB_TOKEN").ok();
373
374        match github_url {
375            GitHubUrl::Repository { owner, repo } => {
376                let (filename, content) =
377                    self.fetch_readme(&owner, &repo, token.as_deref()).await?;
378                let title = extract_title(&content, &filename);
379                let hash = hash_content(&content);
380                let uri = format!("{}/{}/{}", owner, repo, filename);
381
382                Ok(
383                    SourceItem::new(uri, title, content, hash, "github".to_string())
384                        .with_metadata("owner".to_string(), owner)
385                        .with_metadata("repo".to_string(), repo)
386                        .with_metadata("file".to_string(), filename),
387                )
388            }
389            GitHubUrl::File {
390                owner,
391                repo,
392                branch,
393                path,
394            } => {
395                let content = self
396                    .fetch_file(&owner, &repo, &branch, &path, token.as_deref())
397                    .await?;
398                let title = extract_title(&content, &path);
399                let hash = hash_content(&content);
400                let uri = format!("{}/{}/{}", owner, repo, path);
401
402                Ok(
403                    SourceItem::new(uri, title, content, hash, "github".to_string())
404                        .with_metadata("owner".to_string(), owner)
405                        .with_metadata("repo".to_string(), repo)
406                        .with_metadata("branch".to_string(), branch)
407                        .with_metadata("path".to_string(), path),
408                )
409            }
410        }
411    }
412}
413
414/// GitHub URL type
415#[derive(Debug, Clone)]
416enum GitHubUrl {
417    Repository {
418        owner: String,
419        repo: String,
420    },
421    File {
422        owner: String,
423        repo: String,
424        branch: String,
425        path: String,
426    },
427}
428
429/// GitHub API response for README
430#[derive(Debug, Deserialize)]
431struct ReadmeResponse {
432    name: String,
433    content: String,
434}
435
436/// GitHub API response for tree
437#[derive(Debug, Deserialize)]
438struct TreeResponse {
439    tree: Vec<RepoFile>,
440}
441
442/// Repository file from tree API
443#[derive(Debug, Deserialize)]
444struct RepoFile {
445    path: String,
446    #[serde(rename = "type")]
447    file_type: String,
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453
454    #[test]
455    fn test_github_provider_type() {
456        let provider = GitHubProvider::new();
457        assert_eq!(provider.provider_type(), "github");
458    }
459
460    #[test]
461    fn test_parse_github_repo_url() {
462        let provider = GitHubProvider::new();
463        let url = "https://github.com/rust-lang/rust";
464        let parsed = provider.parse_github_url(url).unwrap();
465
466        match parsed {
467            GitHubUrl::Repository { owner, repo } => {
468                assert_eq!(owner, "rust-lang");
469                assert_eq!(repo, "rust");
470            }
471            _ => panic!("Expected Repository variant"),
472        }
473    }
474
475    #[test]
476    fn test_parse_github_file_url() {
477        let provider = GitHubProvider::new();
478        let url = "https://github.com/rust-lang/rust/blob/master/README.md";
479        let parsed = provider.parse_github_url(url).unwrap();
480
481        match parsed {
482            GitHubUrl::File {
483                owner,
484                repo,
485                branch,
486                path,
487            } => {
488                assert_eq!(owner, "rust-lang");
489                assert_eq!(repo, "rust");
490                assert_eq!(branch, "master");
491                assert_eq!(path, "README.md");
492            }
493            _ => panic!("Expected File variant"),
494        }
495    }
496
497    #[test]
498    fn test_parse_invalid_url() {
499        let provider = GitHubProvider::new();
500        let url = "https://example.com/not-github";
501        let result = provider.parse_github_url(url);
502        assert!(result.is_err());
503    }
504
505    #[test]
506    fn test_parse_github_url_variants() {
507        let provider = GitHubProvider::new();
508
509        let test_cases = vec![
510            ("https://github.com/rust-lang/rust", true),
511            ("http://github.com/rust-lang/rust", true),
512            ("https://github.com/user/repo/blob/main/README.md", true),
513            (
514                "https://github.com/user/repo/blob/feature-branch/src/main.rs",
515                true,
516            ),
517            ("https://gitlab.com/user/repo", false),
518            ("github.com/user/repo", false),
519            ("https://github.com/", false),
520            ("https://github.com/user", false),
521        ];
522
523        for (url, should_succeed) in test_cases {
524            let result = provider.parse_github_url(url);
525            assert_eq!(
526                result.is_ok(),
527                should_succeed,
528                "URL: {} - Expected success: {}, Got: {:?}",
529                url,
530                should_succeed,
531                result
532            );
533        }
534    }
535
536    #[test]
537    fn test_parse_github_file_url_components() {
538        let provider = GitHubProvider::new();
539        let url = "https://github.com/rust-lang/rust/blob/master/src/main.rs";
540        let result = provider.parse_github_url(url).unwrap();
541
542        match result {
543            GitHubUrl::File {
544                owner,
545                repo,
546                branch,
547                path,
548            } => {
549                assert_eq!(owner, "rust-lang");
550                assert_eq!(repo, "rust");
551                assert_eq!(branch, "master");
552                assert_eq!(path, "src/main.rs");
553            }
554            _ => panic!("Expected File variant"),
555        }
556    }
557
558    #[test]
559    fn test_parse_github_file_url_nested_path() {
560        let provider = GitHubProvider::new();
561        let url = "https://github.com/owner/repo/blob/main/deep/nested/path/file.md";
562        let result = provider.parse_github_url(url).unwrap();
563
564        match result {
565            GitHubUrl::File { path, .. } => {
566                assert_eq!(path, "deep/nested/path/file.md");
567            }
568            _ => panic!("Expected File variant"),
569        }
570    }
571
572    #[test]
573    fn test_get_token_from_config() {
574        let provider = GitHubProvider::new();
575
576        let config = ProviderConfig::new(
577            "https://github.com/user/repo".to_string(),
578            "*.md".to_string(),
579        )
580        .with_option("github_token".to_string(), "ghp_test123".to_string());
581
582        let token = provider.get_token(&config);
583        assert_eq!(token, Some("ghp_test123".to_string()));
584    }
585
586    #[test]
587    fn test_get_token_priority() {
588        let provider = GitHubProvider::new();
589
590        let config_with_token = ProviderConfig::new(
591            "https://github.com/user/repo".to_string(),
592            "*.md".to_string(),
593        )
594        .with_option("github_token".to_string(), "ghp_config".to_string());
595
596        let token = provider.get_token(&config_with_token);
597        assert_eq!(token, Some("ghp_config".to_string()));
598    }
599
600    #[test]
601    fn test_provider_type() {
602        let provider = GitHubProvider::new();
603        assert_eq!(provider.provider_type(), "github");
604    }
605
606    #[test]
607    fn test_parse_github_url_edge_cases() {
608        let provider = GitHubProvider::new();
609
610        let edge_cases = vec![
611            "https://github.com/user/repo-with-dashes",
612            "https://github.com/user/repo_with_underscores",
613            "https://github.com/user/repo.with.dots",
614            "https://github.com/user-with-dash/repo",
615            "https://github.com/user_with_underscore/repo",
616        ];
617
618        for url in edge_cases {
619            let result = provider.parse_github_url(url);
620            assert!(result.is_ok(), "Failed to parse valid URL: {}", url);
621        }
622    }
623}