1use reqwest::{Client, Request, StatusCode};
2use serde::{Deserialize, Deserializer};
3use soft_assert::*;
4use url::Url;
5use zeroize::Zeroize;
6
7pub struct Forgejo {
8 url: Url,
9 client: Client,
10}
11
12mod generated;
13
14#[derive(thiserror::Error, Debug)]
15pub enum ForgejoError {
16 #[error("url must have a host")]
17 HostRequired,
18 #[error("scheme must be http or https")]
19 HttpRequired,
20 #[error(transparent)]
21 ReqwestError(#[from] reqwest::Error),
22 #[error("API key should be ascii")]
23 KeyNotAscii,
24 #[error("the response from forgejo was not properly structured")]
25 BadStructure(#[from] StructureError),
26 #[error("unexpected status code {} {}", .0.as_u16(), .0.canonical_reason().unwrap_or(""))]
27 UnexpectedStatusCode(StatusCode),
28 #[error("{} {}{}", .0.as_u16(), .0.canonical_reason().unwrap_or(""), .1.as_ref().map(|s| format!(": {s}")).unwrap_or_default())]
29 ApiError(StatusCode, Option<String>),
30 #[error("the provided authorization was too long to accept")]
31 AuthTooLong,
32}
33
34#[derive(thiserror::Error, Debug)]
35pub enum StructureError {
36 #[error("{contents}")]
37 Serde {
38 e: serde_json::Error,
39 contents: String,
40 },
41 #[error("failed to find header `{0}`")]
42 HeaderMissing(&'static str),
43 #[error("header was not ascii")]
44 HeaderNotAscii,
45 #[error("failed to parse header")]
46 HeaderParseFailed,
47}
48
49pub enum Auth<'a> {
51 Token(&'a str),
60 OAuth2(&'a str),
66 Password {
69 username: &'a str,
70 password: &'a str,
71 mfa: Option<&'a str>,
72 },
73 None,
75}
76
77impl Forgejo {
78 pub fn new(auth: Auth, url: Url) -> Result<Self, ForgejoError> {
79 Self::with_user_agent(auth, url, "forgejo-api-rs")
80 }
81
82 pub fn with_user_agent(auth: Auth, url: Url, user_agent: &str) -> Result<Self, ForgejoError> {
83 soft_assert!(
84 matches!(url.scheme(), "http" | "https"),
85 Err(ForgejoError::HttpRequired)
86 );
87
88 let mut headers = reqwest::header::HeaderMap::new();
89 match auth {
90 Auth::Token(token) => {
91 let mut header: reqwest::header::HeaderValue = format!("token {token}")
92 .try_into()
93 .map_err(|_| ForgejoError::KeyNotAscii)?;
94 header.set_sensitive(true);
95 headers.insert("Authorization", header);
96 }
97 Auth::Password {
98 username,
99 password,
100 mfa,
101 } => {
102 let unencoded_len = username.len() + password.len() + 1;
103 let unpadded_len = unencoded_len
104 .checked_mul(4)
105 .ok_or(ForgejoError::AuthTooLong)?
106 .div_ceil(3);
107 let len = unpadded_len.div_ceil(4) * 4;
109 let mut bytes = vec![0; len];
110
111 let mut encoder = base64ct::Encoder::<base64ct::Base64>::new(&mut bytes).unwrap();
113
114 encoder.encode(username.as_bytes()).unwrap();
116 encoder.encode(b":").unwrap();
117 encoder.encode(password.as_bytes()).unwrap();
118
119 let b64 = encoder.finish().unwrap();
120
121 let mut header: reqwest::header::HeaderValue =
122 format!("Basic {b64}").try_into().unwrap(); header.set_sensitive(true);
124 headers.insert("Authorization", header);
125
126 bytes.zeroize();
127
128 if let Some(mfa) = mfa {
129 let mut key_header: reqwest::header::HeaderValue =
130 mfa.try_into().map_err(|_| ForgejoError::KeyNotAscii)?;
131 key_header.set_sensitive(true);
132 headers.insert("X-FORGEJO-OTP", key_header);
133 }
134 }
135 Auth::OAuth2(token) => {
136 let mut header: reqwest::header::HeaderValue = format!("Bearer {token}")
137 .try_into()
138 .map_err(|_| ForgejoError::KeyNotAscii)?;
139 header.set_sensitive(true);
140 headers.insert("Authorization", header);
141 }
142 Auth::None => (),
143 }
144 let client = Client::builder()
145 .user_agent(user_agent)
146 .default_headers(headers)
147 .build()?;
148 Ok(Self { url, client })
149 }
150
151 pub async fn download_release_attachment(
152 &self,
153 owner: &str,
154 repo: &str,
155 release: u64,
156 attach: u64,
157 ) -> Result<bytes::Bytes, ForgejoError> {
158 let release = self
159 .repo_get_release_attachment(owner, repo, release, attach)
160 .await?;
161 let mut url = self.url.clone();
162 url.path_segments_mut()
163 .unwrap()
164 .pop_if_empty()
165 .extend(["attachments", &release.uuid.unwrap().to_string()]);
166 let request = self.client.get(url).build()?;
167 Ok(self.execute(request).await?.bytes().await?)
168 }
169
170 pub async fn oauth_get_access_token(
174 &self,
175 body: structs::OAuthTokenRequest<'_>,
176 ) -> Result<structs::OAuthToken, ForgejoError> {
177 let url = self.url.join("login/oauth/access_token").unwrap();
178 let request = self.client.post(url).json(&body).build()?;
179 let response = self.execute(request).await?;
180 match response.status().as_u16() {
181 200 => Ok(response.json().await?),
182 _ => Err(ForgejoError::UnexpectedStatusCode(response.status())),
183 }
184 }
185
186 fn get(&self, path: &str) -> reqwest::RequestBuilder {
187 let url = self.url.join("api/v1/").unwrap().join(path).unwrap();
188 self.client.get(url)
189 }
190
191 fn put(&self, path: &str) -> reqwest::RequestBuilder {
192 let url = self.url.join("api/v1/").unwrap().join(path).unwrap();
193 self.client.put(url)
194 }
195
196 fn post(&self, path: &str) -> reqwest::RequestBuilder {
197 let url = self.url.join("api/v1/").unwrap().join(path).unwrap();
198 self.client.post(url)
199 }
200
201 fn delete(&self, path: &str) -> reqwest::RequestBuilder {
202 let url = self.url.join("api/v1/").unwrap().join(path).unwrap();
203 self.client.delete(url)
204 }
205
206 fn patch(&self, path: &str) -> reqwest::RequestBuilder {
207 let url = self.url.join("api/v1/").unwrap().join(path).unwrap();
208 self.client.patch(url)
209 }
210
211 async fn execute(&self, request: Request) -> Result<reqwest::Response, ForgejoError> {
212 let response = self.client.execute(request).await?;
213 match response.status() {
214 status if status.is_success() => Ok(response),
215 status if status.is_client_error() => {
216 Err(ForgejoError::ApiError(status, maybe_err(response).await))
217 }
218 status => Err(ForgejoError::UnexpectedStatusCode(status)),
219 }
220 }
221}
222
223async fn maybe_err(res: reqwest::Response) -> Option<String> {
224 res.json::<ErrorMessage>().await.ok().map(|e| e.message)
225}
226
227#[derive(serde::Deserialize)]
228struct ErrorMessage {
229 message: String,
230 }
233
234pub mod structs {
235 pub use crate::generated::structs::*;
236
237 #[derive(serde::Serialize)]
241 #[serde(tag = "grant_type")]
242 pub enum OAuthTokenRequest<'a> {
243 #[serde(rename = "authorization_code")]
248 Confidential {
249 client_id: &'a str,
250 client_secret: &'a str,
251 code: &'a str,
252 redirect_uri: url::Url,
253 },
254 #[serde(rename = "authorization_code")]
259 Public {
260 client_id: &'a str,
261 code_verifier: &'a str,
262 code: &'a str,
263 redirect_uri: url::Url,
264 },
265 #[serde(rename = "refresh_token")]
267 Refresh {
268 refresh_token: &'a str,
269 client_id: &'a str,
270 client_secret: &'a str,
271 },
272 }
273
274 #[derive(serde::Deserialize)]
275 pub struct OAuthToken {
276 pub access_token: String,
277 pub refresh_token: String,
278 pub token_type: String,
279 pub expires_in: u32,
281 }
282}
283
284fn none_if_blank_url<'de, D: serde::Deserializer<'de>>(
287 deserializer: D,
288) -> Result<Option<Url>, D::Error> {
289 use serde::de::{Error, Unexpected, Visitor};
290 use std::fmt;
291
292 struct EmptyUrlVisitor;
293
294 impl<'de> Visitor<'de> for EmptyUrlVisitor {
295 type Value = Option<Url>;
296
297 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
298 formatter.write_str("option")
299 }
300
301 #[inline]
302 fn visit_unit<E>(self) -> Result<Self::Value, E>
303 where
304 E: Error,
305 {
306 Ok(None)
307 }
308
309 #[inline]
310 fn visit_none<E>(self) -> Result<Self::Value, E>
311 where
312 E: Error,
313 {
314 Ok(None)
315 }
316
317 #[inline]
318 fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
319 where
320 E: Error,
321 {
322 if s.is_empty() {
323 return Ok(None);
324 }
325 Url::parse(s)
326 .map_err(|err| {
327 let err_s = format!("{}", err);
328 Error::invalid_value(Unexpected::Str(s), &err_s.as_str())
329 })
330 .map(Some)
331 }
332 }
333
334 deserializer.deserialize_str(EmptyUrlVisitor)
335}
336
337#[allow(dead_code)] fn deserialize_ssh_url<'de, D, DE>(deserializer: D) -> Result<Url, DE>
339where
340 D: Deserializer<'de>,
341 DE: serde::de::Error,
342{
343 let raw_url: String = String::deserialize(deserializer).map_err(DE::custom)?;
344 parse_ssh_url(&raw_url).map_err(DE::custom)
345}
346
347fn deserialize_optional_ssh_url<'de, D, DE>(deserializer: D) -> Result<Option<Url>, DE>
348where
349 D: Deserializer<'de>,
350 DE: serde::de::Error,
351{
352 let raw_url: Option<String> = Option::deserialize(deserializer).map_err(DE::custom)?;
353 raw_url
354 .as_ref()
355 .map(parse_ssh_url)
356 .map(|res| res.map_err(DE::custom))
357 .transpose()
358 .or(Ok(None))
359}
360
361fn requested_reviewers_ignore_null<'de, D, DE>(
362 deserializer: D,
363) -> Result<Option<Vec<structs::User>>, DE>
364where
365 D: Deserializer<'de>,
366 DE: serde::de::Error,
367{
368 let list: Option<Vec<Option<structs::User>>> =
369 Option::deserialize(deserializer).map_err(DE::custom)?;
370 Ok(list.map(|list| list.into_iter().flatten().collect::<Vec<_>>()))
371}
372
373fn parse_ssh_url(raw_url: &String) -> Result<Url, url::ParseError> {
374 Url::parse(raw_url).or_else(|_| {
377 let url = format!("ssh://{url}", url = raw_url.replace(":", "/"));
380 Url::parse(url.as_str())
381 })
382}
383
384#[test]
385fn ssh_url_deserialization() {
386 #[derive(serde::Deserialize)]
387 struct SshUrl {
388 #[serde(deserialize_with = "deserialize_ssh_url")]
389 url: url::Url,
390 }
391 let full_url = r#"{ "url": "ssh://git@codeberg.org/Cyborus/forgejo-api" }"#;
392 let ssh_url = r#"{ "url": "git@codeberg.org:Cyborus/forgejo-api" }"#;
393
394 let full_url_de =
395 serde_json::from_str::<SshUrl>(full_url).expect("failed to deserialize full url");
396 let ssh_url_de =
397 serde_json::from_str::<SshUrl>(ssh_url).expect("failed to deserialize ssh url");
398
399 let expected = "ssh://git@codeberg.org/Cyborus/forgejo-api";
400 assert_eq!(full_url_de.url.as_str(), expected);
401 assert_eq!(ssh_url_de.url.as_str(), expected);
402
403 #[derive(serde::Deserialize)]
404 struct OptSshUrl {
405 #[serde(deserialize_with = "deserialize_optional_ssh_url")]
406 url: Option<url::Url>,
407 }
408 let null_url = r#"{ "url": null }"#;
409
410 let full_url_de = serde_json::from_str::<OptSshUrl>(full_url)
411 .expect("failed to deserialize optional full url");
412 let ssh_url_de =
413 serde_json::from_str::<OptSshUrl>(ssh_url).expect("failed to deserialize optional ssh url");
414 let null_url_de =
415 serde_json::from_str::<OptSshUrl>(null_url).expect("failed to deserialize null url");
416
417 let expected = Some("ssh://git@codeberg.org/Cyborus/forgejo-api");
418 assert_eq!(full_url_de.url.as_ref().map(|u| u.as_ref()), expected);
419 assert_eq!(ssh_url_de.url.as_ref().map(|u| u.as_ref()), expected);
420 assert!(null_url_de.url.is_none());
421}
422
423impl From<structs::DefaultMergeStyle> for structs::MergePullRequestOptionDo {
424 fn from(value: structs::DefaultMergeStyle) -> Self {
425 match value {
426 structs::DefaultMergeStyle::Merge => structs::MergePullRequestOptionDo::Merge,
427 structs::DefaultMergeStyle::Rebase => structs::MergePullRequestOptionDo::Rebase,
428 structs::DefaultMergeStyle::RebaseMerge => {
429 structs::MergePullRequestOptionDo::RebaseMerge
430 }
431 structs::DefaultMergeStyle::Squash => structs::MergePullRequestOptionDo::Squash,
432 structs::DefaultMergeStyle::FastForwardOnly => {
433 structs::MergePullRequestOptionDo::FastForwardOnly
434 }
435 }
436 }
437}