1use openssl_macros::corresponds;
2
3use crate::cvt;
4use crate::error::ErrorStack;
5use crate::foreign_types::ForeignTypeRef;
6use crate::hash::MessageDigest;
7
8foreign_type_and_impl_send_sync! {
9 type CType = ffi::HMAC_CTX;
10 fn drop = ffi::HMAC_CTX_free;
11
12 pub struct HmacCtx;
13}
14
15impl HmacCtxRef {
16 #[corresponds(HMAC_Init_ex)]
19 pub fn init(&mut self, key: &[u8], md: &MessageDigest) -> Result<(), ErrorStack> {
20 ffi::init();
21
22 unsafe {
23 cvt(ffi::HMAC_Init_ex(
24 self.as_ptr(),
25 key.as_ptr().cast(),
26 key.len(),
27 md.as_ptr(),
28 core::ptr::null_mut(),
30 ))
31 }
32 }
33}
34
35pub struct Hmac(*mut ffi::HMAC_CTX);
37
38impl Hmac {
39 pub fn init(key: &[u8], md: &MessageDigest) -> Result<Hmac, ErrorStack> {
41 ffi::init();
42
43 let ctx = unsafe {
44 let ctx = ffi::HMAC_CTX_new();
45 cvt(ffi::HMAC_Init_ex(
46 ctx,
47 key.as_ptr().cast(),
48 key.len(),
49 md.as_ptr(),
50 core::ptr::null_mut(),
52 ))?;
53 ctx
54 };
55
56 Ok(Hmac(ctx))
57 }
58
59 pub fn update(&mut self, data: &[u8]) -> Result<(), ErrorStack> {
61 unsafe { cvt(ffi::HMAC_Update(self.0, data.as_ptr().cast(), data.len())) }
62 }
63
64 pub fn finalize(self) -> Result<Vec<u8>, ErrorStack> {
66 let out_len = unsafe { ffi::HMAC_size(self.0) };
67 let mut out = vec![0; out_len];
68 unsafe {
69 cvt(ffi::HMAC_Final(
70 self.0,
71 out.as_mut_ptr().cast(),
72 core::ptr::null_mut(),
74 ))?;
75 }
76 Ok(out)
77 }
78}
79
80impl Drop for Hmac {
81 fn drop(&mut self) {
82 unsafe { ffi::HMAC_CTX_free(self.0) }
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use crate::hash;
89
90 use super::*;
91
92 fn test<const N: usize>(md: MessageDigest) {
93 assert_eq!(N, md.size());
94 let key = vec![0; N];
95 let message_parts = [
96 b"hello".to_vec(),
97 b"world!".to_vec(),
98 b"".to_vec(),
99 vec![0; 23],
100 b"fella guy".to_vec(),
101 ];
102 let message = message_parts.concat();
103
104 let mut hmac = Hmac::init(&key, &md).unwrap();
105 for part in &message_parts {
106 hmac.update(part).unwrap();
107 }
108 let res = hmac.finalize().unwrap();
109 assert_eq!(res, hash::hmac::<N>(md, &key, &message).unwrap());
110 }
111
112 #[test]
113 fn test_sha1() {
114 test::<20>(MessageDigest::sha1());
115 }
116
117 #[test]
118 fn test_sha256() {
119 test::<32>(MessageDigest::sha256());
120 }
121
122 #[test]
123 fn test_sha384() {
124 test::<48>(MessageDigest::sha384());
125 }
126
127 #[test]
128 fn test_sha512() {
129 test::<64>(MessageDigest::sha512());
130 }
131}