use std::io::Cursor;
use rand::CryptoRng;
use tls_codec::{Deserialize, Serialize, Size, VLByteSlice};
use super::InitiatorOuterPayloadOut;
use crate::{
aead::{AEADKeyNonce, AeadType},
handshake::{
ciphersuite::CiphersuiteName,
derive_k0,
dhkem::{DHKeyPair, DHPrivateKey, DHPublicKey, DHSharedSecret},
responder::ResponderQueryPayload,
transcript::{tx2, Transcript},
write_output, HandshakeError as Error, HandshakeMessage, HandshakeMessageOut, K2IkmQuery,
},
traits::Channel,
};
pub struct QueryInitiator<'a> {
responder_longterm_ecdh_pk: &'a DHPublicKey,
initiator_ephemeral_keys: DHKeyPair,
tx0: Transcript,
k0: AEADKeyNonce,
outer_aad: &'a [u8],
}
fn derive_k2_query_initiator(
k0: &AEADKeyNonce,
responder_ephemeral_ecdh_pk: &DHPublicKey,
initiator_ephemeral_ecdh_sk: &DHPrivateKey,
responder_longterm_ecdh_pk: &DHPublicKey,
tx2: &Transcript,
) -> Result<AEADKeyNonce, Error> {
let initiator_ikm = K2IkmQuery {
k0,
g_xs: &DHSharedSecret::derive(initiator_ephemeral_ecdh_sk, responder_longterm_ecdh_pk)?,
g_xy: &DHSharedSecret::derive(initiator_ephemeral_ecdh_sk, responder_ephemeral_ecdh_pk)?,
};
AEADKeyNonce::new(&initiator_ikm, tx2, AeadType::ChaCha20Poly1305).map_err(|e| e.into())
}
impl<'a> QueryInitiator<'a> {
pub(crate) fn new(
responder_longterm_ecdh_pk: &'a DHPublicKey,
ctx: &[u8],
outer_aad: &'a [u8],
mut rng: impl CryptoRng,
) -> Result<Self, Error> {
let initiator_ephemeral_keys = DHKeyPair::new(&mut rng);
let (tx0, k0) = derive_k0(
responder_longterm_ecdh_pk,
&initiator_ephemeral_keys.pk,
&initiator_ephemeral_keys.sk,
ctx,
false,
AeadType::ChaCha20Poly1305,
)?;
Ok(Self {
responder_longterm_ecdh_pk,
tx0,
k0,
outer_aad,
initiator_ephemeral_keys,
})
}
fn read_response(
&self,
responder_msg: &HandshakeMessage,
) -> Result<ResponderQueryPayload, Error> {
let tx2 = tx2(&self.tx0, &responder_msg.pk)?;
let mut k2 = derive_k2_query_initiator(
&self.k0,
&responder_msg.pk,
&self.initiator_ephemeral_keys.sk,
self.responder_longterm_ecdh_pk,
&tx2,
)?;
k2.handshake_decrypt::<ResponderQueryPayload>(
responder_msg.ciphertext.as_slice(),
&responder_msg.tag,
responder_msg.aad.as_slice(),
)
.map_err(|e| e.into())
}
fn prepare_message_contents(&mut self, payload: &[u8]) -> Result<(Vec<u8>, [u8; 16]), Error> {
let outer_payload = InitiatorOuterPayloadOut::Query(VLByteSlice(payload));
let (ciphertext, tag) = self.k0.handshake_encrypt(&outer_payload, self.outer_aad)?;
Ok((ciphertext, tag))
}
fn process_message(
&mut self,
message: &HandshakeMessage,
out: &mut [u8],
) -> Result<usize, Error> {
let result = self.read_response(&message)?;
let out_bytes_written = write_output(result.0.as_slice(), out)?;
Ok(out_bytes_written)
}
}
impl<'a> Channel<Error, HandshakeMessage> for QueryInitiator<'a> {
fn write_message(&mut self, payload: &[u8], out: &mut [u8]) -> Result<usize, Error> {
let (ciphertext, tag) = self.prepare_message_contents(payload)?;
let msg = HandshakeMessageOut {
pk: &self.initiator_ephemeral_keys.pk,
ciphertext: VLByteSlice(&ciphertext),
tag,
aad: VLByteSlice(self.outer_aad),
ciphersuite: CiphersuiteName::query_ciphersuite(),
};
msg.tls_serialize(&mut &mut out[..])
.map_err(Error::Serialize)
}
fn read_message(
&mut self,
message_bytes: &[u8],
out: &mut [u8],
) -> Result<(usize, usize), Error> {
let msg = HandshakeMessage::tls_deserialize(&mut Cursor::new(message_bytes))
.map_err(Error::Deserialize)?;
let out_bytes_written = self.process_message(&msg, out)?;
Ok((msg.tls_serialized_len(), out_bytes_written))
}
fn write_message_external_encoding(
&mut self,
payload: &[u8],
) -> Result<HandshakeMessage, Error> {
let (ciphertext, tag) = self.prepare_message_contents(payload)?;
Ok(HandshakeMessage {
pk: self.initiator_ephemeral_keys.pk.clone(),
ciphertext,
tag,
aad: self.outer_aad.to_vec(),
ciphersuite: CiphersuiteName::query_ciphersuite(),
})
}
fn read_message_external_encoding(
&mut self,
message: &HandshakeMessage,
) -> Result<Vec<u8>, Error> {
let mut out = vec![0u8; message.ciphertext.len()];
let out_bytes_written = self.process_message(&message, &mut out)?;
out.truncate(out_bytes_written);
Ok(out)
}
}