1use super::secret::Secret;
5use aead::{generic_array::typenum::Unsigned, KeySizeUser, Payload};
6use chacha20poly1305::{
7 aead::{Aead, AeadCore, KeyInit, OsRng},
8 Key, XChaCha20Poly1305,
9};
10use cid::Cid;
11use co_primitives::{from_cbor, to_cbor, Block, KnownMultiCodec, MultiCodec, MultiCodecError};
12use derive_more::From;
13use multihash_codetable::{Code, MultihashDigest};
14use serde::{Deserialize, Serialize};
15use serde_repr::{Deserialize_repr, Serialize_repr};
16use std::{cmp::min, collections::BTreeMap, fmt::Debug, mem::take};
17
18pub const BLOCK_KEY_DERIVATION: &str = "co 2023-10-24T10:25:23Z block key derivation v1";
22pub const BLOCK_DERIVATION: &str = "co 2023-10-26T14:31:38Z block derivation v1";
23pub const BLOCK_MULTICODEC: u64 = KnownMultiCodec::CoEncryptedBlock as u64;
24
25pub type Nonce = Vec<u8>;
27pub type Salt = Vec<u8>;
29pub type CipherU8 = u8;
31
32#[derive(Debug, thiserror::Error)]
33pub enum AlgorithmError {
34 #[error("Generic Cipher Error")]
35 Cipher,
36
37 #[error("Invalid arguments specified")]
38 InvalidArguments(#[source] anyhow::Error),
39
40 #[error("Generic decoding error")]
41 Decoding,
42
43 #[error("Generic encoding error")]
44 Encoding,
45
46 #[error("Size is to large")]
47 Size,
48}
49impl From<aead::Error> for AlgorithmError {
50 fn from(_: aead::Error) -> Self {
51 AlgorithmError::Cipher
52 }
53}
54impl From<MultiCodecError> for AlgorithmError {
55 fn from(value: MultiCodecError) -> Self {
56 AlgorithmError::InvalidArguments(value.into())
57 }
58}
59
60#[derive(Debug, Clone, Copy, Serialize_repr, Deserialize_repr, PartialEq)]
61#[repr(u8)]
62#[derive(Default)]
63pub enum Algorithm {
64 #[default]
65 XChaCha20Poly1305 = 1,
66}
67impl Algorithm {
68 pub fn key_size(&self) -> usize {
70 match self {
71 Algorithm::XChaCha20Poly1305 => XChaCha20Poly1305::key_size(),
72 }
73 }
74
75 pub fn nonce_size(&self) -> usize {
77 match self {
78 Algorithm::XChaCha20Poly1305 => <XChaCha20Poly1305 as AeadCore>::NonceSize::USIZE,
79 }
80 }
81
82 pub fn tag_size(&self) -> usize {
84 match self {
85 Algorithm::XChaCha20Poly1305 => <XChaCha20Poly1305 as AeadCore>::TagSize::USIZE,
86 }
87 }
88
89 pub fn generate_serect(&self) -> Secret {
91 match self {
92 Algorithm::XChaCha20Poly1305 => Secret::new(XChaCha20Poly1305::generate_key(&mut OsRng).to_vec()),
93 }
94 }
95
96 pub fn generate_nonce(&self) -> Nonce {
98 match self {
99 Algorithm::XChaCha20Poly1305 => XChaCha20Poly1305::generate_nonce(&mut OsRng).to_vec(),
100 }
101 }
102
103 pub fn encrypt(
105 &self,
106 secret: &Secret,
107 nonce: &Nonce,
108 plaintext: &[u8],
109 aad: &[u8],
110 ) -> Result<Vec<u8>, AlgorithmError> {
111 if self.nonce_size() != nonce.len() {
113 return Err(AlgorithmError::InvalidArguments(anyhow::anyhow!("nonce size")));
114 }
115 if self.key_size() != secret.divulge().len() {
116 return Err(AlgorithmError::InvalidArguments(anyhow::anyhow!("key size")));
117 }
118
119 match self {
121 Algorithm::XChaCha20Poly1305 => {
122 let cipher = XChaCha20Poly1305::new(Key::from_slice(secret.divulge()));
123 let payload = Payload { msg: plaintext, aad };
124 cipher
125 .encrypt(aead::Nonce::<XChaCha20Poly1305>::from_slice(nonce.as_slice()), payload)
126 .map_err(|e| e.into())
127 },
128 }
129 }
130
131 pub fn decrypt(
133 &self,
134 secret: &Secret,
135 nonce: &Nonce,
136 ciphertext: &[CipherU8],
137 aad: &[u8],
138 ) -> Result<Vec<u8>, AlgorithmError> {
139 if self.nonce_size() != nonce.len() {
141 return Err(AlgorithmError::InvalidArguments(anyhow::anyhow!("nonce size")));
142 }
143 if self.key_size() != secret.divulge().len() {
144 return Err(AlgorithmError::InvalidArguments(anyhow::anyhow!("key size")));
145 }
146
147 match self {
149 Algorithm::XChaCha20Poly1305 => {
150 let cipher = XChaCha20Poly1305::new(Key::from_slice(secret.divulge()));
151 let payload = Payload { msg: ciphertext, aad };
152 cipher
153 .decrypt(aead::Nonce::<XChaCha20Poly1305>::from_slice(nonce.as_slice()), payload)
154 .map_err(|e| e.into())
155 },
156 }
157 }
158}
159
160#[derive(Debug, Clone, Copy, Serialize_repr, Deserialize_repr, PartialEq)]
161#[repr(u8)]
162pub enum EncryptionVersion {
163 V1 = 1,
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct EncryptedBlock {
169 #[serde(rename = "h")]
171 pub header: Header,
172
173 #[serde(rename = "d")]
175 pub payload: EncryptedData,
176}
177impl EncryptedBlock {
178 pub fn encrypt(
180 algorithm: Algorithm,
181 secret: &Secret,
182 block: impl Into<BlockPayload>,
183 ) -> Result<EncryptedBlock, AlgorithmError> {
184 let block_secret = algorithm.generate_serect();
185 Self::encrypt_with_block_secret(algorithm, secret, &block_secret, block)
186 }
187
188 pub fn encrypt_with_block_secret(
190 algorithm: Algorithm,
191 secret: &Secret,
192 block_secret: &Secret,
193 block: impl Into<BlockPayload>,
194 ) -> Result<EncryptedBlock, AlgorithmError> {
195 let block: BlockPayload = block.into();
196
197 let data_secret = block_secret.derive_serect(BLOCK_DERIVATION);
199
200 let key_slot = KeySlot::new(algorithm, secret, block_secret)?;
202 let header = Header::new(algorithm, vec![key_slot]);
203
204 let aad = header.aad();
206
207 let data = block.to_bytes().map_err(|_e| AlgorithmError::Encoding)?;
209
210 Ok(Self {
212 payload: header
213 .algorithm
214 .encrypt(&data_secret, &header.nonce, data.as_slice(), aad.as_slice())?
215 .into(),
216 header,
217 })
218 }
219
220 pub fn block(&self, secret: &Secret) -> Result<BlockPayload, AlgorithmError> {
222 let block_secret = self
223 .header
224 .block_secret(secret)
225 .ok_or(AlgorithmError::InvalidArguments(anyhow::anyhow!("key")))?;
226 let aad = self.header.aad();
227 let data = self
228 .payload
229 .inline()
230 .ok_or(AlgorithmError::InvalidArguments(anyhow::anyhow!("Expected inline data")))?;
231 let data_plain = self.decrypt_data(&block_secret, data, &aad)?;
232 from_cbor(&data_plain).map_err(|err| AlgorithmError::InvalidArguments(err.into()))
233 }
234
235 fn decrypt_data(&self, block_secret: &Secret, data: &[u8], aad: &[u8]) -> Result<Vec<u8>, AlgorithmError> {
236 let data_secret = block_secret.derive_serect(BLOCK_DERIVATION);
237 let data = self.header.algorithm.decrypt(&data_secret, &self.header.nonce, data, aad)?;
238 Ok(data)
239 }
240
241 pub fn is_valid(&self) -> bool {
243 self.header.is_valid()
244 }
245}
246impl TryInto<Block> for EncryptedBlock {
247 type Error = AlgorithmError;
248
249 fn try_into(self) -> Result<Block, Self::Error> {
251 let encrypted_data = to_cbor(&self).map_err(|_| AlgorithmError::Encoding)?;
252 let mh = Code::Blake3_256.digest(&encrypted_data);
253 let cid = Cid::new_v1(KnownMultiCodec::CoEncryptedBlock.into(), mh);
254 Ok(Block::new_unchecked(cid, encrypted_data))
255 }
256}
257impl TryFrom<Block> for EncryptedBlock {
258 type Error = AlgorithmError;
259
260 fn try_from(value: Block) -> Result<Self, Self::Error> {
262 MultiCodec::with_codec(KnownMultiCodec::CoEncryptedBlock, value.cid())?;
264
265 let block: EncryptedBlock = from_cbor(value.data()).map_err(|_| AlgorithmError::Decoding)?;
267
268 if !block.is_valid() {
270 return Err(AlgorithmError::Decoding);
271 }
272
273 Ok(block)
275 }
276}
277
278#[derive(Debug, Clone, Serialize, Deserialize, From)]
279#[serde(untagged)]
280pub enum EncryptedData {
281 #[from]
283 #[serde(with = "serde_bytes")]
284 Inline(Vec<CipherU8>),
285
286 #[from]
289 Block(Vec<Cid>),
290}
291impl EncryptedData {
292 pub fn inline(&self) -> Option<&[u8]> {
293 match self {
294 Self::Inline(data) => Some(data),
295 _ => None,
296 }
297 }
298
299 pub fn blocks(&self) -> Option<&[Cid]> {
300 match self {
301 Self::Block(data) => Some(data),
302 _ => None,
303 }
304 }
305
306 pub fn fit_into_blocks(&mut self, max_block_size: usize, inline_offset: Option<usize>) -> Vec<Block> {
314 let mut data = match self {
315 Self::Inline(data) => {
316 if max_block_size >= data.len() + inline_offset.unwrap_or(0) {
317 return vec![];
318 } else {
319 take(data)
320 }
321 },
322 Self::Block(_) => {
323 return vec![];
324 },
325 };
326 let mut extra_blocks = Vec::new();
327 while !data.is_empty() {
328 let rest = data.split_off(min(data.len(), max_block_size));
329 extra_blocks.push(Block::new_data(KnownMultiCodec::Raw, data));
330 data = rest;
331 }
332 *self = Self::Block(extra_blocks.iter().map(|block| *block.cid()).collect());
333 extra_blocks
334 }
335
336 pub fn try_inline_blocks(&mut self, blocks: impl IntoIterator<Item = (Cid, Vec<u8>)>) -> Result<(), ()> {
338 match self {
339 Self::Inline(_) => Ok(()),
340 Self::Block(cids) => {
341 let mut blocks: BTreeMap<Cid, Vec<u8>> = blocks.into_iter().collect();
342 if !cids.iter().all(|cid| blocks.contains_key(cid)) {
343 return Err(());
344 }
345 let mut inline = Vec::new();
346 for cid in cids {
347 if let Some(mut block) = blocks.remove(cid) {
348 inline.append(&mut block);
349 } else {
350 return Err(());
351 }
352 }
353 *self = Self::Inline(inline);
354 Ok(())
355 },
356 }
357 }
358}
359
360#[derive(Debug, Clone, Serialize, Deserialize)]
363pub struct BlockPayload {
364 #[serde(rename = "c")]
366 pub cid: Cid,
367
368 #[serde(rename = "r", default, skip_serializing_if = "BTreeMap::is_empty")]
371 pub references: BTreeMap<Cid, Cid>,
372
373 #[serde(with = "serde_bytes", rename = "d")]
375 pub data: Vec<u8>,
376}
377impl BlockPayload {
378 pub fn cid(&self) -> &Cid {
380 &self.cid
381 }
382
383 pub fn to_bytes(&self) -> Result<Vec<u8>, anyhow::Error> {
390 Ok(to_cbor(self)?)
391 }
392}
393impl From<Block> for BlockPayload {
394 fn from(value: Block) -> Self {
395 let (cid, data) = value.into_inner();
396 Self { cid, data, references: Default::default() }
397 }
398}
399impl From<BlockPayload> for Block {
400 fn from(value: BlockPayload) -> Self {
401 Block::new_unchecked(value.cid, value.data)
402 }
403}
404
405#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
406pub struct Header {
407 #[serde(rename = "v")]
409 pub version: EncryptionVersion,
410
411 #[serde(rename = "a")]
413 pub algorithm: Algorithm,
414
415 #[serde(rename = "k")]
417 pub key_slots: Vec<KeySlot>,
418
419 #[serde(rename = "n", with = "serde_bytes")]
421 pub nonce: Nonce,
422}
423impl Header {
424 pub fn new(algorithm: Algorithm, key_slots: Vec<KeySlot>) -> Self {
425 Self { version: EncryptionVersion::V1, algorithm, nonce: algorithm.generate_nonce(), key_slots }
426 }
427
428 pub fn is_valid(&self) -> bool {
430 self.version == EncryptionVersion::V1
431 && self.nonce.len() == self.algorithm.nonce_size()
432 && self.key_slots.iter().all(KeySlot::is_valid)
433 }
434
435 pub fn aad(&self) -> Vec<u8> {
437 let mut result = Vec::with_capacity(1 + 1 + self.nonce.len());
438 result.extend([self.version as u8, self.algorithm as u8].iter());
439 result.extend(self.nonce.iter());
440 result
447 }
448
449 pub fn block_secret(&self, secret: &Secret) -> Option<Secret> {
451 self.key_slots
452 .iter()
453 .map(|key_slot| key_slot.block_secret(secret))
454 .filter_map(|r| r.ok())
455 .next()
456 }
457
458 pub fn encoded_size(algorithm: Algorithm) -> usize {
462 let field_size = 1;
463 let cbor_size = 1;
464 cbor_size
465 + 1 + field_size + cbor_size
467 + 1 + field_size + cbor_size
469 + KeySlot::encoded_size(algorithm) + field_size + cbor_size + cbor_size
471 + algorithm.nonce_size() + field_size + cbor_size + cbor_size + cbor_size
473 }
474}
475#[derive(Debug, Clone, Copy, Serialize_repr, Deserialize_repr, PartialEq)]
488#[repr(u8)]
489pub enum KeySlotVersion {
490 V1 = 1,
494}
495
496#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
497pub struct KeySlot {
498 #[serde(rename = "v")]
500 pub version: KeySlotVersion,
501
502 #[serde(rename = "a")]
504 pub algorithm: Algorithm,
505
506 #[serde(rename = "k", with = "serde_bytes")]
510 pub key: Vec<CipherU8>,
511
512 #[serde(rename = "s", with = "serde_bytes")]
514 pub salt: Salt,
515
516 #[serde(rename = "n", with = "serde_bytes")]
518 pub nonce: Nonce,
519}
520impl KeySlot {
521 pub fn encoded_size(algorithm: Algorithm) -> usize {
525 let tag_size = algorithm.tag_size();
526 let field_size = 1;
527 let cbor_size = 1;
528 cbor_size
529 + 1 + field_size + cbor_size
531 + 1 + field_size + cbor_size
533 + algorithm.key_size() + field_size + tag_size + cbor_size + cbor_size + cbor_size
535 + algorithm.nonce_size() + field_size + cbor_size + cbor_size + cbor_size
537 + algorithm.nonce_size() + field_size + cbor_size + cbor_size + cbor_size
539 }
540
541 pub fn new(algorithm: Algorithm, secret: &Secret, block_secret: &Secret) -> Result<Self, AlgorithmError> {
543 let salt = algorithm.generate_nonce(); let secret_derived = secret.derive_serect_with_salt(BLOCK_KEY_DERIVATION, &salt);
545 let nonce = algorithm.generate_nonce();
546 let block_secret_encrypted = algorithm.encrypt(&secret_derived, &nonce, block_secret.divulge(), b"")?;
547 Ok(Self { version: KeySlotVersion::V1, algorithm, key: block_secret_encrypted, nonce, salt })
548 }
549
550 pub fn is_valid(&self) -> bool {
552 self.version == KeySlotVersion::V1
553 && self.key.len() == self.algorithm.key_size() + self.algorithm.tag_size()
554 && self.nonce.len() == self.algorithm.nonce_size()
555 }
556
557 pub fn block_secret(&self, secret: &Secret) -> Result<Secret, AlgorithmError> {
559 let secret_derived = secret.derive_serect_with_salt(BLOCK_KEY_DERIVATION, &self.salt);
560 let block_secret = self.algorithm.decrypt(&secret_derived, &self.nonce, self.key.as_slice(), b"")?;
561 Ok(Secret::new(block_secret))
562 }
563}
564#[cfg(test)]
579mod tests {
580 use super::{Algorithm, EncryptedBlock, Header, KeySlot};
581 use crate::crypto::{block::EncryptedData, secret::Secret};
582 use cid::Cid;
583 use co_primitives::{from_cbor, to_cbor, Block, BlockSerializer, DefaultParams, KnownMultiCodec, StoreParams};
584 use std::iter::repeat_n;
585
586 #[test]
587 fn algorithm_key_size() {
588 assert_eq!(Algorithm::XChaCha20Poly1305.key_size(), 32);
589 }
590
591 #[test]
592 fn algorithm_nonce_size() {
593 assert_eq!(Algorithm::XChaCha20Poly1305.nonce_size(), 24);
594 }
595
596 #[test]
597 fn is_valid() {
598 let secret = Secret::new(repeat_n(0u8, Algorithm::default().key_size()).collect());
599 let block_secret = Secret::new(repeat_n(1u8, Algorithm::default().key_size()).collect());
600 let key_slot = KeySlot::new(Algorithm::default(), &secret, &block_secret).unwrap();
601 let header = Header::new(Algorithm::default(), vec![key_slot]);
602 assert!(header.is_valid());
603 }
604
605 #[test]
606 fn serialize_header() {
607 let secret = Secret::new(repeat_n(0u8, Algorithm::default().key_size()).collect());
608 let block_secret = Secret::new(repeat_n(1u8, Algorithm::default().key_size()).collect());
609 let key_slot = KeySlot::new(Algorithm::default(), &secret, &block_secret).unwrap();
610 let header = Header::new(Algorithm::default(), vec![key_slot]);
611
612 let bytes = to_cbor(&header).unwrap();
614 assert_eq!(bytes.len(), 153);
630
631 let header_deserialized: Header = from_cbor(bytes.as_slice()).unwrap();
633 assert_eq!(header_deserialized, header);
634 assert!(header.is_valid());
635 }
636
637 #[test]
638 fn key_slot_encoded_size() {
639 let secret = Secret::new(repeat_n(0u8, Algorithm::default().key_size()).collect());
640 let block_secret = Secret::new(repeat_n(1u8, Algorithm::default().key_size()).collect());
641 let key_slot = KeySlot::new(Algorithm::default(), &secret, &block_secret).unwrap();
642
643 let bytes = to_cbor(&key_slot).unwrap();
645 assert_eq!(bytes.len(), KeySlot::encoded_size(Algorithm::default()));
647 }
648
649 #[test]
650 fn header_encoded_size() {
651 let secret = Secret::new(repeat_n(0u8, Algorithm::default().key_size()).collect());
652 let block_secret = Secret::new(repeat_n(1u8, Algorithm::default().key_size()).collect());
653 let key_slot = KeySlot::new(Algorithm::default(), &secret, &block_secret).unwrap();
654 let header = Header::new(Algorithm::default(), vec![key_slot]);
655
656 let bytes = to_cbor(&header).unwrap();
658 assert_eq!(bytes.len(), Header::encoded_size(Algorithm::default()));
660 }
661
662 #[test]
663 fn encrypt_block_roundtrip() {
664 let secret = Secret::new(repeat_n(0u8, Algorithm::default().key_size()).collect());
665 let block = BlockSerializer::default().serialize(&"Hello World!").unwrap();
666
667 let encrypted_block = EncryptedBlock::encrypt(Algorithm::default(), &secret, block.clone()).unwrap();
672 assert_ne!(encrypted_block.payload.inline().unwrap(), block.data());
673 let encrypted_block_bytes = to_cbor(&encrypted_block).unwrap();
678 assert_eq!(encrypted_block_bytes.len(), 236);
680 let encrypted_block_deserialized: EncryptedBlock = from_cbor(&encrypted_block_bytes).unwrap();
685
686 let decrypted_block = encrypted_block_deserialized.block(&secret).unwrap();
688 assert_eq!(decrypted_block.cid(), block.cid());
689 assert_eq!(&decrypted_block.data, block.data());
690 }
691
692 #[test]
693 fn test_fit_to_blocks() {
694 let secret = Secret::new(repeat_n(0u8, Algorithm::default().key_size()).collect());
695 let data: Vec<u8> = repeat_n(0u8, DefaultParams::MAX_BLOCK_SIZE).collect();
696 let block = Block::new_data(KnownMultiCodec::Raw, data);
697
698 let mut encrypted_block = EncryptedBlock::encrypt(Algorithm::default(), &secret, block.clone()).unwrap();
703
704 let encrypted_extra_blocks = encrypted_block
706 .payload
707 .fit_into_blocks(DefaultParams::MAX_BLOCK_SIZE, Some(Header::encoded_size(Algorithm::default())));
708 assert!(match &encrypted_block.payload {
709 EncryptedData::Block(blocks) =>
710 blocks == &encrypted_extra_blocks.iter().map(|b| *b.cid()).collect::<Vec<Cid>>(),
711 _ => false,
712 });
713
714 encrypted_block
716 .payload
717 .try_inline_blocks(encrypted_extra_blocks.into_iter().map(|v| v.into_inner()))
718 .unwrap();
719
720 let decrypted_block = encrypted_block.block(&secret).unwrap();
722 assert_eq!(decrypted_block.cid(), block.cid());
723 assert_eq!(&decrypted_block.data, block.data());
724 }
725}