cyfs_base/crypto/
private_key.rs

1use crate::*;
2
3use generic_array::GenericArray;
4use libc::memcpy;
5use rand::{thread_rng, Rng};
6use rsa::PublicKeyParts;
7use std::{os::raw::c_void, str::FromStr};
8
9// 密钥类型的编码
10pub(crate) const KEY_TYPE_RSA: u8 = 0u8;
11pub(crate) const KEY_TYPE_RSA2048: u8 = 1u8;
12pub(crate) const KEY_TYPE_RSA3072: u8 = 2u8;
13pub(crate) const KEY_TYPE_SECP256K1: u8 = 5u8;
14
15// rsa key size in bits
16pub(crate) const RSA_KEY_BITS: usize = 1024;
17pub(crate) const RSA2048_KEY_BITS: usize = 2048;
18pub(crate) const RSA3072_KEY_BITS: usize = 3072;
19
20// rsa key size in bytes
21pub(crate) const RSA_KEY_BYTES: usize = 128;
22pub(crate) const RSA2048_KEY_BYTES: usize = 256;
23pub(crate) const RSA3072_KEY_BYTES: usize = 384;
24
25#[derive(Debug, Clone, Copy, Eq, PartialEq)]
26pub enum PrivateKeyType {
27    Rsa,
28    Secp256k1,
29}
30
31impl PrivateKeyType {
32    pub fn as_str(&self) -> &str {
33        match *self {
34            Self::Rsa => "rsa",
35            Self::Secp256k1 => "secp256k1",
36        }
37    }
38}
39
40impl Default for PrivateKeyType {
41    fn default() -> Self {
42        Self::Rsa
43    }
44}
45
46impl std::fmt::Display for PrivateKeyType {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        write!(f, "{}", self.as_str())
49    }
50}
51
52impl FromStr for PrivateKeyType {
53    type Err = BuckyError;
54    fn from_str(s: &str) -> Result<Self, Self::Err> {
55        Ok(match s {
56            "rsa" => Self::Rsa,
57            "secp256k1" => Self::Secp256k1,
58             _ => {
59                let msg = format!("unknown PrivateKey type: {}", s);
60                warn!("{}", msg);
61                return Err(BuckyError::new(BuckyErrorCode::InvalidData, msg))
62             }
63        })
64    }
65}
66
67#[derive(Clone, Eq, PartialEq)]
68pub enum PrivateKey {
69    Rsa(rsa::RSAPrivateKey),
70    Secp256k1(::secp256k1::SecretKey),
71}
72
73// 避免私钥被日志打印出来
74impl std::fmt::Debug for PrivateKey {
75    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
76        write!(f, "[Protected PrivateKey]")
77    }
78}
79impl std::fmt::Display for PrivateKey {
80    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
81        write!(f, "[Protected PrivateKey]")
82    }
83}
84
85pub const CYFS_PRIVTAE_KEY_DEFAULT_RSA_BITS: usize = 1024;
86
87impl PrivateKey {
88    pub fn key_type(&self) -> PrivateKeyType {
89        match *self {
90            Self::Rsa(_) => PrivateKeyType::Rsa,
91            Self::Secp256k1(_) => PrivateKeyType::Secp256k1,
92        }
93    }
94
95    fn check_bits(bits: usize) -> BuckyResult<()> {
96        match bits {
97            RSA_KEY_BITS | RSA2048_KEY_BITS | RSA3072_KEY_BITS=> {
98                Ok(())
99            }
100            _ => {
101                let msg = format!("unsupport rsa key bits: {}", bits);
102                error!("{}", msg);
103                Err(BuckyError::new(BuckyErrorCode::UnSupport, msg))
104            }
105        }
106    }
107    // 生成rsa密钥的相关接口
108    pub fn generate_rsa(bits: usize) -> Result<Self, BuckyError> {
109        Self::check_bits(bits)?;
110
111        let mut rng = thread_rng();
112        Self::generate_rsa_by_rng(&mut rng, bits)
113    }
114
115    pub fn generate_rsa_by_rng<R: Rng>(rng: &mut R, bits: usize) -> Result<Self, BuckyError> {
116        Self::check_bits(bits)?;
117
118        match rsa::RSAPrivateKey::new(rng, bits) {
119            Ok(rsa) => Ok(Self::Rsa(rsa)),
120            Err(e) => Err(BuckyError::from(e)),
121        }
122    }
123
124    // 生成secp256k1密钥的相关接口
125    pub fn generate_secp256k1() -> Result<Self, BuckyError> {
126        let mut rng = thread_rng();
127        Self::generate_secp256k1_by_rng(&mut rng)
128    }
129
130    pub fn generate_secp256k1_by_rng<R: Rng>(rng: &mut R) -> Result<Self, BuckyError> {
131        let key = ::secp256k1::SecretKey::random(rng);
132        Ok(Self::Secp256k1(key))
133    }
134
135    pub fn generate_by_rng<R: Rng>(rng: &mut R, bits: Option<usize>, pt: PrivateKeyType) -> BuckyResult<Self> {
136        match pt {
137            PrivateKeyType::Rsa => Self::generate_rsa_by_rng(rng, bits.unwrap_or(CYFS_PRIVTAE_KEY_DEFAULT_RSA_BITS)),
138            PrivateKeyType::Secp256k1 => Self::generate_secp256k1_by_rng(rng)
139        }
140    }
141
142    pub fn public(&self) -> PublicKey {
143        match self {
144            Self::Rsa(private_key) => PublicKey::Rsa(private_key.to_public_key()),
145            Self::Secp256k1(private_key) => {
146                PublicKey::Secp256k1(::secp256k1::PublicKey::from_secret_key(private_key))
147            }
148        }
149    }
150
151    pub fn sign(&self, data: &[u8], sign_source: SignatureSource) -> BuckyResult<Signature> {
152        let create_time = bucky_time_now();
153
154        // 签名必须也包含签名的时刻,这个时刻是敏感的不可修改
155        let mut data_new = data.to_vec();
156        data_new.resize(data.len() + create_time.raw_measure(&None).unwrap(), 0);
157        create_time
158            .raw_encode(&mut data_new.as_mut_slice()[data.len()..], &None)?;
159
160        let sign = match self {
161            Self::Rsa(private_key) => {
162                let hash = hash_data(&data_new);
163                let sign = private_key
164                    .sign(
165                        rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_256)),
166                        &hash.as_slice(),
167                    )?;
168
169                assert_eq!(sign.len(), private_key.size());
170                let sign_data = match private_key.size() {
171                    RSA_KEY_BYTES => {
172                        let mut sign_array: [u32; 32] = [0; 32];
173                        unsafe {
174                            memcpy(
175                                sign_array.as_mut_ptr() as *mut c_void,
176                                sign.as_ptr() as *const c_void,
177                                sign.len(),
178                            )
179                        };
180                        SignData::Rsa1024(GenericArray::from(sign_array))
181                    }
182                    RSA2048_KEY_BYTES => {
183                        let mut sign_array: [u32; 64] = [0; 64];
184                        unsafe {
185                            memcpy(
186                                sign_array.as_mut_ptr() as *mut c_void,
187                                sign.as_ptr() as *const c_void,
188                                sign.len(),
189                            )
190                        };
191                        SignData::Rsa2048(*GenericArray::from_slice(&sign_array))
192                    }
193                    RSA3072_KEY_BYTES => {
194                        let mut sign_array: [u32; 96] = [0; 96];
195                        unsafe {
196                            memcpy(
197                                sign_array.as_mut_ptr() as *mut c_void,
198                                sign.as_ptr() as *const c_void,
199                                sign.len(),
200                            )
201                        };
202                        SignData::Rsa3072(*GenericArray::from_slice(&sign_array))
203                    }
204
205                    len @ _ =>  {
206                        let msg = format!("unsupport rsa key length! {}", len);
207                        error!("{}", msg);
208                        return Err(BuckyError::new(BuckyErrorCode::UnSupport, msg));
209                    }
210                };
211
212                Signature::new(sign_source, 0, create_time, sign_data)
213            }
214
215            Self::Secp256k1(private_key) => {
216                let hash = hash_data(&data_new);
217                assert_eq!(HashValue::len(), ::secp256k1::util::MESSAGE_SIZE);
218                let ctx = ::secp256k1::Message::parse(hash.as_slice().try_into().unwrap());
219
220                let (signature, _) = ::secp256k1::sign(&ctx, &private_key);
221                let sign_buf = signature.serialize();
222
223                let mut sign_array: [u32; 16] = [0; 16];
224                unsafe {
225                    memcpy(
226                        sign_array.as_mut_ptr() as *mut c_void,
227                        sign_buf.as_ptr() as *const c_void,
228                        sign_buf.len(),
229                    )
230                };
231                let sign_data = SignData::Ecc(GenericArray::from(sign_array));
232                Signature::new(sign_source, 0, create_time, sign_data)
233            }
234        };
235
236        Ok(sign)
237    }
238
239    pub fn decrypt(&self, input: &[u8], output: &mut [u8]) -> BuckyResult<usize> {
240        let buf = self.decrypt_data(input)?;
241        if output.len() < buf.len() {
242            let msg = format!(
243                "rsa decrypt error, except={}, got={}",
244                buf.len(),
245                output.len()
246            );
247            error!("{}", msg);
248
249            Err(BuckyError::new(BuckyErrorCode::InvalidFormat, msg))
250        } else {
251            output[..buf.len()].copy_from_slice(buf.as_slice());
252            Ok(buf.len())
253        }
254    }
255
256    pub fn decrypt_data(&self, input: &[u8]) -> BuckyResult<Vec<u8>> {
257        match self {
258            Self::Rsa(private_key) => {
259                let buf = private_key
260                    .decrypt(rsa::PaddingScheme::PKCS1v15Encrypt, input)
261                    .map_err(|e| BuckyError::from(e))?;
262                Ok(buf)
263            }
264
265            Self::Secp256k1(_) => {
266                // 目前secp256k1的非对称加解密只支持交换aes_key时候使用
267                let msg = format!("direct decyrpt with private key of secp256 not support!");
268                error!("{}", msg);
269                Err(BuckyError::new(BuckyErrorCode::NotSupport, msg))
270            }
271        }
272    }
273
274    pub fn decrypt_aeskey<'d>(&self, input: &'d [u8], output: &mut [u8]) -> BuckyResult<(&'d [u8], usize)> {
275        let (input, data) = self.decrypt_aeskey_data(input)?;
276        if output.len() < data.len() {
277            let msg = format!(
278                "not enough buffer for decrypt aeskey result, except={}, got={}",
279                data.len(),
280                output.len()
281            );
282            error!("{}", msg);
283
284            return Err(BuckyError::new(BuckyErrorCode::InvalidParam, msg));
285        }
286
287        output[..data.len()].copy_from_slice(&data);
288
289        Ok((input, data.len()))
290    }
291
292    pub fn decrypt_aeskey_data<'d>(&self, input: &'d [u8]) -> BuckyResult<(&'d [u8], Vec<u8>)> {
293        match self {
294            Self::Rsa(_) => {
295                let key_size = self.public().key_size();
296                if input.len() < key_size {
297                    let msg = format!(
298                        "not enough buffer for RSA private key, except={}, got={}",
299                        key_size,
300                        input.len()
301                    );
302                    error!("{}", msg);
303
304                    return Err(BuckyError::new(BuckyErrorCode::InvalidFormat, msg));
305                }
306
307                let buf = self.decrypt_data(&input[..key_size])?;
308
309                Ok((&input[key_size..], buf))
310            },
311
312            Self::Secp256k1(private_key) => {
313                if input.len() < ::secp256k1::util::COMPRESSED_PUBLIC_KEY_SIZE {
314                    let msg = format!(
315                        "not enough buffer for secp256k1 private key, except={}, got={}",
316                        ::secp256k1::util::COMPRESSED_PUBLIC_KEY_SIZE,
317                        input.len()
318                    );
319                    error!("{}", msg);
320
321                    return Err(BuckyError::new(BuckyErrorCode::InvalidFormat, msg));
322                }
323
324                let ephemeral_pk = ::secp256k1::PublicKey::parse_slice(
325                    &input[..::secp256k1::util::COMPRESSED_PUBLIC_KEY_SIZE],
326                    Some(::secp256k1::PublicKeyFormat::Compressed),
327                )
328                .map_err(|e| {
329                    let msg = format!("parse secp256k1 public key error: {}", e);
330                    error!("{}", msg);
331
332                    BuckyError::new(BuckyErrorCode::InvalidFormat, msg)
333                })?;
334                let aes_key = ::cyfs_ecies::utils::decapsulate(&ephemeral_pk, &private_key);
335                
336                Ok((&input[::secp256k1::util::COMPRESSED_PUBLIC_KEY_SIZE..], aes_key.into()))
337            }
338        }
339    }
340}
341
342impl RawEncode for PrivateKey {
343    fn raw_measure(&self, _purpose: &Option<RawEncodePurpose>) -> Result<usize, BuckyError> {
344        // 这里直接输出正确长度先,然后看如何优化
345        match self {
346            Self::Rsa(pk) => {
347                let spki_der = rsa_export::pkcs1::private_key(pk)?;
348                Ok(spki_der.len() + 3)
349            }
350            Self::Secp256k1(_) => Ok(::secp256k1::util::SECRET_KEY_SIZE + 1),
351        }
352    }
353
354    fn raw_encode<'a>(
355        &self,
356        buf: &'a mut [u8],
357        purpose: &Option<RawEncodePurpose>,
358    ) -> Result<&'a mut [u8], BuckyError> {
359        let size = self.raw_measure(purpose).unwrap();
360        if buf.len() < size {
361            return Err(BuckyError::new(
362                BuckyErrorCode::OutOfLimit,
363                "[raw_encode] not enough buffer for privake key for private_key",
364            ));
365        }
366
367        match self {
368            Self::Rsa(pk) => {
369                let spki_der = rsa_export::pkcs1::private_key(pk)?;
370                let mut buf = KEY_TYPE_RSA.raw_encode(buf, purpose)?;
371                buf = (spki_der.len() as u16).raw_encode(buf, purpose)?;
372                buf[..spki_der.len()].copy_from_slice(&spki_der.as_slice());
373                Ok(&mut buf[spki_der.len()..])
374            }
375            Self::Secp256k1(pk) => {
376                let buf = KEY_TYPE_SECP256K1.raw_encode(buf, purpose)?;
377
378                // 由于长度固定,所以我们这里不需要额外存储一个长度信息了
379                let key_buf = pk.serialize();
380                buf[..::secp256k1::util::SECRET_KEY_SIZE].copy_from_slice(&key_buf);
381                Ok(&mut buf[::secp256k1::util::SECRET_KEY_SIZE..])
382            }
383        }
384    }
385}
386
387impl<'de> RawDecode<'de> for PrivateKey {
388    fn raw_decode(buf: &'de [u8]) -> Result<(Self, &'de [u8]), BuckyError> {
389        if buf.len() < 1 {
390            return Err(BuckyError::new(
391                BuckyErrorCode::OutOfLimit,
392                "not enough buffer for PrivateKey",
393            ));
394        }
395        let (type_code, buf) = u8::raw_decode(buf)?;
396        match type_code {
397            KEY_TYPE_RSA => {
398                let (len, buf) = u16::raw_decode(buf)?;
399                if buf.len() < len as usize {
400                    return Err(BuckyError::new(
401                        BuckyErrorCode::OutOfLimit,
402                        "not enough buffer for rsa privateKey",
403                    ));
404                }
405                let der = &buf[..len as usize];
406                let private_key = rsa::RSAPrivateKey::from_pkcs1(der)?;
407                Ok((PrivateKey::Rsa(private_key), &buf[len as usize..]))
408            }
409            KEY_TYPE_SECP256K1 => {
410                if buf.len() < ::secp256k1::util::SECRET_KEY_SIZE {
411                    return Err(BuckyError::new(
412                        BuckyErrorCode::OutOfLimit,
413                        "not enough buffer for secp256k1 privateKey",
414                    ));
415                }
416
417                match ::secp256k1::SecretKey::parse_slice(
418                    &buf[..::secp256k1::util::SECRET_KEY_SIZE],
419                ) {
420                    Ok(private_key) => Ok((
421                        PrivateKey::Secp256k1(private_key),
422                        &buf[::secp256k1::util::SECRET_KEY_SIZE..],
423                    )),
424                    Err(e) => {
425                        let msg = format!("parse secp256k1 private key error: {}", e);
426                        error!("{}", e);
427
428                        Err(BuckyError::new(BuckyErrorCode::InvalidFormat, msg))
429                    }
430                }
431            }
432            _ => Err(BuckyError::new(
433                BuckyErrorCode::InvalidData,
434                &format!("invalid private key type code {}", buf[0]),
435            )),
436        }
437    }
438}
439
440#[cfg(test)]
441mod test {
442    use crate::{PrivateKey, RawConvertTo, RawDecode, SignatureSource, Signature, RawFrom};
443
444    #[test]
445    fn private_key() {
446        secp_private_key_sign();
447        rsa_private_key_sign(1024);
448        rsa_private_key_sign(2048);
449        rsa_private_key_sign(3072);
450    }
451
452    fn rsa_private_key_sign(bits: usize) {
453        let msg = b"112233445566778899";
454        let pk1 = PrivateKey::generate_rsa(bits).unwrap();
455        let sign = pk1.sign(msg, SignatureSource::RefIndex(0)).unwrap();
456        assert!(pk1.public().verify(msg, &sign));
457
458        let pk1_buf = pk1.to_vec().unwrap();
459        let (pk2, buf) = PrivateKey::raw_decode(&pk1_buf).unwrap();
460        assert!(buf.len() == 0);
461
462        assert!(pk2.public().verify(msg, &sign));
463
464        let buf = sign.to_vec().unwrap();
465        let sign2 = Signature::clone_from_slice(&buf).unwrap();
466        assert_eq!(sign, sign2);
467    }
468
469    fn secp_private_key_sign() {
470        let msg = b"112233445566778899";
471        let pk1 = PrivateKey::generate_secp256k1().unwrap();
472        let sign = pk1.sign(msg, SignatureSource::RefIndex(0)).unwrap();
473        assert!(pk1.public().verify(msg, &sign));
474
475        let pk1_buf = pk1.to_vec().unwrap();
476        let (pk2, buf) = PrivateKey::raw_decode(&pk1_buf).unwrap();
477        assert!(buf.len() == 0);
478
479        assert!(pk2.public().verify(msg, &sign));
480
481        let buf = sign.to_vec().unwrap();
482        let sign2 = Signature::clone_from_slice(&buf).unwrap();
483        assert_eq!(sign, sign2);
484    }
485
486    #[test]
487    fn crypto() {
488        rsa_private_key_crypto(1024);
489        rsa_private_key_crypto(2048);
490        rsa_private_key_crypto(3072);
491
492        let pk1 = PrivateKey::generate_secp256k1().unwrap();
493        let (aes_key, mut data) = pk1.public().gen_aeskey_and_encrypt().unwrap();
494        println!("secp256k1 aes_key encrypt len={}", data.len());
495        let (buf, data2) = pk1.decrypt_aeskey_data(&data).unwrap();
496        assert_eq!(buf.len(), 0);
497        assert_eq!(aes_key.as_slice(), data2);
498
499        let encrypt_len = data.len();
500        data.resize(1024, 0);
501        let mut output = vec![0; 48];
502        let (buf, size) = pk1.decrypt_aeskey(&data, &mut output).unwrap();
503        assert_eq!(buf.len(), 1024 - encrypt_len);
504        assert_eq!(aes_key.as_slice(), &output[0..size]);
505    }
506
507    fn rsa_private_key_crypto(bits: usize) {
508        let pk1 = PrivateKey::generate_rsa(bits).unwrap();
509        let (aes_key, data) = pk1.public().gen_aeskey_and_encrypt().unwrap();
510        let (buf, data2) = pk1.decrypt_aeskey_data(&data).unwrap();
511        assert_eq!(buf.len(), 0);
512        assert_eq!(aes_key.as_slice(), data2);
513    }
514
515    #[test]
516    fn crypto_unaligned() {
517        let pk1 = PrivateKey::generate_rsa(1024).unwrap();
518
519        let origin_data = "test data".as_bytes();
520        let data = pk1.public().encrypt_data(origin_data).unwrap();
521        println!("len={}", data.len());
522
523        let mut output = vec![0; 48];
524        let (_buf, size) = pk1.decrypt_aeskey(&data, &mut output).unwrap();
525        assert_eq!(size, origin_data.len());
526        assert_eq!(&output[..origin_data.len()], origin_data);
527    }
528}