flake_edit/
api.rs

1use std::collections::HashMap;
2use std::process::Command;
3
4use reqwest::blocking::Client;
5use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT};
6use semver::Version;
7use serde::Deserialize;
8use thiserror::Error;
9
10use crate::version::parse_ref;
11#[derive(Error, Debug)]
12pub enum ApiError {
13    #[error("HTTP request failed: {0}")]
14    HttpError(#[from] reqwest::Error),
15
16    #[error("JSON parsing failed: {0}")]
17    JsonError(#[from] serde_json::Error),
18
19    #[error("Invalid header value: {0}")]
20    InvalidHeader(#[from] reqwest::header::InvalidHeaderValue),
21
22    #[error("UTF-8 conversion failed: {0}")]
23    Utf8Error(#[from] std::string::FromUtf8Error),
24
25    #[error("Failed to execute command: {0}")]
26    CommandError(#[from] std::io::Error),
27
28    #[error("No tags found for repository")]
29    NoTagsFound,
30
31    #[error("Invalid domain or repository: {0}")]
32    InvalidInput(String),
33}
34
35#[derive(Default)]
36pub struct ForgeClient {
37    client: Client,
38}
39
40impl ForgeClient {
41    fn get(&self, url: &str, headers: HeaderMap) -> Result<String, ApiError> {
42        let response = self.client.get(url).headers(headers).send()?;
43        let text = response.text()?;
44        Ok(text)
45    }
46
47    /// Check if a URL returns a successful (2xx) response
48    fn head_ok(&self, url: &str, headers: HeaderMap) -> bool {
49        match self.client.get(url).headers(headers).send() {
50            Ok(response) => response.status().is_success(),
51            Err(_) => false,
52        }
53    }
54
55    fn base_headers() -> HeaderMap {
56        let mut headers = HeaderMap::new();
57        if let Ok(user_agent) = HeaderValue::from_str("flake-edit") {
58            headers.insert(USER_AGENT, user_agent);
59        }
60        headers
61    }
62
63    /// Create headers with optional Bearer token authentication for the given domain.
64    fn auth_headers(domain: &str) -> Result<HeaderMap, ApiError> {
65        let mut headers = Self::base_headers();
66        if let Some(token) = get_forge_token(domain) {
67            tracing::debug!("Found token for {}", domain);
68            headers.insert(
69                AUTHORIZATION,
70                HeaderValue::from_str(&format!("Bearer {token}"))?,
71            );
72        }
73        Ok(headers)
74    }
75
76    pub fn detect_forge_type(&self, domain: &str) -> ForgeType {
77        if domain == "github.com" {
78            return ForgeType::GitHub;
79        }
80
81        tracing::debug!("Attempting to detect forge type for domain: {}", domain);
82        let headers = Self::base_headers();
83
84        // Try both HTTPS and HTTP for each endpoint
85        for scheme in ["https", "http"] {
86            // Try Forgejo version endpoint first
87            let forgejo_url = format!("{}://{}/api/forgejo/v1/version", scheme, domain);
88            tracing::debug!("Trying Forgejo endpoint: {}", forgejo_url);
89            if let Ok(text) = self.get(&forgejo_url, headers.clone()) {
90                tracing::debug!("Forgejo endpoint response body: {}", text);
91                if let Some(forge_type) = parse_forge_version(&text) {
92                    tracing::info!("Detected Forgejo/Gitea at {}", domain);
93                    return forge_type;
94                }
95            }
96
97            // Try Gitea version endpoint
98            let gitea_url = format!("{}://{}/api/v1/version", scheme, domain);
99            tracing::debug!("Trying Gitea endpoint: {}", gitea_url);
100            if let Ok(text) = self.get(&gitea_url, headers.clone()) {
101                tracing::debug!("Gitea endpoint response body: {}", text);
102                if let Some(forge_type) = parse_forge_version(&text) {
103                    tracing::info!("Detected Forgejo/Gitea at {}", domain);
104                    return forge_type;
105                }
106                // Plain Gitea just has a version number without +gitea or +forgejo
107                if serde_json::from_str::<ForgeVersion>(&text).is_ok() {
108                    tracing::info!("Detected Gitea at {}", domain);
109                    return ForgeType::Gitea;
110                }
111            }
112        }
113
114        tracing::warn!(
115            "Could not detect forge type for {}, will try GitHub API as fallback",
116            domain
117        );
118        ForgeType::Unknown
119    }
120
121    pub fn query_github_tags(&self, repo: &str, owner: &str) -> Result<IntermediaryTags, ApiError> {
122        let headers = Self::auth_headers("github.com")?;
123        let body = self.get(
124            &format!("https://api.github.com/repos/{}/{}/tags", owner, repo),
125            headers,
126        )?;
127
128        tracing::debug!("Body from api: {body}");
129        let tags = serde_json::from_str::<IntermediaryTags>(&body)?;
130        Ok(tags)
131    }
132
133    /// Check if a specific branch exists (returns true/false, no error on 404)
134    pub fn branch_exists_github(&self, repo: &str, owner: &str, branch: &str) -> bool {
135        let headers = Self::auth_headers("github.com").unwrap_or_else(|_| Self::base_headers());
136        let url = format!(
137            "https://api.github.com/repos/{}/{}/branches/{}",
138            owner, repo, branch
139        );
140
141        self.head_ok(&url, headers)
142    }
143
144    /// Check if a specific branch exists on Gitea/Forgejo
145    pub fn branch_exists_gitea(&self, repo: &str, owner: &str, domain: &str, branch: &str) -> bool {
146        let headers = Self::auth_headers(domain).unwrap_or_else(|_| Self::base_headers());
147        for scheme in ["https", "http"] {
148            let url = format!(
149                "{}://{}/api/v1/repos/{}/{}/branches/{}",
150                scheme, domain, owner, repo, branch
151            );
152            if self.head_ok(&url, headers.clone()) {
153                return true;
154            }
155        }
156        false
157    }
158
159    pub fn query_gitea_tags(
160        &self,
161        repo: &str,
162        owner: &str,
163        domain: &str,
164    ) -> Result<IntermediaryTags, ApiError> {
165        let headers = Self::auth_headers(domain)?;
166
167        // Try HTTPS first, then HTTP
168        for scheme in ["https", "http"] {
169            let url = format!(
170                "{}://{}/api/v1/repos/{}/{}/tags",
171                scheme, domain, owner, repo
172            );
173            tracing::debug!("Trying Gitea tags endpoint: {}", url);
174
175            if let Ok(body) = self.get(&url, headers.clone()) {
176                tracing::debug!("Body from Gitea API: {body}");
177                if let Ok(tags) = serde_json::from_str::<IntermediaryTags>(&body) {
178                    return Ok(tags);
179                }
180            }
181        }
182
183        Err(ApiError::NoTagsFound)
184    }
185
186    pub fn query_github_branches(
187        &self,
188        repo: &str,
189        owner: &str,
190    ) -> Result<IntermediaryBranches, ApiError> {
191        let headers = Self::auth_headers("github.com")?;
192
193        let mut all_branches = Vec::new();
194        let mut page = 1;
195        const MAX_PAGES: u32 = 20; // Safety limit to avoid infinite loops
196
197        loop {
198            let url = format!(
199                "https://api.github.com/repos/{}/{}/branches?per_page=100&page={}",
200                owner, repo, page
201            );
202            tracing::debug!("Fetching branches page {}: {}", page, url);
203
204            let body = self.get(&url, headers.clone())?;
205            let page_branches = serde_json::from_str::<IntermediaryBranches>(&body)?;
206
207            let count = page_branches.0.len();
208            tracing::debug!("Got {} branches on page {}", count, page);
209
210            all_branches.extend(page_branches.0);
211
212            // Stop if we got fewer than 100 (last page) or hit max pages
213            if count < 100 || page >= MAX_PAGES {
214                break;
215            }
216
217            page += 1;
218        }
219
220        tracing::debug!("Total branches fetched: {}", all_branches.len());
221        Ok(IntermediaryBranches(all_branches))
222    }
223
224    pub fn query_gitea_branches(
225        &self,
226        repo: &str,
227        owner: &str,
228        domain: &str,
229    ) -> Result<IntermediaryBranches, ApiError> {
230        let headers = Self::auth_headers(domain)?;
231
232        let mut all_branches = Vec::new();
233        let mut page = 1;
234        const MAX_PAGES: u32 = 20;
235
236        // Try HTTPS first, then HTTP
237        for scheme in ["https", "http"] {
238            loop {
239                let url = format!(
240                    "{}://{}/api/v1/repos/{}/{}/branches?limit=50&page={}",
241                    scheme, domain, owner, repo, page
242                );
243                tracing::debug!("Trying Gitea branches endpoint: {}", url);
244
245                match self.get(&url, headers.clone()) {
246                    Ok(body) => {
247                        tracing::debug!("Body from Gitea API: {body}");
248                        match serde_json::from_str::<IntermediaryBranches>(&body) {
249                            Ok(page_branches) => {
250                                let count = page_branches.0.len();
251                                all_branches.extend(page_branches.0);
252
253                                if count < 50 || page >= MAX_PAGES {
254                                    return Ok(IntermediaryBranches(all_branches));
255                                }
256                                page += 1;
257                            }
258                            Err(_) => break, // Try next scheme
259                        }
260                    }
261                    Err(_) => break, // Try next scheme
262                }
263            }
264
265            if !all_branches.is_empty() {
266                return Ok(IntermediaryBranches(all_branches));
267            }
268            page = 1; // Reset for next scheme
269        }
270
271        Err(ApiError::InvalidInput("Could not fetch branches".into()))
272    }
273}
274
275#[derive(Deserialize, Debug)]
276pub struct IntermediaryTags(Vec<IntermediaryTag>);
277
278#[derive(Deserialize, Debug)]
279pub struct IntermediaryBranches(Vec<IntermediaryBranch>);
280
281#[derive(Deserialize, Debug)]
282pub struct IntermediaryBranch {
283    name: String,
284}
285
286#[derive(Debug, Default)]
287pub struct Branches {
288    pub names: Vec<String>,
289}
290
291#[derive(Debug)]
292pub struct Tags {
293    versions: Vec<TagVersion>,
294}
295
296impl Tags {
297    pub fn get_latest_tag(&mut self) -> Option<String> {
298        self.sort();
299        self.versions.last().map(|tag| tag.original.clone())
300    }
301    pub fn sort(&mut self) {
302        self.versions
303            .sort_by(|a, b| a.version.cmp_precedence(&b.version));
304    }
305}
306
307#[derive(Deserialize, Debug)]
308pub struct IntermediaryTag {
309    name: String,
310}
311
312#[derive(Debug)]
313struct TagVersion {
314    version: Version,
315    original: String,
316}
317
318#[derive(Deserialize, Debug)]
319struct ForgeVersion {
320    version: String,
321}
322
323#[derive(Debug, PartialEq)]
324pub enum ForgeType {
325    GitHub,
326    Gitea, // Covers both Gitea and Forgejo
327    Unknown,
328}
329
330fn parse_forge_version(json: &str) -> Option<ForgeType> {
331    serde_json::from_str::<ForgeVersion>(json)
332        .ok()
333        .and_then(|v| {
334            if v.version.contains("+forgejo") || v.version.contains("+gitea") {
335                Some(ForgeType::Gitea)
336            } else {
337                None
338            }
339        })
340}
341
342// Test helpers are always available but not documented
343#[doc(hidden)]
344pub mod test_helpers {
345    use super::*;
346
347    pub fn parse_forge_version_test(json: &str) -> Option<ForgeType> {
348        parse_forge_version(json)
349    }
350}
351
352pub fn get_tags(repo: &str, owner: &str, domain: Option<&str>) -> Result<Tags, ApiError> {
353    let domain = domain.unwrap_or("github.com");
354    let client = ForgeClient::default();
355    let forge_type = client.detect_forge_type(domain);
356
357    tracing::debug!("Detected forge type for {}: {:?}", domain, forge_type);
358
359    let tags = match forge_type {
360        ForgeType::GitHub => client.query_github_tags(repo, owner)?,
361        ForgeType::Gitea => client.query_gitea_tags(repo, owner, domain)?,
362        ForgeType::Unknown => {
363            tracing::warn!("Unknown forge type for {}, trying Gitea API", domain);
364            client.query_gitea_tags(repo, owner, domain)?
365        }
366    };
367
368    Ok(tags.into())
369}
370
371pub fn get_branches(repo: &str, owner: &str, domain: Option<&str>) -> Result<Branches, ApiError> {
372    let domain = domain.unwrap_or("github.com");
373    let client = ForgeClient::default();
374    let forge_type = client.detect_forge_type(domain);
375
376    tracing::debug!(
377        "Fetching branches for {}/{} on {} ({:?})",
378        owner,
379        repo,
380        domain,
381        forge_type
382    );
383
384    let branches = match forge_type {
385        ForgeType::GitHub => client.query_github_branches(repo, owner)?,
386        ForgeType::Gitea => client.query_gitea_branches(repo, owner, domain)?,
387        ForgeType::Unknown => {
388            tracing::warn!("Unknown forge type for {}, trying Gitea API", domain);
389            client.query_gitea_branches(repo, owner, domain)?
390        }
391    };
392
393    Ok(branches.into())
394}
395
396/// Check if a specific branch exists without listing all branches.
397/// Much more efficient for repos with many branches (like nixpkgs).
398pub fn branch_exists(repo: &str, owner: &str, branch: &str, domain: Option<&str>) -> bool {
399    let domain = domain.unwrap_or("github.com");
400    let client = ForgeClient::default();
401    let forge_type = client.detect_forge_type(domain);
402
403    match forge_type {
404        ForgeType::GitHub => client.branch_exists_github(repo, owner, branch),
405        ForgeType::Gitea => client.branch_exists_gitea(repo, owner, domain, branch),
406        ForgeType::Unknown => client.branch_exists_gitea(repo, owner, domain, branch),
407    }
408}
409
410/// Check multiple branches and return which ones exist.
411/// More efficient than get_branches for known candidate branches.
412pub fn filter_existing_branches(
413    repo: &str,
414    owner: &str,
415    candidates: &[String],
416    domain: Option<&str>,
417) -> Vec<String> {
418    candidates
419        .iter()
420        .filter(|branch| branch_exists(repo, owner, branch, domain))
421        .cloned()
422        .collect()
423}
424
425#[derive(Deserialize, Debug, Clone)]
426struct NixConfig {
427    #[serde(rename = "access-tokens")]
428    access_tokens: Option<AccessTokens>,
429}
430
431impl NixConfig {
432    fn forge_token(&self, domain: &str) -> Option<String> {
433        self.access_tokens.as_ref()?.value.get(domain).cloned()
434    }
435}
436
437#[derive(Deserialize, Debug, Clone)]
438struct AccessTokens {
439    value: HashMap<String, String>,
440}
441
442fn get_forge_token(domain: &str) -> Option<String> {
443    // Try to get token from nix config
444    if let Ok(output) = Command::new("nix")
445        .arg("config")
446        .arg("show")
447        .arg("--json")
448        .output()
449        && let Ok(stdout) = String::from_utf8(output.stdout)
450        && let Ok(config) = serde_json::from_str::<NixConfig>(&stdout)
451        && let Some(token) = config.forge_token(domain)
452    {
453        return Some(token);
454    }
455
456    // Fallback to environment variables
457    if let Ok(token) = std::env::var("GITEA_TOKEN") {
458        return Some(token);
459    }
460    if let Ok(token) = std::env::var("FORGEJO_TOKEN") {
461        return Some(token);
462    }
463    if domain == "github.com"
464        && let Ok(token) = std::env::var("GITHUB_TOKEN")
465    {
466        return Some(token);
467    }
468
469    None
470}
471
472impl From<IntermediaryTags> for Tags {
473    fn from(value: IntermediaryTags) -> Self {
474        let mut versions = vec![];
475        for itag in value.0 {
476            let parsed = parse_ref(&itag.name, false);
477            let normalized = parsed.normalized_for_semver;
478            match Version::parse(&normalized) {
479                Ok(semver) => {
480                    versions.push(TagVersion {
481                        version: semver,
482                        original: parsed.original_ref,
483                    });
484                }
485                Err(e) => {
486                    tracing::error!("Could not parse version {:?}", e);
487                }
488            }
489        }
490        Tags { versions }
491    }
492}
493
494impl From<IntermediaryBranches> for Branches {
495    fn from(value: IntermediaryBranches) -> Self {
496        Branches {
497            names: value.0.into_iter().map(|b| b.name).collect(),
498        }
499    }
500}