use bitflags::bitflags;
use zeroize::Zeroize;
use crate::classic::crypto_secretstream_xchacha20poly1305::{
State, crypto_secretstream_xchacha20poly1305_init_pull,
crypto_secretstream_xchacha20poly1305_init_push, crypto_secretstream_xchacha20poly1305_pull,
crypto_secretstream_xchacha20poly1305_push, crypto_secretstream_xchacha20poly1305_rekey,
};
use crate::constants::{
CRYPTO_SECRETSTREAM_XCHACHA20POLY1305_HEADERBYTES,
CRYPTO_SECRETSTREAM_XCHACHA20POLY1305_KEYBYTES,
CRYPTO_SECRETSTREAM_XCHACHA20POLY1305_TAG_MESSAGE,
CRYPTO_SECRETSTREAM_XCHACHA20POLY1305_TAG_PUSH,
CRYPTO_SECRETSTREAM_XCHACHA20POLY1305_TAG_REKEY, CRYPTO_STREAM_CHACHA20_IETF_NONCEBYTES,
};
use crate::error::Error;
pub use crate::types::*;
pub trait Mode {}
pub struct Push;
pub struct Pull;
impl Mode for Push {}
impl Mode for Pull {}
pub type Key = StackByteArray<CRYPTO_SECRETSTREAM_XCHACHA20POLY1305_KEYBYTES>;
pub type Nonce = StackByteArray<CRYPTO_STREAM_CHACHA20_IETF_NONCEBYTES>;
pub type Header = StackByteArray<CRYPTO_SECRETSTREAM_XCHACHA20POLY1305_HEADERBYTES>;
#[cfg(any(feature = "nightly", all(doc, not(doctest))))]
#[cfg_attr(all(feature = "nightly", doc), doc(cfg(feature = "nightly")))]
pub mod protected {
use super::*;
pub use crate::protected::*;
pub type Key = HeapByteArray<CRYPTO_SECRETSTREAM_XCHACHA20POLY1305_KEYBYTES>;
pub type Nonce = HeapByteArray<CRYPTO_STREAM_CHACHA20_IETF_NONCEBYTES>;
pub type Header = HeapByteArray<CRYPTO_SECRETSTREAM_XCHACHA20POLY1305_HEADERBYTES>;
}
bitflags! {
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct Tag: u8 {
const MESSAGE = CRYPTO_SECRETSTREAM_XCHACHA20POLY1305_TAG_MESSAGE;
const PUSH = CRYPTO_SECRETSTREAM_XCHACHA20POLY1305_TAG_PUSH;
const REKEY = CRYPTO_SECRETSTREAM_XCHACHA20POLY1305_TAG_REKEY;
const FINAL = Self::PUSH.bits() | Self::REKEY.bits();
}
}
impl From<u8> for Tag {
fn from(other: u8) -> Self {
Self::from_bits(other).expect("Unable to parse tag")
}
}
#[derive(PartialEq, Eq, Clone, Zeroize)]
pub struct DryocStream<Mode> {
state: State,
phantom: std::marker::PhantomData<Mode>,
}
impl<Mode> Drop for DryocStream<Mode> {
fn drop(&mut self) {
self.state.zeroize()
}
}
impl<M> DryocStream<M> {
pub fn rekey(&mut self) {
crypto_secretstream_xchacha20poly1305_rekey(&mut self.state)
}
}
impl DryocStream<Push> {
pub fn init_push<
Key: ByteArray<CRYPTO_SECRETSTREAM_XCHACHA20POLY1305_KEYBYTES>,
Header: NewByteArray<CRYPTO_SECRETSTREAM_XCHACHA20POLY1305_HEADERBYTES>,
>(
key: &Key,
) -> (Self, Header) {
let mut state = State::new();
let mut header = Header::new_byte_array();
crypto_secretstream_xchacha20poly1305_init_push(
&mut state,
header.as_mut_array(),
key.as_array(),
);
(
Self {
state,
phantom: std::marker::PhantomData,
},
header,
)
}
pub fn push<Input: Bytes, Output: NewBytes + ResizableBytes>(
&mut self,
message: &Input,
associated_data: Option<&Input>,
tag: Tag,
) -> Result<Output, Error> {
use crate::constants::CRYPTO_SECRETSTREAM_XCHACHA20POLY1305_ABYTES;
let mut ciphertext = Output::new_bytes();
ciphertext.resize(
message.as_slice().len() + CRYPTO_SECRETSTREAM_XCHACHA20POLY1305_ABYTES,
0,
);
crypto_secretstream_xchacha20poly1305_push(
&mut self.state,
ciphertext.as_mut_slice(),
message.as_slice(),
associated_data.map(|aad| aad.as_slice()),
tag.bits(),
)?;
Ok(ciphertext)
}
pub fn push_to_vec<Input: Bytes>(
&mut self,
message: &Input,
associated_data: Option<&Input>,
tag: Tag,
) -> Result<Vec<u8>, Error> {
self.push(message, associated_data, tag)
}
}
impl DryocStream<Pull> {
pub fn init_pull<
Key: ByteArray<CRYPTO_SECRETSTREAM_XCHACHA20POLY1305_KEYBYTES>,
Header: ByteArray<CRYPTO_SECRETSTREAM_XCHACHA20POLY1305_HEADERBYTES>,
>(
key: &Key,
header: &Header,
) -> Self {
let mut state = State::new();
crypto_secretstream_xchacha20poly1305_init_pull(
&mut state,
header.as_array(),
key.as_array(),
);
Self {
state,
phantom: std::marker::PhantomData,
}
}
pub fn pull<Input: Bytes, Output: MutBytes + Default + ResizableBytes>(
&mut self,
ciphertext: &Input,
associated_data: Option<&Input>,
) -> Result<(Output, Tag), Error> {
use crate::constants::CRYPTO_SECRETSTREAM_XCHACHA20POLY1305_ABYTES;
let mut message = Output::default();
message.resize(
ciphertext.as_slice().len() - CRYPTO_SECRETSTREAM_XCHACHA20POLY1305_ABYTES,
0,
);
let mut tag = 0u8;
crypto_secretstream_xchacha20poly1305_pull(
&mut self.state,
message.as_mut_slice(),
&mut tag,
ciphertext.as_slice(),
associated_data.map(|aad| aad.as_slice()),
)?;
Ok((message, Tag::from_bits(tag).expect("invalid tag")))
}
pub fn pull_to_vec<Input: Bytes>(
&mut self,
ciphertext: &Input,
associated_data: Option<&Input>,
) -> Result<(Vec<u8>, Tag), Error> {
self.pull(ciphertext, associated_data)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stream_push() {
use sodiumoxide::crypto::secretstream::{
Header as SOHeader, Key as SOKey, Stream as SOStream, Tag as SOTag,
};
let message1 = b"Arbitrary data to encrypt";
let message2 = b"split into";
let message3 = b"three messages";
let key = Key::gen();
let (mut push_stream, header): (_, Header) = DryocStream::init_push(&key);
let c1: Vec<u8> = push_stream
.push(message1, None, Tag::MESSAGE)
.expect("Encrypt failed");
let c2: Vec<u8> = push_stream
.push(message2, None, Tag::MESSAGE)
.expect("Encrypt failed");
let c3: Vec<u8> = push_stream
.push(message3, None, Tag::FINAL)
.expect("Encrypt failed");
let mut so_stream_pull = SOStream::init_pull(
&SOHeader::from_slice(header.as_slice()).expect("header failed"),
&SOKey::from_slice(key.as_slice()).expect("key failed"),
)
.expect("pull init failed");
let (m1, tag1) = so_stream_pull.pull(&c1, None).expect("decrypt failed");
let (m2, tag2) = so_stream_pull.pull(&c2, None).expect("decrypt failed");
let (m3, tag3) = so_stream_pull.pull(&c3, None).expect("decrypt failed");
assert_eq!(message1, m1.as_slice());
assert_eq!(message2, m2.as_slice());
assert_eq!(message3, m3.as_slice());
assert_eq!(tag1, SOTag::Message);
assert_eq!(tag2, SOTag::Message);
assert_eq!(tag3, SOTag::Final);
}
#[test]
fn test_stream_pull() {
use std::convert::TryFrom;
use sodiumoxide::crypto::secretstream::{Key as SOKey, Stream as SOStream, Tag as SOTag};
let message1 = b"Arbitrary data to encrypt";
let message2 = b"split into";
let message3 = b"three messages";
let key = Key::gen();
let (mut so_push_stream, so_header) =
SOStream::init_push(&SOKey::from_slice(key.as_slice()).expect("key failed"))
.expect("init push failed");
let c1: Vec<u8> = so_push_stream
.push(message1, None, SOTag::Message)
.expect("Encrypt failed");
let c2: Vec<u8> = so_push_stream
.push(message2, None, SOTag::Message)
.expect("Encrypt failed");
let c3: Vec<u8> = so_push_stream
.push(message3, None, SOTag::Final)
.expect("Encrypt failed");
let mut pull_stream =
DryocStream::init_pull(&key, &Header::try_from(so_header.as_ref()).expect("header"));
let (m1, tag1): (Vec<u8>, Tag) = pull_stream.pull(&c1, None).expect("Decrypt failed");
let (m2, tag2): (Vec<u8>, Tag) = pull_stream.pull(&c2, None).expect("Decrypt failed");
let (m3, tag3): (Vec<u8>, Tag) = pull_stream.pull(&c3, None).expect("Decrypt failed");
assert_eq!(message1, m1.as_slice());
assert_eq!(message2, m2.as_slice());
assert_eq!(message3, m3.as_slice());
assert_eq!(tag1, Tag::MESSAGE);
assert_eq!(tag2, Tag::MESSAGE);
assert_eq!(tag3, Tag::FINAL);
}
#[cfg(feature = "nightly")]
#[test]
fn test_protected_memory() {
use crate::protected::*;
let message1 = b"Arbitrary data to encrypt";
let message2 = b"split into";
let message3 = b"three messages";
let key = protected::Key::gen_locked().expect("gen locked");
let (mut push_stream, header): (_, Header) = DryocStream::init_push(&key);
let key = key
.munlock()
.expect("munlock")
.mprotect_noaccess()
.expect("mprotect");
let c1: Locked<HeapBytes> = push_stream
.push(message1, None, Tag::MESSAGE)
.expect("Encrypt failed");
let c2: Vec<u8> = push_stream
.push(message2, None, Tag::MESSAGE)
.expect("Encrypt failed");
let c3: Vec<u8> = push_stream
.push(message3, None, Tag::FINAL)
.expect("Encrypt failed");
let key = key.mprotect_readonly().expect("mprotect");
let mut pull_stream = DryocStream::init_pull(&key, &header);
let _key = key.mprotect_noaccess().expect("mprotect");
let (m1, tag1): (Locked<HeapBytes>, Tag) =
pull_stream.pull(&c1, None).expect("Decrypt failed");
let (m2, tag2): (Locked<HeapBytes>, Tag) =
pull_stream.pull(&c2, None).expect("Decrypt failed");
let (m3, tag3): (Locked<HeapBytes>, Tag) =
pull_stream.pull(&c3, None).expect("Decrypt failed");
assert_eq!(message1, m1.as_slice());
assert_eq!(message2, m2.as_slice());
assert_eq!(message3, m3.as_slice());
assert_eq!(tag1, Tag::MESSAGE);
assert_eq!(tag2, Tag::MESSAGE);
assert_eq!(tag3, Tag::FINAL);
}
}