1use std::collections::HashMap;
2use std::process::Command;
3
4use semver::Version;
5use serde::Deserialize;
6use thiserror::Error;
7use ureq::Agent;
8
9use crate::version::parse_ref;
10
11#[derive(Error, Debug)]
12pub enum ApiError {
13 #[error("HTTP request failed: {0}")]
14 HttpError(#[from] ureq::Error),
15
16 #[error("JSON parsing failed: {0}")]
17 JsonError(#[from] serde_json::Error),
18
19 #[error("UTF-8 conversion failed: {0}")]
20 Utf8Error(#[from] std::string::FromUtf8Error),
21
22 #[error("IO error: {0}")]
23 IoError(#[from] std::io::Error),
24
25 #[error("No tags found for repository")]
26 NoTagsFound,
27
28 #[error("Invalid domain or repository: {0}")]
29 InvalidInput(String),
30}
31
32#[derive(Clone, Default)]
34struct Headers {
35 user_agent: Option<String>,
36 authorization: Option<String>,
37}
38
39pub struct ForgeClient {
40 agent: Agent,
41}
42
43impl Default for ForgeClient {
44 fn default() -> Self {
45 Self {
46 agent: Agent::new_with_defaults(),
47 }
48 }
49}
50
51impl ForgeClient {
52 fn get(&self, url: &str, headers: &Headers) -> Result<String, ApiError> {
53 let mut request = self.agent.get(url);
54 if let Some(ref ua) = headers.user_agent {
55 request = request.header("User-Agent", ua);
56 }
57 if let Some(ref auth) = headers.authorization {
58 request = request.header("Authorization", auth);
59 }
60 let body = request.call()?.body_mut().read_to_string()?;
61 Ok(body)
62 }
63
64 fn head_ok(&self, url: &str, headers: &Headers) -> bool {
66 let mut request = self.agent.get(url);
67 if let Some(ref ua) = headers.user_agent {
68 request = request.header("User-Agent", ua);
69 }
70 if let Some(ref auth) = headers.authorization {
71 request = request.header("Authorization", auth);
72 }
73 request.call().is_ok()
74 }
75
76 fn base_headers() -> Headers {
77 Headers {
78 user_agent: Some("flake-edit".to_string()),
79 authorization: None,
80 }
81 }
82
83 fn auth_headers(domain: &str) -> Headers {
85 let mut headers = Self::base_headers();
86 if let Some(token) = get_forge_token(domain) {
87 tracing::debug!("Found token for {}", domain);
88 headers.authorization = Some(format!("Bearer {token}"));
89 }
90 headers
91 }
92
93 pub fn detect_forge_type(&self, domain: &str) -> ForgeType {
94 if domain == "github.com" {
95 return ForgeType::GitHub;
96 }
97
98 tracing::debug!("Attempting to detect forge type for domain: {}", domain);
99 let headers = Self::base_headers();
100
101 for scheme in ["https", "http"] {
103 let forgejo_url = format!("{}://{}/api/forgejo/v1/version", scheme, domain);
105 tracing::debug!("Trying Forgejo endpoint: {}", forgejo_url);
106 if let Ok(text) = self.get(&forgejo_url, &headers) {
107 tracing::debug!("Forgejo endpoint response body: {}", text);
108 if let Some(forge_type) = parse_forge_version(&text) {
109 tracing::info!("Detected Forgejo/Gitea at {}", domain);
110 return forge_type;
111 }
112 }
113
114 let gitea_url = format!("{}://{}/api/v1/version", scheme, domain);
116 tracing::debug!("Trying Gitea endpoint: {}", gitea_url);
117 if let Ok(text) = self.get(&gitea_url, &headers) {
118 tracing::debug!("Gitea endpoint response body: {}", text);
119 if let Some(forge_type) = parse_forge_version(&text) {
120 tracing::info!("Detected Forgejo/Gitea at {}", domain);
121 return forge_type;
122 }
123 if serde_json::from_str::<ForgeVersion>(&text).is_ok() {
125 tracing::info!("Detected Gitea at {}", domain);
126 return ForgeType::Gitea;
127 }
128 }
129 }
130
131 tracing::warn!(
132 "Could not detect forge type for {}, will try GitHub API as fallback",
133 domain
134 );
135 ForgeType::Unknown
136 }
137
138 pub fn query_github_tags(&self, repo: &str, owner: &str) -> Result<IntermediaryTags, ApiError> {
139 let headers = Self::auth_headers("github.com");
140 let body = self.get(
141 &format!("https://api.github.com/repos/{}/{}/tags", owner, repo),
142 &headers,
143 )?;
144
145 tracing::debug!("Body from api: {body}");
146 let tags = serde_json::from_str::<IntermediaryTags>(&body)?;
147 Ok(tags)
148 }
149
150 pub fn branch_exists_github(&self, repo: &str, owner: &str, branch: &str) -> bool {
152 let headers = Self::auth_headers("github.com");
153 let url = format!(
154 "https://api.github.com/repos/{}/{}/branches/{}",
155 owner, repo, branch
156 );
157
158 self.head_ok(&url, &headers)
159 }
160
161 pub fn branch_exists_gitea(&self, repo: &str, owner: &str, domain: &str, branch: &str) -> bool {
163 let headers = Self::auth_headers(domain);
164 for scheme in ["https", "http"] {
165 let url = format!(
166 "{}://{}/api/v1/repos/{}/{}/branches/{}",
167 scheme, domain, owner, repo, branch
168 );
169 if self.head_ok(&url, &headers) {
170 return true;
171 }
172 }
173 false
174 }
175
176 pub fn query_gitea_tags(
177 &self,
178 repo: &str,
179 owner: &str,
180 domain: &str,
181 ) -> Result<IntermediaryTags, ApiError> {
182 let headers = Self::auth_headers(domain);
183
184 for scheme in ["https", "http"] {
186 let url = format!(
187 "{}://{}/api/v1/repos/{}/{}/tags",
188 scheme, domain, owner, repo
189 );
190 tracing::debug!("Trying Gitea tags endpoint: {}", url);
191
192 if let Ok(body) = self.get(&url, &headers) {
193 tracing::debug!("Body from Gitea API: {body}");
194 if let Ok(tags) = serde_json::from_str::<IntermediaryTags>(&body) {
195 return Ok(tags);
196 }
197 }
198 }
199
200 Err(ApiError::NoTagsFound)
201 }
202
203 pub fn query_github_branches(
204 &self,
205 repo: &str,
206 owner: &str,
207 ) -> Result<IntermediaryBranches, ApiError> {
208 let headers = Self::auth_headers("github.com");
209
210 let mut all_branches = Vec::new();
211 let mut page = 1;
212 const MAX_PAGES: u32 = 20; loop {
215 let url = format!(
216 "https://api.github.com/repos/{}/{}/branches?per_page=100&page={}",
217 owner, repo, page
218 );
219 tracing::debug!("Fetching branches page {}: {}", page, url);
220
221 let body = self.get(&url, &headers)?;
222 let page_branches = serde_json::from_str::<IntermediaryBranches>(&body)?;
223
224 let count = page_branches.0.len();
225 tracing::debug!("Got {} branches on page {}", count, page);
226
227 all_branches.extend(page_branches.0);
228
229 if count < 100 || page >= MAX_PAGES {
231 break;
232 }
233
234 page += 1;
235 }
236
237 tracing::debug!("Total branches fetched: {}", all_branches.len());
238 Ok(IntermediaryBranches(all_branches))
239 }
240
241 pub fn query_gitea_branches(
242 &self,
243 repo: &str,
244 owner: &str,
245 domain: &str,
246 ) -> Result<IntermediaryBranches, ApiError> {
247 let headers = Self::auth_headers(domain);
248
249 let mut all_branches = Vec::new();
250 let mut page = 1;
251 const MAX_PAGES: u32 = 20;
252
253 for scheme in ["https", "http"] {
255 loop {
256 let url = format!(
257 "{}://{}/api/v1/repos/{}/{}/branches?limit=50&page={}",
258 scheme, domain, owner, repo, page
259 );
260 tracing::debug!("Trying Gitea branches endpoint: {}", url);
261
262 match self.get(&url, &headers) {
263 Ok(body) => {
264 tracing::debug!("Body from Gitea API: {body}");
265 match serde_json::from_str::<IntermediaryBranches>(&body) {
266 Ok(page_branches) => {
267 let count = page_branches.0.len();
268 all_branches.extend(page_branches.0);
269
270 if count < 50 || page >= MAX_PAGES {
271 return Ok(IntermediaryBranches(all_branches));
272 }
273 page += 1;
274 }
275 Err(_) => break, }
277 }
278 Err(_) => break, }
280 }
281
282 if !all_branches.is_empty() {
283 return Ok(IntermediaryBranches(all_branches));
284 }
285 page = 1; }
287
288 Err(ApiError::InvalidInput("Could not fetch branches".into()))
289 }
290}
291
292#[derive(Deserialize, Debug)]
293pub struct IntermediaryTags(Vec<IntermediaryTag>);
294
295#[derive(Deserialize, Debug)]
296pub struct IntermediaryBranches(Vec<IntermediaryBranch>);
297
298#[derive(Deserialize, Debug)]
299pub struct IntermediaryBranch {
300 name: String,
301}
302
303#[derive(Debug, Default)]
304pub struct Branches {
305 pub names: Vec<String>,
306}
307
308#[derive(Debug)]
309pub struct Tags {
310 versions: Vec<TagVersion>,
311}
312
313impl Tags {
314 pub fn get_latest_tag(&mut self) -> Option<String> {
315 self.sort();
316 self.versions.last().map(|tag| tag.original.clone())
317 }
318 pub fn sort(&mut self) {
319 self.versions
320 .sort_by(|a, b| a.version.cmp_precedence(&b.version));
321 }
322}
323
324#[derive(Deserialize, Debug)]
325pub struct IntermediaryTag {
326 name: String,
327}
328
329#[derive(Debug)]
330struct TagVersion {
331 version: Version,
332 original: String,
333}
334
335#[derive(Deserialize, Debug)]
336struct ForgeVersion {
337 version: String,
338}
339
340#[derive(Debug, PartialEq)]
341pub enum ForgeType {
342 GitHub,
343 Gitea, Unknown,
345}
346
347fn parse_forge_version(json: &str) -> Option<ForgeType> {
348 serde_json::from_str::<ForgeVersion>(json)
349 .ok()
350 .and_then(|v| {
351 if v.version.contains("+forgejo") || v.version.contains("+gitea") {
352 Some(ForgeType::Gitea)
353 } else {
354 None
355 }
356 })
357}
358
359#[doc(hidden)]
361pub mod test_helpers {
362 use super::*;
363
364 pub fn parse_forge_version_test(json: &str) -> Option<ForgeType> {
365 parse_forge_version(json)
366 }
367}
368
369pub fn get_tags(repo: &str, owner: &str, domain: Option<&str>) -> Result<Tags, ApiError> {
370 let domain = domain.unwrap_or("github.com");
371 let client = ForgeClient::default();
372 let forge_type = client.detect_forge_type(domain);
373
374 tracing::debug!("Detected forge type for {}: {:?}", domain, forge_type);
375
376 let tags = match forge_type {
377 ForgeType::GitHub => client.query_github_tags(repo, owner)?,
378 ForgeType::Gitea => client.query_gitea_tags(repo, owner, domain)?,
379 ForgeType::Unknown => {
380 tracing::warn!("Unknown forge type for {}, trying Gitea API", domain);
381 client.query_gitea_tags(repo, owner, domain)?
382 }
383 };
384
385 Ok(tags.into())
386}
387
388pub fn get_branches(repo: &str, owner: &str, domain: Option<&str>) -> Result<Branches, ApiError> {
389 let domain = domain.unwrap_or("github.com");
390 let client = ForgeClient::default();
391 let forge_type = client.detect_forge_type(domain);
392
393 tracing::debug!(
394 "Fetching branches for {}/{} on {} ({:?})",
395 owner,
396 repo,
397 domain,
398 forge_type
399 );
400
401 let branches = match forge_type {
402 ForgeType::GitHub => client.query_github_branches(repo, owner)?,
403 ForgeType::Gitea => client.query_gitea_branches(repo, owner, domain)?,
404 ForgeType::Unknown => {
405 tracing::warn!("Unknown forge type for {}, trying Gitea API", domain);
406 client.query_gitea_branches(repo, owner, domain)?
407 }
408 };
409
410 Ok(branches.into())
411}
412
413pub fn branch_exists(repo: &str, owner: &str, branch: &str, domain: Option<&str>) -> bool {
416 let domain = domain.unwrap_or("github.com");
417 let client = ForgeClient::default();
418 let forge_type = client.detect_forge_type(domain);
419
420 match forge_type {
421 ForgeType::GitHub => client.branch_exists_github(repo, owner, branch),
422 ForgeType::Gitea => client.branch_exists_gitea(repo, owner, domain, branch),
423 ForgeType::Unknown => client.branch_exists_gitea(repo, owner, domain, branch),
424 }
425}
426
427pub fn filter_existing_branches(
430 repo: &str,
431 owner: &str,
432 candidates: &[String],
433 domain: Option<&str>,
434) -> Vec<String> {
435 candidates
436 .iter()
437 .filter(|branch| branch_exists(repo, owner, branch, domain))
438 .cloned()
439 .collect()
440}
441
442#[derive(Deserialize, Debug, Clone)]
443struct NixConfig {
444 #[serde(rename = "access-tokens")]
445 access_tokens: Option<AccessTokens>,
446}
447
448impl NixConfig {
449 fn forge_token(&self, domain: &str) -> Option<String> {
450 self.access_tokens.as_ref()?.value.get(domain).cloned()
451 }
452}
453
454#[derive(Deserialize, Debug, Clone)]
455struct AccessTokens {
456 value: HashMap<String, String>,
457}
458
459fn get_forge_token(domain: &str) -> Option<String> {
460 if let Ok(output) = Command::new("nix")
462 .arg("config")
463 .arg("show")
464 .arg("--json")
465 .output()
466 && let Ok(stdout) = String::from_utf8(output.stdout)
467 && let Ok(config) = serde_json::from_str::<NixConfig>(&stdout)
468 && let Some(token) = config.forge_token(domain)
469 {
470 return Some(token);
471 }
472
473 if let Ok(token) = std::env::var("GITEA_TOKEN") {
475 return Some(token);
476 }
477 if let Ok(token) = std::env::var("FORGEJO_TOKEN") {
478 return Some(token);
479 }
480 if domain == "github.com"
481 && let Ok(token) = std::env::var("GITHUB_TOKEN")
482 {
483 return Some(token);
484 }
485
486 None
487}
488
489impl From<IntermediaryTags> for Tags {
490 fn from(value: IntermediaryTags) -> Self {
491 let mut versions = vec![];
492 for itag in value.0 {
493 let parsed = parse_ref(&itag.name, false);
494 let normalized = parsed.normalized_for_semver;
495 match Version::parse(&normalized) {
496 Ok(semver) => {
497 versions.push(TagVersion {
498 version: semver,
499 original: parsed.original_ref,
500 });
501 }
502 Err(e) => {
503 tracing::error!("Could not parse version {:?}", e);
504 }
505 }
506 }
507 Tags { versions }
508 }
509}
510
511impl From<IntermediaryBranches> for Branches {
512 fn from(value: IntermediaryBranches) -> Self {
513 Branches {
514 names: value.0.into_iter().map(|b| b.name).collect(),
515 }
516 }
517}