github_code_searching_lib/
github_searcher.rs

1//! # GitHub Searcher Module
2//!
3//! This module provides functionality for searching GitHub code repositories
4//! with advanced features like concurrent searches, rate-limit handling,
5//! and real-time progress visualization.
6//!
7//! The core component is the `GitHubSearcher` struct which manages
8//! all search operations and handles paginated results, GitHub API interactions,
9//! and concurrent execution.
10
11use chrono::Utc;
12use futures::future::join_all;
13use indicatif::{ MultiProgress, ProgressBar, ProgressStyle };
14use reqwest::{ Client, StatusCode };
15use serde_json::{ json, Value };
16use std::env;
17use std::error::Error;
18use std::sync::Arc;
19use tokio::fs::OpenOptions;
20use tokio::io::AsyncWriteExt;
21use tokio::sync::Semaphore;
22use tokio::time::{ Duration, Instant };
23use tracing::{ debug, error, info, warn };
24
25use crate::Args;
26
27/// GitHubSearcher handles all aspects of searching for code on GitHub,
28/// including authentication, API rate limiting, concurrent processing,
29/// and results management.
30///
31/// This struct encapsulates the search process flow with a focus on
32/// concurrent execution and user feedback through progress indicators.
33pub struct GitHubSearcher {
34    /// HTTP client for making API requests
35    client: Client,
36
37    /// GitHub API token for authentication
38    token: String,
39
40    /// Path where search results will be saved
41    output_path: String,
42
43    /// Optional limit on the number of pages to retrieve per search term
44    max_page_limit: Option<u32>,
45
46    /// Multi-progress display for showing search progress
47    progress: Arc<MultiProgress>,
48
49    /// Maximum number of concurrent searches to run
50    concurrency: usize,
51
52    /// List of file extensions to include in the search
53    include_extensions: Arc<tokio::sync::Mutex<Vec<String>>>,
54
55    /// List of file extensions to exclude from the search
56    exclude_extensions: Arc<tokio::sync::Mutex<Vec<String>>>,
57}
58
59impl GitHubSearcher {
60    /// Creates a new GitHubSearcher instance configured with the provided arguments.
61    ///
62    /// # Arguments
63    ///
64    /// * `args` - Command line arguments containing search configuration
65    ///
66    /// # Returns
67    ///
68    /// A Result containing either the configured GitHubSearcher or an error
69    ///
70    /// # Errors
71    ///
72    /// Will return an error if:
73    /// - GitHub token is not provided and not found in environment
74    /// - HTTP client creation fails
75    ///
76    /// # Example
77    ///
78    /// ```no_run
79    /// use github_code_searching::{Args, GitHubSearcher};
80    /// use clap::Parser;
81    ///
82    /// #[tokio::main]
83    /// async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
84    ///     let args = Args::parse();
85    ///     let searcher = GitHubSearcher::new(&args).await?;
86    ///     Ok(())
87    /// }
88    /// ```
89    pub async fn new(args: &Args) -> Result<Self, Box<dyn Error + Send + Sync>> {
90        // Get GitHub API token from arguments or environment
91        let token = match &args.token {
92            Some(t) if !t.trim().is_empty() => t.clone(),
93            _ =>
94                match env::var("GITHUB_TOKEN") {
95                    Ok(token) if !token.trim().is_empty() => token,
96                    _ => {
97                        error!("GitHub token not provided or found in environment");
98                        return Err("GitHub token is required".into());
99                    }
100                }
101        };
102
103        // Create HTTP client
104        let client = Client::builder()
105            .user_agent("Mozilla/5.0 (Windows NT 10.0; Win64; x64)")
106            .build()?;
107
108        // Create progress display
109        let progress = Arc::new(MultiProgress::new());
110
111        Ok(GitHubSearcher {
112            client,
113            token,
114            output_path: args.output.clone(),
115            max_page_limit: args.max_pages,
116            progress,
117            concurrency: args.concurrency,
118            include_extensions: Arc::new(
119                tokio::sync::Mutex::new(args.include_extensions.clone().unwrap_or_default())
120            ),
121            exclude_extensions: Arc::new(
122                tokio::sync::Mutex::new(args.exclude_extensions.clone().unwrap_or_default())
123            ),
124        })
125    }
126
127    /// Executes searches for all provided words with concurrency control
128    /// and displays progress in real-time.
129    ///
130    /// This method orchestrates the entire search process:
131    /// 1. Sets up progress bars for visualization
132    /// 2. Spawns a concurrent task for each search term
133    /// 3. Manages concurrency with a semaphore
134    /// 4. Handles result output and error reporting
135    ///
136    /// # Arguments
137    ///
138    /// * `words` - Vector of search terms to process
139    ///
140    /// # Returns
141    ///
142    /// A Result indicating success or an error encountered during search
143    ///
144    /// # Errors
145    ///
146    /// Will return an error if any search task fails, including:
147    /// - API request errors
148    /// - Authentication failures
149    /// - File I/O errors when saving results
150    ///
151    /// # Example
152    ///
153    /// ```no_run
154    /// use github_code_searching::{Args, GitHubSearcher};
155    /// use clap::Parser;
156    ///
157    /// #[tokio::main]
158    /// async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
159    ///     let args = Args::parse();
160    ///     let searcher = GitHubSearcher::new(&args).await?;
161    ///     searcher.run(vec!["rust concurrency".to_string(), "tokio async".to_string()]).await?;
162    ///     Ok(())
163    /// }
164    /// ```
165    pub async fn run(&self, words: Vec<String>) -> Result<(), Box<dyn Error + Send + Sync>> {
166        // Create semaphore for concurrency control
167        let semaphore = Arc::new(Semaphore::new(self.concurrency));
168
169        // Create a thread-safe collection to store all results
170        let all_results = Arc::new(tokio::sync::Mutex::new(Vec::new()));
171
172        // Create progress bars for all words upfront
173        let progress_bars: Arc<std::collections::HashMap<String, ProgressBar>> = {
174            let mut bars = std::collections::HashMap::new();
175
176            // Create a main spinner style
177            let spinner_style = ProgressStyle::default_spinner()
178                .template("{spinner:.green} [{elapsed_precise}] {wide_msg}")
179                .unwrap()
180                .progress_chars("=>-")
181                .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏");
182
183            // Create a progress bar style for rate limiting
184            let progress_style = ProgressStyle::default_bar()
185                .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {wide_msg}")
186                .unwrap()
187                .progress_chars("=>-");
188
189            // Add a progress bar for each word
190            for word in &words {
191                let pb = self.progress.add(ProgressBar::new_spinner());
192                pb.set_style(spinner_style.clone());
193                pb.set_message(format!("Waiting to search for '{}'", word));
194                bars.insert(word.clone(), pb);
195            }
196
197            Arc::new(bars)
198        };
199
200        // Add a rate limit progress bar at the bottom
201        let rate_limit_pb = Arc::new(self.progress.add(ProgressBar::new(100)));
202        rate_limit_pb.set_style(
203            ProgressStyle::default_bar()
204                .template("[{elapsed_precise}] {bar:40.red/yellow} {pos:>7}/{len:7} {wide_msg}")
205                .unwrap()
206                .progress_chars("=>-")
207        );
208        rate_limit_pb.set_message("Rate limit status: OK");
209        rate_limit_pb.set_position(100); // Start full
210
211        // Launch tasks for each word
212        let mut tasks = Vec::new();
213        let mut spinner_tasks = Vec::new();
214
215        for word in words {
216            let word_clone = word.clone();
217            let sem_clone = semaphore.clone();
218            let client_clone = self.client.clone();
219            let token_clone = self.token.clone();
220            let max_pages = self.max_page_limit;
221            let pb = progress_bars.get(&word).unwrap().clone();
222            let rate_limit_pb_clone = rate_limit_pb.clone();
223            let results_clone = all_results.clone();
224
225            // Start a background ticker to keep spinner animated
226            let pb_ticker = pb.clone();
227            let spinner_task = tokio::spawn(async move {
228                loop {
229                    pb_ticker.tick();
230                    tokio::time::sleep(Duration::from_millis(80)).await;
231                }
232            });
233            spinner_tasks.push(spinner_task);
234
235            let exclude_extensions = self.exclude_extensions.clone();
236            let include_extensions = self.include_extensions.clone();
237
238            // Main search task
239            let task = tokio::spawn(async move {
240                // Acquire semaphore permit
241                let _permit = sem_clone.acquire().await.unwrap();
242
243                pb.set_message(format!("Starting search for '{}'", word_clone));
244
245                // Process this word
246                let result = GitHubSearcher::search_word(
247                    &client_clone,
248                    &token_clone,
249                    &word_clone,
250                    max_pages,
251                    pb.clone(),
252                    rate_limit_pb_clone,
253                    exclude_extensions.clone(),
254                    include_extensions.clone(),
255                    results_clone
256                ).await;
257
258                // Finalize progress bar
259                if result.is_ok() {
260                    pb.set_message(format!("✓ Completed '{}'", word_clone));
261                } else {
262                    pb.set_message(format!("✗ Failed '{}'", word_clone));
263                }
264
265                result
266            });
267
268            tasks.push(task);
269        }
270
271        // Await all tasks
272        let results = join_all(tasks).await;
273
274        // Abort ticker tasks now that main tasks are done
275        for task in spinner_tasks {
276            task.abort();
277        }
278
279        // Clear rate limit progress bar
280        rate_limit_pb.finish_and_clear();
281
282        // Check for errors
283        for result in results {
284            if let Ok(Err(e)) = result {
285                error!("Search task error: {}", e);
286                return Err(e);
287            }
288        }
289
290        // Write all accumulated results to file as a single JSON array
291        info!("Writing all results to disk as a single JSON array");
292        let all_results_guard = all_results.lock().await;
293        let json_output = serde_json::to_string(&*all_results_guard)?;
294
295        // Create output file
296        let mut file = OpenOptions::new()
297            .create(true)
298            .write(true)
299            .truncate(true)
300            .open(&self.output_path).await?;
301
302        // Write JSON array to file
303        file.write_all(json_output.as_bytes()).await?;
304        file.flush().await?;
305
306        info!("All searches completed successfully");
307        Ok(())
308    }
309
310    /// Executes a search for a single word, handling all pages of results.
311    ///
312    /// This method handles the complete search process for one term, including:
313    /// - Pagination through all results
314    /// - Progress updates
315    /// - Respecting page limits
316    ///
317    /// # Arguments
318    ///
319    /// * `client` - HTTP client for making API requests
320    /// * `token` - GitHub API token for authentication
321    /// * `word` - The search term to look for
322    /// * `max_page_limit` - Optional maximum number of pages to retrieve
323    /// * `pb` - Progress bar for this search term
324    /// * `rate_limit_pb` - Progress bar for rate limit visualization
325    /// * `all_results` - Shared collection to store results
326    ///
327    /// # Returns
328    ///
329    /// A Result indicating success or an error encountered during search
330    ///
331    /// # Errors
332    ///
333    /// Will propagate errors from the page search process
334    async fn search_word(
335        client: &Client,
336        token: &str,
337        word: &str,
338        max_page_limit: Option<u32>,
339        pb: ProgressBar,
340        rate_limit_pb: Arc<ProgressBar>,
341        exclude_extensions: Arc<tokio::sync::Mutex<Vec<String>>>,
342        include_extensions: Arc<tokio::sync::Mutex<Vec<String>>>,
343        all_results: Arc<tokio::sync::Mutex<Vec<Value>>>
344    ) -> Result<(), Box<dyn Error + Send + Sync>> {
345        let mut page: u32 = 1;
346
347        loop {
348            pb.set_message(format!("Searching {} - page {}", word, page));
349
350            // Search this page
351            match
352                GitHubSearcher::search_page(
353                    client,
354                    token,
355                    word,
356                    page,
357                    &pb,
358                    &rate_limit_pb,
359                    &exclude_extensions,
360                    &include_extensions,
361                    &all_results
362                ).await
363            {
364                Ok(has_more_pages) => {
365                    if !has_more_pages {
366                        debug!("No more results for '{}'", word);
367                        break;
368                    }
369                }
370                Err(e) => {
371                    error!("Error searching '{}' page {}: {}", word, page, e);
372                    return Err(e);
373                }
374            }
375
376            page += 1;
377
378            // Check page limit
379            if let Some(max_page) = max_page_limit {
380                if page > max_page {
381                    info!("Max page limit reached for '{}' (limit: {})", word, max_page);
382                    break;
383                }
384            }
385        }
386
387        Ok(())
388    }
389
390    /// Searches a specific page of results for a search term.
391    ///
392    /// This method:
393    /// 1. Makes a GitHub API request
394    /// 2. Handles rate limiting
395    /// 3. Processes and filters the response
396    /// 4. Adds filtered results to the shared collection
397    ///
398    /// # Arguments
399    ///
400    /// * `client` - HTTP client for making API requests
401    /// * `token` - GitHub API token for authentication
402    /// * `word` - The search term to look for
403    /// * `page` - Page number to retrieve (1-based)
404    /// * `pb` - Progress bar for this search term
405    /// * `rate_limit_pb` - Progress bar for rate limit visualization
406    /// * `all_results` - Shared collection to store results
407    ///
408    /// # Returns
409    ///
410    /// A Result containing a boolean indicating if more pages exist (true)
411    /// or if this was the last page (false)
412    ///
413    /// # Errors
414    ///
415    /// Will return an error for:
416    /// - API request failures
417    /// - Authentication issues
418    /// - JSON parsing problems
419    async fn search_page(
420        client: &Client,
421        token: &str,
422        word: &str,
423        page: u32,
424        pb: &ProgressBar,
425        rate_limit_pb: &ProgressBar,
426        exclude_extensions: &Arc<tokio::sync::Mutex<Vec<String>>>,
427        include_extensions: &Arc<tokio::sync::Mutex<Vec<String>>>,
428        all_results: &Arc<tokio::sync::Mutex<Vec<Value>>>
429    ) -> Result<bool, Box<dyn Error + Send + Sync>> {
430        // Construct the search URL
431        // Include extensions if provided
432        let include_ext = &*include_extensions.lock().await;
433        let include_ext_str = include_ext.join("%20extension:");
434        let exclude_ext = &*exclude_extensions.lock().await;
435        let exclude_ext_str = exclude_ext.join("%20-extension:");
436        let word = if !include_ext.is_empty() {
437            format!("{} extension:{}", word, include_ext_str)
438        } else if !exclude_ext.is_empty() {
439            format!("{} -extension:{}", word, exclude_ext_str)
440        } else {
441            word.to_string()
442        };
443
444        let url = format!(
445            "https://api.github.com/search/code?q={}&page={}&per_page=100",
446            word,
447            page
448        );
449
450        debug!("Requesting URL: {}", url);
451        let response = client
452            .get(&url)
453            .header("Accept", "application/vnd.github.text-match+json")
454            .header("Authorization", format!("Bearer {}", token))
455            .header("X-GitHub-Api-Version", "2022-11-28")
456            .send().await?;
457
458        // Handle pagination limit
459        if response.status() == StatusCode::UNPROCESSABLE_ENTITY {
460            warn!("Reached search limit for '{}' at page {}", word, page);
461            return Ok(false);
462        }
463
464        // Handle rate limiting
465        GitHubSearcher::handle_rate_limit(response.headers(), pb, rate_limit_pb).await?;
466
467        // Check for other errors
468        if !response.status().is_success() && response.status() != StatusCode::FORBIDDEN {
469            return Err(
470                format!("API error: {} on word '{}' page {}", response.status(), word, page).into()
471            );
472        }
473
474        // Parse response
475        let json: Value = response.json().await?;
476        let mut filtered_items = Vec::new();
477
478        // Process items
479        if let Some(items) = json["items"].as_array() {
480            if items.is_empty() {
481                return Ok(false);
482            }
483
484            for item in items {
485                let name = item
486                    .get("name")
487                    .and_then(|v| v.as_str())
488                    .unwrap_or("");
489                let html_url = item
490                    .get("html_url")
491                    .and_then(|v| v.as_str())
492                    .unwrap_or("");
493                // Extract SHA hash
494                let sha = item
495                    .get("sha")
496                    .and_then(|v| v.as_str())
497                    .unwrap_or("");
498
499                let repo_owner = if let Some(repo) = item.get("repository") {
500                    repo.get("owner")
501                } else {
502                    None
503                };
504
505                let owner_login = repo_owner
506                    .and_then(|o| o.get("login"))
507                    .and_then(|v| v.as_str())
508                    .unwrap_or("");
509                let owner_avatar_url = repo_owner
510                    .and_then(|o| o.get("avatar_url"))
511                    .and_then(|v| v.as_str())
512                    .unwrap_or("");
513                let owner_html_url = repo_owner
514                    .and_then(|o| o.get("html_url"))
515                    .and_then(|v| v.as_str())
516                    .unwrap_or("");
517
518                let text_matches = item
519                    .get("text_matches")
520                    .cloned()
521                    .unwrap_or_else(|| json!([]));
522
523                // Build filtered JSON object
524                let new_item =
525                    json!({
526                    "name": name,
527                    "html_url": html_url,
528                    "sha": sha,
529                    "search_term": word,
530                    "repository_owner": {
531                        "login": owner_login,
532                        "avatar_url": owner_avatar_url,
533                        "html_url": owner_html_url,
534                    },
535                    "text_matches": text_matches
536                });
537
538                filtered_items.push(new_item);
539            }
540        } else {
541            warn!("No 'items' array found in response for '{}'", word);
542            return Ok(false);
543        }
544
545        // Add results to the shared collection
546        let mut results_guard = all_results.lock().await;
547        results_guard.extend(filtered_items.clone());
548        drop(results_guard);
549
550        info!("Collected {} results for '{}' page {}", filtered_items.len(), word, page);
551        Ok(true)
552    }
553
554    /// Handles GitHub API rate limiting with visual feedback.
555    ///
556    /// When rate limits are approached or reached, this method:
557    /// 1. Updates the rate limit progress bar
558    /// 2. Calculates and displays remaining capacity
559    /// 3. If limits are exceeded, waits with a countdown until reset
560    ///
561    /// # Arguments
562    ///
563    /// * `headers` - Response headers from GitHub API containing rate limit info
564    /// * `pb` - Progress bar for current search operation
565    /// * `rate_limit_pb` - Progress bar dedicated to rate limit visualization
566    ///
567    /// # Returns
568    ///
569    /// A Result indicating success or failure of rate limit handling
570    ///
571    /// # Errors
572    ///
573    /// Generally doesn't produce errors, but propagates any unexpected issues
574    async fn handle_rate_limit(
575        headers: &reqwest::header::HeaderMap,
576        pb: &ProgressBar,
577        rate_limit_pb: &ProgressBar
578    ) -> Result<(), Box<dyn Error + Send + Sync>> {
579        // Update rate limit indicator
580        if let Some(remaining_header) = headers.get("X-RateLimit-Remaining") {
581            if let Ok(remaining_str) = remaining_header.to_str() {
582                if let Ok(remaining) = remaining_str.parse::<u32>() {
583                    if let Some(limit_header) = headers.get("X-RateLimit-Limit") {
584                        if let Ok(limit_str) = limit_header.to_str() {
585                            if let Ok(limit) = limit_str.parse::<u32>() {
586                                // Calculate percentage of rate limit remaining
587                                let percentage = if limit > 0 {
588                                    (remaining * 100) / limit
589                                } else {
590                                    100
591                                };
592
593                                rate_limit_pb.set_message(
594                                    format!("Rate limit: {}/{}", remaining, limit)
595                                );
596                                rate_limit_pb.set_position(percentage.into());
597
598                                // If we're out of requests, wait for reset
599                                if remaining == 0 {
600                                    if let Some(reset_header) = headers.get("X-RateLimit-Reset") {
601                                        if let Ok(reset_str) = reset_header.to_str() {
602                                            if let Ok(reset_timestamp) = reset_str.parse::<u64>() {
603                                                let now = Utc::now().timestamp() as u64;
604                                                if reset_timestamp > now {
605                                                    let wait_secs = reset_timestamp - now;
606                                                    warn!(
607                                                        "Rate limit reached. Waiting {} seconds...",
608                                                        wait_secs + 1
609                                                    );
610
611                                                    // Save the current message to restore later
612                                                    let original_msg = pb.message();
613                                                    pb.set_message(
614                                                        format!(
615                                                            "Rate limited - waiting {} seconds",
616                                                            wait_secs + 1
617                                                        )
618                                                    );
619
620                                                    // Create a visual countdown
621                                                    let start = Instant::now();
622                                                    let duration = Duration::from_secs(
623                                                        wait_secs + 1
624                                                    );
625                                                    let end = start + duration;
626
627                                                    while Instant::now() < end {
628                                                        let elapsed = start.elapsed();
629                                                        if elapsed < duration {
630                                                            let remaining = duration - elapsed;
631                                                            let secs_remaining =
632                                                                remaining.as_secs();
633                                                            let percentage =
634                                                                ((
635                                                                    duration - remaining
636                                                                ).as_millis() *
637                                                                    100) /
638                                                                duration.as_millis();
639
640                                                            rate_limit_pb.set_position(
641                                                                percentage as u64
642                                                            );
643                                                            rate_limit_pb.set_message(
644                                                                format!("Rate limit cooldown: {}s remaining", secs_remaining)
645                                                            );
646                                                            pb.set_message(
647                                                                format!("Rate limited - waiting {}s", secs_remaining)
648                                                            );
649
650                                                            tokio::time::sleep(
651                                                                Duration::from_millis(500)
652                                                            ).await;
653                                                        } else {
654                                                            break;
655                                                        }
656                                                    }
657
658                                                    // Reset rate limit bar and restore original message
659                                                    rate_limit_pb.set_position(100);
660                                                    rate_limit_pb.set_message(
661                                                        "Rate limit status: Ready"
662                                                    );
663                                                    pb.set_message(original_msg);
664                                                }
665                                            }
666                                        }
667                                    }
668                                }
669                            }
670                        }
671                    }
672                }
673            }
674        }
675        Ok(())
676    }
677}