1use crate::utils::{xor_set1, xor_set2};
2use cipher::{
3 AsyncStreamCipher, Block, BlockCipher, BlockEncrypt, FromBlockCipher, NewBlockCipher,
4};
5use core::ops::Sub;
6use generic_array::{
7 typenum::{
8 type_operators::{IsGreater, IsGreaterOrEqual, IsLessOrEqual},
9 Diff, Unsigned, U0, U255,
10 },
11 ArrayLength, GenericArray,
12};
13
14type BlockSize<C> = <C as BlockCipher>::BlockSize;
15
16type Tail<C, M> = GenericArray<u8, Diff<M, <C as BlockCipher>::BlockSize>>;
17
18#[derive(Clone)]
28pub struct GostCfb<C, M = BlockSize<C>, S = BlockSize<C>>
29where
30 C: BlockCipher + BlockEncrypt + NewBlockCipher,
31 C::BlockSize: IsLessOrEqual<U255>,
32 M: Unsigned + ArrayLength<u8> + IsGreaterOrEqual<C::BlockSize> + Sub<C::BlockSize>,
33 S: Unsigned + ArrayLength<u8> + IsGreater<U0> + IsLessOrEqual<C::BlockSize>,
34 Diff<M, C::BlockSize>: ArrayLength<u8>,
35{
36 cipher: C,
37 block: GenericArray<u8, S>,
38 tail: Tail<C, M>,
39 pos: u8,
40}
41
42impl<C, M, S> GostCfb<C, M, S>
43where
44 C: BlockCipher + BlockEncrypt + NewBlockCipher,
45 C::BlockSize: IsLessOrEqual<U255>,
46 M: Unsigned + ArrayLength<u8> + IsGreaterOrEqual<C::BlockSize> + Sub<C::BlockSize>,
47 S: Unsigned + ArrayLength<u8> + IsGreater<U0> + IsLessOrEqual<C::BlockSize>,
48 Diff<M, C::BlockSize>: ArrayLength<u8>,
49{
50 fn gen_block(&mut self) {
51 let s = S::USIZE;
52 let ts = self.tail.len();
53 let mut block: Block<C> = Default::default();
54 if ts <= s {
55 let d = s - ts;
56 block[..ts].copy_from_slice(&self.tail);
57 block[ts..].copy_from_slice(&self.block[..d]);
58 self.tail = GenericArray::clone_from_slice(&self.block[d..]);
59 } else {
60 let d = ts - s;
61 let mut tail: Tail<C, M> = Default::default();
62 tail[..d].copy_from_slice(&self.tail[s..]);
63 tail[d..].copy_from_slice(&self.block);
64 block = GenericArray::clone_from_slice(&self.tail[..s]);
65 self.tail = tail;
66 }
67 self.cipher.encrypt_block(&mut block);
68 self.block.copy_from_slice(&block[..s]);
69 }
70}
71
72impl<C, M, S> FromBlockCipher for GostCfb<C, M, S>
73where
74 C: BlockCipher + BlockEncrypt + NewBlockCipher,
75 C::BlockSize: IsLessOrEqual<U255>,
76 M: Unsigned + ArrayLength<u8> + IsGreaterOrEqual<C::BlockSize> + Sub<C::BlockSize>,
77 S: Unsigned + ArrayLength<u8> + IsGreater<U0> + IsLessOrEqual<C::BlockSize>,
78 Diff<M, C::BlockSize>: ArrayLength<u8>,
79{
80 type BlockCipher = C;
81 type NonceSize = M;
82
83 fn from_block_cipher(cipher: C, nonce: &GenericArray<u8, M>) -> Self {
84 let bs = C::BlockSize::USIZE;
85 let mut full_block = Block::<C>::clone_from_slice(&nonce[..bs]);
86 cipher.encrypt_block(&mut full_block);
87 Self {
88 cipher,
89 block: GenericArray::clone_from_slice(&full_block[..S::USIZE]),
90 tail: GenericArray::clone_from_slice(&nonce[bs..]),
91 pos: 0,
92 }
93 }
94}
95
96impl<C, M, S> AsyncStreamCipher for GostCfb<C, M, S>
97where
98 C: BlockCipher + BlockEncrypt + NewBlockCipher,
99 C::BlockSize: IsLessOrEqual<U255>,
100 M: Unsigned + ArrayLength<u8> + IsGreaterOrEqual<C::BlockSize> + Sub<C::BlockSize>,
101 S: Unsigned + ArrayLength<u8> + IsGreater<U0> + IsLessOrEqual<C::BlockSize>,
102 Diff<M, C::BlockSize>: ArrayLength<u8>,
103{
104 fn encrypt(&mut self, mut data: &mut [u8]) {
105 let s = S::USIZE;
106 let pos = self.pos as usize;
107
108 if data.len() < s - pos {
109 let n = data.len();
110 xor_set1(data, &mut self.block[pos..pos + n]);
111 self.pos += n as u8;
112 return;
113 } else if pos != 0 {
114 let (l, r) = { data }.split_at_mut(s - pos);
115 data = r;
116 xor_set1(l, &mut self.block[pos..s]);
117 self.gen_block()
118 }
119
120 let mut iter = data.chunks_exact_mut(s);
121 for chunk in &mut iter {
122 xor_set1(chunk, &mut self.block[..s]);
123 self.gen_block();
124 }
125 let rem = iter.into_remainder();
126 xor_set1(rem, &mut self.block[..rem.len()]);
127 self.pos = rem.len() as u8;
128 }
129
130 fn decrypt(&mut self, mut data: &mut [u8]) {
131 let s = S::USIZE;
132 let pos = self.pos as usize;
133
134 if data.len() < s - pos {
135 let n = data.len();
136 xor_set2(data, &mut self.block[pos..pos + n]);
137 self.pos += n as u8;
138 return;
139 } else if pos != 0 {
140 let (l, r) = { data }.split_at_mut(s - pos);
141 data = r;
142 xor_set2(l, &mut self.block[pos..]);
143 self.gen_block()
144 }
145
146 let mut iter = data.chunks_exact_mut(s);
147 for chunk in &mut iter {
148 xor_set2(chunk, &mut self.block);
149 self.gen_block();
150 }
151 let rem = iter.into_remainder();
152 xor_set2(rem, &mut self.block[..rem.len()]);
153 self.pos = rem.len() as u8;
154 }
155}