use crate::jose::{
create_signed_jwt,
jws::RegisteredHeader,
jwt::{Claims, PublicClaims, RegisteredClaims},
};
use atrium_common::store::{Store, memory::MemoryStore};
use atrium_xrpc::{
HttpClient,
http::{Request, Response},
};
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use chrono::Utc;
use jose_jwa::{Algorithm, Signing};
use jose_jwk::{EcCurves, Jwk, Key, crypto};
use rand::{
rngs::SmallRng,
{RngCore, SeedableRng},
};
use serde::Deserialize;
use sha2::{Digest, Sha256};
use std::sync::Arc;
use thiserror::Error;
const JWT_HEADER_TYP_DPOP: &str = "dpop+jwt";
#[derive(Deserialize)]
struct ErrorResponse {
error: String,
}
#[derive(Error, Debug)]
pub enum Error {
#[error("crypto error: {0:?}")]
JwkCrypto(crypto::Error),
#[error("key does not match any alg supported by the server")]
UnsupportedKey,
#[error(transparent)]
SerdeJson(#[from] serde_json::Error),
}
type Result<T> = core::result::Result<T, Error>;
pub struct DpopClient<T, S = MemoryStore<String, String>>
where
S: Store<String, String>,
{
inner: Arc<T>,
pub(crate) key: Key,
nonces: S,
is_auth_server: bool,
}
impl<T> DpopClient<T> {
pub fn new(
key: Key,
http_client: Arc<T>,
is_auth_server: bool,
supported_algs: &Option<Vec<String>>,
) -> Result<Self> {
if let Some(algs) = supported_algs {
let alg = String::from(match &key {
Key::Ec(ec) => match &ec.crv {
EcCurves::P256 => "ES256",
_ => unimplemented!(),
},
_ => unimplemented!(),
});
if !algs.contains(&alg) {
return Err(Error::UnsupportedKey);
}
}
let nonces = MemoryStore::<String, String>::default();
Ok(Self { inner: http_client, key, nonces, is_auth_server })
}
}
impl<T, S> DpopClient<T, S>
where
S: Store<String, String>,
{
fn build_proof(
&self,
htm: String,
htu: String,
ath: Option<String>,
nonce: Option<String>,
) -> Result<String> {
match crypto::Key::try_from(&self.key).map_err(Error::JwkCrypto)? {
crypto::Key::P256(crypto::Kind::Secret(secret_key)) => {
let mut header = RegisteredHeader::from(Algorithm::Signing(Signing::Es256));
header.typ = Some(JWT_HEADER_TYP_DPOP.into());
header.jwk = Some(Jwk {
key: Key::from(&crypto::Key::from(secret_key.public_key())),
prm: Default::default(),
});
let claims = Claims {
registered: RegisteredClaims {
jti: Some(Self::generate_jti()),
iat: Some(Utc::now().timestamp()),
..Default::default()
},
public: PublicClaims { htm: Some(htm), htu: Some(htu), ath, nonce },
};
Ok(create_signed_jwt(secret_key.into(), header.into(), claims)?)
}
_ => unimplemented!(),
}
}
fn is_use_dpop_nonce_error(&self, response: &Response<Vec<u8>>) -> bool {
if self.is_auth_server {
if response.status() == 400 {
if let Ok(res) = serde_json::from_slice::<ErrorResponse>(response.body()) {
return res.error == "use_dpop_nonce";
};
}
}
else if response.status() == 401 {
if let Some(www_auth) =
response.headers().get("WWW-Authenticate").and_then(|v| v.to_str().ok())
{
return www_auth.starts_with("DPoP")
&& www_auth.contains(r#"error="use_dpop_nonce""#);
}
}
false
}
fn generate_jti() -> String {
let mut rng = SmallRng::from_entropy();
let mut bytes = [0u8; 12];
rng.fill_bytes(&mut bytes);
URL_SAFE_NO_PAD.encode(bytes)
}
}
impl<T, S> HttpClient for DpopClient<T, S>
where
T: HttpClient + Send + Sync + 'static,
S: Store<String, String> + Send + Sync + 'static,
S::Error: std::error::Error + Send + Sync + 'static,
{
async fn send_http(
&self,
mut request: Request<Vec<u8>>,
) -> core::result::Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>>
{
let uri = request.uri();
let nonce_key = uri.authority().unwrap().to_string();
let htm = request.method().to_string();
let htu = uri.to_string();
let ath = request
.headers()
.get("Authorization")
.filter(|v| v.to_str().is_ok_and(|s| s.starts_with("DPoP ")))
.map(|auth| URL_SAFE_NO_PAD.encode(Sha256::digest(&auth.as_bytes()[5..])));
let init_nonce = self.nonces.get(&nonce_key).await?;
let init_proof =
self.build_proof(htm.clone(), htu.clone(), ath.clone(), init_nonce.clone())?;
request.headers_mut().insert("DPoP", init_proof.parse()?);
let response = self.inner.send_http(request.clone()).await?;
let next_nonce =
response.headers().get("DPoP-Nonce").and_then(|v| v.to_str().ok()).map(String::from);
match &next_nonce {
Some(s) if next_nonce != init_nonce => {
self.nonces.set(nonce_key, s.clone()).await?;
}
_ => {
return Ok(response);
}
}
if !self.is_use_dpop_nonce_error(&response) {
return Ok(response);
}
let next_proof = self.build_proof(htm, htu, ath, next_nonce)?;
request.headers_mut().insert("DPoP", next_proof.parse()?);
let response = self.inner.send_http(request).await?;
Ok(response)
}
}
impl<T> Clone for DpopClient<T> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
key: self.key.clone(),
nonces: self.nonces.clone(),
is_auth_server: self.is_auth_server,
}
}
}