use oauth2::AsyncHttpClient;
use oauth2::http::{self, HeaderValue, StatusCode};
use std::error::Error as StdError;
use std::future::Future;
use std::pin::Pin;
#[cfg(feature = "dpop")]
use std::sync::Arc;
#[cfg(feature = "dpop")]
use tokio::sync::Mutex;
#[cfg(feature = "dpop")]
use turbomcp_dpop::{DpopKeyPair, DpopProofGenerator};
pub type HttpRequest = http::Request<Vec<u8>>;
pub type HttpResponse = http::Response<Vec<u8>>;
#[cfg(feature = "dpop")]
#[derive(Clone)]
pub struct DpopBinding {
generator: Arc<DpopProofGenerator>,
key_pair: Option<Arc<DpopKeyPair>>,
server_nonce: Arc<Mutex<Option<String>>>,
}
#[cfg(feature = "dpop")]
impl DpopBinding {
pub fn new(generator: Arc<DpopProofGenerator>) -> Self {
Self {
generator,
key_pair: None,
server_nonce: Arc::new(Mutex::new(None)),
}
}
#[must_use]
pub fn with_key_pair(mut self, key: Arc<DpopKeyPair>) -> Self {
self.key_pair = Some(key);
self
}
}
#[cfg(feature = "dpop")]
impl std::fmt::Debug for DpopBinding {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DpopBinding")
.field("key_pair", &self.key_pair.is_some())
.finish()
}
}
#[derive(Clone)]
pub struct OAuth2HttpClient {
inner: reqwest::Client,
#[cfg(feature = "dpop")]
dpop: Option<DpopBinding>,
}
impl OAuth2HttpClient {
pub fn new() -> Result<Self, reqwest::Error> {
let inner = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.timeout(std::time::Duration::from_secs(30))
.build()?;
Ok(Self {
inner,
#[cfg(feature = "dpop")]
dpop: None,
})
}
pub fn from_client(client: reqwest::Client) -> Self {
Self {
inner: client,
#[cfg(feature = "dpop")]
dpop: None,
}
}
#[cfg(feature = "dpop")]
#[must_use]
pub fn with_dpop(mut self, binding: DpopBinding) -> Self {
self.dpop = Some(binding);
self
}
#[cfg(feature = "dpop")]
async fn build_dpop_proof(
&self,
method: &str,
url: &str,
access_token: Option<&str>,
nonce: Option<&str>,
) -> Result<String, OAuth2HttpError> {
let Some(binding) = &self.dpop else {
return Err(OAuth2HttpError::Dpop(
"DPoP binding missing when generating proof".to_string(),
));
};
let key_ref = binding.key_pair.as_deref();
let proof = binding
.generator
.generate_proof_with_params(method, url, access_token, nonce, key_ref)
.await
.map_err(|e| OAuth2HttpError::Dpop(e.to_string()))?;
Ok(proof.to_jwt_string())
}
async fn execute(&self, request: HttpRequest) -> Result<HttpResponse, OAuth2HttpError> {
let (parts, body) = request.into_parts();
let url = parts.uri.to_string();
let method = match parts.method.as_str() {
"GET" => reqwest::Method::GET,
"POST" => reqwest::Method::POST,
"PUT" => reqwest::Method::PUT,
"DELETE" => reqwest::Method::DELETE,
"PATCH" => reqwest::Method::PATCH,
"HEAD" => reqwest::Method::HEAD,
"OPTIONS" => reqwest::Method::OPTIONS,
other => reqwest::Method::from_bytes(other.as_bytes())
.map_err(|_| OAuth2HttpError::InvalidHeader(format!("Invalid method: {other}")))?,
};
#[cfg(feature = "dpop")]
if self.dpop.is_some() {
return self.send_with_dpop(&parts, &method, &url, body).await;
}
let mut req_builder = self.inner.request(method, &url);
for (name, value) in parts.headers.iter() {
req_builder = req_builder.header(name.as_str(), value.as_bytes());
}
req_builder = req_builder.body(body);
let response = req_builder.send().await?;
Self::convert_response(response).await
}
async fn convert_response(
response: reqwest::Response,
) -> Result<HttpResponse, OAuth2HttpError> {
let status = StatusCode::from_u16(response.status().as_u16())
.map_err(|_| OAuth2HttpError::InvalidHeader("Invalid status code".to_string()))?;
let mut builder = http::Response::builder().status(status);
for (name, value) in response.headers().iter() {
let header_value = HeaderValue::from_bytes(value.as_bytes())
.map_err(|e| OAuth2HttpError::InvalidHeader(e.to_string()))?;
builder = builder.header(name.as_str(), header_value);
}
let body_bytes = response
.bytes()
.await
.map_err(|e| OAuth2HttpError::BodyRead(e.to_string()))?;
builder
.body(body_bytes.to_vec())
.map_err(|e| OAuth2HttpError::InvalidHeader(e.to_string()))
}
#[cfg(feature = "dpop")]
async fn send_with_dpop(
&self,
parts: &http::request::Parts,
method: &reqwest::Method,
url: &str,
body: Vec<u8>,
) -> Result<HttpResponse, OAuth2HttpError> {
let cached_nonce = {
let guard = self.dpop.as_ref().unwrap().server_nonce.lock().await;
guard.clone()
};
let proof = self
.build_dpop_proof(method.as_str(), url, None, cached_nonce.as_deref())
.await?;
let mut req = self.inner.request(method.clone(), url);
for (name, value) in parts.headers.iter() {
req = req.header(name.as_str(), value.as_bytes());
}
req = req.header("DPoP", proof).body(body.clone());
let response = req.send().await?;
if let Some(nonce_value) = response.headers().get("DPoP-Nonce")
&& let Ok(s) = nonce_value.to_str()
{
let mut guard = self.dpop.as_ref().unwrap().server_nonce.lock().await;
*guard = Some(s.to_string());
}
if response.status().as_u16() == 400 || response.status().as_u16() == 401 {
let new_nonce = response
.headers()
.get("DPoP-Nonce")
.and_then(|v| v.to_str().ok())
.map(str::to_string);
let buffered = Self::convert_response(response).await?;
if let Some(nonce) = new_nonce.as_deref() {
let is_nonce_challenge =
serde_json::from_slice::<serde_json::Value>(buffered.body())
.ok()
.and_then(|v| {
v.get("error")
.and_then(|e| e.as_str())
.map(|s| s == "use_dpop_nonce")
})
.unwrap_or(false);
if is_nonce_challenge {
let proof = self
.build_dpop_proof(method.as_str(), url, None, Some(nonce))
.await?;
let mut retry = self.inner.request(method.clone(), url);
for (name, value) in parts.headers.iter() {
retry = retry.header(name.as_str(), value.as_bytes());
}
retry = retry.header("DPoP", proof).body(body);
let retry_response = retry.send().await?;
if let Some(n) = retry_response
.headers()
.get("DPoP-Nonce")
.and_then(|v| v.to_str().ok())
{
let mut guard = self.dpop.as_ref().unwrap().server_nonce.lock().await;
*guard = Some(n.to_string());
}
return Self::convert_response(retry_response).await;
}
}
return Ok(buffered);
}
Self::convert_response(response).await
}
}
impl Default for OAuth2HttpClient {
fn default() -> Self {
Self::new().expect("Failed to create default HTTP client")
}
}
impl std::fmt::Debug for OAuth2HttpClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OAuth2HttpClient")
.field("inner", &"<reqwest::Client>")
.finish()
}
}
#[derive(Debug)]
pub enum OAuth2HttpError {
Request(reqwest::Error),
InvalidHeader(String),
BodyRead(String),
Dpop(String),
}
impl std::fmt::Display for OAuth2HttpError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Request(e) => write!(f, "HTTP request failed: {e}"),
Self::InvalidHeader(msg) => write!(f, "Invalid header value: {msg}"),
Self::BodyRead(msg) => write!(f, "Failed to read response body: {msg}"),
Self::Dpop(msg) => write!(f, "DPoP proof generation failed: {msg}"),
}
}
}
impl StdError for OAuth2HttpError {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
match self {
Self::Request(e) => Some(e),
_ => None,
}
}
}
impl From<reqwest::Error> for OAuth2HttpError {
fn from(e: reqwest::Error) -> Self {
Self::Request(e)
}
}
pub type OAuth2HttpFuture<'c> =
Pin<Box<dyn Future<Output = Result<HttpResponse, OAuth2HttpError>> + Send + 'c>>;
impl<'c> AsyncHttpClient<'c> for OAuth2HttpClient {
type Error = OAuth2HttpError;
type Future = OAuth2HttpFuture<'c>;
fn call(&'c self, request: HttpRequest) -> Self::Future {
Box::pin(async move { self.execute(request).await })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_creation() {
let client = OAuth2HttpClient::new();
assert!(client.is_ok());
}
#[test]
fn test_default() {
let _client = OAuth2HttpClient::default();
}
#[test]
fn test_error_display() {
let err = OAuth2HttpError::InvalidHeader("test".to_string());
assert!(err.to_string().contains("Invalid header value"));
}
}