1use serde_json::json;
2
3use crate::{
4 acme::{get_header, AcmeError, Auth, Challenge, ChallengeType, Directory, Identifier, Order},
5 cache::AcmeCache,
6 crypto::EcdsaP256SHA256KeyPair,
7 jose::{jose_req, key_authorization_sha256},
8 B64_URL_SAFE_NO_PAD,
9};
10use base64::Engine;
11use generic_async_http_client::Response;
12
13#[derive(Debug)]
24pub struct Account {
25 key_pair: EcdsaP256SHA256KeyPair,
26 directory: Directory,
27 kid: String,
28}
29
30impl Account {
31 pub async fn load_or_create<'a, C, S, I>(
35 directory: Directory,
36 cache: Option<&C>,
37 contact: I,
38 ) -> Result<Self, AcmeError>
39 where
40 C: AcmeCache,
41 S: AsRef<str> + 'a,
42 I: IntoIterator<Item = &'a S>,
43 {
44 let contact: Vec<&'a str> = contact.into_iter().map(AsRef::<str>::as_ref).collect();
45 let pkcs8 = match &cache {
46 Some(cache) => cache
47 .read_account(&contact)
48 .await
49 .map_err(AcmeError::cache)?,
50 None => None,
51 };
52 let key_pair = match pkcs8 {
53 Some(pkcs8) => {
54 log::info!("found cached account key");
55 EcdsaP256SHA256KeyPair::load(&pkcs8)
56 }
57 None => {
58 log::info!("creating a new account key");
59 match EcdsaP256SHA256KeyPair::generate() {
60 Ok(pkcs8) => {
61 let data = pkcs8.as_ref();
62 if let Some(cache) = &cache {
63 cache
64 .write_account(&contact, data)
65 .await
66 .map_err(AcmeError::cache)?;
67 }
68 EcdsaP256SHA256KeyPair::load(data)
69 }
70 Err(_) => Err(()),
71 }
72 }
73 }
74 .map_err(|_| {
75 AcmeError::Io(std::io::Error::new(
76 std::io::ErrorKind::InvalidData,
77 "could not create key pair",
78 ))
79 })?;
80 let payload = json!({
81 "termsOfServiceAgreed": true,
82 "contact": contact,
83 })
84 .to_string();
85 let response = jose_req(
86 &key_pair,
87 None,
88 &directory.nonce().await?,
89 &directory.new_account,
90 &payload,
91 )
92 .await?;
93 let kid = get_header(&response, "Location")?;
94 Ok(Account {
95 key_pair,
96 kid,
97 directory,
98 })
99 }
100 async fn request(&self, url: impl AsRef<str>, payload: &str) -> Result<Response, AcmeError> {
102 jose_req(
103 &self.key_pair,
104 Some(&self.kid),
105 &self.directory.nonce().await?,
106 url.as_ref(),
107 payload,
108 )
109 .await
110 }
111 pub async fn new_order(&self, domains: Vec<String>) -> Result<Order, AcmeError> {
113 let domains: Vec<Identifier> = domains.into_iter().map(Identifier::Dns).collect();
114 let payload = format!("{{\"identifiers\":{}}}", serde_json::to_string(&domains)?);
115 let mut response = self.request(&self.directory.new_order, &payload).await?;
116 Ok(response.json().await?)
117 }
118 pub async fn check_auth(&self, url: impl AsRef<str>) -> Result<Auth, AcmeError> {
120 let payload = "".to_string();
121 let mut response = self.request(url, &payload).await?;
122 Ok(response.json().await?)
123 }
124 pub async fn trigger_challenge(&self, url: impl AsRef<str>) -> Result<(), AcmeError> {
126 self.request(&url, "{}").await?;
127 Ok(())
128 }
129 pub async fn send_csr(&self, url: impl AsRef<str>, csr: Vec<u8>) -> Result<Order, AcmeError> {
131 let payload = format!("{{\"csr\":\"{}\"}}", B64_URL_SAFE_NO_PAD.encode(csr));
132 let mut response = self.request(&url, &payload).await?;
133 Ok(response.json().await?)
134 }
135 pub async fn obtain_certificate(&self, url: impl AsRef<str>) -> Result<String, AcmeError> {
137 Ok(self.request(&url, "").await?.text().await?)
138 }
139 pub fn tls_alpn_01<'a>(
142 &self,
143 challenges: &'a [Challenge],
144 ) -> Result<(&'a Challenge, impl AsRef<[u8]>), AcmeError> {
145 let challenge = challenges
146 .iter()
147 .find(|c| c.typ == ChallengeType::TlsAlpn01);
148 let challenge = match challenge {
149 Some(challenge) => challenge,
150 None => return Err(AcmeError::NoTlsAlpn01Challenge),
151 };
152 let key_auth = key_authorization_sha256(&self.key_pair, &*challenge.token)?;
153
154 Ok((challenge, key_auth))
155 }
156}
157
158#[cfg(test)]
159#[cfg(any(feature = "use_async_std", feature = "use_tokio"))]
160mod test {
161 use std::collections::HashMap;
162
163 use super::*;
164 use crate::acme::test::{new_dir, return_nounce};
165 use crate::test::*;
166
167 fn parse_req(req: Vec<u8>) -> (String, Option<serde_json::Map<String, serde_json::Value>>, serde_json::Map<String, serde_json::Value>) {
168 let req = String::from_utf8(req).expect("request not utf8");
169
170 let parts = req.split_once("\r\n\r\n").expect("no body");
171
172 let body: HashMap<String, String> =
173 serde_json::from_str(parts.1).expect("body not json");
174
175 let payload = body.get("payload").expect("no payload");
176 let payload = if payload.is_empty() {
177 None
178 }else{
179 let payload: serde_json::Map<String, serde_json::Value> = serde_json::from_slice(
180 &B64_URL_SAFE_NO_PAD
181 .decode(payload)
182 .expect("b64"),
183 )
184 .expect("payload not json");
185 Some(payload)
186 };
187
188 let protected: serde_json::Map<String, serde_json::Value> = serde_json::from_slice(
189 &B64_URL_SAFE_NO_PAD
190 .decode(body.get("protected").expect("no protected"))
191 .expect("b64"),
192 )
193 .expect("protected not json");
194 (parts.0.to_owned(), payload, protected)
195 }
196
197 #[test]
198 fn new() {
199 async fn server(listener: TcpListener, host: String, port: u16) -> std::io::Result<bool> {
200 return_nounce(&listener).await?;
201 let (mut stream, _) = listener.accept().await?;
202 let mut req: Vec<u8> = vec![0; 1024];
203 let r = stream.read(req.as_mut_slice()).await?;
204 let (header, payload, protected) = parse_req(req[0..r].to_vec());
205 let payload = payload.expect("no payload");
206 assert!(header.starts_with("POST /acme/new-acct HTTP"));
207
208 assert_eq!(payload.get("termsOfServiceAgreed"), Some(&true.into()));
209 assert_eq!(
210 payload
211 .get("contact")
212 .expect("no contact")
213 .as_array()
214 .expect("no contact array")
215 .first(),
216 Some(&"mailto:admin@example.com".into())
217 );
218
219 assert_eq!(protected.get("alg"), Some(&"ES256".into()));
220 assert_eq!(protected.get("nonce"), Some(&"abc".into()));
221 assert_eq!(
222 protected.get("url"),
223 Some(&format!("http://{host}:{port}/acme/new-acct").into())
224 );
225
226 stream
227 .write_all(b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\nLocation: abc\r\n\r\n")
228 .await?;
229
230 close(stream).await?;
231
232 Ok(true)
233 }
234 block_on(async {
235 let (listener, port, host) = listen_somewhere().await?;
236 let directory = new_dir(&host, port);
237 let t = spawn(server(listener, host, port));
238
239 let account = Account::load_or_create(
240 directory,
241 None::<&String>,
242 &vec!["mailto:admin@example.com".to_string()],
243 )
244 .await?;
245 assert_eq!(account.kid, "abc");
246
247 assert!(t.await?, "not cool");
248 Ok(())
249 });
250 }
251 fn new_account(directory: Directory) -> Account {
252 let key_pair = EcdsaP256SHA256KeyPair::load(b"0\x81\x87\x02\x01\x000\x13\x06\x07*\x86H\xce=\x02\x01\x06\x08*\x86H\xce=\x03\x01\x07\x04m0k\x02\x01\x01\x04 \x9e!\xcd\x90u\x8d\xba\xe9\xa0-(S\x86\x9aCt\x9c\xcb\xda6Z2\xb8\x9a\xad\xac\x11\n\xb9J\xcei\xa1D\x03B\x00\x04\x834\xd0\xfb\xff\x83D\xfe\xeb\xabn\xb4$\xf5\xe7\xd0\x11\x1cE\xbfK\xb7\x85ZL\x15'\xdfs\x0c\xfb\xdd\xe5\x97|\x93\xf2g\xbd+\xc8\xd0\xaf\xe0\xc1\x88\x16\x99\xde\x9b\xbb\xe4\xb9`_\xe6=\xe2MLP\xa1Ab").unwrap();
253 Account {
254 key_pair,
255 directory,
256 kid: "kid".to_string(),
257 }
258 }
259 #[test]
260 fn new_order() {
261 async fn server(listener: TcpListener) -> std::io::Result<bool> {
262 return_nounce(&listener).await?;
263 let (mut stream, _) = listener.accept().await?;
264 let mut req: Vec<u8> = vec![0; 1024];
265 let r = stream.read(req.as_mut_slice()).await?;
266 let (header, payload, _) = parse_req(req[0..r].to_vec());
267 let payload = payload.expect("no payload");
268
269 assert!(header.starts_with("POST /acme/new-order HTTP"));
270
271 let i = payload
272 .get("identifiers")
273 .expect("no identifiers")
274 .as_array()
275 .expect("no identifiers array")
276 .first()
277 .expect("no ele")
278 .as_object()
279 .expect("id not a obj");
280
281 assert_eq!(i.get("type"), Some(&"dns".into()));
282 assert_eq!(i.get("value"), Some(&"example.com".into()));
283
284 let body = r##"{"status":"pending", "authorizations": ["http://example.com/auth"], "finalize": "finalize"}"##;
285
286 stream
287 .write_all(format!("HTTP/1.1 200 OK\r\nContent-Length: {}\r\nContent-Type: application/json\r\n\r\n{}", body.len(),body).as_bytes())
288 .await?;
289
290 close(stream).await?;
291
292 Ok(true)
293 }
294 block_on(async {
295 let (listener, port, host) = listen_somewhere().await?;
296 let directory = new_dir(&host, port);
297 let t = spawn(server(listener));
298
299 let account = new_account(directory);
300 let o = account.new_order(vec!["example.com".to_string()]).await?;
301
302 let (a, f) = match o {
303 Order::Pending {
304 authorizations,
305 finalize
306 } => (authorizations,
307 finalize),
308 _ => panic!("wrong variant")
309 };
310 assert_eq!(a, vec!["http://example.com/auth".to_string()]);
311 assert_eq!(f, "finalize");
312
313 assert!(t.await?, "not cool");
314 Ok(())
315 });
316 }
317 #[test]
319 fn check_auth() {
320 async fn server(listener: TcpListener) -> std::io::Result<bool> {
321 return_nounce(&listener).await?;
322 let (mut stream, _) = listener.accept().await?;
323 let mut req: Vec<u8> = vec![0; 1024];
324 let r = stream.read(req.as_mut_slice()).await?;
325
326 let (header, payload, _) = parse_req(req[0..r].to_vec());
327
328 assert!(payload.is_none());
329 assert!(header.starts_with("POST /check_auth HTTP"));
330
331 let body = r##"{"status":"pending", "challenges": [{"token":"t","type":"tls-alpn-01","url":"http://example.com/bla"}], "identifier": {"type": "dns", "value": "id"}}"##;
332
333 stream
334 .write_all(format!("HTTP/1.1 200 OK\r\nContent-Length: {}\r\nContent-Type: application/json\r\n\r\n{}", body.len(),body).as_bytes())
335 .await?;
336
337 return_nounce(&listener).await?;
338 let (mut stream, _) = listener.accept().await?;
339 let mut req: Vec<u8> = vec![0; 1024];
340 let r = stream.read(req.as_mut_slice()).await?;
341
342 let (header, payload, _) = parse_req(req[0..r].to_vec());
343
344 assert!(payload.is_none());
345 assert!(header.starts_with("POST /check_auth HTTP"));
346
347 let body = r##"{"status":"valid"}"##;
348
349 stream
350 .write_all(format!("HTTP/1.1 200 OK\r\nContent-Length: {}\r\nContent-Type: application/json\r\n\r\n{}", body.len(),body).as_bytes())
351 .await?;
352
353 close(stream).await?;
354 Ok(true)
355 }
356 block_on(async {
357 let (listener, port, host) = listen_somewhere().await?;
358 let directory = new_dir(&host, port);
359 let auth_url = format!("http://{host}:{port}/check_auth");
360 let t = spawn(server(listener));
361
362 let account = new_account(directory);
363 let o = account.check_auth(&auth_url).await?;
364
365 let (i, c) = match o {
366 Auth::Pending {
367 identifier: Identifier::Dns(i),
368 challenges
369 } => (i,
370 challenges),
371 _ => panic!("wrong variant")
372 };
373 assert_eq!(i, "id");
374 let Challenge { typ, url, token } = c.first().expect("no challange");
375 assert_eq!(*typ, ChallengeType::TlsAlpn01);
376 assert_eq!(url, "http://example.com/bla");
377 assert_eq!(token, "t");
378
379 let o = account.check_auth(auth_url).await?;
380
381 assert!(matches!(o, Auth::Valid));
382
383 assert!(t.await?, "not cool");
384 Ok(())
385 });
386 }
387 #[test]
388 fn send_csr() {
389 async fn server(listener: TcpListener) -> std::io::Result<bool> {
390 return_nounce(&listener).await?;
391 let (mut stream, _) = listener.accept().await?;
392 let mut req: Vec<u8> = vec![0; 1024];
393 let r = stream.read(req.as_mut_slice()).await?;
394 let (header, payload, _) = parse_req(req[0..r].to_vec());
395 let payload = payload.expect("no payload");
396
397 assert!(header.starts_with("POST /csr HTTP"));
398
399 let i = payload
400 .get("csr")
401 .expect("no csr")
402 .as_str()
403 .expect("csr not str");
404 let i = B64_URL_SAFE_NO_PAD
405 .decode(i)
406 .expect("b64");
407
408 assert_eq!(i, b"csr");
409
410 let body = r##"{"status":"valid", "certificate": "your_cert"}"##;
411
412 stream
413 .write_all(format!("HTTP/1.1 200 OK\r\nContent-Length: {}\r\nContent-Type: application/json\r\n\r\n{}", body.len(),body).as_bytes())
414 .await?;
415
416 close(stream).await?;
417
418 Ok(true)
419 }
420 block_on(async {
421 let (listener, port, host) = listen_somewhere().await?;
422 let directory = new_dir(&host, port);
423 let t = spawn(server(listener));
424
425 let account = new_account(directory);
426 let o = account.send_csr(format!("http://{host}:{port}/csr"), b"csr".to_vec()).await?;
427
428 let c = match o {
429 Order::Valid {
430 certificate
431 } => certificate,
432 _ => panic!("wrong variant")
433 };
434 assert_eq!(c, "your_cert");
435
436 assert!(t.await?, "not cool");
437 Ok(())
438 });
439 }
440}