1use std::collections::HashMap;
2use std::process::Command;
3
4use reqwest::blocking::Client;
5use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT};
6use semver::Version;
7use serde::Deserialize;
8use thiserror::Error;
9
10use crate::version::parse_ref;
11#[derive(Error, Debug)]
12pub enum ApiError {
13 #[error("HTTP request failed: {0}")]
14 HttpError(#[from] reqwest::Error),
15
16 #[error("JSON parsing failed: {0}")]
17 JsonError(#[from] serde_json::Error),
18
19 #[error("Invalid header value: {0}")]
20 InvalidHeader(#[from] reqwest::header::InvalidHeaderValue),
21
22 #[error("UTF-8 conversion failed: {0}")]
23 Utf8Error(#[from] std::string::FromUtf8Error),
24
25 #[error("Failed to execute command: {0}")]
26 CommandError(#[from] std::io::Error),
27
28 #[error("No tags found for repository")]
29 NoTagsFound,
30
31 #[error("Invalid domain or repository: {0}")]
32 InvalidInput(String),
33}
34
35#[derive(Default)]
36pub struct ForgeClient {
37 client: Client,
38}
39
40impl ForgeClient {
41 fn get(&self, url: &str, headers: HeaderMap) -> Result<String, ApiError> {
42 let response = self.client.get(url).headers(headers).send()?;
43 let text = response.text()?;
44 Ok(text)
45 }
46
47 fn head_ok(&self, url: &str, headers: HeaderMap) -> bool {
49 match self.client.get(url).headers(headers).send() {
50 Ok(response) => response.status().is_success(),
51 Err(_) => false,
52 }
53 }
54
55 fn make_headers() -> HeaderMap {
56 let mut headers = HeaderMap::new();
57 if let Ok(user_agent) = HeaderValue::from_str("flake-edit") {
58 headers.insert(USER_AGENT, user_agent);
59 }
60 headers
61 }
62
63 pub fn detect_forge_type(&self, domain: &str) -> ForgeType {
64 if domain == "github.com" {
65 return ForgeType::GitHub;
66 }
67
68 tracing::debug!("Attempting to detect forge type for domain: {}", domain);
69 let headers = Self::make_headers();
70
71 for scheme in ["https", "http"] {
73 let forgejo_url = format!("{}://{}/api/forgejo/v1/version", scheme, domain);
75 tracing::debug!("Trying Forgejo endpoint: {}", forgejo_url);
76 if let Ok(text) = self.get(&forgejo_url, headers.clone()) {
77 tracing::debug!("Forgejo endpoint response body: {}", text);
78 if let Some(forge_type) = parse_forge_version(&text) {
79 tracing::info!("Detected Forgejo/Gitea at {}", domain);
80 return forge_type;
81 }
82 }
83
84 let gitea_url = format!("{}://{}/api/v1/version", scheme, domain);
86 tracing::debug!("Trying Gitea endpoint: {}", gitea_url);
87 if let Ok(text) = self.get(&gitea_url, headers.clone()) {
88 tracing::debug!("Gitea endpoint response body: {}", text);
89 if let Some(forge_type) = parse_forge_version(&text) {
90 tracing::info!("Detected Forgejo/Gitea at {}", domain);
91 return forge_type;
92 }
93 if serde_json::from_str::<ForgeVersion>(&text).is_ok() {
95 tracing::info!("Detected Gitea at {}", domain);
96 return ForgeType::Gitea;
97 }
98 }
99 }
100
101 tracing::warn!(
102 "Could not detect forge type for {}, will try GitHub API as fallback",
103 domain
104 );
105 ForgeType::Unknown
106 }
107
108 pub fn query_github_tags(&self, repo: &str, owner: &str) -> Result<IntermediaryTags, ApiError> {
109 let mut headers = Self::make_headers();
110 if let Some(token) = get_forge_token("github.com") {
111 tracing::debug!("Found github token.");
112 headers.insert(
113 AUTHORIZATION,
114 HeaderValue::from_str(&format!("Bearer {token}"))?,
115 );
116 }
117
118 let body = self.get(
119 &format!("https://api.github.com/repos/{}/{}/tags", owner, repo),
120 headers,
121 )?;
122
123 tracing::debug!("Body from api: {body}");
124 let tags = serde_json::from_str::<IntermediaryTags>(&body)?;
125 Ok(tags)
126 }
127
128 pub fn branch_exists_github(&self, repo: &str, owner: &str, branch: &str) -> bool {
130 let mut headers = Self::make_headers();
131 if let Some(token) = get_forge_token("github.com") {
132 headers.insert(
133 AUTHORIZATION,
134 HeaderValue::from_str(&format!("Bearer {token}")).unwrap(),
135 );
136 }
137
138 let url = format!(
139 "https://api.github.com/repos/{}/{}/branches/{}",
140 owner, repo, branch
141 );
142
143 self.head_ok(&url, headers)
144 }
145
146 pub fn branch_exists_gitea(&self, repo: &str, owner: &str, domain: &str, branch: &str) -> bool {
148 let mut headers = Self::make_headers();
149 if let Some(token) = get_forge_token(domain) {
150 headers.insert(
151 AUTHORIZATION,
152 HeaderValue::from_str(&format!("Bearer {token}")).unwrap(),
153 );
154 }
155
156 for scheme in ["https", "http"] {
157 let url = format!(
158 "{}://{}/api/v1/repos/{}/{}/branches/{}",
159 scheme, domain, owner, repo, branch
160 );
161 if self.head_ok(&url, headers.clone()) {
162 return true;
163 }
164 }
165 false
166 }
167
168 pub fn query_gitea_tags(
169 &self,
170 repo: &str,
171 owner: &str,
172 domain: &str,
173 ) -> Result<IntermediaryTags, ApiError> {
174 let mut headers = Self::make_headers();
175
176 if let Some(token) = get_forge_token(domain) {
177 tracing::debug!("Found token for {}", domain);
178 headers.insert(
179 AUTHORIZATION,
180 HeaderValue::from_str(&format!("Bearer {token}"))?,
181 );
182 }
183
184 for scheme in ["https", "http"] {
186 let url = format!(
187 "{}://{}/api/v1/repos/{}/{}/tags",
188 scheme, domain, owner, repo
189 );
190 tracing::debug!("Trying Gitea tags endpoint: {}", url);
191
192 if let Ok(body) = self.get(&url, headers.clone()) {
193 tracing::debug!("Body from Gitea API: {body}");
194 if let Ok(tags) = serde_json::from_str::<IntermediaryTags>(&body) {
195 return Ok(tags);
196 }
197 }
198 }
199
200 Err(ApiError::NoTagsFound)
201 }
202
203 pub fn query_github_branches(
204 &self,
205 repo: &str,
206 owner: &str,
207 ) -> Result<IntermediaryBranches, ApiError> {
208 let mut headers = Self::make_headers();
209 if let Some(token) = get_forge_token("github.com") {
210 tracing::debug!("Found github token.");
211 headers.insert(
212 AUTHORIZATION,
213 HeaderValue::from_str(&format!("Bearer {token}"))?,
214 );
215 }
216
217 let mut all_branches = Vec::new();
218 let mut page = 1;
219 const MAX_PAGES: u32 = 20; loop {
222 let url = format!(
223 "https://api.github.com/repos/{}/{}/branches?per_page=100&page={}",
224 owner, repo, page
225 );
226 tracing::debug!("Fetching branches page {}: {}", page, url);
227
228 let body = self.get(&url, headers.clone())?;
229 let page_branches = serde_json::from_str::<IntermediaryBranches>(&body)?;
230
231 let count = page_branches.0.len();
232 tracing::debug!("Got {} branches on page {}", count, page);
233
234 all_branches.extend(page_branches.0);
235
236 if count < 100 || page >= MAX_PAGES {
238 break;
239 }
240
241 page += 1;
242 }
243
244 tracing::debug!("Total branches fetched: {}", all_branches.len());
245 Ok(IntermediaryBranches(all_branches))
246 }
247
248 pub fn query_gitea_branches(
249 &self,
250 repo: &str,
251 owner: &str,
252 domain: &str,
253 ) -> Result<IntermediaryBranches, ApiError> {
254 let mut headers = Self::make_headers();
255
256 if let Some(token) = get_forge_token(domain) {
257 tracing::debug!("Found token for {}", domain);
258 headers.insert(
259 AUTHORIZATION,
260 HeaderValue::from_str(&format!("Bearer {token}"))?,
261 );
262 }
263
264 let mut all_branches = Vec::new();
265 let mut page = 1;
266 const MAX_PAGES: u32 = 20;
267
268 for scheme in ["https", "http"] {
270 loop {
271 let url = format!(
272 "{}://{}/api/v1/repos/{}/{}/branches?limit=50&page={}",
273 scheme, domain, owner, repo, page
274 );
275 tracing::debug!("Trying Gitea branches endpoint: {}", url);
276
277 match self.get(&url, headers.clone()) {
278 Ok(body) => {
279 tracing::debug!("Body from Gitea API: {body}");
280 match serde_json::from_str::<IntermediaryBranches>(&body) {
281 Ok(page_branches) => {
282 let count = page_branches.0.len();
283 all_branches.extend(page_branches.0);
284
285 if count < 50 || page >= MAX_PAGES {
286 return Ok(IntermediaryBranches(all_branches));
287 }
288 page += 1;
289 }
290 Err(_) => break, }
292 }
293 Err(_) => break, }
295 }
296
297 if !all_branches.is_empty() {
298 return Ok(IntermediaryBranches(all_branches));
299 }
300 page = 1; }
302
303 Err(ApiError::InvalidInput("Could not fetch branches".into()))
304 }
305}
306
307#[derive(Deserialize, Debug)]
308pub struct IntermediaryTags(Vec<IntermediaryTag>);
309
310#[derive(Deserialize, Debug)]
311pub struct IntermediaryBranches(Vec<IntermediaryBranch>);
312
313#[derive(Deserialize, Debug)]
314pub struct IntermediaryBranch {
315 name: String,
316}
317
318#[derive(Debug, Default)]
319pub struct Branches {
320 pub names: Vec<String>,
321}
322
323#[derive(Debug)]
324pub struct Tags {
325 versions: Vec<TagVersion>,
326}
327
328impl Tags {
329 pub fn get_latest_tag(&mut self) -> Option<String> {
330 self.sort();
331 self.versions.last().map(|tag| tag.original.clone())
332 }
333 pub fn sort(&mut self) {
334 self.versions
335 .sort_by(|a, b| a.version.cmp_precedence(&b.version));
336 }
337}
338
339#[derive(Deserialize, Debug)]
340pub struct IntermediaryTag {
341 name: String,
342}
343
344#[derive(Debug)]
345struct TagVersion {
346 version: Version,
347 original: String,
348}
349
350#[derive(Deserialize, Debug)]
351struct ForgeVersion {
352 version: String,
353}
354
355#[derive(Debug, PartialEq)]
356pub enum ForgeType {
357 GitHub,
358 Gitea, Unknown,
360}
361
362fn parse_forge_version(json: &str) -> Option<ForgeType> {
363 serde_json::from_str::<ForgeVersion>(json)
364 .ok()
365 .and_then(|v| {
366 if v.version.contains("+forgejo") || v.version.contains("+gitea") {
367 Some(ForgeType::Gitea)
368 } else {
369 None
370 }
371 })
372}
373
374#[doc(hidden)]
376pub mod test_helpers {
377 use super::*;
378
379 pub fn parse_forge_version_test(json: &str) -> Option<ForgeType> {
380 parse_forge_version(json)
381 }
382}
383
384pub fn get_tags(repo: &str, owner: &str, domain: Option<&str>) -> Result<Tags, ApiError> {
385 let domain = domain.unwrap_or("github.com");
386 let client = ForgeClient::default();
387 let forge_type = client.detect_forge_type(domain);
388
389 tracing::debug!("Detected forge type for {}: {:?}", domain, forge_type);
390
391 let tags = match forge_type {
392 ForgeType::GitHub => client.query_github_tags(repo, owner)?,
393 ForgeType::Gitea => client.query_gitea_tags(repo, owner, domain)?,
394 ForgeType::Unknown => {
395 tracing::warn!("Unknown forge type for {}, trying Gitea API", domain);
396 client.query_gitea_tags(repo, owner, domain)?
397 }
398 };
399
400 Ok(tags.into())
401}
402
403pub fn get_branches(repo: &str, owner: &str, domain: Option<&str>) -> Result<Branches, ApiError> {
404 let domain = domain.unwrap_or("github.com");
405 let client = ForgeClient::default();
406 let forge_type = client.detect_forge_type(domain);
407
408 tracing::debug!(
409 "Fetching branches for {}/{} on {} ({:?})",
410 owner,
411 repo,
412 domain,
413 forge_type
414 );
415
416 let branches = match forge_type {
417 ForgeType::GitHub => client.query_github_branches(repo, owner)?,
418 ForgeType::Gitea => client.query_gitea_branches(repo, owner, domain)?,
419 ForgeType::Unknown => {
420 tracing::warn!("Unknown forge type for {}, trying Gitea API", domain);
421 client.query_gitea_branches(repo, owner, domain)?
422 }
423 };
424
425 Ok(branches.into())
426}
427
428pub fn branch_exists(repo: &str, owner: &str, branch: &str, domain: Option<&str>) -> bool {
431 let domain = domain.unwrap_or("github.com");
432 let client = ForgeClient::default();
433 let forge_type = client.detect_forge_type(domain);
434
435 match forge_type {
436 ForgeType::GitHub => client.branch_exists_github(repo, owner, branch),
437 ForgeType::Gitea => client.branch_exists_gitea(repo, owner, domain, branch),
438 ForgeType::Unknown => client.branch_exists_gitea(repo, owner, domain, branch),
439 }
440}
441
442pub fn filter_existing_branches(
445 repo: &str,
446 owner: &str,
447 candidates: &[String],
448 domain: Option<&str>,
449) -> Vec<String> {
450 candidates
451 .iter()
452 .filter(|branch| branch_exists(repo, owner, branch, domain))
453 .cloned()
454 .collect()
455}
456
457#[derive(Deserialize, Debug, Clone)]
458struct NixConfig {
459 #[serde(rename = "access-tokens")]
460 access_tokens: Option<AccessTokens>,
461}
462
463impl NixConfig {
464 fn forge_token(&self, domain: &str) -> Option<String> {
465 self.access_tokens.as_ref()?.value.get(domain).cloned()
466 }
467}
468
469#[derive(Deserialize, Debug, Clone)]
470struct AccessTokens {
471 value: HashMap<String, String>,
472}
473
474fn get_forge_token(domain: &str) -> Option<String> {
475 if let Ok(output) = Command::new("nix")
477 .arg("config")
478 .arg("show")
479 .arg("--json")
480 .output()
481 {
482 if let Ok(stdout) = String::from_utf8(output.stdout) {
483 if let Ok(config) = serde_json::from_str::<NixConfig>(&stdout) {
484 if let Some(token) = config.forge_token(domain) {
485 return Some(token);
486 }
487 }
488 }
489 }
490
491 if let Ok(token) = std::env::var("GITEA_TOKEN") {
493 return Some(token);
494 }
495 if let Ok(token) = std::env::var("FORGEJO_TOKEN") {
496 return Some(token);
497 }
498 if domain == "github.com" {
499 if let Ok(token) = std::env::var("GITHUB_TOKEN") {
500 return Some(token);
501 }
502 }
503
504 None
505}
506
507impl From<IntermediaryTags> for Tags {
508 fn from(value: IntermediaryTags) -> Self {
509 let mut versions = vec![];
510 for itag in value.0 {
511 let parsed = parse_ref(&itag.name, false);
512 let normalized = parsed.normalized_for_semver;
513 match Version::parse(&normalized) {
514 Ok(semver) => {
515 versions.push(TagVersion {
516 version: semver,
517 original: parsed.original_ref,
518 });
519 }
520 Err(e) => {
521 tracing::error!("Could not parse version {:?}", e);
522 }
523 }
524 }
525 Tags { versions }
526 }
527}
528
529impl From<IntermediaryBranches> for Branches {
530 fn from(value: IntermediaryBranches) -> Self {
531 Branches {
532 names: value.0.into_iter().map(|b| b.name).collect(),
533 }
534 }
535}