atrium_oauth/http_client/
dpop.rs1use crate::jose::{
2 create_signed_jwt,
3 jws::RegisteredHeader,
4 jwt::{Claims, PublicClaims, RegisteredClaims},
5};
6use atrium_common::store::{memory::MemoryStore, Store};
7use atrium_xrpc::{
8 http::{Request, Response},
9 HttpClient,
10};
11use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
12use chrono::Utc;
13use jose_jwa::{Algorithm, Signing};
14use jose_jwk::{crypto, EcCurves, Jwk, Key};
15use rand::{
16 rngs::SmallRng,
17 {RngCore, SeedableRng},
18};
19use serde::Deserialize;
20use sha2::{Digest, Sha256};
21use std::sync::Arc;
22use thiserror::Error;
23
24const JWT_HEADER_TYP_DPOP: &str = "dpop+jwt";
25
26#[derive(Deserialize)]
27struct ErrorResponse {
28 error: String,
29}
30
31#[derive(Error, Debug)]
32pub enum Error {
33 #[error("crypto error: {0:?}")]
34 JwkCrypto(crypto::Error),
35 #[error("key does not match any alg supported by the server")]
36 UnsupportedKey,
37 #[error(transparent)]
38 SerdeJson(#[from] serde_json::Error),
39}
40
41type Result<T> = core::result::Result<T, Error>;
42
43pub struct DpopClient<T, S = MemoryStore<String, String>>
44where
45 S: Store<String, String>,
46{
47 inner: Arc<T>,
48 pub(crate) key: Key,
49 nonces: S,
50 is_auth_server: bool,
51}
52
53impl<T> DpopClient<T> {
54 pub fn new(
55 key: Key,
56 http_client: Arc<T>,
57 is_auth_server: bool,
58 supported_algs: &Option<Vec<String>>,
59 ) -> Result<Self> {
60 if let Some(algs) = supported_algs {
61 let alg = String::from(match &key {
62 Key::Ec(ec) => match &ec.crv {
63 EcCurves::P256 => "ES256",
64 _ => unimplemented!(),
65 },
66 _ => unimplemented!(),
67 });
68 if !algs.contains(&alg) {
69 return Err(Error::UnsupportedKey);
70 }
71 }
72 let nonces = MemoryStore::<String, String>::default();
73 Ok(Self { inner: http_client, key, nonces, is_auth_server })
74 }
75}
76
77impl<T, S> DpopClient<T, S>
78where
79 S: Store<String, String>,
80{
81 fn build_proof(
82 &self,
83 htm: String,
84 htu: String,
85 ath: Option<String>,
86 nonce: Option<String>,
87 ) -> Result<String> {
88 match crypto::Key::try_from(&self.key).map_err(Error::JwkCrypto)? {
89 crypto::Key::P256(crypto::Kind::Secret(secret_key)) => {
90 let mut header = RegisteredHeader::from(Algorithm::Signing(Signing::Es256));
91 header.typ = Some(JWT_HEADER_TYP_DPOP.into());
92 header.jwk = Some(Jwk {
93 key: Key::from(&crypto::Key::from(secret_key.public_key())),
94 prm: Default::default(),
95 });
96 let claims = Claims {
97 registered: RegisteredClaims {
98 jti: Some(Self::generate_jti()),
99 iat: Some(Utc::now().timestamp()),
100 ..Default::default()
101 },
102 public: PublicClaims { htm: Some(htm), htu: Some(htu), ath, nonce },
103 };
104 Ok(create_signed_jwt(secret_key.into(), header.into(), claims)?)
105 }
106 _ => unimplemented!(),
107 }
108 }
109 fn is_use_dpop_nonce_error(&self, response: &Response<Vec<u8>>) -> bool {
110 if self.is_auth_server {
112 if response.status() == 400 {
113 if let Ok(res) = serde_json::from_slice::<ErrorResponse>(response.body()) {
114 return res.error == "use_dpop_nonce";
115 };
116 }
117 }
118 else if response.status() == 401 {
121 if let Some(www_auth) =
122 response.headers().get("WWW-Authenticate").and_then(|v| v.to_str().ok())
123 {
124 return www_auth.starts_with("DPoP")
125 && www_auth.contains(r#"error="use_dpop_nonce""#);
126 }
127 }
128 false
129 }
130 fn generate_jti() -> String {
132 let mut rng = SmallRng::from_entropy();
133 let mut bytes = [0u8; 12];
134 rng.fill_bytes(&mut bytes);
135 URL_SAFE_NO_PAD.encode(bytes)
136 }
137}
138
139impl<T, S> HttpClient for DpopClient<T, S>
140where
141 T: HttpClient + Send + Sync + 'static,
142 S: Store<String, String> + Send + Sync + 'static,
143 S::Error: std::error::Error + Send + Sync + 'static,
144{
145 async fn send_http(
146 &self,
147 mut request: Request<Vec<u8>>,
148 ) -> core::result::Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>>
149 {
150 let uri = request.uri();
151 let nonce_key = uri.authority().unwrap().to_string();
152 let htm = request.method().to_string();
153 let htu = uri.to_string();
154 let ath = request
156 .headers()
157 .get("Authorization")
158 .filter(|v| v.to_str().is_ok_and(|s| s.starts_with("DPoP ")))
159 .map(|auth| URL_SAFE_NO_PAD.encode(Sha256::digest(&auth.as_bytes()[5..])));
160
161 let init_nonce = self.nonces.get(&nonce_key).await?;
162 let init_proof =
163 self.build_proof(htm.clone(), htu.clone(), ath.clone(), init_nonce.clone())?;
164 request.headers_mut().insert("DPoP", init_proof.parse()?);
165 let response = self.inner.send_http(request.clone()).await?;
166
167 let next_nonce =
168 response.headers().get("DPoP-Nonce").and_then(|v| v.to_str().ok()).map(String::from);
169 match &next_nonce {
170 Some(s) if next_nonce != init_nonce => {
171 self.nonces.set(nonce_key, s.clone()).await?;
173 }
174 _ => {
175 return Ok(response);
178 }
179 }
180
181 if !self.is_use_dpop_nonce_error(&response) {
182 return Ok(response);
183 }
184 let next_proof = self.build_proof(htm, htu, ath, next_nonce)?;
185 request.headers_mut().insert("DPoP", next_proof.parse()?);
186 let response = self.inner.send_http(request).await?;
187 Ok(response)
188 }
189}
190
191impl<T> Clone for DpopClient<T> {
192 fn clone(&self) -> Self {
193 Self {
194 inner: Arc::clone(&self.inner),
195 key: self.key.clone(),
196 nonces: self.nonces.clone(),
197 is_auth_server: self.is_auth_server,
198 }
199 }
200}