use std::fmt::Debug;
use prost::Message as ProstMessage;
use serde::{Deserialize, Serialize};
use super::Message;
use crate::{
Curve25519PublicKey, DecodeError,
olm::SessionKeys,
utilities::{base64_decode, base64_encode},
};
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct PreKeyMessage {
pub(crate) session_keys: SessionKeys,
pub(crate) message: Message,
}
impl PreKeyMessage {
const VERSION: u8 = 3;
pub const fn one_time_key(&self) -> Curve25519PublicKey {
self.session_keys.one_time_key
}
pub const fn base_key(&self) -> Curve25519PublicKey {
self.session_keys.base_key
}
pub const fn identity_key(&self) -> Curve25519PublicKey {
self.session_keys.identity_key
}
pub const fn session_keys(&self) -> SessionKeys {
self.session_keys
}
pub fn session_id(&self) -> String {
self.session_keys.session_id()
}
pub const fn message(&self) -> &Message {
&self.message
}
pub fn from_bytes(message: &[u8]) -> Result<Self, DecodeError> {
Self::try_from(message)
}
pub fn to_bytes(&self) -> Vec<u8> {
let message = ProtoBufPreKeyMessage {
one_time_key: self.session_keys.one_time_key.as_bytes().to_vec(),
base_key: self.session_keys.base_key.as_bytes().to_vec(),
identity_key: self.session_keys.identity_key.as_bytes().to_vec(),
message: self.message.to_bytes(),
};
let mut output: Vec<u8> = vec![0u8; message.encoded_len() + 1];
output[0] = Self::VERSION;
#[allow(clippy::expect_used)]
message
.encode(&mut output[1..].as_mut())
.expect("We should be able to encode a pre-key message into protobuf.");
output
}
pub fn from_base64(message: &str) -> Result<Self, DecodeError> {
Self::try_from(message)
}
pub fn to_base64(&self) -> String {
base64_encode(self.to_bytes())
}
#[cfg(feature = "low-level-api")]
pub const fn wrap(session_keys: SessionKeys, message: Message) -> Self {
PreKeyMessage::new(session_keys, message)
}
pub(crate) const fn new(session_keys: SessionKeys, message: Message) -> Self {
Self { session_keys, message }
}
}
impl Serialize for PreKeyMessage {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let message = self.to_base64();
serializer.serialize_str(&message)
}
}
impl<'de> Deserialize<'de> for PreKeyMessage {
fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let ciphertext = String::deserialize(d)?;
PreKeyMessage::from_base64(&ciphertext).map_err(serde::de::Error::custom)
}
}
impl TryFrom<&str> for PreKeyMessage {
type Error = DecodeError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
let decoded = base64_decode(value)?;
Self::try_from(decoded)
}
}
impl TryFrom<Vec<u8>> for PreKeyMessage {
type Error = DecodeError;
fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
Self::try_from(value.as_slice())
}
}
impl TryFrom<&[u8]> for PreKeyMessage {
type Error = DecodeError;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
let version = *value.first().ok_or(DecodeError::MissingVersion)?;
if version != Self::VERSION {
Err(DecodeError::InvalidVersion(Self::VERSION, version))
} else {
let decoded = ProtoBufPreKeyMessage::decode(&value[1..value.len()])?;
let one_time_key = Curve25519PublicKey::from_slice(&decoded.one_time_key)?;
let base_key = Curve25519PublicKey::from_slice(&decoded.base_key)?;
let identity_key = Curve25519PublicKey::from_slice(&decoded.identity_key)?;
let message = decoded.message.try_into()?;
let session_keys = SessionKeys { one_time_key, identity_key, base_key };
Ok(Self { session_keys, message })
}
}
}
#[derive(Clone, ProstMessage)]
struct ProtoBufPreKeyMessage {
#[prost(bytes, tag = "1")]
one_time_key: Vec<u8>,
#[prost(bytes, tag = "2")]
base_key: Vec<u8>,
#[prost(bytes, tag = "3")]
identity_key: Vec<u8>,
#[prost(bytes, tag = "4")]
message: Vec<u8>,
}