Skip to main content

brainos_mcphost/
aud_check.rs

1//! OAuth access-token audience-claim (`aud`) validation.
2//!
3//! Threat: CVE-2025-6514 class "confused-deputy via token passthrough."
4//! A legitimate access token issued for MCP server A is replayed
5//! against server B. The mitigation is RFC 8707 `resource` indicators
6//! at token-request time, plus `aud`-claim validation at token-use
7//! time.
8//!
9//! Brain mcphost is the OAuth client. The complementary risk on the
10//! client side is that the vault stores a token whose `aud` doesn't
11//! match the configured `resource` for the server name we're keyed
12//! by — either because the token was minted with the wrong resource
13//! indicator at the time, or because something tampered with the
14//! vault rows after the fact. This module's job is to catch that at
15//! load time and fail closed before the bad token reaches the wire.
16//!
17//! ## What we validate
18//!
19//! - If the access token parses as a JWT (three base64url-encoded
20//!   segments, middle segment is JSON) AND the JSON carries an `aud`
21//!   claim: the claim must contain the configured resource string.
22//! - If the token doesn't parse as a JWT (opaque token — common for
23//!   first-party AS / OAuth 2.1 confidential clients): the audience
24//!   check is skipped. Opaque tokens carry no in-band audience info;
25//!   the resource indicator we passed at request time is the
26//!   server-side enforcement.
27//! - If the token parses as a JWT but has no `aud` claim: log a
28//!   warning. RFC 9068 says JWT access tokens MUST carry `aud`, but
29//!   real-world ASes deviate, so we don't fail closed on this.
30//!
31//! Signature verification is intentionally out of scope. Verifying
32//! would require fetching the AS's JWKS, which adds a network round
33//! trip and key-rotation complexity that rmcp's transport layer
34//! already owns. The vault is the trust boundary here — we got the
35//! token from a successful PKCE flow through TLS to the AS, and the
36//! vault row is gated by the OS keychain or the encrypted-file
37//! backend.
38
39use base64::engine::general_purpose::URL_SAFE_NO_PAD;
40use base64::Engine;
41
42/// Outcome of an audience check against an access token.
43#[derive(Debug, Clone, PartialEq, Eq)]
44pub enum AudCheckOutcome {
45    /// Token parses as JWT, has an `aud` claim, and the configured
46    /// resource is present.
47    Match,
48    /// Token parses as JWT, has an `aud` claim, but the configured
49    /// resource is **not** present. This is the confused-deputy
50    /// signal — caller should reject the token and emit a
51    /// `BrainEvent::Error`.
52    Mismatch {
53        found: Vec<String>,
54        expected: String,
55    },
56    /// Token parses as JWT but the payload has no `aud` claim. Warn,
57    /// don't fail closed — some ASes still don't include it.
58    MissingAud,
59    /// Token does not parse as a JWT (doesn't have three `.`-separated
60    /// base64url segments, or the middle segment isn't valid JSON).
61    /// This is the common case for opaque OAuth 2.1 tokens; no
62    /// in-band audience info, validation is skipped.
63    OpaqueToken,
64}
65
66impl AudCheckOutcome {
67    /// Is this outcome a hard failure that should reject the token?
68    pub fn is_mismatch(&self) -> bool {
69        matches!(self, AudCheckOutcome::Mismatch { .. })
70    }
71}
72
73/// Decode the `aud` claim from `access_token` and compare against
74/// `expected_resource`. See module docs for what each outcome means.
75///
76/// `expected_resource` is the RFC 8707 resource indicator the client
77/// configured for the server — typically the server's base URL, or a
78/// distinct logical identifier the AS expects.
79pub fn validate_token_aud(access_token: &str, expected_resource: &str) -> AudCheckOutcome {
80    let Some(payload_b64) = jwt_payload_segment(access_token) else {
81        return AudCheckOutcome::OpaqueToken;
82    };
83    let Ok(payload_bytes) = URL_SAFE_NO_PAD.decode(payload_b64) else {
84        return AudCheckOutcome::OpaqueToken;
85    };
86    let Ok(payload) = serde_json::from_slice::<serde_json::Value>(&payload_bytes) else {
87        return AudCheckOutcome::OpaqueToken;
88    };
89
90    let Some(aud_field) = payload.get("aud") else {
91        return AudCheckOutcome::MissingAud;
92    };
93
94    let found = collect_aud(aud_field);
95    if found.is_empty() {
96        return AudCheckOutcome::MissingAud;
97    }
98
99    if found.iter().any(|v| v == expected_resource) {
100        AudCheckOutcome::Match
101    } else {
102        AudCheckOutcome::Mismatch {
103            found,
104            expected: expected_resource.to_string(),
105        }
106    }
107}
108
109/// Extract the middle (payload) segment of a compact JWS-encoded JWT.
110/// Returns `None` if the structure doesn't look like a JWT — exactly
111/// three non-empty `.`-separated segments.
112fn jwt_payload_segment(token: &str) -> Option<&str> {
113    let mut parts = token.split('.');
114    let _header = parts.next().filter(|s| !s.is_empty())?;
115    let payload = parts.next().filter(|s| !s.is_empty())?;
116    let _signature = parts.next().filter(|s| !s.is_empty())?;
117    if parts.next().is_some() {
118        return None;
119    }
120    Some(payload)
121}
122
123/// RFC 7519 `aud` is either a single string or an array of strings.
124/// Other shapes (number, object) are treated as missing.
125fn collect_aud(value: &serde_json::Value) -> Vec<String> {
126    match value {
127        serde_json::Value::String(s) => vec![s.clone()],
128        serde_json::Value::Array(items) => items
129            .iter()
130            .filter_map(|v| v.as_str().map(String::from))
131            .collect(),
132        _ => Vec::new(),
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use base64::engine::general_purpose::URL_SAFE_NO_PAD;
140
141    /// Build a fake compact-JWS-shaped token: `header.payload.sig`,
142    /// where every segment is base64url(no-pad). Header + signature
143    /// are placeholders — only the payload is read.
144    pub(crate) fn fake_jwt(payload: serde_json::Value) -> String {
145        let header_b64 = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\",\"typ\":\"JWT\"}");
146        let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&payload).unwrap());
147        let sig_b64 = URL_SAFE_NO_PAD.encode(b"signature-bytes");
148        format!("{header_b64}.{payload_b64}.{sig_b64}")
149    }
150
151    #[test]
152    fn match_for_string_aud() {
153        let token = fake_jwt(serde_json::json!({
154            "aud": "https://server.example.com",
155            "sub": "user-1",
156        }));
157        let outcome = validate_token_aud(&token, "https://server.example.com");
158        assert_eq!(outcome, AudCheckOutcome::Match);
159    }
160
161    #[test]
162    fn match_for_array_aud_containing_expected() {
163        let token = fake_jwt(serde_json::json!({
164            "aud": ["https://other.example.com", "https://server.example.com"],
165        }));
166        let outcome = validate_token_aud(&token, "https://server.example.com");
167        assert_eq!(outcome, AudCheckOutcome::Match);
168    }
169
170    #[test]
171    fn mismatch_for_string_aud_pointing_elsewhere() {
172        let token = fake_jwt(serde_json::json!({
173            "aud": "https://other.example.com",
174        }));
175        let outcome = validate_token_aud(&token, "https://server.example.com");
176        match outcome {
177            AudCheckOutcome::Mismatch { found, expected } => {
178                assert_eq!(found, vec!["https://other.example.com".to_string()]);
179                assert_eq!(expected, "https://server.example.com");
180            }
181            other => panic!("expected Mismatch, got {other:?}"),
182        }
183    }
184
185    #[test]
186    fn mismatch_for_array_aud_without_expected() {
187        let token = fake_jwt(serde_json::json!({
188            "aud": ["a", "b", "c"],
189        }));
190        let outcome = validate_token_aud(&token, "z");
191        assert!(outcome.is_mismatch());
192    }
193
194    #[test]
195    fn missing_aud_when_payload_has_no_field() {
196        let token = fake_jwt(serde_json::json!({
197            "sub": "user-1",
198        }));
199        assert_eq!(
200            validate_token_aud(&token, "anything"),
201            AudCheckOutcome::MissingAud
202        );
203    }
204
205    #[test]
206    fn missing_aud_when_field_is_unexpected_type() {
207        // RFC 7519 says aud is string or array-of-strings; numbers /
208        // objects / arrays-of-non-strings are treated as missing.
209        let token = fake_jwt(serde_json::json!({ "aud": 42 }));
210        assert_eq!(
211            validate_token_aud(&token, "anything"),
212            AudCheckOutcome::MissingAud
213        );
214        let token = fake_jwt(serde_json::json!({ "aud": [123, 456] }));
215        assert_eq!(
216            validate_token_aud(&token, "anything"),
217            AudCheckOutcome::MissingAud
218        );
219    }
220
221    #[test]
222    fn opaque_token_when_not_three_segments() {
223        // Common opaque-token shapes from real-world OAuth deployments.
224        assert_eq!(
225            validate_token_aud("abc.def", "anything"),
226            AudCheckOutcome::OpaqueToken
227        );
228        assert_eq!(
229            validate_token_aud("just-a-random-string", "anything"),
230            AudCheckOutcome::OpaqueToken
231        );
232        assert_eq!(
233            validate_token_aud("a.b.c.d", "anything"),
234            AudCheckOutcome::OpaqueToken
235        );
236        assert_eq!(
237            validate_token_aud("a..c", "anything"),
238            AudCheckOutcome::OpaqueToken
239        );
240    }
241
242    #[test]
243    fn opaque_token_when_middle_segment_isnt_json() {
244        let header = URL_SAFE_NO_PAD.encode(b"{}");
245        let payload = URL_SAFE_NO_PAD.encode(b"not-actually-json");
246        let sig = URL_SAFE_NO_PAD.encode(b"sig");
247        let token = format!("{header}.{payload}.{sig}");
248        assert_eq!(
249            validate_token_aud(&token, "anything"),
250            AudCheckOutcome::OpaqueToken
251        );
252    }
253
254    #[test]
255    fn opaque_token_when_middle_segment_isnt_base64url() {
256        let token = "abc.@@@not-base64@@@.def";
257        assert_eq!(
258            validate_token_aud(token, "anything"),
259            AudCheckOutcome::OpaqueToken
260        );
261    }
262}