Skip to main content

sley_remote/
protocol.rs

1use std::env;
2use std::fmt;
3
4use sley_config::GitConfig;
5use sley_core::GitError;
6use sley_transport::{RemoteTransport, RemoteUrl};
7
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub enum TransportPolicyError {
10    NotAllowed { scheme: String },
11    UnknownConfigValue { key: String, value: String },
12}
13
14impl fmt::Display for TransportPolicyError {
15    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16        match self {
17            Self::NotAllowed { scheme } => write!(f, "transport '{scheme}' not allowed"),
18            Self::UnknownConfigValue { key, value } => {
19                write!(f, "unknown value for config '{key}': {value}")
20            }
21        }
22    }
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26enum ProtocolAllow {
27    Never,
28    UserOnly,
29    Always,
30}
31
32/// Classify the policy protocol name for a raw transport URL.
33///
34/// This mirrors the names `transport_check_allowed()` receives in upstream Git:
35/// local paths and `file://` are `file`; scp-like and SSH URL forms are `ssh`;
36/// `git+ssh://` and `ssh+git://` are deprecated aliases for `ssh`; `<name>::...`
37/// and unknown URL schemes use their helper name.
38pub fn transport_scheme_for_url(url: &str) -> String {
39    if let Some(helper_end) = helper_scheme_end(url) {
40        return url[..helper_end].to_ascii_lowercase();
41    }
42    if let Some(scheme_end) = url.find("://") {
43        let scheme = url[..scheme_end].to_ascii_lowercase();
44        return match scheme.as_str() {
45            "git+ssh" | "ssh+git" => "ssh".to_string(),
46            _ => scheme,
47        };
48    }
49    if scp_like_separator(url).is_some() {
50        "ssh".to_string()
51    } else {
52        "file".to_string()
53    }
54}
55
56pub fn transport_scheme_for_remote(remote: &RemoteUrl) -> &'static str {
57    match remote.transport {
58        RemoteTransport::Local | RemoteTransport::File => "file",
59        RemoteTransport::Ext => "ext",
60        RemoteTransport::Ssh => "ssh",
61        RemoteTransport::Git => "git",
62        RemoteTransport::Http => "http",
63        RemoteTransport::Https => "https",
64    }
65}
66
67pub fn check_transport_allowed(
68    scheme: &str,
69    config: Option<&GitConfig>,
70    from_user: Option<bool>,
71) -> std::result::Result<(), TransportPolicyError> {
72    if is_transport_allowed(scheme, config, from_user)? {
73        Ok(())
74    } else {
75        Err(TransportPolicyError::NotAllowed {
76            scheme: scheme.to_string(),
77        })
78    }
79}
80
81pub(crate) fn transport_policy_git_error(err: TransportPolicyError) -> GitError {
82    GitError::InvalidFormat(format!("fatal: {err}"))
83}
84
85pub fn is_transport_allowed(
86    scheme: &str,
87    config: Option<&GitConfig>,
88    from_user: Option<bool>,
89) -> std::result::Result<bool, TransportPolicyError> {
90    if let Ok(allow) = env::var("GIT_ALLOW_PROTOCOL") {
91        return Ok(allow.split(':').any(|entry| entry == scheme));
92    }
93    Ok(match protocol_config(scheme, config)? {
94        ProtocolAllow::Always => true,
95        ProtocolAllow::Never => false,
96        ProtocolAllow::UserOnly => from_user.unwrap_or_else(protocol_from_user),
97    })
98}
99
100fn protocol_config(
101    scheme: &str,
102    config: Option<&GitConfig>,
103) -> std::result::Result<ProtocolAllow, TransportPolicyError> {
104    if let Some(config) = config {
105        let key = format!("protocol.{scheme}.allow");
106        if let Some(value) = config.get("protocol", Some(scheme), "allow") {
107            return parse_protocol_config(&key, value);
108        }
109        if let Some(value) = config.get("protocol", None, "allow") {
110            return parse_protocol_config("protocol.allow", value);
111        }
112    }
113    Ok(match scheme {
114        "http" | "https" | "git" | "ssh" => ProtocolAllow::Always,
115        "ext" => ProtocolAllow::Never,
116        _ => ProtocolAllow::UserOnly,
117    })
118}
119
120fn parse_protocol_config(
121    key: &str,
122    value: &str,
123) -> std::result::Result<ProtocolAllow, TransportPolicyError> {
124    if value.eq_ignore_ascii_case("always") {
125        Ok(ProtocolAllow::Always)
126    } else if value.eq_ignore_ascii_case("never") {
127        Ok(ProtocolAllow::Never)
128    } else if value.eq_ignore_ascii_case("user") {
129        Ok(ProtocolAllow::UserOnly)
130    } else {
131        Err(TransportPolicyError::UnknownConfigValue {
132            key: key.to_string(),
133            value: value.to_string(),
134        })
135    }
136}
137
138fn protocol_from_user() -> bool {
139    env::var("GIT_PROTOCOL_FROM_USER")
140        .ok()
141        .and_then(|value| parse_git_bool(&value))
142        .unwrap_or(true)
143}
144
145fn parse_git_bool(value: &str) -> Option<bool> {
146    if value.is_empty() {
147        return Some(false);
148    }
149    match value.to_ascii_lowercase().as_str() {
150        "true" | "yes" | "on" => Some(true),
151        "false" | "no" | "off" => Some(false),
152        _ => value.parse::<i64>().ok().map(|number| number != 0),
153    }
154}
155
156fn helper_scheme_end(url: &str) -> Option<usize> {
157    let mut chars = url.char_indices();
158    let (_, first) = chars.next()?;
159    if !first.is_ascii_alphabetic() {
160        return None;
161    }
162    for (idx, ch) in chars {
163        if ch == ':' {
164            return url[idx..].starts_with("::").then_some(idx);
165        }
166        if !is_url_scheme_char(ch) {
167            return None;
168        }
169    }
170    None
171}
172
173fn is_url_scheme_char(ch: char) -> bool {
174    ch.is_ascii_alphanumeric() || matches!(ch, '+' | '-' | '.')
175}
176
177fn scp_like_separator(value: &str) -> Option<usize> {
178    let colon = if let Some(rest) = value.strip_prefix('[') {
179        let close = rest.find(']')?;
180        let colon = close + 2;
181        if value.as_bytes().get(colon) == Some(&b':') {
182            colon
183        } else {
184            return None;
185        }
186    } else {
187        value.find(':')?
188    };
189    if value[..colon].contains('/') {
190        return None;
191    }
192    if colon == 1
193        && value
194            .as_bytes()
195            .first()
196            .is_some_and(|byte| byte.is_ascii_alphabetic())
197        && (value.as_bytes().get(2) == Some(&b'/') || cfg!(windows))
198    {
199        return None;
200    }
201    Some(colon)
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use sley_config::{ConfigEntry, ConfigSection};
208
209    fn config(section: ConfigSection) -> GitConfig {
210        GitConfig {
211            sections: vec![section],
212            ..GitConfig::default()
213        }
214    }
215
216    #[test]
217    fn classifies_builtin_and_helper_url_forms() {
218        assert_eq!(transport_scheme_for_url("/repo.git"), "file");
219        assert_eq!(transport_scheme_for_url("file:///repo.git"), "file");
220        assert_eq!(transport_scheme_for_url("git://host/repo.git"), "git");
221        assert_eq!(transport_scheme_for_url("ssh://host/repo.git"), "ssh");
222        assert_eq!(transport_scheme_for_url("git+ssh://host/repo.git"), "ssh");
223        assert_eq!(transport_scheme_for_url("user@host:repo.git"), "ssh");
224        assert_eq!(
225            transport_scheme_for_url("ext::fake-remote %S repo.git"),
226            "ext"
227        );
228        assert_eq!(transport_scheme_for_url("foo://host/repo.git"), "foo");
229    }
230
231    #[test]
232    fn user_policy_honors_from_user_env() {
233        let cfg = config(ConfigSection::new(
234            "protocol",
235            Some("file".into()),
236            vec![ConfigEntry::new("allow", Some("user".into()))],
237        ));
238        assert!(is_transport_allowed("file", Some(&cfg), Some(true)).unwrap());
239        assert!(!is_transport_allowed("file", Some(&cfg), Some(false)).unwrap());
240    }
241}