flake_edit/
api.rs

1use std::collections::HashMap;
2use std::process::Command;
3
4use reqwest::blocking::Client;
5use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT};
6use semver::Version;
7use serde::Deserialize;
8use thiserror::Error;
9
10use crate::version::parse_ref;
11#[derive(Error, Debug)]
12pub enum ApiError {
13    #[error("HTTP request failed: {0}")]
14    HttpError(#[from] reqwest::Error),
15
16    #[error("JSON parsing failed: {0}")]
17    JsonError(#[from] serde_json::Error),
18
19    #[error("Invalid header value: {0}")]
20    InvalidHeader(#[from] reqwest::header::InvalidHeaderValue),
21
22    #[error("UTF-8 conversion failed: {0}")]
23    Utf8Error(#[from] std::string::FromUtf8Error),
24
25    #[error("Failed to execute command: {0}")]
26    CommandError(#[from] std::io::Error),
27
28    #[error("No tags found for repository")]
29    NoTagsFound,
30
31    #[error("Invalid domain or repository: {0}")]
32    InvalidInput(String),
33}
34
35#[derive(Default)]
36pub struct ForgeClient {
37    client: Client,
38}
39
40impl ForgeClient {
41    fn get(&self, url: &str, headers: HeaderMap) -> Result<String, ApiError> {
42        let response = self.client.get(url).headers(headers).send()?;
43        let text = response.text()?;
44        Ok(text)
45    }
46
47    /// Check if a URL returns a successful (2xx) response
48    fn head_ok(&self, url: &str, headers: HeaderMap) -> bool {
49        match self.client.get(url).headers(headers).send() {
50            Ok(response) => response.status().is_success(),
51            Err(_) => false,
52        }
53    }
54
55    fn make_headers() -> HeaderMap {
56        let mut headers = HeaderMap::new();
57        if let Ok(user_agent) = HeaderValue::from_str("flake-edit") {
58            headers.insert(USER_AGENT, user_agent);
59        }
60        headers
61    }
62
63    pub fn detect_forge_type(&self, domain: &str) -> ForgeType {
64        if domain == "github.com" {
65            return ForgeType::GitHub;
66        }
67
68        tracing::debug!("Attempting to detect forge type for domain: {}", domain);
69        let headers = Self::make_headers();
70
71        // Try both HTTPS and HTTP for each endpoint
72        for scheme in ["https", "http"] {
73            // Try Forgejo version endpoint first
74            let forgejo_url = format!("{}://{}/api/forgejo/v1/version", scheme, domain);
75            tracing::debug!("Trying Forgejo endpoint: {}", forgejo_url);
76            if let Ok(text) = self.get(&forgejo_url, headers.clone()) {
77                tracing::debug!("Forgejo endpoint response body: {}", text);
78                if let Some(forge_type) = parse_forge_version(&text) {
79                    tracing::info!("Detected Forgejo/Gitea at {}", domain);
80                    return forge_type;
81                }
82            }
83
84            // Try Gitea version endpoint
85            let gitea_url = format!("{}://{}/api/v1/version", scheme, domain);
86            tracing::debug!("Trying Gitea endpoint: {}", gitea_url);
87            if let Ok(text) = self.get(&gitea_url, headers.clone()) {
88                tracing::debug!("Gitea endpoint response body: {}", text);
89                if let Some(forge_type) = parse_forge_version(&text) {
90                    tracing::info!("Detected Forgejo/Gitea at {}", domain);
91                    return forge_type;
92                }
93                // Plain Gitea just has a version number without +gitea or +forgejo
94                if serde_json::from_str::<ForgeVersion>(&text).is_ok() {
95                    tracing::info!("Detected Gitea at {}", domain);
96                    return ForgeType::Gitea;
97                }
98            }
99        }
100
101        tracing::warn!(
102            "Could not detect forge type for {}, will try GitHub API as fallback",
103            domain
104        );
105        ForgeType::Unknown
106    }
107
108    pub fn query_github_tags(&self, repo: &str, owner: &str) -> Result<IntermediaryTags, ApiError> {
109        let mut headers = Self::make_headers();
110        if let Some(token) = get_forge_token("github.com") {
111            tracing::debug!("Found github token.");
112            headers.insert(
113                AUTHORIZATION,
114                HeaderValue::from_str(&format!("Bearer {token}"))?,
115            );
116        }
117
118        let body = self.get(
119            &format!("https://api.github.com/repos/{}/{}/tags", owner, repo),
120            headers,
121        )?;
122
123        tracing::debug!("Body from api: {body}");
124        let tags = serde_json::from_str::<IntermediaryTags>(&body)?;
125        Ok(tags)
126    }
127
128    /// Check if a specific branch exists (returns true/false, no error on 404)
129    pub fn branch_exists_github(&self, repo: &str, owner: &str, branch: &str) -> bool {
130        let mut headers = Self::make_headers();
131        if let Some(token) = get_forge_token("github.com") {
132            headers.insert(
133                AUTHORIZATION,
134                HeaderValue::from_str(&format!("Bearer {token}")).unwrap(),
135            );
136        }
137
138        let url = format!(
139            "https://api.github.com/repos/{}/{}/branches/{}",
140            owner, repo, branch
141        );
142
143        self.head_ok(&url, headers)
144    }
145
146    /// Check if a specific branch exists on Gitea/Forgejo
147    pub fn branch_exists_gitea(&self, repo: &str, owner: &str, domain: &str, branch: &str) -> bool {
148        let mut headers = Self::make_headers();
149        if let Some(token) = get_forge_token(domain) {
150            headers.insert(
151                AUTHORIZATION,
152                HeaderValue::from_str(&format!("Bearer {token}")).unwrap(),
153            );
154        }
155
156        for scheme in ["https", "http"] {
157            let url = format!(
158                "{}://{}/api/v1/repos/{}/{}/branches/{}",
159                scheme, domain, owner, repo, branch
160            );
161            if self.head_ok(&url, headers.clone()) {
162                return true;
163            }
164        }
165        false
166    }
167
168    pub fn query_gitea_tags(
169        &self,
170        repo: &str,
171        owner: &str,
172        domain: &str,
173    ) -> Result<IntermediaryTags, ApiError> {
174        let mut headers = Self::make_headers();
175
176        if let Some(token) = get_forge_token(domain) {
177            tracing::debug!("Found token for {}", domain);
178            headers.insert(
179                AUTHORIZATION,
180                HeaderValue::from_str(&format!("Bearer {token}"))?,
181            );
182        }
183
184        // Try HTTPS first, then HTTP
185        for scheme in ["https", "http"] {
186            let url = format!(
187                "{}://{}/api/v1/repos/{}/{}/tags",
188                scheme, domain, owner, repo
189            );
190            tracing::debug!("Trying Gitea tags endpoint: {}", url);
191
192            if let Ok(body) = self.get(&url, headers.clone()) {
193                tracing::debug!("Body from Gitea API: {body}");
194                if let Ok(tags) = serde_json::from_str::<IntermediaryTags>(&body) {
195                    return Ok(tags);
196                }
197            }
198        }
199
200        Err(ApiError::NoTagsFound)
201    }
202
203    pub fn query_github_branches(
204        &self,
205        repo: &str,
206        owner: &str,
207    ) -> Result<IntermediaryBranches, ApiError> {
208        let mut headers = Self::make_headers();
209        if let Some(token) = get_forge_token("github.com") {
210            tracing::debug!("Found github token.");
211            headers.insert(
212                AUTHORIZATION,
213                HeaderValue::from_str(&format!("Bearer {token}"))?,
214            );
215        }
216
217        let mut all_branches = Vec::new();
218        let mut page = 1;
219        const MAX_PAGES: u32 = 20; // Safety limit to avoid infinite loops
220
221        loop {
222            let url = format!(
223                "https://api.github.com/repos/{}/{}/branches?per_page=100&page={}",
224                owner, repo, page
225            );
226            tracing::debug!("Fetching branches page {}: {}", page, url);
227
228            let body = self.get(&url, headers.clone())?;
229            let page_branches = serde_json::from_str::<IntermediaryBranches>(&body)?;
230
231            let count = page_branches.0.len();
232            tracing::debug!("Got {} branches on page {}", count, page);
233
234            all_branches.extend(page_branches.0);
235
236            // Stop if we got fewer than 100 (last page) or hit max pages
237            if count < 100 || page >= MAX_PAGES {
238                break;
239            }
240
241            page += 1;
242        }
243
244        tracing::debug!("Total branches fetched: {}", all_branches.len());
245        Ok(IntermediaryBranches(all_branches))
246    }
247
248    pub fn query_gitea_branches(
249        &self,
250        repo: &str,
251        owner: &str,
252        domain: &str,
253    ) -> Result<IntermediaryBranches, ApiError> {
254        let mut headers = Self::make_headers();
255
256        if let Some(token) = get_forge_token(domain) {
257            tracing::debug!("Found token for {}", domain);
258            headers.insert(
259                AUTHORIZATION,
260                HeaderValue::from_str(&format!("Bearer {token}"))?,
261            );
262        }
263
264        let mut all_branches = Vec::new();
265        let mut page = 1;
266        const MAX_PAGES: u32 = 20;
267
268        // Try HTTPS first, then HTTP
269        for scheme in ["https", "http"] {
270            loop {
271                let url = format!(
272                    "{}://{}/api/v1/repos/{}/{}/branches?limit=50&page={}",
273                    scheme, domain, owner, repo, page
274                );
275                tracing::debug!("Trying Gitea branches endpoint: {}", url);
276
277                match self.get(&url, headers.clone()) {
278                    Ok(body) => {
279                        tracing::debug!("Body from Gitea API: {body}");
280                        match serde_json::from_str::<IntermediaryBranches>(&body) {
281                            Ok(page_branches) => {
282                                let count = page_branches.0.len();
283                                all_branches.extend(page_branches.0);
284
285                                if count < 50 || page >= MAX_PAGES {
286                                    return Ok(IntermediaryBranches(all_branches));
287                                }
288                                page += 1;
289                            }
290                            Err(_) => break, // Try next scheme
291                        }
292                    }
293                    Err(_) => break, // Try next scheme
294                }
295            }
296
297            if !all_branches.is_empty() {
298                return Ok(IntermediaryBranches(all_branches));
299            }
300            page = 1; // Reset for next scheme
301        }
302
303        Err(ApiError::InvalidInput("Could not fetch branches".into()))
304    }
305}
306
307#[derive(Deserialize, Debug)]
308pub struct IntermediaryTags(Vec<IntermediaryTag>);
309
310#[derive(Deserialize, Debug)]
311pub struct IntermediaryBranches(Vec<IntermediaryBranch>);
312
313#[derive(Deserialize, Debug)]
314pub struct IntermediaryBranch {
315    name: String,
316}
317
318#[derive(Debug, Default)]
319pub struct Branches {
320    pub names: Vec<String>,
321}
322
323#[derive(Debug)]
324pub struct Tags {
325    versions: Vec<TagVersion>,
326}
327
328impl Tags {
329    pub fn get_latest_tag(&mut self) -> Option<String> {
330        self.sort();
331        self.versions.last().map(|tag| tag.original.clone())
332    }
333    pub fn sort(&mut self) {
334        self.versions
335            .sort_by(|a, b| a.version.cmp_precedence(&b.version));
336    }
337}
338
339#[derive(Deserialize, Debug)]
340pub struct IntermediaryTag {
341    name: String,
342}
343
344#[derive(Debug)]
345struct TagVersion {
346    version: Version,
347    original: String,
348}
349
350#[derive(Deserialize, Debug)]
351struct ForgeVersion {
352    version: String,
353}
354
355#[derive(Debug, PartialEq)]
356pub enum ForgeType {
357    GitHub,
358    Gitea, // Covers both Gitea and Forgejo
359    Unknown,
360}
361
362fn parse_forge_version(json: &str) -> Option<ForgeType> {
363    serde_json::from_str::<ForgeVersion>(json)
364        .ok()
365        .and_then(|v| {
366            if v.version.contains("+forgejo") || v.version.contains("+gitea") {
367                Some(ForgeType::Gitea)
368            } else {
369                None
370            }
371        })
372}
373
374// Test helpers are always available but not documented
375#[doc(hidden)]
376pub mod test_helpers {
377    use super::*;
378
379    pub fn parse_forge_version_test(json: &str) -> Option<ForgeType> {
380        parse_forge_version(json)
381    }
382}
383
384pub fn get_tags(repo: &str, owner: &str, domain: Option<&str>) -> Result<Tags, ApiError> {
385    let domain = domain.unwrap_or("github.com");
386    let client = ForgeClient::default();
387    let forge_type = client.detect_forge_type(domain);
388
389    tracing::debug!("Detected forge type for {}: {:?}", domain, forge_type);
390
391    let tags = match forge_type {
392        ForgeType::GitHub => client.query_github_tags(repo, owner)?,
393        ForgeType::Gitea => client.query_gitea_tags(repo, owner, domain)?,
394        ForgeType::Unknown => {
395            tracing::warn!("Unknown forge type for {}, trying Gitea API", domain);
396            client.query_gitea_tags(repo, owner, domain)?
397        }
398    };
399
400    Ok(tags.into())
401}
402
403pub fn get_branches(repo: &str, owner: &str, domain: Option<&str>) -> Result<Branches, ApiError> {
404    let domain = domain.unwrap_or("github.com");
405    let client = ForgeClient::default();
406    let forge_type = client.detect_forge_type(domain);
407
408    tracing::debug!(
409        "Fetching branches for {}/{} on {} ({:?})",
410        owner,
411        repo,
412        domain,
413        forge_type
414    );
415
416    let branches = match forge_type {
417        ForgeType::GitHub => client.query_github_branches(repo, owner)?,
418        ForgeType::Gitea => client.query_gitea_branches(repo, owner, domain)?,
419        ForgeType::Unknown => {
420            tracing::warn!("Unknown forge type for {}, trying Gitea API", domain);
421            client.query_gitea_branches(repo, owner, domain)?
422        }
423    };
424
425    Ok(branches.into())
426}
427
428/// Check if a specific branch exists without listing all branches.
429/// Much more efficient for repos with many branches (like nixpkgs).
430pub fn branch_exists(repo: &str, owner: &str, branch: &str, domain: Option<&str>) -> bool {
431    let domain = domain.unwrap_or("github.com");
432    let client = ForgeClient::default();
433    let forge_type = client.detect_forge_type(domain);
434
435    match forge_type {
436        ForgeType::GitHub => client.branch_exists_github(repo, owner, branch),
437        ForgeType::Gitea => client.branch_exists_gitea(repo, owner, domain, branch),
438        ForgeType::Unknown => client.branch_exists_gitea(repo, owner, domain, branch),
439    }
440}
441
442/// Check multiple branches and return which ones exist.
443/// More efficient than get_branches for known candidate branches.
444pub fn filter_existing_branches(
445    repo: &str,
446    owner: &str,
447    candidates: &[String],
448    domain: Option<&str>,
449) -> Vec<String> {
450    candidates
451        .iter()
452        .filter(|branch| branch_exists(repo, owner, branch, domain))
453        .cloned()
454        .collect()
455}
456
457#[derive(Deserialize, Debug, Clone)]
458struct NixConfig {
459    #[serde(rename = "access-tokens")]
460    access_tokens: Option<AccessTokens>,
461}
462
463impl NixConfig {
464    fn forge_token(&self, domain: &str) -> Option<String> {
465        self.access_tokens.as_ref()?.value.get(domain).cloned()
466    }
467}
468
469#[derive(Deserialize, Debug, Clone)]
470struct AccessTokens {
471    value: HashMap<String, String>,
472}
473
474fn get_forge_token(domain: &str) -> Option<String> {
475    // Try to get token from nix config
476    if let Ok(output) = Command::new("nix")
477        .arg("config")
478        .arg("show")
479        .arg("--json")
480        .output()
481    {
482        if let Ok(stdout) = String::from_utf8(output.stdout) {
483            if let Ok(config) = serde_json::from_str::<NixConfig>(&stdout) {
484                if let Some(token) = config.forge_token(domain) {
485                    return Some(token);
486                }
487            }
488        }
489    }
490
491    // Fallback to environment variables
492    if let Ok(token) = std::env::var("GITEA_TOKEN") {
493        return Some(token);
494    }
495    if let Ok(token) = std::env::var("FORGEJO_TOKEN") {
496        return Some(token);
497    }
498    if domain == "github.com" {
499        if let Ok(token) = std::env::var("GITHUB_TOKEN") {
500            return Some(token);
501        }
502    }
503
504    None
505}
506
507impl From<IntermediaryTags> for Tags {
508    fn from(value: IntermediaryTags) -> Self {
509        let mut versions = vec![];
510        for itag in value.0 {
511            let parsed = parse_ref(&itag.name, false);
512            let normalized = parsed.normalized_for_semver;
513            match Version::parse(&normalized) {
514                Ok(semver) => {
515                    versions.push(TagVersion {
516                        version: semver,
517                        original: parsed.original_ref,
518                    });
519                }
520                Err(e) => {
521                    tracing::error!("Could not parse version {:?}", e);
522                }
523            }
524        }
525        Tags { versions }
526    }
527}
528
529impl From<IntermediaryBranches> for Branches {
530    fn from(value: IntermediaryBranches) -> Self {
531        Branches {
532            names: value.0.into_iter().map(|b| b.name).collect(),
533        }
534    }
535}