1use thiserror::Error;
7
8#[derive(Clone, Debug, Error, PartialEq, Eq)]
10pub enum ProtocolError {
11 #[error("protocol '{protocol}' is not allowed by GIT_ALLOW_PROTOCOL")]
13 NotAllowedByEnvironment {
14 protocol: String,
16 },
17 #[error("protocol '{protocol}' is not allowed")]
19 NotAllowed {
20 protocol: String,
22 },
23 #[error("unknown protocol.allow value '{value}' for protocol '{protocol}'")]
25 UnknownAllowValue {
26 protocol: String,
28 value: String,
30 },
31}
32
33#[derive(Clone, Debug, Default)]
35pub struct ProtocolPolicyInputs {
36 pub git_allow_protocol: Option<String>,
38 pub git_protocol_from_user: Option<String>,
40 pub specific_allow: Option<String>,
42 pub blanket_allow: Option<String>,
44}
45
46pub fn check_protocol_allowed_with(
55 protocol: &str,
56 inputs: &ProtocolPolicyInputs,
57) -> Result<(), ProtocolError> {
58 if let Some(val) = inputs.git_allow_protocol.as_deref() {
59 let allowed: Vec<&str> = val
60 .split([':', ','])
61 .map(str::trim)
62 .filter(|s| !s.is_empty())
63 .collect();
64 if allowed.contains(&protocol) {
65 return Ok(());
66 }
67 return Err(ProtocolError::NotAllowedByEnvironment {
68 protocol: protocol.to_owned(),
69 });
70 }
71
72 if let Some(ref val) = inputs.specific_allow {
73 return check_allow_value(protocol, val, inputs.git_protocol_from_user.as_deref());
74 }
75
76 if let Some(ref val) = inputs.blanket_allow {
77 return check_allow_value(protocol, val, inputs.git_protocol_from_user.as_deref());
78 }
79
80 match protocol.to_ascii_lowercase().as_str() {
81 "http" | "https" | "git" | "ssh" => Ok(()),
82 "ext" => Err(ProtocolError::NotAllowed {
83 protocol: protocol.to_owned(),
84 }),
85 _ => check_allow_value(protocol, "user", inputs.git_protocol_from_user.as_deref()),
86 }
87}
88
89fn check_allow_value(
90 protocol: &str,
91 value: &str,
92 git_protocol_from_user: Option<&str>,
93) -> Result<(), ProtocolError> {
94 match value.to_lowercase().as_str() {
95 "always" => Ok(()),
96 "never" => Err(ProtocolError::NotAllowed {
97 protocol: protocol.to_owned(),
98 }),
99 "user" => {
100 if protocol_from_user(git_protocol_from_user) {
101 Ok(())
102 } else {
103 Err(ProtocolError::NotAllowed {
104 protocol: protocol.to_owned(),
105 })
106 }
107 }
108 other => Err(ProtocolError::UnknownAllowValue {
109 protocol: protocol.to_owned(),
110 value: other.to_owned(),
111 }),
112 }
113}
114
115#[must_use]
117pub fn protocol_from_user(raw: Option<&str>) -> bool {
118 match raw {
119 None => true,
120 Some(v) => {
121 let v = v.trim().to_ascii_lowercase();
122 v.is_empty() || !matches!(v.as_str(), "0" | "false" | "no" | "off")
123 }
124 }
125}
126
127#[derive(Clone, Debug, Default)]
129pub struct ClientProtocolVersionInputs {
130 pub config_param_version: Option<String>,
132 pub repo_config_version: Option<String>,
134 pub git_test_protocol_version: Option<String>,
136}
137
138#[must_use]
140pub fn parse_protocol_version_digit(s: &str) -> Option<u8> {
141 match s.trim() {
142 "0" => Some(0),
143 "1" => Some(1),
144 "2" => Some(2),
145 _ => None,
146 }
147}
148
149#[must_use]
155pub fn effective_client_protocol_version_from_inputs(inputs: &ClientProtocolVersionInputs) -> u8 {
156 if let Some(v) = inputs.config_param_version.as_deref() {
157 return parse_protocol_version_digit(v).unwrap_or(2);
158 }
159 if let Some(v) = inputs.repo_config_version.as_deref() {
160 return parse_protocol_version_digit(v).unwrap_or(2);
161 }
162 if let Some(raw) = inputs
163 .git_test_protocol_version
164 .as_deref()
165 .filter(|s| !s.is_empty())
166 {
167 return parse_protocol_version_digit(raw).unwrap_or(2);
168 }
169 2
170}
171
172#[must_use]
174pub fn server_protocol_version_from_git_protocol(raw: Option<&str>) -> u8 {
175 let Some(raw) = raw else {
176 return 0;
177 };
178 let mut best = 0u8;
179 for part in raw.split(':') {
180 let Some(rest) = part.strip_prefix("version=") else {
181 continue;
182 };
183 let v = rest.parse::<u8>().unwrap_or(0);
184 if v > best {
185 best = v;
186 }
187 }
188 best
189}
190
191fn strip_protocol_version_entries(s: &str) -> String {
192 s.split(':')
193 .filter(|p| !p.trim_start().starts_with("version="))
194 .filter(|p| !p.is_empty())
195 .collect::<Vec<_>>()
196 .join(":")
197}
198
199#[must_use]
204pub fn merged_git_protocol_value(client_wants: u8, existing: Option<&str>) -> Option<String> {
205 if client_wants == 0 {
206 return None;
207 }
208 let entry = format!("version={client_wants}");
209 Some(match existing {
210 Some(e) if !e.is_empty() => {
211 let base = strip_protocol_version_entries(e.trim());
212 if base.is_empty() {
213 entry
214 } else {
215 format!("{base}:{entry}")
216 }
217 }
218 _ => entry,
219 })
220}