odoh_rs/
protocol.rs

1//! API for protocol functionality such as creating and parsing ODoH queries and responses.
2
3#![deny(missing_docs)]
4
5use aes_gcm::aead::generic_array::GenericArray;
6use aes_gcm::aead::{AeadInPlace, KeyInit};
7use aes_gcm::Aes128Gcm;
8use bytes::{Buf, BufMut, Bytes, BytesMut};
9use hkdf::Hkdf;
10use hpke::aead::{Aead as AeadTrait, AesGcm128};
11use hpke::kdf::{HkdfSha256, Kdf as KdfTrait};
12use hpke::kem::X25519HkdfSha256;
13use hpke::rand_core::{CryptoRng, RngCore};
14use hpke::{Deserializable, HpkeError, Kem as KemTrait, OpModeR, OpModeS, Serializable};
15use std::convert::{TryFrom, TryInto};
16use thiserror::Error as ThisError;
17
18// Extra info string used by various crypto routines.
19const LABEL_QUERY: &[u8] = b"odoh query";
20const LABEL_KEY: &[u8] = b"odoh key";
21const LABEL_NONCE: &[u8] = b"odoh nonce";
22const LABEL_KEY_ID: &[u8] = b"odoh key id";
23const LABEL_RESPONSE: &[u8] = b"odoh response";
24
25// The fixed HPKE ciphersuite this crate supports, and their associated constants
26type Kem = X25519HkdfSha256;
27type Aead = AesGcm128;
28type Kdf = HkdfSha256;
29const KEM_ID: u16 = Kem::KEM_ID;
30const KDF_ID: u16 = Kdf::KDF_ID;
31const AEAD_ID: u16 = Aead::AEAD_ID;
32
33/// For the selected KDF: SHA256
34const KDF_OUTPUT_SIZE: usize = 32;
35const AEAD_KEY_SIZE: usize = 16;
36const AEAD_NONCE_SIZE: usize = 12;
37
38/// This is the maximum of `AEAD_KEY_SIZE` and `AEAD_NONCE_SIZE`
39const RESPONSE_NONCE_SIZE: usize = 16;
40
41/// Length of public key used in config
42const PUBLIC_KEY_SIZE: usize = 32;
43
44type AeadKey = [u8; AEAD_KEY_SIZE];
45type AeadNonce = [u8; AEAD_NONCE_SIZE];
46
47/// Secret used in encrypt/decrypt API.
48pub type OdohSecret = [u8; AEAD_KEY_SIZE];
49
50/// Response nonce needed by [`encrypt_response`](fn.encrypt_response.html)
51pub type ResponseNonce = [u8; RESPONSE_NONCE_SIZE];
52
53/// HTTP content-type header required for sending queries and responses
54pub const ODOH_HTTP_HEADER: &str = "application/oblivious-dns-message";
55
56/// ODoH version supported by this library
57pub const ODOH_VERSION: u16 = 0x0001;
58
59/// Errors generated by this crate.
60#[derive(ThisError, Debug, Clone, PartialEq, Eq)]
61pub enum Error {
62    /// Input data is too short.
63    #[error("Input data is too short")]
64    ShortInput,
65    /// Input data has incorrect length.
66    #[error("Input data has incorrect length")]
67    InvalidInputLength,
68    /// Padding is not zero.
69    #[error("Padding is not zero")]
70    InvalidPadding,
71    /// Config parameter is invalid.
72    #[error("Config parameter is invalid")]
73    InvalidParameter,
74    /// Type byte in ObliviousDoHMessage is invalid.
75    #[error("Type byte in ObliviousDoHMessage is invalid")]
76    InvalidMessageType,
77    /// Message key_id does not match public key.
78    #[error("Message key_id does not match public key")]
79    KeyIdMismatch,
80    /// Response nonce is not equal to max(key, nonce) size.
81    #[error("Response nonce is not equal to max(key, nonce) size")]
82    InvalidResponseNonceLength,
83
84    /// Errors from hpke crate.
85    #[error(transparent)]
86    Hpke(#[from] HpkeError),
87
88    /// Errors from aes-gcm crate.
89    #[error(transparent)]
90    AesGcm(#[from] aes_gcm::Error),
91
92    /// Unexpected internal error.
93    #[error("Unexpected internal error")]
94    Internal,
95}
96
97type Result<T, E = Error> = std::result::Result<T, E>;
98
99/// Serialize to IETF wireformat that is similar to [XDR](https://tools.ietf.org/html/rfc1014)
100pub trait Serialize {
101    /// Serialize the provided struct into the buf.
102    fn serialize<B: BufMut>(self, buf: &mut B) -> Result<()>;
103}
104
105/// Deserialize from IETF wireformat that is similar to [XDR](https://tools.ietf.org/html/rfc1014)
106pub trait Deserialize {
107    /// Deserialize a struct from the buf.
108    fn deserialize<B: Buf>(buf: &mut B) -> Result<Self>
109    where
110        Self: Sized;
111}
112
113/// Convenient function to deserialize a structure from Bytes.
114pub fn parse<D: Deserialize, B: Buf>(buf: &mut B) -> Result<D> {
115    D::deserialize(buf)
116}
117
118/// Convenient function to serialize a structure into a new BytesMut.
119pub fn compose<S: Serialize>(s: S) -> Result<BytesMut> {
120    let mut buf = BytesMut::new();
121    s.serialize(&mut buf)?;
122    Ok(buf)
123}
124
125fn read_lengthed<B: Buf>(b: &mut B) -> Result<Bytes> {
126    if b.remaining() < 2 {
127        return Err(Error::ShortInput);
128    }
129
130    let len = b.get_u16() as usize;
131
132    if len > b.remaining() {
133        return Err(Error::InvalidInputLength);
134    }
135
136    Ok(b.copy_to_bytes(len))
137}
138
139/// Supplies config information to the client.
140///
141/// It contains one or more `ObliviousDoHConfig` structures in
142/// decreasing order of preference. This allows a server to support multiple versions
143/// of ODoH and multiple sets of ODoH HPKE suite parameters.
144///
145/// This information is designed to be disseminated via [DNS HTTPS
146/// records](https://tools.ietf.org/html/draft-ietf-dnsop-svcb-httpssvc-03),
147/// using the param `odohconfig`.
148#[derive(Debug, Clone)]
149pub struct ObliviousDoHConfigs {
150    // protocol: length prefix
151    configs: Vec<ObliviousDoHConfig>,
152}
153
154impl ObliviousDoHConfigs {
155    /// Filter the list of configs, leave ones matches ODOH_VERSION.
156    pub fn supported(self) -> Vec<ObliviousDoHConfig> {
157        self.into_iter().collect()
158    }
159}
160
161type VecIter = std::vec::IntoIter<ObliviousDoHConfig>;
162impl IntoIterator for ObliviousDoHConfigs {
163    type Item = ObliviousDoHConfig;
164    type IntoIter = std::iter::Filter<VecIter, fn(&Self::Item) -> bool>;
165
166    fn into_iter(self) -> Self::IntoIter {
167        self.configs
168            .into_iter()
169            .filter(|c| c.version == ODOH_VERSION)
170    }
171}
172
173impl From<Vec<ObliviousDoHConfig>> for ObliviousDoHConfigs {
174    fn from(configs: Vec<ObliviousDoHConfig>) -> Self {
175        Self { configs }
176    }
177}
178
179impl Serialize for &ObliviousDoHConfigs {
180    fn serialize<B: BufMut>(self, buf: &mut B) -> Result<()> {
181        // calculate total length
182        let mut len = 0;
183        for c in self.configs.iter() {
184            // 2 bytes of version and 2 bytes of length
185            len += 2 + 2 + c.length;
186        }
187
188        buf.put_u16(len);
189        for c in self.configs.iter() {
190            c.serialize(buf)?;
191        }
192
193        Ok(())
194    }
195}
196
197impl Deserialize for ObliviousDoHConfigs {
198    fn deserialize<B: Buf>(buf: &mut B) -> Result<Self> {
199        let mut buf = read_lengthed(buf)?;
200
201        let mut configs = Vec::new();
202        loop {
203            if buf.is_empty() {
204                break;
205            }
206            let c = parse(&mut buf)?;
207            configs.push(c);
208        }
209
210        Ok(Self { configs })
211    }
212}
213
214/// Contains version and encryption information. Based on the version specified,
215/// the contents can differ.
216///
217/// For `ODOH_VERSION = 0x0001`, `ObliviousDoHConfig::contents`
218/// deserializes into
219/// [ObliviousDoHConfigContents](./../struct.ObliviousDoHConfigContents.html).
220#[derive(Debug, Clone, PartialEq, Eq)]
221pub struct ObliviousDoHConfig {
222    version: u16,
223    length: u16,
224    contents: ObliviousDoHConfigContents,
225}
226
227impl Serialize for &ObliviousDoHConfig {
228    fn serialize<B: BufMut>(self, buf: &mut B) -> Result<()> {
229        buf.put_u16(self.version);
230        buf.put_u16(self.length);
231        self.contents.serialize(buf)
232    }
233}
234
235impl Deserialize for ObliviousDoHConfig {
236    fn deserialize<B: Buf>(mut buf: &mut B) -> Result<Self> {
237        if buf.remaining() < 2 {
238            return Err(Error::ShortInput);
239        }
240        let version = buf.get_u16();
241        let mut contents = read_lengthed(&mut buf)?;
242        let length = contents.len() as u16;
243
244        Ok(Self {
245            version,
246            length,
247            contents: parse(&mut contents)?,
248        })
249    }
250}
251
252impl From<ObliviousDoHConfig> for ObliviousDoHConfigContents {
253    fn from(c: ObliviousDoHConfig) -> Self {
254        c.contents
255    }
256}
257
258impl From<ObliviousDoHConfigContents> for ObliviousDoHConfig {
259    fn from(c: ObliviousDoHConfigContents) -> Self {
260        Self {
261            version: ODOH_VERSION,
262            length: c.len() as u16,
263            contents: c,
264        }
265    }
266}
267
268/// Contains the HPKE suite parameters and the
269/// resolver (target's) public key.
270#[derive(Debug, Clone, PartialEq, Eq)]
271pub struct ObliviousDoHConfigContents {
272    kem_id: u16,
273    kdf_id: u16,
274    aead_id: u16,
275    // protocol: length prefix
276    public_key: Bytes,
277}
278
279impl ObliviousDoHConfigContents {
280    /// Creates a KeyID for an `ObliviousDoHConfigContents` struct
281    pub fn identifier(&self) -> Result<Vec<u8>> {
282        let buf = compose(self)?;
283
284        let key_id_info = LABEL_KEY_ID.to_vec();
285        let prk = Hkdf::<<Kdf as KdfTrait>::HashImpl>::new(None, &buf);
286        let mut key_id = [0; KDF_OUTPUT_SIZE];
287        prk.expand(&key_id_info, &mut key_id)
288            .map_err(|_| Error::from(HpkeError::KdfOutputTooLong))?;
289        Ok(key_id.to_vec())
290    }
291
292    fn len(&self) -> usize {
293        2 + 2 + 2 + 2 + self.public_key.len()
294    }
295}
296
297impl Serialize for &ObliviousDoHConfigContents {
298    fn serialize<B: BufMut>(self, buf: &mut B) -> Result<()> {
299        buf.put_u16(self.kem_id);
300        buf.put_u16(self.kdf_id);
301        buf.put_u16(self.aead_id);
302
303        buf.put_u16(to_u16(self.public_key.len())?);
304        buf.put(self.public_key.clone());
305        Ok(())
306    }
307}
308
309impl Deserialize for ObliviousDoHConfigContents {
310    fn deserialize<B: Buf>(mut buf: &mut B) -> Result<Self> {
311        if buf.remaining() < 2 + 2 + 2 {
312            return Err(Error::ShortInput);
313        }
314
315        let kem_id = buf.get_u16();
316        let kdf_id = buf.get_u16();
317        let aead_id = buf.get_u16();
318
319        if kem_id != KEM_ID || kdf_id != KDF_ID || aead_id != AEAD_ID {
320            return Err(Error::InvalidParameter);
321        }
322
323        let public_key = read_lengthed(&mut buf)?;
324        if public_key.len() != PUBLIC_KEY_SIZE {
325            return Err(Error::InvalidInputLength);
326        }
327
328        Ok(Self {
329            kem_id,
330            kdf_id,
331            aead_id,
332            public_key,
333        })
334    }
335}
336
337/// `ObliviousDoHMessageType` is supplied at the beginning of every ODoH message.
338/// It is used to specify whether a message is a query or a response.
339#[derive(Debug, Clone, Eq, PartialEq, Copy)]
340enum ObliviousDoHMessageType {
341    Query = 1,
342    Response = 2,
343}
344
345impl TryFrom<u8> for ObliviousDoHMessageType {
346    type Error = Error;
347    fn try_from(n: u8) -> Result<Self> {
348        match n {
349            1 => Ok(Self::Query),
350            2 => Ok(Self::Response),
351            _ => Err(Error::InvalidMessageType),
352        }
353    }
354}
355
356/// Main structure used to transfer queries and responses.
357///
358/// It specifies a message type, an identifier of the corresponding `ObliviousDoHConfigContents`
359/// structure being used, and the encrypted message for the target resolver, or a DNS response
360/// message for the client.
361pub struct ObliviousDoHMessage {
362    msg_type: ObliviousDoHMessageType,
363    // protocol: length prefix
364    key_id: Bytes,
365    // protocol: length prefix
366    encrypted_msg: Bytes,
367}
368
369impl ObliviousDoHMessage {
370    /// Returns the key ID contained in this message.
371    pub fn key_id(&self) -> &[u8] {
372        self.key_id.as_ref()
373    }
374}
375
376impl Deserialize for ObliviousDoHMessage {
377    fn deserialize<B: Buf>(mut buf: &mut B) -> Result<Self> {
378        if !buf.has_remaining() {
379            return Err(Error::ShortInput);
380        }
381
382        let msg_type = buf.get_u8().try_into()?;
383        let key_id = read_lengthed(&mut buf)?;
384        let encrypted_msg = read_lengthed(&mut buf)?;
385
386        Ok(Self {
387            msg_type,
388            key_id,
389            encrypted_msg,
390        })
391    }
392}
393
394impl Serialize for &ObliviousDoHMessage {
395    fn serialize<B: BufMut>(self, buf: &mut B) -> Result<()> {
396        buf.put_u8(self.msg_type as u8);
397        buf.put_u16(to_u16(self.key_id.len())?);
398        buf.put(self.key_id.clone());
399        buf.put_u16(to_u16(self.encrypted_msg.len())?);
400        buf.put(self.encrypted_msg.clone());
401        Ok(())
402    }
403}
404
405/// Structure holding unencrypted dns message and padding.
406#[derive(Debug, Clone, Eq, PartialEq)]
407pub struct ObliviousDoHMessagePlaintext {
408    // protocol: length prefix
409    dns_msg: Bytes,
410    // protocol: length prefix
411    padding: Bytes,
412}
413
414impl ObliviousDoHMessagePlaintext {
415    /// Create a new [`ObliviousDoHMessagePlaintext`] from DNS message
416    /// bytes and an optional padding.
417    ///
418    /// [`ObliviousDoHMessagePlaintext`]: struct.ObliviousDoHMessagePlaintext.html
419    pub fn new<M: AsRef<[u8]>>(msg: M, padding_len: usize) -> Self {
420        Self {
421            dns_msg: msg.as_ref().to_vec().into(),
422            padding: vec![0; padding_len].into(),
423        }
424    }
425
426    /// Consume the struct, return the inner DNS message bytes.
427    pub fn into_msg(self) -> Bytes {
428        self.dns_msg
429    }
430
431    /// Return the length of padding.
432    pub fn padding_len(&self) -> usize {
433        self.padding.len()
434    }
435}
436
437impl Deserialize for ObliviousDoHMessagePlaintext {
438    fn deserialize<B: Buf>(buf: &mut B) -> Result<Self> {
439        let dns_msg = read_lengthed(buf)?;
440        let padding = read_lengthed(buf)?;
441
442        if !padding.iter().all(|&x| x == 0x00) {
443            return Err(Error::InvalidPadding);
444        }
445
446        Ok(Self { dns_msg, padding })
447    }
448}
449
450impl Serialize for &ObliviousDoHMessagePlaintext {
451    fn serialize<B: BufMut>(self, buf: &mut B) -> Result<()> {
452        if !self.padding.iter().all(|&x| x == 0x00) {
453            return Err(Error::InvalidPadding);
454        }
455
456        buf.put_u16(to_u16(self.dns_msg.len())?);
457        buf.put(self.dns_msg.clone());
458
459        buf.put_u16(to_u16(self.padding.len())?);
460        buf.put(self.padding.clone());
461
462        Ok(())
463    }
464}
465
466/// `ObliviousDoHKeyPair` supplies relevant encryption/decryption information
467/// required by the target resolver to process DNS queries.
468#[derive(Clone)]
469pub struct ObliviousDoHKeyPair {
470    private_key: <Kem as KemTrait>::PrivateKey,
471    public_key: ObliviousDoHConfigContents,
472}
473
474impl ObliviousDoHKeyPair {
475    /// Generate a new keypair from given RNG.
476    pub fn new<R: RngCore + CryptoRng>(mut rng: &mut R) -> Self {
477        let (private_key, public_key) = Kem::gen_keypair(&mut rng);
478
479        let contents = ObliviousDoHConfigContents {
480            kem_id: KEM_ID,
481            kdf_id: KDF_ID,
482            aead_id: AEAD_ID,
483            public_key: public_key.to_bytes().to_vec().into(),
484        };
485
486        Self {
487            private_key,
488            public_key: contents,
489        }
490    }
491
492    /// Create a key pair from provided parameters.
493    pub fn from_parameters(kem_id: u16, kdf_id: u16, aead_id: u16, ikm: &[u8]) -> Self {
494        // derive keypair from ikm
495        let (private_key, public_key) = Kem::derive_keypair(ikm);
496        Self {
497            private_key,
498            public_key: ObliviousDoHConfigContents {
499                kem_id,
500                kdf_id,
501                aead_id,
502                public_key: public_key.to_bytes().to_vec().into(),
503            },
504        }
505    }
506
507    /// Return a reference of the private key.
508    pub fn private(&self) -> &<Kem as KemTrait>::PrivateKey {
509        &self.private_key
510    }
511
512    /// Return a reference of the public key.
513    pub fn public(&self) -> &ObliviousDoHConfigContents {
514        &self.public_key
515    }
516}
517
518/// Encrypt a client DNS query with a proper config, return the
519/// encrypted query and client secret.
520pub fn encrypt_query<R: RngCore + CryptoRng>(
521    query: &ObliviousDoHMessagePlaintext,
522    config: &ObliviousDoHConfigContents,
523    rng: &mut R,
524) -> Result<(ObliviousDoHMessage, OdohSecret)> {
525    let server_pk = <Kem as KemTrait>::PublicKey::from_bytes(&config.public_key)?;
526    let (encapped_key, mut send_ctx) =
527        hpke::setup_sender::<Aead, Kdf, Kem, _>(&OpModeS::Base, &server_pk, LABEL_QUERY, rng)?;
528
529    let key_id = config.identifier()?;
530    let aad = build_aad(ObliviousDoHMessageType::Query, &key_id)?;
531
532    let mut odoh_secret = OdohSecret::default();
533    send_ctx.export(LABEL_RESPONSE, &mut odoh_secret)?;
534
535    let mut buf = compose(query)?;
536
537    let tag = send_ctx.seal_in_place_detached(&mut buf, &aad)?;
538
539    let result = [
540        encapped_key.to_bytes().as_slice(),
541        &buf,
542        tag.to_bytes().as_slice(),
543    ]
544    .concat();
545
546    let msg = ObliviousDoHMessage {
547        msg_type: ObliviousDoHMessageType::Query,
548        key_id: key_id.to_vec().into(),
549        encrypted_msg: result.into(),
550    };
551
552    Ok((msg, odoh_secret))
553}
554
555/// Decrypt a DNS response from the server.
556pub fn decrypt_response(
557    query: &ObliviousDoHMessagePlaintext,
558    response: &ObliviousDoHMessage,
559    secret: OdohSecret,
560) -> Result<ObliviousDoHMessagePlaintext> {
561    if response.msg_type != ObliviousDoHMessageType::Response {
562        return Err(Error::InvalidMessageType);
563    }
564
565    let response_nonce = response
566        .key_id
567        .as_ref()
568        .try_into()
569        .map_err(|_| Error::InvalidResponseNonceLength)?;
570    let (key, nonce) = derive_secrets(secret, query, response_nonce)?;
571    let cipher = Aes128Gcm::new(GenericArray::from_slice(&key));
572    let mut data = response.encrypted_msg.to_vec();
573
574    let aad = build_aad(ObliviousDoHMessageType::Response, &response.key_id)?;
575
576    cipher.decrypt_in_place(GenericArray::from_slice(&nonce), &aad, &mut data)?;
577
578    let response_decrypted = parse(&mut Bytes::from(data))?;
579    Ok(response_decrypted)
580}
581
582/// Decrypt a client query.
583pub fn decrypt_query(
584    query: &ObliviousDoHMessage,
585    key_pair: &ObliviousDoHKeyPair,
586) -> Result<(ObliviousDoHMessagePlaintext, OdohSecret)> {
587    if query.msg_type != ObliviousDoHMessageType::Query {
588        return Err(Error::InvalidMessageType);
589    }
590
591    let key_id = key_pair.public().identifier()?;
592    let key_id_recv = &query.key_id;
593
594    if !key_id_recv.eq(&key_id) {
595        return Err(Error::KeyIdMismatch);
596    }
597
598    let server_sk = key_pair.private();
599    let key_size = <Kem as KemTrait>::PublicKey::size();
600    if key_size > query.encrypted_msg.len() {
601        return Err(Error::InvalidInputLength);
602    }
603    let (enc, ct) = query.encrypted_msg.split_at(key_size);
604
605    let encapped_key = <Kem as KemTrait>::EncappedKey::from_bytes(enc)?;
606
607    let mut recv_ctx = hpke::setup_receiver::<Aead, Kdf, Kem>(
608        &OpModeR::Base,
609        server_sk,
610        &encapped_key,
611        LABEL_QUERY,
612    )?;
613
614    // Open the payload
615    let aad = build_aad(ObliviousDoHMessageType::Query, &key_id)?;
616    let plaintext = recv_ctx.open(ct, &aad)?;
617
618    let mut odoh_secret = OdohSecret::default();
619    recv_ctx.export(LABEL_RESPONSE, &mut odoh_secret)?;
620
621    let query_decrypted = parse(&mut Bytes::from(plaintext))?;
622    Ok((query_decrypted, odoh_secret))
623}
624
625/// Encrypt a server response.
626pub fn encrypt_response(
627    query: &ObliviousDoHMessagePlaintext,
628    response: &ObliviousDoHMessagePlaintext,
629    secret: OdohSecret,
630    response_nonce: ResponseNonce,
631) -> Result<ObliviousDoHMessage> {
632    let (key, nonce) = derive_secrets(secret, query, response_nonce)?;
633    let cipher = Aes128Gcm::new(GenericArray::from_slice(&key));
634    let aad = build_aad(ObliviousDoHMessageType::Response, &response_nonce)?;
635
636    let mut buf = Vec::new();
637    response.serialize(&mut buf)?;
638    cipher.encrypt_in_place(GenericArray::from_slice(&nonce), &aad, &mut buf)?;
639
640    Ok(ObliviousDoHMessage {
641        msg_type: ObliviousDoHMessageType::Response,
642        key_id: response_nonce.to_vec().into(),
643        encrypted_msg: buf.into(),
644    })
645}
646
647// TODO: try to use a static buffer for aad building
648fn build_aad(t: ObliviousDoHMessageType, key_id: &[u8]) -> Result<Vec<u8>> {
649    let mut aad = vec![t as u8];
650    aad.extend(&to_u16(key_id.len())?.to_be_bytes());
651    aad.extend(key_id);
652    Ok(aad)
653}
654
655/// Derives a key and nonce pair using the odoh secret and
656/// response_nonce.
657fn derive_secrets(
658    odoh_secret: OdohSecret,
659    query: &ObliviousDoHMessagePlaintext,
660    response_nonce: ResponseNonce,
661) -> Result<(AeadKey, AeadNonce)> {
662    let buf = compose(query)?;
663    let salt = [
664        buf.as_ref(),
665        &to_u16(response_nonce.len())?.to_be_bytes(),
666        &response_nonce,
667    ]
668    .concat();
669
670    let h_key = Hkdf::<<Kdf as KdfTrait>::HashImpl>::new(Some(&salt), &odoh_secret);
671    let mut key = AeadKey::default();
672    h_key
673        .expand(LABEL_KEY, &mut key)
674        .map_err(|_| Error::from(HpkeError::KdfOutputTooLong))?;
675
676    let h_nonce = Hkdf::<<Kdf as KdfTrait>::HashImpl>::new(Some(&salt), &odoh_secret);
677    let mut nonce = AeadNonce::default();
678    h_nonce
679        .expand(LABEL_NONCE, &mut nonce)
680        .map_err(|_| Error::from(HpkeError::KdfOutputTooLong))?;
681
682    Ok((key, nonce))
683}
684
685#[inline]
686fn to_u16(n: usize) -> Result<u16> {
687    n.try_into().map_err(|_| Error::InvalidInputLength)
688}
689
690#[cfg(test)]
691mod tests {
692    use super::*;
693    use rand::rngs::StdRng;
694    use rand::SeedableRng;
695
696    #[test]
697    fn configs() {
698        // parse
699        let configs_hex = "002c000100280020000100010020bbd80565312cff62c44020a60c511711a6754425d5f42be1de3bca6b9bb3c50f";
700        let mut configs_bin: Bytes = hex::decode(configs_hex).unwrap().into();
701        let configs: ObliviousDoHConfigs = parse(&mut configs_bin).unwrap();
702        assert_eq!(configs.configs.len(), 1);
703        // check all bytes have been consumed
704        assert!(configs_bin.is_empty());
705
706        // compose
707        let buf = compose(&configs).unwrap();
708        assert_eq!(configs_hex, hex::encode(&buf));
709
710        // check support
711        let mut c1 = configs.configs[0].clone();
712        let mut c2 = c1.clone();
713        c1.version = 0xff;
714        let supported = ObliviousDoHConfigs::from(vec![c1.clone(), c2.clone()]).supported();
715        assert_eq!(supported[0], c2);
716
717        c2.version = 0xff;
718        let supported = ObliviousDoHConfigs::from(vec![c1, c2]).supported();
719        assert!(supported.is_empty());
720    }
721
722    #[test]
723    fn pubkey() {
724        // parse
725        let key_hex =
726            "0020000100010020aacc53b3df0c6eb2d7d5ce4ddf399593376c9903ba6a52a52c3a2340f97bb764";
727        let mut key_bin: Bytes = hex::decode(key_hex).unwrap().into();
728        let key: ObliviousDoHConfigContents = parse(&mut key_bin).unwrap();
729        assert!(key_bin.is_empty());
730
731        // compose
732        let buf = compose(&key).unwrap();
733        assert_eq!(key_hex, hex::encode(&buf));
734    }
735
736    #[test]
737    fn exchange() {
738        // Use a seed to initialize a RNG. *Note* you should rely on some
739        // random source.
740        let mut rng = StdRng::from_seed([0; 32]);
741
742        // Generate a key pair on server side.
743        let key_pair = ObliviousDoHKeyPair::new(&mut rng);
744
745        // Create client configs from the key pair. It can be distributed
746        // to the clients.
747        let public_key = key_pair.public().clone();
748        let client_configs: ObliviousDoHConfigs = vec![ObliviousDoHConfig::from(public_key)].into();
749        let mut client_configs_bytes = compose(&client_configs).unwrap().freeze();
750
751        // ... distributing client_configs_bytes ...
752
753        // Parse and extract first supported config from client configs on client side.
754        let client_configs: ObliviousDoHConfigs = parse(&mut client_configs_bytes).unwrap();
755        let config_contents = client_configs.supported()[0].clone().into();
756
757        // This is a example client request. This library doesn't validate
758        // DNS message.
759        let query = ObliviousDoHMessagePlaintext::new(b"What's the IP of one.one.one.one?", 0);
760
761        // Encrypt the above request. The client_secret returned will be
762        // used later to decrypt server's response.
763        let (query_enc, cli_secret) = encrypt_query(&query, &config_contents, &mut rng).unwrap();
764
765        // ... sending query_enc to the server ...
766
767        // Server decrypt request.
768        let (query_dec, srv_secret) = decrypt_query(&query_enc, &key_pair).unwrap();
769        assert_eq!(query, query_dec);
770
771        // Server could now resolve the decrypted query, and compose a response.
772        let response = ObliviousDoHMessagePlaintext::new(b"The IP is 1.1.1.1", 0);
773
774        // server encrypt response
775        let nonce = ResponseNonce::default();
776        let response_enc = encrypt_response(&query_dec, &response, srv_secret, nonce).unwrap();
777
778        // ... sending response_enc back to the client ...
779
780        // client descrypt response
781        let response_dec = decrypt_response(&query, &response_enc, cli_secret).unwrap();
782        assert_eq!(response, response_dec);
783    }
784
785    #[test]
786    fn test_vector() {
787        use super::*;
788        use serde::Deserialize as SerdeDeserialize;
789
790        const TEST_VECTORS: &str = std::include_str!("../tests/test-vectors.json");
791
792        #[derive(SerdeDeserialize, Debug, Clone)]
793        pub struct TestVector {
794            pub aead_id: u16,
795            pub kdf_id: u16,
796            pub kem_id: u16,
797            pub key_id: String,
798            pub odohconfigs: String,
799            pub public_key_seed: String,
800            pub transactions: Vec<Transaction>,
801        }
802
803        #[derive(SerdeDeserialize, Debug, Clone)]
804        #[serde(rename_all = "camelCase")]
805        pub struct Transaction {
806            pub oblivious_query: String,
807            pub oblivious_response: String,
808            pub query: String,
809            pub response: String,
810            pub query_padding_length: usize,
811            pub response_padding_length: usize,
812        }
813
814        let test_vectors: Vec<TestVector> = serde_json::from_str(TEST_VECTORS).unwrap();
815        for tv in test_vectors {
816            let ikm_bytes = hex::decode(tv.public_key_seed).unwrap();
817            let (secret_key, _) = Kem::derive_keypair(&ikm_bytes);
818
819            let mut configs_bytes: Bytes = hex::decode(tv.odohconfigs).unwrap().into();
820            let configs: ObliviousDoHConfigs = parse(&mut configs_bytes).unwrap();
821            let odoh_public_key: ObliviousDoHConfigContents =
822                configs.supported().into_iter().next().unwrap().into();
823
824            assert_eq!(
825                odoh_public_key.identifier().unwrap(),
826                hex::decode(tv.key_id).unwrap(),
827            );
828
829            let key_pair = ObliviousDoHKeyPair {
830                private_key: secret_key,
831                public_key: odoh_public_key,
832            };
833
834            for t in tv.transactions {
835                let query = ObliviousDoHMessagePlaintext::new(
836                    &hex::decode(t.query).unwrap(),
837                    t.query_padding_length,
838                );
839
840                let mut odoh_query_bytes: Bytes = hex::decode(t.oblivious_query).unwrap().into();
841                let odoh_query = parse(&mut odoh_query_bytes).unwrap();
842
843                // decrypt oblivious_query from test should match its query
844                let (odoh_query_dec, srv_secret) = decrypt_query(&odoh_query, &key_pair).unwrap();
845                assert_eq!(odoh_query_dec, query);
846
847                let odoh_response_bytes: Bytes = hex::decode(t.oblivious_response).unwrap().into();
848                let odoh_response: ObliviousDoHMessage =
849                    parse(&mut odoh_response_bytes.clone()).unwrap();
850
851                let response = ObliviousDoHMessagePlaintext::new(
852                    &hex::decode(t.response).unwrap(),
853                    t.response_padding_length,
854                );
855
856                // assert with fixed response nonce to make sure the
857                // right hpke version is being used
858                let response_enc = encrypt_response(
859                    &query,
860                    &response,
861                    srv_secret,
862                    odoh_response.key_id[..16].try_into().unwrap(),
863                )
864                .unwrap();
865
866                // encrypted response is the same as the one parsed from test
867                let response_enc_bytes = compose(&response_enc).unwrap();
868                assert_eq!(response_enc_bytes.as_ref(), odoh_response_bytes.as_ref(),);
869            }
870        }
871    }
872
873    #[test]
874    fn padding() {
875        let query = ObliviousDoHMessagePlaintext::new(&[], 0);
876        assert_eq!(query.padding_len(), 0);
877
878        let query = ObliviousDoHMessagePlaintext::new(&[], 2);
879        assert_eq!(query.padding_len(), 2);
880
881        let mut query_bytes = compose(&query).unwrap();
882        let last = query_bytes.len() - 1;
883        query_bytes[last] = 0x01;
884        assert_eq!(
885            Error::InvalidPadding,
886            parse::<ObliviousDoHMessagePlaintext, _>(&mut query_bytes.freeze()).unwrap_err()
887        );
888
889        let mut query = query;
890        query.padding = vec![1, 2].into();
891        assert_eq!(Error::InvalidPadding, compose(&query).unwrap_err());
892    }
893
894    #[test]
895    fn parse_encapsulated_key() {
896        // Use a seed to initialize a RNG. *Note* you should rely on some
897        // random source.
898        let mut rng = StdRng::from_seed([0; 32]);
899        let key_pair = ObliviousDoHKeyPair::new(&mut rng);
900
901        // Construct a malformed payload. Parsing the encrypted message should fail because it is
902        // too short to include the encapsulated key.
903        let query_enc = ObliviousDoHMessage {
904            msg_type: ObliviousDoHMessageType::Query,
905            key_id: key_pair.public().identifier().unwrap().to_vec().into(),
906            encrypted_msg: b"too short".to_vec().into(),
907        };
908        assert!(decrypt_query(&query_enc, &key_pair).is_err());
909    }
910}