flake_edit/
api.rs

1use std::collections::HashMap;
2use std::process::Command;
3
4use semver::Version;
5use serde::Deserialize;
6use thiserror::Error;
7use ureq::Agent;
8
9use crate::version::parse_ref;
10
11#[derive(Error, Debug)]
12pub enum ApiError {
13    #[error("HTTP request failed: {0}")]
14    HttpError(#[from] ureq::Error),
15
16    #[error("JSON parsing failed: {0}")]
17    JsonError(#[from] serde_json::Error),
18
19    #[error("UTF-8 conversion failed: {0}")]
20    Utf8Error(#[from] std::string::FromUtf8Error),
21
22    #[error("IO error: {0}")]
23    IoError(#[from] std::io::Error),
24
25    #[error("No tags found for repository")]
26    NoTagsFound,
27
28    #[error("Invalid domain or repository: {0}")]
29    InvalidInput(String),
30}
31
32/// Headers for HTTP requests
33#[derive(Clone, Default)]
34struct Headers {
35    user_agent: Option<String>,
36    authorization: Option<String>,
37}
38
39pub struct ForgeClient {
40    agent: Agent,
41}
42
43impl Default for ForgeClient {
44    fn default() -> Self {
45        Self {
46            agent: Agent::new_with_defaults(),
47        }
48    }
49}
50
51impl ForgeClient {
52    fn get(&self, url: &str, headers: &Headers) -> Result<String, ApiError> {
53        let mut request = self.agent.get(url);
54        if let Some(ref ua) = headers.user_agent {
55            request = request.header("User-Agent", ua);
56        }
57        if let Some(ref auth) = headers.authorization {
58            request = request.header("Authorization", auth);
59        }
60        let body = request.call()?.body_mut().read_to_string()?;
61        Ok(body)
62    }
63
64    /// Check if a URL returns a successful (2xx) response
65    fn head_ok(&self, url: &str, headers: &Headers) -> bool {
66        let mut request = self.agent.get(url);
67        if let Some(ref ua) = headers.user_agent {
68            request = request.header("User-Agent", ua);
69        }
70        if let Some(ref auth) = headers.authorization {
71            request = request.header("Authorization", auth);
72        }
73        request.call().is_ok()
74    }
75
76    fn base_headers() -> Headers {
77        Headers {
78            user_agent: Some("flake-edit".to_string()),
79            authorization: None,
80        }
81    }
82
83    /// Create headers with optional Bearer token authentication for the given domain.
84    fn auth_headers(domain: &str) -> Headers {
85        let mut headers = Self::base_headers();
86        if let Some(token) = get_forge_token(domain) {
87            tracing::debug!("Found token for {}", domain);
88            headers.authorization = Some(format!("Bearer {token}"));
89        }
90        headers
91    }
92
93    pub fn detect_forge_type(&self, domain: &str) -> ForgeType {
94        if domain == "github.com" {
95            return ForgeType::GitHub;
96        }
97
98        tracing::debug!("Attempting to detect forge type for domain: {}", domain);
99        let headers = Self::base_headers();
100
101        // Try both HTTPS and HTTP for each endpoint
102        for scheme in ["https", "http"] {
103            // Try Forgejo version endpoint first
104            let forgejo_url = format!("{}://{}/api/forgejo/v1/version", scheme, domain);
105            tracing::debug!("Trying Forgejo endpoint: {}", forgejo_url);
106            if let Ok(text) = self.get(&forgejo_url, &headers) {
107                tracing::debug!("Forgejo endpoint response body: {}", text);
108                if let Some(forge_type) = parse_forge_version(&text) {
109                    tracing::info!("Detected Forgejo/Gitea at {}", domain);
110                    return forge_type;
111                }
112            }
113
114            // Try Gitea version endpoint
115            let gitea_url = format!("{}://{}/api/v1/version", scheme, domain);
116            tracing::debug!("Trying Gitea endpoint: {}", gitea_url);
117            if let Ok(text) = self.get(&gitea_url, &headers) {
118                tracing::debug!("Gitea endpoint response body: {}", text);
119                if let Some(forge_type) = parse_forge_version(&text) {
120                    tracing::info!("Detected Forgejo/Gitea at {}", domain);
121                    return forge_type;
122                }
123                // Plain Gitea just has a version number without +gitea or +forgejo
124                if serde_json::from_str::<ForgeVersion>(&text).is_ok() {
125                    tracing::info!("Detected Gitea at {}", domain);
126                    return ForgeType::Gitea;
127                }
128            }
129        }
130
131        tracing::warn!(
132            "Could not detect forge type for {}, will try GitHub API as fallback",
133            domain
134        );
135        ForgeType::Unknown
136    }
137
138    pub fn query_github_tags(&self, repo: &str, owner: &str) -> Result<IntermediaryTags, ApiError> {
139        let headers = Self::auth_headers("github.com");
140        let body = self.get(
141            &format!("https://api.github.com/repos/{}/{}/tags", owner, repo),
142            &headers,
143        )?;
144
145        tracing::debug!("Body from api: {body}");
146        let tags = serde_json::from_str::<IntermediaryTags>(&body)?;
147        Ok(tags)
148    }
149
150    /// Check if a specific branch exists (returns true/false, no error on 404)
151    pub fn branch_exists_github(&self, repo: &str, owner: &str, branch: &str) -> bool {
152        let headers = Self::auth_headers("github.com");
153        let url = format!(
154            "https://api.github.com/repos/{}/{}/branches/{}",
155            owner, repo, branch
156        );
157
158        self.head_ok(&url, &headers)
159    }
160
161    /// Check if a specific branch exists on Gitea/Forgejo
162    pub fn branch_exists_gitea(&self, repo: &str, owner: &str, domain: &str, branch: &str) -> bool {
163        let headers = Self::auth_headers(domain);
164        for scheme in ["https", "http"] {
165            let url = format!(
166                "{}://{}/api/v1/repos/{}/{}/branches/{}",
167                scheme, domain, owner, repo, branch
168            );
169            if self.head_ok(&url, &headers) {
170                return true;
171            }
172        }
173        false
174    }
175
176    pub fn query_gitea_tags(
177        &self,
178        repo: &str,
179        owner: &str,
180        domain: &str,
181    ) -> Result<IntermediaryTags, ApiError> {
182        let headers = Self::auth_headers(domain);
183
184        // Try HTTPS first, then HTTP
185        for scheme in ["https", "http"] {
186            let url = format!(
187                "{}://{}/api/v1/repos/{}/{}/tags",
188                scheme, domain, owner, repo
189            );
190            tracing::debug!("Trying Gitea tags endpoint: {}", url);
191
192            if let Ok(body) = self.get(&url, &headers) {
193                tracing::debug!("Body from Gitea API: {body}");
194                if let Ok(tags) = serde_json::from_str::<IntermediaryTags>(&body) {
195                    return Ok(tags);
196                }
197            }
198        }
199
200        Err(ApiError::NoTagsFound)
201    }
202
203    pub fn query_github_branches(
204        &self,
205        repo: &str,
206        owner: &str,
207    ) -> Result<IntermediaryBranches, ApiError> {
208        let headers = Self::auth_headers("github.com");
209
210        let mut all_branches = Vec::new();
211        let mut page = 1;
212        const MAX_PAGES: u32 = 20; // Safety limit to avoid infinite loops
213
214        loop {
215            let url = format!(
216                "https://api.github.com/repos/{}/{}/branches?per_page=100&page={}",
217                owner, repo, page
218            );
219            tracing::debug!("Fetching branches page {}: {}", page, url);
220
221            let body = self.get(&url, &headers)?;
222            let page_branches = serde_json::from_str::<IntermediaryBranches>(&body)?;
223
224            let count = page_branches.0.len();
225            tracing::debug!("Got {} branches on page {}", count, page);
226
227            all_branches.extend(page_branches.0);
228
229            // Stop if we got fewer than 100 (last page) or hit max pages
230            if count < 100 || page >= MAX_PAGES {
231                break;
232            }
233
234            page += 1;
235        }
236
237        tracing::debug!("Total branches fetched: {}", all_branches.len());
238        Ok(IntermediaryBranches(all_branches))
239    }
240
241    pub fn query_gitea_branches(
242        &self,
243        repo: &str,
244        owner: &str,
245        domain: &str,
246    ) -> Result<IntermediaryBranches, ApiError> {
247        let headers = Self::auth_headers(domain);
248
249        let mut all_branches = Vec::new();
250        let mut page = 1;
251        const MAX_PAGES: u32 = 20;
252
253        // Try HTTPS first, then HTTP
254        for scheme in ["https", "http"] {
255            loop {
256                let url = format!(
257                    "{}://{}/api/v1/repos/{}/{}/branches?limit=50&page={}",
258                    scheme, domain, owner, repo, page
259                );
260                tracing::debug!("Trying Gitea branches endpoint: {}", url);
261
262                match self.get(&url, &headers) {
263                    Ok(body) => {
264                        tracing::debug!("Body from Gitea API: {body}");
265                        match serde_json::from_str::<IntermediaryBranches>(&body) {
266                            Ok(page_branches) => {
267                                let count = page_branches.0.len();
268                                all_branches.extend(page_branches.0);
269
270                                if count < 50 || page >= MAX_PAGES {
271                                    return Ok(IntermediaryBranches(all_branches));
272                                }
273                                page += 1;
274                            }
275                            Err(_) => break, // Try next scheme
276                        }
277                    }
278                    Err(_) => break, // Try next scheme
279                }
280            }
281
282            if !all_branches.is_empty() {
283                return Ok(IntermediaryBranches(all_branches));
284            }
285            page = 1; // Reset for next scheme
286        }
287
288        Err(ApiError::InvalidInput("Could not fetch branches".into()))
289    }
290}
291
292#[derive(Deserialize, Debug)]
293pub struct IntermediaryTags(Vec<IntermediaryTag>);
294
295#[derive(Deserialize, Debug)]
296pub struct IntermediaryBranches(Vec<IntermediaryBranch>);
297
298#[derive(Deserialize, Debug)]
299pub struct IntermediaryBranch {
300    name: String,
301}
302
303#[derive(Debug, Default)]
304pub struct Branches {
305    pub names: Vec<String>,
306}
307
308#[derive(Debug)]
309pub struct Tags {
310    versions: Vec<TagVersion>,
311}
312
313impl Tags {
314    pub fn get_latest_tag(&mut self) -> Option<String> {
315        self.sort();
316        self.versions.last().map(|tag| tag.original.clone())
317    }
318    pub fn sort(&mut self) {
319        self.versions
320            .sort_by(|a, b| a.version.cmp_precedence(&b.version));
321    }
322}
323
324#[derive(Deserialize, Debug)]
325pub struct IntermediaryTag {
326    name: String,
327}
328
329#[derive(Debug)]
330struct TagVersion {
331    version: Version,
332    original: String,
333}
334
335#[derive(Deserialize, Debug)]
336struct ForgeVersion {
337    version: String,
338}
339
340#[derive(Debug, PartialEq)]
341pub enum ForgeType {
342    GitHub,
343    Gitea, // Covers both Gitea and Forgejo
344    Unknown,
345}
346
347fn parse_forge_version(json: &str) -> Option<ForgeType> {
348    serde_json::from_str::<ForgeVersion>(json)
349        .ok()
350        .and_then(|v| {
351            if v.version.contains("+forgejo") || v.version.contains("+gitea") {
352                Some(ForgeType::Gitea)
353            } else {
354                None
355            }
356        })
357}
358
359// Test helpers are always available but not documented
360#[doc(hidden)]
361pub mod test_helpers {
362    use super::*;
363
364    pub fn parse_forge_version_test(json: &str) -> Option<ForgeType> {
365        parse_forge_version(json)
366    }
367}
368
369pub fn get_tags(repo: &str, owner: &str, domain: Option<&str>) -> Result<Tags, ApiError> {
370    let domain = domain.unwrap_or("github.com");
371    let client = ForgeClient::default();
372    let forge_type = client.detect_forge_type(domain);
373
374    tracing::debug!("Detected forge type for {}: {:?}", domain, forge_type);
375
376    let tags = match forge_type {
377        ForgeType::GitHub => client.query_github_tags(repo, owner)?,
378        ForgeType::Gitea => client.query_gitea_tags(repo, owner, domain)?,
379        ForgeType::Unknown => {
380            tracing::warn!("Unknown forge type for {}, trying Gitea API", domain);
381            client.query_gitea_tags(repo, owner, domain)?
382        }
383    };
384
385    Ok(tags.into())
386}
387
388pub fn get_branches(repo: &str, owner: &str, domain: Option<&str>) -> Result<Branches, ApiError> {
389    let domain = domain.unwrap_or("github.com");
390    let client = ForgeClient::default();
391    let forge_type = client.detect_forge_type(domain);
392
393    tracing::debug!(
394        "Fetching branches for {}/{} on {} ({:?})",
395        owner,
396        repo,
397        domain,
398        forge_type
399    );
400
401    let branches = match forge_type {
402        ForgeType::GitHub => client.query_github_branches(repo, owner)?,
403        ForgeType::Gitea => client.query_gitea_branches(repo, owner, domain)?,
404        ForgeType::Unknown => {
405            tracing::warn!("Unknown forge type for {}, trying Gitea API", domain);
406            client.query_gitea_branches(repo, owner, domain)?
407        }
408    };
409
410    Ok(branches.into())
411}
412
413/// Check if a specific branch exists without listing all branches.
414/// Much more efficient for repos with many branches (like nixpkgs).
415pub fn branch_exists(repo: &str, owner: &str, branch: &str, domain: Option<&str>) -> bool {
416    let domain = domain.unwrap_or("github.com");
417    let client = ForgeClient::default();
418    let forge_type = client.detect_forge_type(domain);
419
420    match forge_type {
421        ForgeType::GitHub => client.branch_exists_github(repo, owner, branch),
422        ForgeType::Gitea => client.branch_exists_gitea(repo, owner, domain, branch),
423        ForgeType::Unknown => client.branch_exists_gitea(repo, owner, domain, branch),
424    }
425}
426
427/// Check multiple branches and return which ones exist.
428/// More efficient than get_branches for known candidate branches.
429pub fn filter_existing_branches(
430    repo: &str,
431    owner: &str,
432    candidates: &[String],
433    domain: Option<&str>,
434) -> Vec<String> {
435    candidates
436        .iter()
437        .filter(|branch| branch_exists(repo, owner, branch, domain))
438        .cloned()
439        .collect()
440}
441
442#[derive(Deserialize, Debug, Clone)]
443struct NixConfig {
444    #[serde(rename = "access-tokens")]
445    access_tokens: Option<AccessTokens>,
446}
447
448impl NixConfig {
449    fn forge_token(&self, domain: &str) -> Option<String> {
450        self.access_tokens.as_ref()?.value.get(domain).cloned()
451    }
452}
453
454#[derive(Deserialize, Debug, Clone)]
455struct AccessTokens {
456    value: HashMap<String, String>,
457}
458
459fn get_forge_token(domain: &str) -> Option<String> {
460    // Try to get token from nix config
461    if let Ok(output) = Command::new("nix")
462        .arg("config")
463        .arg("show")
464        .arg("--json")
465        .output()
466        && let Ok(stdout) = String::from_utf8(output.stdout)
467        && let Ok(config) = serde_json::from_str::<NixConfig>(&stdout)
468        && let Some(token) = config.forge_token(domain)
469    {
470        return Some(token);
471    }
472
473    // Fallback to environment variables
474    if let Ok(token) = std::env::var("GITEA_TOKEN") {
475        return Some(token);
476    }
477    if let Ok(token) = std::env::var("FORGEJO_TOKEN") {
478        return Some(token);
479    }
480    if domain == "github.com"
481        && let Ok(token) = std::env::var("GITHUB_TOKEN")
482    {
483        return Some(token);
484    }
485
486    None
487}
488
489impl From<IntermediaryTags> for Tags {
490    fn from(value: IntermediaryTags) -> Self {
491        let mut versions = vec![];
492        for itag in value.0 {
493            let parsed = parse_ref(&itag.name, false);
494            let normalized = parsed.normalized_for_semver;
495            match Version::parse(&normalized) {
496                Ok(semver) => {
497                    versions.push(TagVersion {
498                        version: semver,
499                        original: parsed.original_ref,
500                    });
501                }
502                Err(e) => {
503                    tracing::error!("Could not parse version {:?}", e);
504                }
505            }
506        }
507        Tags { versions }
508    }
509}
510
511impl From<IntermediaryBranches> for Branches {
512    fn from(value: IntermediaryBranches) -> Self {
513        Branches {
514            names: value.0.into_iter().map(|b| b.name).collect(),
515        }
516    }
517}