acme_lite/
trans.rs

1use std::{collections::VecDeque, sync::Arc};
2
3use base64::prelude::*;
4use parking_lot::Mutex;
5use serde::Serialize;
6
7use crate::{
8    acc::AcmeKey,
9    jws::{FlattenedJsonJws, Jwk, JwsProtectedHeader},
10    req::{req_expect_header, req_handle_error, req_head, req_post},
11};
12
13/// JWS payload and nonce handling for requests to the API.
14///
15/// Setup is:
16///
17/// 1. `Transport::new()`
18/// 2. `call_jwk()` against newAccount url
19/// 3. `set_key_id` from the returned `Location` header.
20/// 4. `call()` for all calls after that.
21#[derive(Clone, Debug)]
22pub(crate) struct Transport {
23    acme_key: AcmeKey,
24    nonce_pool: Arc<NoncePool>,
25}
26
27impl Transport {
28    pub fn new(nonce_pool: Arc<NoncePool>, acme_key: AcmeKey) -> Self {
29        Transport {
30            acme_key,
31            nonce_pool,
32        }
33    }
34
35    /// Update the key id once it is known (part of setting up the transport).
36    pub fn set_key_id(&mut self, kid: String) {
37        self.acme_key.set_key_id(kid);
38    }
39
40    /// The key used in the transport
41    pub fn acme_key(&self) -> &AcmeKey {
42        &self.acme_key
43    }
44
45    /// Make call using the full jwk. Only for the first newAccount request.
46    pub async fn call_jwk<T>(&self, url: &str, body: &T) -> eyre::Result<reqwest::Response>
47    where
48        T: Serialize + ?Sized,
49    {
50        self.do_call(url, body, jws_with_jwk).await
51    }
52
53    /// Make call using the key id
54    pub async fn call<T>(&self, url: &str, body: &T) -> eyre::Result<reqwest::Response>
55    where
56        T: Serialize + ?Sized,
57    {
58        self.do_call(url, body, jws_with_kid).await
59    }
60
61    async fn do_call<T, F>(
62        &self,
63        url: &str,
64        body: &T,
65        make_body: F,
66    ) -> eyre::Result<reqwest::Response>
67    where
68        T: Serialize + ?Sized,
69        F: Fn(&str, String, &AcmeKey, &T) -> eyre::Result<String>,
70    {
71        // The ACME API may at any point invalidate all nonces. If we detect such an
72        // error, we loop until the server accepts the nonce.
73        loop {
74            // Either get a new nonce, or reuse one from a previous request.
75            let nonce = self.nonce_pool.get_nonce().await?;
76
77            // Sign the body.
78            let body = make_body(url, nonce, &self.acme_key, body)?;
79
80            log::debug!("Call endpoint: {url}");
81
82            // Post it to the URL
83            let response = req_post(url, &body).await;
84
85            // Regardless of the request being a success or not, there might be a nonce in the
86            // response.
87            self.nonce_pool.extract_nonce(&response);
88
89            // Turn errors into ApiProblem.
90            let result = req_handle_error(response).await;
91
92            if let Err(problem) = &result {
93                if problem.is_bad_nonce() {
94                    // retry the request with a new nonce.
95                    log::debug!("Retrying on bad nonce");
96                    continue;
97                }
98
99                // it seems we sometimes make bad JWTs. Why?!
100                if problem.is_jwt_verification_error() {
101                    log::debug!("Retrying on: {problem}");
102                    continue;
103                }
104            }
105
106            return Ok(result?);
107        }
108    }
109}
110
111/// Shared pool of nonces.
112#[derive(Default, Debug)]
113pub(crate) struct NoncePool {
114    nonce_url: String,
115    pool: Mutex<VecDeque<String>>,
116}
117
118impl NoncePool {
119    pub fn new(nonce_url: &str) -> Self {
120        NoncePool {
121            nonce_url: nonce_url.to_owned(),
122            ..Default::default()
123        }
124    }
125
126    fn extract_nonce(&self, res: &reqwest::Response) {
127        if let Some(nonce) = res.headers().get("replay-nonce") {
128            log::trace!("Extracting new nonce");
129
130            let mut pool = self.pool.lock();
131
132            // TODO: ignore invalid replay-nonce values
133            // see https://datatracker.ietf.org/doc/html/rfc8555#section-6.5.1
134            pool.push_back(nonce.to_str().unwrap().to_owned());
135
136            if pool.len() > 10 {
137                pool.pop_front();
138            }
139        }
140    }
141
142    async fn get_nonce(&self) -> eyre::Result<String> {
143        {
144            let mut pool = self.pool.lock();
145
146            if let Some(nonce) = pool.pop_front() {
147                log::trace!("Use previous nonce");
148                return Ok(nonce);
149            }
150        }
151
152        log::debug!("Request new nonce");
153        let res = req_head(&self.nonce_url).await;
154
155        // TODO: ignore invalid replay-nonce values
156        // see https://datatracker.ietf.org/doc/html/rfc8555#section-6.5.1
157        Ok(req_expect_header(&res, "replay-nonce")?)
158    }
159}
160
161fn jws_with_kid<T: Serialize + ?Sized>(
162    url: &str,
163    nonce: String,
164    key: &AcmeKey,
165    payload: &T,
166) -> eyre::Result<String> {
167    let protected = JwsProtectedHeader::new_kid(key.key_id(), url, nonce);
168    jws_with(protected, key, payload)
169}
170
171fn jws_with_jwk<T: Serialize + ?Sized>(
172    url: &str,
173    nonce: String,
174    key: &AcmeKey,
175    payload: &T,
176) -> eyre::Result<String> {
177    let jwk = Jwk::try_from(key)?;
178    let protected = JwsProtectedHeader::new_jwk(jwk, url, nonce);
179    jws_with(protected, key, payload)
180}
181
182/// Construct JWS with protected header according to [RFC 7515 §5.1].
183///
184/// [RFC 7515 §5.1]: https://datatracker.ietf.org/doc/html/rfc7515#section-5.1
185fn jws_with<T: Serialize + ?Sized>(
186    protected: JwsProtectedHeader,
187    key: &AcmeKey,
188    payload: &T,
189) -> eyre::Result<String> {
190    let header = {
191        let pro_json = serde_json::to_string(&protected)?;
192        BASE64_URL_SAFE_NO_PAD.encode(pro_json)
193    };
194
195    let payload = {
196        let payload_json = serde_json::to_string(payload)?;
197
198        if payload_json == "\"\"" {
199            // This is a special case produced by ApiEmptyString and should
200            // not be further base64url encoded.
201            String::new()
202        } else {
203            BASE64_URL_SAFE_NO_PAD.encode(payload_json)
204        }
205    };
206
207    let to_sign = format!("{header}.{payload}");
208    let (signature, _rec_id) = key
209        .signing_key()
210        .sign_recoverable(to_sign.as_bytes())
211        .unwrap();
212
213    let signature = BASE64_URL_SAFE_NO_PAD.encode(signature.to_bytes());
214
215    let jws = FlattenedJsonJws::new(header, payload, signature);
216
217    Ok(serde_json::to_string(&jws)?)
218}