brainos_mcphost/
aud_check.rs1use base64::engine::general_purpose::URL_SAFE_NO_PAD;
40use base64::Engine;
41
42#[derive(Debug, Clone, PartialEq, Eq)]
44pub enum AudCheckOutcome {
45 Match,
48 Mismatch {
53 found: Vec<String>,
54 expected: String,
55 },
56 MissingAud,
59 OpaqueToken,
64}
65
66impl AudCheckOutcome {
67 pub fn is_mismatch(&self) -> bool {
69 matches!(self, AudCheckOutcome::Mismatch { .. })
70 }
71}
72
73pub fn validate_token_aud(access_token: &str, expected_resource: &str) -> AudCheckOutcome {
80 let Some(payload_b64) = jwt_payload_segment(access_token) else {
81 return AudCheckOutcome::OpaqueToken;
82 };
83 let Ok(payload_bytes) = URL_SAFE_NO_PAD.decode(payload_b64) else {
84 return AudCheckOutcome::OpaqueToken;
85 };
86 let Ok(payload) = serde_json::from_slice::<serde_json::Value>(&payload_bytes) else {
87 return AudCheckOutcome::OpaqueToken;
88 };
89
90 let Some(aud_field) = payload.get("aud") else {
91 return AudCheckOutcome::MissingAud;
92 };
93
94 let found = collect_aud(aud_field);
95 if found.is_empty() {
96 return AudCheckOutcome::MissingAud;
97 }
98
99 if found.iter().any(|v| v == expected_resource) {
100 AudCheckOutcome::Match
101 } else {
102 AudCheckOutcome::Mismatch {
103 found,
104 expected: expected_resource.to_string(),
105 }
106 }
107}
108
109fn jwt_payload_segment(token: &str) -> Option<&str> {
113 let mut parts = token.split('.');
114 let _header = parts.next().filter(|s| !s.is_empty())?;
115 let payload = parts.next().filter(|s| !s.is_empty())?;
116 let _signature = parts.next().filter(|s| !s.is_empty())?;
117 if parts.next().is_some() {
118 return None;
119 }
120 Some(payload)
121}
122
123fn collect_aud(value: &serde_json::Value) -> Vec<String> {
126 match value {
127 serde_json::Value::String(s) => vec![s.clone()],
128 serde_json::Value::Array(items) => items
129 .iter()
130 .filter_map(|v| v.as_str().map(String::from))
131 .collect(),
132 _ => Vec::new(),
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use base64::engine::general_purpose::URL_SAFE_NO_PAD;
140
141 pub(crate) fn fake_jwt(payload: serde_json::Value) -> String {
145 let header_b64 = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\",\"typ\":\"JWT\"}");
146 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&payload).unwrap());
147 let sig_b64 = URL_SAFE_NO_PAD.encode(b"signature-bytes");
148 format!("{header_b64}.{payload_b64}.{sig_b64}")
149 }
150
151 #[test]
152 fn match_for_string_aud() {
153 let token = fake_jwt(serde_json::json!({
154 "aud": "https://server.example.com",
155 "sub": "user-1",
156 }));
157 let outcome = validate_token_aud(&token, "https://server.example.com");
158 assert_eq!(outcome, AudCheckOutcome::Match);
159 }
160
161 #[test]
162 fn match_for_array_aud_containing_expected() {
163 let token = fake_jwt(serde_json::json!({
164 "aud": ["https://other.example.com", "https://server.example.com"],
165 }));
166 let outcome = validate_token_aud(&token, "https://server.example.com");
167 assert_eq!(outcome, AudCheckOutcome::Match);
168 }
169
170 #[test]
171 fn mismatch_for_string_aud_pointing_elsewhere() {
172 let token = fake_jwt(serde_json::json!({
173 "aud": "https://other.example.com",
174 }));
175 let outcome = validate_token_aud(&token, "https://server.example.com");
176 match outcome {
177 AudCheckOutcome::Mismatch { found, expected } => {
178 assert_eq!(found, vec!["https://other.example.com".to_string()]);
179 assert_eq!(expected, "https://server.example.com");
180 }
181 other => panic!("expected Mismatch, got {other:?}"),
182 }
183 }
184
185 #[test]
186 fn mismatch_for_array_aud_without_expected() {
187 let token = fake_jwt(serde_json::json!({
188 "aud": ["a", "b", "c"],
189 }));
190 let outcome = validate_token_aud(&token, "z");
191 assert!(outcome.is_mismatch());
192 }
193
194 #[test]
195 fn missing_aud_when_payload_has_no_field() {
196 let token = fake_jwt(serde_json::json!({
197 "sub": "user-1",
198 }));
199 assert_eq!(
200 validate_token_aud(&token, "anything"),
201 AudCheckOutcome::MissingAud
202 );
203 }
204
205 #[test]
206 fn missing_aud_when_field_is_unexpected_type() {
207 let token = fake_jwt(serde_json::json!({ "aud": 42 }));
210 assert_eq!(
211 validate_token_aud(&token, "anything"),
212 AudCheckOutcome::MissingAud
213 );
214 let token = fake_jwt(serde_json::json!({ "aud": [123, 456] }));
215 assert_eq!(
216 validate_token_aud(&token, "anything"),
217 AudCheckOutcome::MissingAud
218 );
219 }
220
221 #[test]
222 fn opaque_token_when_not_three_segments() {
223 assert_eq!(
225 validate_token_aud("abc.def", "anything"),
226 AudCheckOutcome::OpaqueToken
227 );
228 assert_eq!(
229 validate_token_aud("just-a-random-string", "anything"),
230 AudCheckOutcome::OpaqueToken
231 );
232 assert_eq!(
233 validate_token_aud("a.b.c.d", "anything"),
234 AudCheckOutcome::OpaqueToken
235 );
236 assert_eq!(
237 validate_token_aud("a..c", "anything"),
238 AudCheckOutcome::OpaqueToken
239 );
240 }
241
242 #[test]
243 fn opaque_token_when_middle_segment_isnt_json() {
244 let header = URL_SAFE_NO_PAD.encode(b"{}");
245 let payload = URL_SAFE_NO_PAD.encode(b"not-actually-json");
246 let sig = URL_SAFE_NO_PAD.encode(b"sig");
247 let token = format!("{header}.{payload}.{sig}");
248 assert_eq!(
249 validate_token_aud(&token, "anything"),
250 AudCheckOutcome::OpaqueToken
251 );
252 }
253
254 #[test]
255 fn opaque_token_when_middle_segment_isnt_base64url() {
256 let token = "abc.@@@not-base64@@@.def";
257 assert_eq!(
258 validate_token_aud(token, "anything"),
259 AudCheckOutcome::OpaqueToken
260 );
261 }
262}