use async_trait::async_trait;
use brainwires_mcp::{JsonRpcError, JsonRpcRequest};
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
use super::{Middleware, MiddlewareResult};
use crate::connection::RequestContext;
pub struct OAuthMiddleware {
decoding_key: DecodingKey,
validation: Validation,
}
impl OAuthMiddleware {
pub fn with_secret(secret: &[u8]) -> Self {
Self {
decoding_key: DecodingKey::from_secret(secret),
validation: Validation::new(Algorithm::HS256),
}
}
pub fn with_rsa_pem(pem: &str) -> Result<Self, jsonwebtoken::errors::Error> {
Ok(Self {
decoding_key: DecodingKey::from_rsa_pem(pem.as_bytes())?,
validation: Validation::new(Algorithm::RS256),
})
}
pub fn require_issuer(mut self, issuer: impl Into<String>) -> Self {
self.validation.set_issuer(&[issuer.into()]);
self
}
pub fn require_audience(mut self, audience: impl Into<String>) -> Self {
self.validation.set_audience(&[audience.into()]);
self
}
fn reject(msg: &str) -> MiddlewareResult {
MiddlewareResult::Reject(JsonRpcError {
code: -32003,
message: format!("Unauthorized: {}", msg),
data: None,
})
}
}
#[async_trait]
impl Middleware for OAuthMiddleware {
async fn process_request(
&self,
request: &JsonRpcRequest,
ctx: &mut RequestContext,
) -> MiddlewareResult {
if request.method == "initialize" {
return MiddlewareResult::Continue;
}
if ctx.get_metadata("oauth_validated").is_some() {
return MiddlewareResult::Continue;
}
let token = match request
.params
.as_ref()
.and_then(|p| p.get("_bearer_token"))
.and_then(|v| v.as_str())
{
Some(t) => t,
None => return Self::reject("missing _bearer_token in params"),
};
match decode::<serde_json::Value>(token, &self.decoding_key, &self.validation) {
Ok(_) => {
ctx.set_metadata("oauth_validated".to_string(), serde_json::Value::Bool(true));
MiddlewareResult::Continue
}
Err(e) => Self::reject(&e.to_string()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use jsonwebtoken::{EncodingKey, Header, encode};
use serde_json::json;
fn make_token(secret: &[u8]) -> String {
let claims = json!({ "sub": "test", "exp": 9999999999u64 });
encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(secret),
)
.unwrap()
}
fn make_request(method: &str, token: Option<&str>) -> JsonRpcRequest {
let params = token.map(|t| json!({ "_bearer_token": t }));
JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: json!(1),
method: method.to_string(),
params,
}
}
#[tokio::test]
async fn valid_jwt_passes() {
let secret = b"supersecret";
let mw = OAuthMiddleware::with_secret(secret);
let req = make_request("tools/call", Some(&make_token(secret)));
let mut ctx = RequestContext::new(json!(1));
assert!(matches!(
mw.process_request(&req, &mut ctx).await,
MiddlewareResult::Continue
));
}
#[tokio::test]
async fn missing_token_rejects() {
let mw = OAuthMiddleware::with_secret(b"secret");
let req = make_request("tools/call", None);
let mut ctx = RequestContext::new(json!(1));
assert!(matches!(
mw.process_request(&req, &mut ctx).await,
MiddlewareResult::Reject(_)
));
}
#[tokio::test]
async fn wrong_secret_rejects() {
let token = make_token(b"correct_secret");
let mw = OAuthMiddleware::with_secret(b"wrong_secret");
let req = make_request("tools/call", Some(&token));
let mut ctx = RequestContext::new(json!(1));
assert!(matches!(
mw.process_request(&req, &mut ctx).await,
MiddlewareResult::Reject(_)
));
}
#[tokio::test]
async fn initialize_skips_auth() {
let mw = OAuthMiddleware::with_secret(b"secret");
let req = make_request("initialize", None);
let mut ctx = RequestContext::new(json!(1));
assert!(matches!(
mw.process_request(&req, &mut ctx).await,
MiddlewareResult::Continue
));
}
#[tokio::test]
async fn validated_token_cached_in_context() {
let secret = b"supersecret";
let mw = OAuthMiddleware::with_secret(secret);
let mut ctx = RequestContext::new(json!(1));
mw.process_request(
&make_request("tools/call", Some(&make_token(secret))),
&mut ctx,
)
.await;
assert!(matches!(
mw.process_request(&make_request("tools/list", None), &mut ctx)
.await,
MiddlewareResult::Continue
));
}
}