use thiserror::Error;
#[derive(Clone, Debug, Error, PartialEq, Eq)]
pub enum ProtocolError {
#[error("protocol '{protocol}' is not allowed by GIT_ALLOW_PROTOCOL")]
NotAllowedByEnvironment {
protocol: String,
},
#[error("protocol '{protocol}' is not allowed")]
NotAllowed {
protocol: String,
},
#[error("unknown protocol.allow value '{value}' for protocol '{protocol}'")]
UnknownAllowValue {
protocol: String,
value: String,
},
}
#[derive(Clone, Debug, Default)]
pub struct ProtocolPolicyInputs {
pub git_allow_protocol: Option<String>,
pub git_protocol_from_user: Option<String>,
pub specific_allow: Option<String>,
pub blanket_allow: Option<String>,
}
pub fn check_protocol_allowed_with(
protocol: &str,
inputs: &ProtocolPolicyInputs,
) -> Result<(), ProtocolError> {
if let Some(val) = inputs.git_allow_protocol.as_deref() {
let allowed: Vec<&str> = val
.split([':', ','])
.map(str::trim)
.filter(|s| !s.is_empty())
.collect();
if allowed.contains(&protocol) {
return Ok(());
}
return Err(ProtocolError::NotAllowedByEnvironment {
protocol: protocol.to_owned(),
});
}
if let Some(ref val) = inputs.specific_allow {
return check_allow_value(protocol, val, inputs.git_protocol_from_user.as_deref());
}
if let Some(ref val) = inputs.blanket_allow {
return check_allow_value(protocol, val, inputs.git_protocol_from_user.as_deref());
}
match protocol.to_ascii_lowercase().as_str() {
"http" | "https" | "git" | "ssh" => Ok(()),
"ext" => Err(ProtocolError::NotAllowed {
protocol: protocol.to_owned(),
}),
_ => check_allow_value(protocol, "user", inputs.git_protocol_from_user.as_deref()),
}
}
fn check_allow_value(
protocol: &str,
value: &str,
git_protocol_from_user: Option<&str>,
) -> Result<(), ProtocolError> {
match value.to_lowercase().as_str() {
"always" => Ok(()),
"never" => Err(ProtocolError::NotAllowed {
protocol: protocol.to_owned(),
}),
"user" => {
if protocol_from_user(git_protocol_from_user) {
Ok(())
} else {
Err(ProtocolError::NotAllowed {
protocol: protocol.to_owned(),
})
}
}
other => Err(ProtocolError::UnknownAllowValue {
protocol: protocol.to_owned(),
value: other.to_owned(),
}),
}
}
#[must_use]
pub fn protocol_from_user(raw: Option<&str>) -> bool {
match raw {
None => true,
Some(v) => {
let v = v.trim().to_ascii_lowercase();
v.is_empty() || !matches!(v.as_str(), "0" | "false" | "no" | "off")
}
}
}
#[derive(Clone, Debug, Default)]
pub struct ClientProtocolVersionInputs {
pub config_param_version: Option<String>,
pub repo_config_version: Option<String>,
pub git_test_protocol_version: Option<String>,
}
#[must_use]
pub fn parse_protocol_version_digit(s: &str) -> Option<u8> {
match s.trim() {
"0" => Some(0),
"1" => Some(1),
"2" => Some(2),
_ => None,
}
}
#[must_use]
pub fn effective_client_protocol_version_from_inputs(inputs: &ClientProtocolVersionInputs) -> u8 {
if let Some(v) = inputs.config_param_version.as_deref() {
return parse_protocol_version_digit(v).unwrap_or(2);
}
if let Some(v) = inputs.repo_config_version.as_deref() {
return parse_protocol_version_digit(v).unwrap_or(2);
}
if let Some(raw) = inputs
.git_test_protocol_version
.as_deref()
.filter(|s| !s.is_empty())
{
return parse_protocol_version_digit(raw).unwrap_or(2);
}
2
}
#[must_use]
pub fn server_protocol_version_from_git_protocol(raw: Option<&str>) -> u8 {
let Some(raw) = raw else {
return 0;
};
let mut best = 0u8;
for part in raw.split(':') {
let Some(rest) = part.strip_prefix("version=") else {
continue;
};
let v = rest.parse::<u8>().unwrap_or(0);
if v > best {
best = v;
}
}
best
}
fn strip_protocol_version_entries(s: &str) -> String {
s.split(':')
.filter(|p| !p.trim_start().starts_with("version="))
.filter(|p| !p.is_empty())
.collect::<Vec<_>>()
.join(":")
}
#[must_use]
pub fn merged_git_protocol_value(client_wants: u8, existing: Option<&str>) -> Option<String> {
if client_wants == 0 {
return None;
}
let entry = format!("version={client_wants}");
Some(match existing {
Some(e) if !e.is_empty() => {
let base = strip_protocol_version_entries(e.trim());
if base.is_empty() {
entry
} else {
format!("{base}:{entry}")
}
}
_ => entry,
})
}