1use std::collections::HashMap;
2use std::process::Command;
3
4use reqwest::blocking::Client;
5use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT};
6use semver::Version;
7use serde::Deserialize;
8use thiserror::Error;
9
10use crate::version::parse_ref;
11#[derive(Error, Debug)]
12pub enum ApiError {
13 #[error("HTTP request failed: {0}")]
14 HttpError(#[from] reqwest::Error),
15
16 #[error("JSON parsing failed: {0}")]
17 JsonError(#[from] serde_json::Error),
18
19 #[error("Invalid header value: {0}")]
20 InvalidHeader(#[from] reqwest::header::InvalidHeaderValue),
21
22 #[error("UTF-8 conversion failed: {0}")]
23 Utf8Error(#[from] std::string::FromUtf8Error),
24
25 #[error("Failed to execute command: {0}")]
26 CommandError(#[from] std::io::Error),
27
28 #[error("No tags found for repository")]
29 NoTagsFound,
30
31 #[error("Invalid domain or repository: {0}")]
32 InvalidInput(String),
33}
34
35#[derive(Default)]
36pub struct ForgeClient {
37 client: Client,
38}
39
40impl ForgeClient {
41 fn get(&self, url: &str, headers: HeaderMap) -> Result<String, ApiError> {
42 let response = self.client.get(url).headers(headers).send()?;
43 let text = response.text()?;
44 Ok(text)
45 }
46
47 fn head_ok(&self, url: &str, headers: HeaderMap) -> bool {
49 match self.client.get(url).headers(headers).send() {
50 Ok(response) => response.status().is_success(),
51 Err(_) => false,
52 }
53 }
54
55 fn base_headers() -> HeaderMap {
56 let mut headers = HeaderMap::new();
57 if let Ok(user_agent) = HeaderValue::from_str("flake-edit") {
58 headers.insert(USER_AGENT, user_agent);
59 }
60 headers
61 }
62
63 fn auth_headers(domain: &str) -> Result<HeaderMap, ApiError> {
65 let mut headers = Self::base_headers();
66 if let Some(token) = get_forge_token(domain) {
67 tracing::debug!("Found token for {}", domain);
68 headers.insert(
69 AUTHORIZATION,
70 HeaderValue::from_str(&format!("Bearer {token}"))?,
71 );
72 }
73 Ok(headers)
74 }
75
76 pub fn detect_forge_type(&self, domain: &str) -> ForgeType {
77 if domain == "github.com" {
78 return ForgeType::GitHub;
79 }
80
81 tracing::debug!("Attempting to detect forge type for domain: {}", domain);
82 let headers = Self::base_headers();
83
84 for scheme in ["https", "http"] {
86 let forgejo_url = format!("{}://{}/api/forgejo/v1/version", scheme, domain);
88 tracing::debug!("Trying Forgejo endpoint: {}", forgejo_url);
89 if let Ok(text) = self.get(&forgejo_url, headers.clone()) {
90 tracing::debug!("Forgejo endpoint response body: {}", text);
91 if let Some(forge_type) = parse_forge_version(&text) {
92 tracing::info!("Detected Forgejo/Gitea at {}", domain);
93 return forge_type;
94 }
95 }
96
97 let gitea_url = format!("{}://{}/api/v1/version", scheme, domain);
99 tracing::debug!("Trying Gitea endpoint: {}", gitea_url);
100 if let Ok(text) = self.get(&gitea_url, headers.clone()) {
101 tracing::debug!("Gitea endpoint response body: {}", text);
102 if let Some(forge_type) = parse_forge_version(&text) {
103 tracing::info!("Detected Forgejo/Gitea at {}", domain);
104 return forge_type;
105 }
106 if serde_json::from_str::<ForgeVersion>(&text).is_ok() {
108 tracing::info!("Detected Gitea at {}", domain);
109 return ForgeType::Gitea;
110 }
111 }
112 }
113
114 tracing::warn!(
115 "Could not detect forge type for {}, will try GitHub API as fallback",
116 domain
117 );
118 ForgeType::Unknown
119 }
120
121 pub fn query_github_tags(&self, repo: &str, owner: &str) -> Result<IntermediaryTags, ApiError> {
122 let headers = Self::auth_headers("github.com")?;
123 let body = self.get(
124 &format!("https://api.github.com/repos/{}/{}/tags", owner, repo),
125 headers,
126 )?;
127
128 tracing::debug!("Body from api: {body}");
129 let tags = serde_json::from_str::<IntermediaryTags>(&body)?;
130 Ok(tags)
131 }
132
133 pub fn branch_exists_github(&self, repo: &str, owner: &str, branch: &str) -> bool {
135 let headers = Self::auth_headers("github.com").unwrap_or_else(|_| Self::base_headers());
136 let url = format!(
137 "https://api.github.com/repos/{}/{}/branches/{}",
138 owner, repo, branch
139 );
140
141 self.head_ok(&url, headers)
142 }
143
144 pub fn branch_exists_gitea(&self, repo: &str, owner: &str, domain: &str, branch: &str) -> bool {
146 let headers = Self::auth_headers(domain).unwrap_or_else(|_| Self::base_headers());
147 for scheme in ["https", "http"] {
148 let url = format!(
149 "{}://{}/api/v1/repos/{}/{}/branches/{}",
150 scheme, domain, owner, repo, branch
151 );
152 if self.head_ok(&url, headers.clone()) {
153 return true;
154 }
155 }
156 false
157 }
158
159 pub fn query_gitea_tags(
160 &self,
161 repo: &str,
162 owner: &str,
163 domain: &str,
164 ) -> Result<IntermediaryTags, ApiError> {
165 let headers = Self::auth_headers(domain)?;
166
167 for scheme in ["https", "http"] {
169 let url = format!(
170 "{}://{}/api/v1/repos/{}/{}/tags",
171 scheme, domain, owner, repo
172 );
173 tracing::debug!("Trying Gitea tags endpoint: {}", url);
174
175 if let Ok(body) = self.get(&url, headers.clone()) {
176 tracing::debug!("Body from Gitea API: {body}");
177 if let Ok(tags) = serde_json::from_str::<IntermediaryTags>(&body) {
178 return Ok(tags);
179 }
180 }
181 }
182
183 Err(ApiError::NoTagsFound)
184 }
185
186 pub fn query_github_branches(
187 &self,
188 repo: &str,
189 owner: &str,
190 ) -> Result<IntermediaryBranches, ApiError> {
191 let headers = Self::auth_headers("github.com")?;
192
193 let mut all_branches = Vec::new();
194 let mut page = 1;
195 const MAX_PAGES: u32 = 20; loop {
198 let url = format!(
199 "https://api.github.com/repos/{}/{}/branches?per_page=100&page={}",
200 owner, repo, page
201 );
202 tracing::debug!("Fetching branches page {}: {}", page, url);
203
204 let body = self.get(&url, headers.clone())?;
205 let page_branches = serde_json::from_str::<IntermediaryBranches>(&body)?;
206
207 let count = page_branches.0.len();
208 tracing::debug!("Got {} branches on page {}", count, page);
209
210 all_branches.extend(page_branches.0);
211
212 if count < 100 || page >= MAX_PAGES {
214 break;
215 }
216
217 page += 1;
218 }
219
220 tracing::debug!("Total branches fetched: {}", all_branches.len());
221 Ok(IntermediaryBranches(all_branches))
222 }
223
224 pub fn query_gitea_branches(
225 &self,
226 repo: &str,
227 owner: &str,
228 domain: &str,
229 ) -> Result<IntermediaryBranches, ApiError> {
230 let headers = Self::auth_headers(domain)?;
231
232 let mut all_branches = Vec::new();
233 let mut page = 1;
234 const MAX_PAGES: u32 = 20;
235
236 for scheme in ["https", "http"] {
238 loop {
239 let url = format!(
240 "{}://{}/api/v1/repos/{}/{}/branches?limit=50&page={}",
241 scheme, domain, owner, repo, page
242 );
243 tracing::debug!("Trying Gitea branches endpoint: {}", url);
244
245 match self.get(&url, headers.clone()) {
246 Ok(body) => {
247 tracing::debug!("Body from Gitea API: {body}");
248 match serde_json::from_str::<IntermediaryBranches>(&body) {
249 Ok(page_branches) => {
250 let count = page_branches.0.len();
251 all_branches.extend(page_branches.0);
252
253 if count < 50 || page >= MAX_PAGES {
254 return Ok(IntermediaryBranches(all_branches));
255 }
256 page += 1;
257 }
258 Err(_) => break, }
260 }
261 Err(_) => break, }
263 }
264
265 if !all_branches.is_empty() {
266 return Ok(IntermediaryBranches(all_branches));
267 }
268 page = 1; }
270
271 Err(ApiError::InvalidInput("Could not fetch branches".into()))
272 }
273}
274
275#[derive(Deserialize, Debug)]
276pub struct IntermediaryTags(Vec<IntermediaryTag>);
277
278#[derive(Deserialize, Debug)]
279pub struct IntermediaryBranches(Vec<IntermediaryBranch>);
280
281#[derive(Deserialize, Debug)]
282pub struct IntermediaryBranch {
283 name: String,
284}
285
286#[derive(Debug, Default)]
287pub struct Branches {
288 pub names: Vec<String>,
289}
290
291#[derive(Debug)]
292pub struct Tags {
293 versions: Vec<TagVersion>,
294}
295
296impl Tags {
297 pub fn get_latest_tag(&mut self) -> Option<String> {
298 self.sort();
299 self.versions.last().map(|tag| tag.original.clone())
300 }
301 pub fn sort(&mut self) {
302 self.versions
303 .sort_by(|a, b| a.version.cmp_precedence(&b.version));
304 }
305}
306
307#[derive(Deserialize, Debug)]
308pub struct IntermediaryTag {
309 name: String,
310}
311
312#[derive(Debug)]
313struct TagVersion {
314 version: Version,
315 original: String,
316}
317
318#[derive(Deserialize, Debug)]
319struct ForgeVersion {
320 version: String,
321}
322
323#[derive(Debug, PartialEq)]
324pub enum ForgeType {
325 GitHub,
326 Gitea, Unknown,
328}
329
330fn parse_forge_version(json: &str) -> Option<ForgeType> {
331 serde_json::from_str::<ForgeVersion>(json)
332 .ok()
333 .and_then(|v| {
334 if v.version.contains("+forgejo") || v.version.contains("+gitea") {
335 Some(ForgeType::Gitea)
336 } else {
337 None
338 }
339 })
340}
341
342#[doc(hidden)]
344pub mod test_helpers {
345 use super::*;
346
347 pub fn parse_forge_version_test(json: &str) -> Option<ForgeType> {
348 parse_forge_version(json)
349 }
350}
351
352pub fn get_tags(repo: &str, owner: &str, domain: Option<&str>) -> Result<Tags, ApiError> {
353 let domain = domain.unwrap_or("github.com");
354 let client = ForgeClient::default();
355 let forge_type = client.detect_forge_type(domain);
356
357 tracing::debug!("Detected forge type for {}: {:?}", domain, forge_type);
358
359 let tags = match forge_type {
360 ForgeType::GitHub => client.query_github_tags(repo, owner)?,
361 ForgeType::Gitea => client.query_gitea_tags(repo, owner, domain)?,
362 ForgeType::Unknown => {
363 tracing::warn!("Unknown forge type for {}, trying Gitea API", domain);
364 client.query_gitea_tags(repo, owner, domain)?
365 }
366 };
367
368 Ok(tags.into())
369}
370
371pub fn get_branches(repo: &str, owner: &str, domain: Option<&str>) -> Result<Branches, ApiError> {
372 let domain = domain.unwrap_or("github.com");
373 let client = ForgeClient::default();
374 let forge_type = client.detect_forge_type(domain);
375
376 tracing::debug!(
377 "Fetching branches for {}/{} on {} ({:?})",
378 owner,
379 repo,
380 domain,
381 forge_type
382 );
383
384 let branches = match forge_type {
385 ForgeType::GitHub => client.query_github_branches(repo, owner)?,
386 ForgeType::Gitea => client.query_gitea_branches(repo, owner, domain)?,
387 ForgeType::Unknown => {
388 tracing::warn!("Unknown forge type for {}, trying Gitea API", domain);
389 client.query_gitea_branches(repo, owner, domain)?
390 }
391 };
392
393 Ok(branches.into())
394}
395
396pub fn branch_exists(repo: &str, owner: &str, branch: &str, domain: Option<&str>) -> bool {
399 let domain = domain.unwrap_or("github.com");
400 let client = ForgeClient::default();
401 let forge_type = client.detect_forge_type(domain);
402
403 match forge_type {
404 ForgeType::GitHub => client.branch_exists_github(repo, owner, branch),
405 ForgeType::Gitea => client.branch_exists_gitea(repo, owner, domain, branch),
406 ForgeType::Unknown => client.branch_exists_gitea(repo, owner, domain, branch),
407 }
408}
409
410pub fn filter_existing_branches(
413 repo: &str,
414 owner: &str,
415 candidates: &[String],
416 domain: Option<&str>,
417) -> Vec<String> {
418 candidates
419 .iter()
420 .filter(|branch| branch_exists(repo, owner, branch, domain))
421 .cloned()
422 .collect()
423}
424
425#[derive(Deserialize, Debug, Clone)]
426struct NixConfig {
427 #[serde(rename = "access-tokens")]
428 access_tokens: Option<AccessTokens>,
429}
430
431impl NixConfig {
432 fn forge_token(&self, domain: &str) -> Option<String> {
433 self.access_tokens.as_ref()?.value.get(domain).cloned()
434 }
435}
436
437#[derive(Deserialize, Debug, Clone)]
438struct AccessTokens {
439 value: HashMap<String, String>,
440}
441
442fn get_forge_token(domain: &str) -> Option<String> {
443 if let Ok(output) = Command::new("nix")
445 .arg("config")
446 .arg("show")
447 .arg("--json")
448 .output()
449 && let Ok(stdout) = String::from_utf8(output.stdout)
450 && let Ok(config) = serde_json::from_str::<NixConfig>(&stdout)
451 && let Some(token) = config.forge_token(domain)
452 {
453 return Some(token);
454 }
455
456 if let Ok(token) = std::env::var("GITEA_TOKEN") {
458 return Some(token);
459 }
460 if let Ok(token) = std::env::var("FORGEJO_TOKEN") {
461 return Some(token);
462 }
463 if domain == "github.com"
464 && let Ok(token) = std::env::var("GITHUB_TOKEN")
465 {
466 return Some(token);
467 }
468
469 None
470}
471
472impl From<IntermediaryTags> for Tags {
473 fn from(value: IntermediaryTags) -> Self {
474 let mut versions = vec![];
475 for itag in value.0 {
476 let parsed = parse_ref(&itag.name, false);
477 let normalized = parsed.normalized_for_semver;
478 match Version::parse(&normalized) {
479 Ok(semver) => {
480 versions.push(TagVersion {
481 version: semver,
482 original: parsed.original_ref,
483 });
484 }
485 Err(e) => {
486 tracing::error!("Could not parse version {:?}", e);
487 }
488 }
489 }
490 Tags { versions }
491 }
492}
493
494impl From<IntermediaryBranches> for Branches {
495 fn from(value: IntermediaryBranches) -> Self {
496 Branches {
497 names: value.0.into_iter().map(|b| b.name).collect(),
498 }
499 }
500}