Skip to main content

mars_agents/source/
parse.rs

1use std::path::PathBuf;
2
3use crate::types::SourceUrl;
4
5/// Classification of source input syntax.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum SourceFormat {
8    LocalPath,
9    GitHubShorthand,
10    HttpsUrl,
11    SshUrl,
12    BareDomain,
13    Unknown,
14}
15
16/// Structured result of parsing a CLI source specifier.
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct ParsedSourceSpec {
19    pub format: SourceFormat,
20    pub raw: String,
21    pub url: Option<SourceUrl>,
22    pub path: Option<PathBuf>,
23    pub version: Option<String>,
24    pub name: String,
25}
26
27/// Errors raised while parsing a source specifier.
28#[derive(Debug, thiserror::Error, PartialEq, Eq)]
29pub enum ParseError {
30    #[error(
31        "cannot determine source type for {input:?} — expected a path, URL, or owner/repo shorthand"
32    )]
33    UnrecognizedFormat { input: String },
34
35    #[error("SSH URL {input:?} is missing the colon-separated path (expected git@host:owner/repo)")]
36    MalformedSshUrl { input: String },
37
38    #[error("cannot derive a name from {input:?}")]
39    CannotDeriveName { input: String },
40
41    #[error("URL {input:?} has no path component")]
42    EmptyUrlPath { input: String },
43}
44
45/// Normalized source kind produced by `normalize()`.
46#[derive(Debug, Clone, PartialEq, Eq)]
47pub enum NormalizedSource {
48    Git(SourceUrl),
49    Path(PathBuf),
50}
51
52/// Classify source input without mutating it.
53pub fn classify(input: &str) -> SourceFormat {
54    if input.starts_with('.') || input.starts_with('/') || input.starts_with('~') {
55        return SourceFormat::LocalPath;
56    }
57
58    if input.starts_with("https://") || input.starts_with("http://") {
59        return SourceFormat::HttpsUrl;
60    }
61
62    if !input.contains("://")
63        && let Some(at_pos) = input.find('@')
64        && let Some(colon_rel) = input[at_pos + 1..].find(':')
65    {
66        let colon_abs = at_pos + 1 + colon_rel;
67        if colon_abs + 1 < input.len() {
68            return SourceFormat::SshUrl;
69        }
70    }
71
72    let shorthand_base = strip_suffix_at(input);
73    let slash_count = shorthand_base.chars().filter(|&c| c == '/').count();
74    if slash_count == 1 && !shorthand_base.contains(':') {
75        let mut segments = shorthand_base.split('/');
76        let owner = segments.next().unwrap_or_default();
77        let repo = segments.next().unwrap_or_default();
78        if !owner.is_empty() && !repo.is_empty() && !owner.contains('.') {
79            return SourceFormat::GitHubShorthand;
80        }
81    }
82
83    if input.contains('.') && input.contains('/') && !input.contains(':') {
84        return SourceFormat::BareDomain;
85    }
86
87    SourceFormat::Unknown
88}
89
90/// Split optional `@version` suffix.
91///
92/// Version suffix extraction is format-aware:
93/// - Enabled for shorthand and HTTPS/bare-domain URLs.
94/// - Disabled for SSH URLs and local paths.
95pub fn split_version(input: &str, format: SourceFormat) -> (&str, Option<&str>) {
96    match format {
97        SourceFormat::LocalPath | SourceFormat::SshUrl | SourceFormat::Unknown => (input, None),
98        SourceFormat::GitHubShorthand | SourceFormat::HttpsUrl | SourceFormat::BareDomain => {
99            if let Some((base, suffix)) = input.rsplit_once('@') {
100                if suffix.is_empty() {
101                    return (input, None);
102                }
103                (base, Some(suffix))
104            } else {
105                (input, None)
106            }
107        }
108    }
109}
110
111/// Normalize input into canonical URL/path form.
112pub fn normalize(input: &str, format: SourceFormat) -> Result<NormalizedSource, ParseError> {
113    match format {
114        SourceFormat::LocalPath => Ok(NormalizedSource::Path(PathBuf::from(input))),
115        SourceFormat::GitHubShorthand => Ok(NormalizedSource::Git(SourceUrl::from(format!(
116            "https://github.com/{input}"
117        )))),
118        SourceFormat::HttpsUrl => {
119            let stripped = input.strip_suffix(".git").unwrap_or(input);
120            let stripped = stripped.trim_end_matches('/');
121            if !has_non_empty_path(stripped) {
122                return Err(ParseError::EmptyUrlPath {
123                    input: input.to_string(),
124                });
125            }
126            Ok(NormalizedSource::Git(SourceUrl::from(stripped.to_string())))
127        }
128        SourceFormat::SshUrl => {
129            let (user_host, path) =
130                input
131                    .split_once(':')
132                    .ok_or_else(|| ParseError::MalformedSshUrl {
133                        input: input.to_string(),
134                    })?;
135            let host = user_host
136                .split_once('@')
137                .map(|(_, host)| host)
138                .ok_or_else(|| ParseError::MalformedSshUrl {
139                    input: input.to_string(),
140                })?;
141            let path = path.trim_end_matches('/');
142            let path = path.trim_start_matches('/');
143            if host.is_empty() || path.is_empty() {
144                return Err(ParseError::MalformedSshUrl {
145                    input: input.to_string(),
146                });
147            }
148            Ok(NormalizedSource::Git(SourceUrl::from(input.to_string())))
149        }
150        SourceFormat::BareDomain => {
151            let stripped = input.strip_suffix(".git").unwrap_or(input);
152            let stripped = stripped.trim_end_matches('/');
153            if !has_non_empty_path(stripped) {
154                return Err(ParseError::EmptyUrlPath {
155                    input: input.to_string(),
156                });
157            }
158            Ok(NormalizedSource::Git(SourceUrl::from(format!(
159                "https://{stripped}"
160            ))))
161        }
162        SourceFormat::Unknown => Err(ParseError::UnrecognizedFormat {
163            input: input.to_string(),
164        }),
165    }
166}
167
168/// Extract hostname from a URL-like git source string.
169///
170/// Removes scheme, user info, port, and path; returns only the host.
171pub fn extract_hostname(input: &str) -> Option<String> {
172    let trimmed = input.trim();
173    if trimmed.is_empty() {
174        return None;
175    }
176
177    // SSH shorthand: git@host:owner/repo(.git)
178    if !trimmed.contains("://")
179        && let Some((user_host, path)) = trimmed.split_once(':')
180        && let Some((_, host)) = user_host.split_once('@')
181        && !path.trim_matches('/').is_empty()
182    {
183        return Some(host.to_string());
184    }
185
186    let mut rest = trimmed;
187    if let Some((_, tail)) = rest.split_once("://") {
188        rest = tail;
189    }
190    if let Some((userinfo, tail)) = rest.split_once('@')
191        && !userinfo.contains('/')
192    {
193        rest = tail;
194    }
195
196    let authority = rest.split('/').next().unwrap_or(rest);
197    if authority.is_empty() {
198        return None;
199    }
200
201    let host = authority.split(':').next().unwrap_or(authority);
202    if host.is_empty() {
203        None
204    } else {
205        Some(host.to_string())
206    }
207}
208
209/// Derive source display name from normalized source.
210pub fn derive_name(source: &NormalizedSource) -> Result<String, ParseError> {
211    match source {
212        NormalizedSource::Git(url) => {
213            let name = url
214                .rsplit('/')
215                .next()
216                .filter(|s| !s.is_empty())
217                .ok_or_else(|| ParseError::CannotDeriveName {
218                    input: url.to_string(),
219                })?;
220            Ok(name.strip_suffix(".git").unwrap_or(name).to_string())
221        }
222        NormalizedSource::Path(path) => {
223            let name = path
224                .file_name()
225                .and_then(|n| n.to_str())
226                .filter(|s| !s.is_empty())
227                .ok_or_else(|| ParseError::CannotDeriveName {
228                    input: path.display().to_string(),
229                })?;
230            Ok(name.to_string())
231        }
232    }
233}
234
235/// Parse a source specifier into a normalized structured value.
236pub fn parse(input: &str) -> Result<ParsedSourceSpec, ParseError> {
237    let format = classify(input);
238    if format == SourceFormat::Unknown {
239        return Err(ParseError::UnrecognizedFormat {
240            input: input.to_string(),
241        });
242    }
243
244    let (base, version) = split_version(input, format);
245    let normalized = normalize(base, format)?;
246    let name = derive_name(&normalized)?;
247
248    let (url, path) = match normalized {
249        NormalizedSource::Git(url) => (Some(url), None),
250        NormalizedSource::Path(path) => (None, Some(path)),
251    };
252
253    Ok(ParsedSourceSpec {
254        format,
255        raw: input.to_string(),
256        url,
257        path,
258        version: version.map(str::to_string),
259        name,
260    })
261}
262
263fn strip_suffix_at(input: &str) -> &str {
264    match input.rsplit_once('@') {
265        Some((base, suffix)) if !suffix.is_empty() => base,
266        _ => input,
267    }
268}
269
270fn has_non_empty_path(input: &str) -> bool {
271    let mut rest = input;
272    if let Some((_, tail)) = rest.split_once("://") {
273        rest = tail;
274    }
275    if let Some((userinfo, tail)) = rest.split_once('@')
276        && !userinfo.contains('/')
277    {
278        rest = tail;
279    }
280
281    if let Some((_, path)) = rest.split_once('/') {
282        !path.trim_matches('/').is_empty()
283    } else {
284        false
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use std::path::Path;
292
293    #[test]
294    fn classify_detects_known_formats() {
295        assert_eq!(classify("./local"), SourceFormat::LocalPath);
296        assert_eq!(classify("owner/repo"), SourceFormat::GitHubShorthand);
297        assert_eq!(
298            classify("https://github.com/org/repo"),
299            SourceFormat::HttpsUrl
300        );
301        assert_eq!(
302            classify("git@github.com:org/repo.git"),
303            SourceFormat::SshUrl
304        );
305        assert_eq!(classify("github.com/org/repo"), SourceFormat::BareDomain);
306        assert_eq!(classify("invalid"), SourceFormat::Unknown);
307    }
308
309    #[test]
310    fn split_version_only_for_supported_formats() {
311        assert_eq!(
312            split_version("owner/repo@v1", SourceFormat::GitHubShorthand),
313            ("owner/repo", Some("v1"))
314        );
315        assert_eq!(
316            split_version("https://github.com/org/repo@v2", SourceFormat::HttpsUrl),
317            ("https://github.com/org/repo", Some("v2"))
318        );
319        assert_eq!(
320            split_version("github.com/org/repo@latest", SourceFormat::BareDomain),
321            ("github.com/org/repo", Some("latest"))
322        );
323        assert_eq!(
324            split_version("git@github.com:org/repo.git@v1", SourceFormat::SshUrl),
325            ("git@github.com:org/repo.git@v1", None)
326        );
327        assert_eq!(
328            split_version("./local@v1", SourceFormat::LocalPath),
329            ("./local@v1", None)
330        );
331    }
332
333    #[test]
334    fn normalize_handles_all_git_formats() {
335        assert_eq!(
336            normalize("owner/repo", SourceFormat::GitHubShorthand).unwrap(),
337            NormalizedSource::Git(SourceUrl::from("https://github.com/owner/repo"))
338        );
339        assert_eq!(
340            normalize("https://github.com/org/repo.git", SourceFormat::HttpsUrl).unwrap(),
341            NormalizedSource::Git(SourceUrl::from("https://github.com/org/repo"))
342        );
343        assert_eq!(
344            normalize("git@github.com:org/repo.git", SourceFormat::SshUrl).unwrap(),
345            NormalizedSource::Git(SourceUrl::from("git@github.com:org/repo.git"))
346        );
347        assert_eq!(
348            normalize("github.com/org/repo.git", SourceFormat::BareDomain).unwrap(),
349            NormalizedSource::Git(SourceUrl::from("https://github.com/org/repo"))
350        );
351    }
352
353    #[test]
354    fn normalize_ssh_rejects_malformed() {
355        let err = normalize("git@github.com", SourceFormat::SshUrl).unwrap_err();
356        assert!(matches!(err, ParseError::MalformedSshUrl { .. }));
357    }
358
359    #[test]
360    fn derive_name_from_git_and_path() {
361        assert_eq!(
362            derive_name(&NormalizedSource::Git(SourceUrl::from(
363                "https://github.com/org/repo"
364            )))
365            .unwrap(),
366            "repo"
367        );
368        assert_eq!(
369            derive_name(&NormalizedSource::Git(SourceUrl::from(
370                "git@github.com:org/repo.git"
371            )))
372            .unwrap(),
373            "repo"
374        );
375        assert_eq!(
376            derive_name(&NormalizedSource::Path(PathBuf::from("../my-agents"))).unwrap(),
377            "my-agents"
378        );
379    }
380
381    #[test]
382    fn extract_hostname_handles_supported_formats() {
383        assert_eq!(
384            extract_hostname("https://github.com/org/repo"),
385            Some("github.com".to_string())
386        );
387        assert_eq!(
388            extract_hostname("https://git@github.com:8443/org/repo"),
389            Some("github.com".to_string())
390        );
391        assert_eq!(
392            extract_hostname("git@github.com:org/repo.git"),
393            Some("github.com".to_string())
394        );
395        assert_eq!(
396            extract_hostname("github.com/org/repo"),
397            Some("github.com".to_string())
398        );
399    }
400
401    #[test]
402    fn parse_matrix_examples() {
403        struct Case {
404            input: &'static str,
405            format: SourceFormat,
406            url: Option<&'static str>,
407            path: Option<&'static str>,
408            version: Option<&'static str>,
409            name: &'static str,
410        }
411
412        let cases = [
413            Case {
414                input: "./my-agents",
415                format: SourceFormat::LocalPath,
416                url: None,
417                path: Some("./my-agents"),
418                version: None,
419                name: "my-agents",
420            },
421            Case {
422                input: "haowjy/meridian-base",
423                format: SourceFormat::GitHubShorthand,
424                url: Some("https://github.com/haowjy/meridian-base"),
425                path: None,
426                version: None,
427                name: "meridian-base",
428            },
429            Case {
430                input: "haowjy/meridian-base@v1.0",
431                format: SourceFormat::GitHubShorthand,
432                url: Some("https://github.com/haowjy/meridian-base"),
433                path: None,
434                version: Some("v1.0"),
435                name: "meridian-base",
436            },
437            Case {
438                input: "https://github.com/org/repo.git",
439                format: SourceFormat::HttpsUrl,
440                url: Some("https://github.com/org/repo"),
441                path: None,
442                version: None,
443                name: "repo",
444            },
445            Case {
446                input: "https://github.com/org/repo@v2",
447                format: SourceFormat::HttpsUrl,
448                url: Some("https://github.com/org/repo"),
449                path: None,
450                version: Some("v2"),
451                name: "repo",
452            },
453            Case {
454                input: "git@github.com:org/repo.git",
455                format: SourceFormat::SshUrl,
456                url: Some("git@github.com:org/repo.git"),
457                path: None,
458                version: None,
459                name: "repo",
460            },
461            Case {
462                input: "git@github.com:org/repo.git@v1.0",
463                format: SourceFormat::SshUrl,
464                url: Some("git@github.com:org/repo.git@v1.0"),
465                path: None,
466                version: None,
467                name: "repo.git@v1.0",
468            },
469            Case {
470                input: "github.com/haowjy/meridian-base",
471                format: SourceFormat::BareDomain,
472                url: Some("https://github.com/haowjy/meridian-base"),
473                path: None,
474                version: None,
475                name: "meridian-base",
476            },
477            Case {
478                input: "github.com/haowjy/meridian-base@latest",
479                format: SourceFormat::BareDomain,
480                url: Some("https://github.com/haowjy/meridian-base"),
481                path: None,
482                version: Some("latest"),
483                name: "meridian-base",
484            },
485        ];
486
487        for case in cases {
488            let parsed = parse(case.input).unwrap();
489            assert_eq!(
490                parsed.format, case.format,
491                "format mismatch for {}",
492                case.input
493            );
494            assert_eq!(
495                parsed.url.as_deref(),
496                case.url,
497                "url mismatch for {}",
498                case.input
499            );
500            assert_eq!(
501                parsed.path.as_deref(),
502                case.path.map(Path::new),
503                "path mismatch for {}",
504                case.input
505            );
506            assert_eq!(
507                parsed.version.as_deref(),
508                case.version,
509                "version mismatch for {}",
510                case.input
511            );
512            assert_eq!(parsed.name, case.name, "name mismatch for {}", case.input);
513        }
514    }
515
516    #[test]
517    fn parse_unknown_returns_error() {
518        let err = parse("source").unwrap_err();
519        assert!(matches!(err, ParseError::UnrecognizedFormat { .. }));
520    }
521}