Skip to main content

ward/github/
client.rs

1use std::sync::Arc;
2
3use anyhow::{Context, Result};
4use reqwest::header::{self, HeaderMap, HeaderValue};
5use tokio::sync::Semaphore;
6
7use crate::config::auth;
8
9/// GitHub API client with rate limiting and concurrency control.
10pub struct Client {
11    pub http: reqwest::Client,
12    pub org: String,
13    pub semaphore: Arc<Semaphore>,
14    pub base_url: String,
15}
16
17impl Client {
18    pub async fn new(org: &str, parallelism: usize) -> Result<Self> {
19        let token = auth::resolve_token()?;
20
21        let mut headers = HeaderMap::new();
22        headers.insert(
23            header::ACCEPT,
24            HeaderValue::from_static("application/vnd.github+json"),
25        );
26        headers.insert(
27            "X-GitHub-Api-Version",
28            HeaderValue::from_static("2022-11-28"),
29        );
30        headers.insert(
31            header::AUTHORIZATION,
32            HeaderValue::from_str(&format!("Bearer {token}"))
33                .context("Invalid token characters")?,
34        );
35        headers.insert(
36            header::USER_AGENT,
37            HeaderValue::from_static("ward-cli/0.1.0"),
38        );
39
40        let http = reqwest::Client::builder()
41            .default_headers(headers)
42            .build()
43            .context("Failed to build HTTP client")?;
44
45        Ok(Self {
46            http,
47            org: org.to_owned(),
48            semaphore: Arc::new(Semaphore::new(parallelism)),
49            base_url: "https://api.github.com".to_owned(),
50        })
51    }
52
53    /// Make a GET request to the GitHub API.
54    pub async fn get(&self, path: &str) -> Result<reqwest::Response> {
55        let _permit = self.semaphore.acquire().await?;
56        let url = format!("{}{}", self.base_url, path);
57
58        tracing::debug!("GET {url}");
59
60        let resp = self
61            .http
62            .get(&url)
63            .send()
64            .await
65            .with_context(|| format!("GET {url} failed"))?;
66
67        check_rate_limit(&resp);
68        Ok(resp)
69    }
70
71    /// Make a PUT request to the GitHub API.
72    pub async fn put(&self, path: &str) -> Result<reqwest::Response> {
73        let _permit = self.semaphore.acquire().await?;
74        let url = format!("{}{}", self.base_url, path);
75
76        tracing::debug!("PUT {url}");
77
78        let resp = self
79            .http
80            .put(&url)
81            .header(header::CONTENT_LENGTH, 0)
82            .send()
83            .await
84            .with_context(|| format!("PUT {url} failed"))?;
85
86        check_rate_limit(&resp);
87        Ok(resp)
88    }
89
90    /// Make a PATCH request with a JSON body.
91    pub async fn patch_json<T: serde::Serialize>(
92        &self,
93        path: &str,
94        body: &T,
95    ) -> Result<reqwest::Response> {
96        let _permit = self.semaphore.acquire().await?;
97        let url = format!("{}{}", self.base_url, path);
98
99        tracing::debug!("PATCH {url}");
100
101        let resp = self
102            .http
103            .patch(&url)
104            .json(body)
105            .send()
106            .await
107            .with_context(|| format!("PATCH {url} failed"))?;
108
109        check_rate_limit(&resp);
110        Ok(resp)
111    }
112
113    /// Make a POST request with a JSON body.
114    pub async fn post_json<T: serde::Serialize>(
115        &self,
116        path: &str,
117        body: &T,
118    ) -> Result<reqwest::Response> {
119        let _permit = self.semaphore.acquire().await?;
120        let url = format!("{}{}", self.base_url, path);
121
122        tracing::debug!("POST {url}");
123
124        let resp = self
125            .http
126            .post(&url)
127            .json(body)
128            .send()
129            .await
130            .with_context(|| format!("POST {url} failed"))?;
131
132        check_rate_limit(&resp);
133        Ok(resp)
134    }
135
136    /// Make a PUT request with a JSON body.
137    pub async fn put_json<T: serde::Serialize>(
138        &self,
139        path: &str,
140        body: &T,
141    ) -> Result<reqwest::Response> {
142        let _permit = self.semaphore.acquire().await?;
143        let url = format!("{}{}", self.base_url, path);
144
145        tracing::debug!("PUT {url}");
146
147        let resp = self
148            .http
149            .put(&url)
150            .json(body)
151            .send()
152            .await
153            .with_context(|| format!("PUT {url} failed"))?;
154
155        check_rate_limit(&resp);
156        Ok(resp)
157    }
158
159    /// Make a DELETE request.
160    pub async fn delete(&self, path: &str) -> Result<reqwest::Response> {
161        let _permit = self.semaphore.acquire().await?;
162        let url = format!("{}{}", self.base_url, path);
163
164        tracing::debug!("DELETE {url}");
165
166        let resp = self
167            .http
168            .delete(&url)
169            .send()
170            .await
171            .with_context(|| format!("DELETE {url} failed"))?;
172
173        check_rate_limit(&resp);
174        Ok(resp)
175    }
176}
177
178fn check_rate_limit(resp: &reqwest::Response) {
179    if let Some(remaining) = resp.headers().get("x-ratelimit-remaining")
180        && let Ok(remaining) = remaining.to_str().unwrap_or("?").parse::<u32>()
181        && remaining < 100
182    {
183        tracing::warn!("GitHub API rate limit low: {remaining} remaining");
184    }
185}