1use crate::error::MokshaCoreError;
44use bitcoin_hashes::{sha256, Hash};
45use secp256k1::{All, PublicKey, Scalar, Secp256k1, SecretKey};
46use std::iter::once;
47
48#[derive(Clone, Debug)]
49pub struct Dhke {
50 secp: Secp256k1<All>,
51}
52
53impl Default for Dhke {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl Dhke {
60 pub fn new() -> Self {
61 Self {
62 secp: Secp256k1::new(),
63 }
64 }
65
66 fn get_hash(message: &[u8]) -> Vec<u8> {
67 let hash = sha256::Hash::hash(message);
68 hash.as_byte_array().to_vec()
69 }
70
71 fn hash_to_curve(message: &[u8]) -> PublicKey {
74 let mut point: Option<PublicKey> = None;
75 let mut msg_to_hash = message.to_vec();
76 while point.is_none() {
77 let hash = Self::get_hash(&msg_to_hash);
78 let input = &once(&0x02).chain(hash.iter()).cloned().collect::<Vec<u8>>();
79 PublicKey::from_slice(input).map_or_else(|_| msg_to_hash = hash, |p| point = Some(p))
80 }
81 point.unwrap()
82 }
83
84 pub fn step1_alice(
85 &self,
86 secret_msg: impl Into<String>,
87 blinding_factor: Option<&[u8]>,
88 ) -> Result<(PublicKey, SecretKey), MokshaCoreError> {
89 let mut rng = rand::thread_rng();
90
91 let y = Self::hash_to_curve(secret_msg.into().as_bytes());
92 let secret_key = match blinding_factor {
93 Some(f) => SecretKey::from_slice(f)?,
94 None => SecretKey::new(&mut rng),
95 };
96 let b = y.combine(&PublicKey::from_secret_key(&self.secp, &secret_key))?;
97 Ok((b, secret_key))
98 }
99
100 pub fn step2_bob(&self, b: PublicKey, a: &SecretKey) -> Result<PublicKey, MokshaCoreError> {
101 b.mul_tweak(&self.secp, &Scalar::from(*a))
102 .map_err(MokshaCoreError::Secp256k1Error)
103 }
104
105 pub fn step3_alice(
106 &self,
107 c_: PublicKey,
108 r: SecretKey,
109 a: PublicKey,
110 ) -> Result<PublicKey, MokshaCoreError> {
111 c_.combine(
112 &a.mul_tweak(&self.secp, &Scalar::from(r))
113 .map_err(MokshaCoreError::Secp256k1Error)?
114 .negate(&self.secp),
115 )
116 .map_err(MokshaCoreError::Secp256k1Error)
117 }
118
119 pub fn verify(
120 &self,
121 a: SecretKey,
122 c: PublicKey,
123 secret_msg: impl Into<String>,
124 ) -> Result<bool, MokshaCoreError> {
125 let y = Self::hash_to_curve(secret_msg.into().as_bytes());
126 Some(c == y.mul_tweak(&self.secp, &Scalar::from(a))?).ok_or(
127 MokshaCoreError::Secp256k1Error(secp256k1::Error::InvalidPublicKey),
128 )
129 }
130}
131
132pub fn public_key_from_hex(hex: &str) -> secp256k1::PublicKey {
133 use hex::FromHex;
134 let input_vec: Vec<u8> = Vec::from_hex(hex).expect("Invalid Hex String");
135 secp256k1::PublicKey::from_slice(&input_vec).expect("Invalid Public Key")
136}
137
138#[cfg(test)]
139mod tests {
140 use std::str::FromStr;
141
142 use crate::dhke::{public_key_from_hex, Dhke};
143 use anyhow::Ok;
144 use pretty_assertions::assert_eq;
145
146 fn hex_to_string(hex: &str) -> String {
147 use hex::FromHex;
148 let input_vec: Vec<u8> = Vec::from_hex(hex).expect("Invalid Hex String");
149 String::from_utf8(input_vec).expect("Invalid UTF-8 String")
150 }
151
152 fn pk_from_hex(hex: &str) -> secp256k1::SecretKey {
153 secp256k1::SecretKey::from_str(hex).expect("Invalid SecretKey")
154 }
155
156 #[test]
157 fn test_hash_to_curve_zero() -> anyhow::Result<()> {
158 let input_str =
159 hex_to_string("0000000000000000000000000000000000000000000000000000000000000000");
160 let expected_result = "0266687aadf862bd776c8fc18b8e9f8e20089714856ee233b3902a591d0d5f2925";
161
162 let pk = Dhke::hash_to_curve(input_str.as_bytes()).to_string();
163 assert_eq!(pk, expected_result);
164 Ok(())
165 }
166
167 #[test]
168 fn test_hash_to_curve_zero_one() -> anyhow::Result<()> {
169 let input_str =
170 hex_to_string("0000000000000000000000000000000000000000000000000000000000000001");
171 let expected_result = "02ec4916dd28fc4c10d78e287ca5d9cc51ee1ae73cbfde08c6b37324cbfaac8bc5";
172
173 let pk = Dhke::hash_to_curve(input_str.as_bytes()).to_string();
174 assert_eq!(pk, expected_result);
175 Ok(())
176 }
177
178 #[test]
179 fn test_hash_to_curve_iterate() -> anyhow::Result<()> {
180 let input_str =
181 hex_to_string("0000000000000000000000000000000000000000000000000000000000000002");
182 let expected_result = "02076c988b353fcbb748178ecb286bc9d0b4acf474d4ba31ba62334e46c97c416a";
183
184 let pk = Dhke::hash_to_curve(input_str.as_bytes()).to_string();
185 assert_eq!(pk, expected_result);
186 Ok(())
187 }
188
189 #[test]
190 fn test_step1_alice() -> anyhow::Result<()> {
191 let dhke = Dhke::new();
192 let blinding_factor =
193 hex_to_string("0000000000000000000000000000000000000000000000000000000000000001");
194 let (pub_key, secret_key) =
195 dhke.step1_alice("test_message", Some(blinding_factor.as_bytes()))?;
196 let pub_key_str = pub_key.to_string();
197
198 assert_eq!(
199 pub_key_str,
200 "02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2"
201 );
202
203 assert_eq!(
204 hex::encode(secret_key.secret_bytes()),
205 "0000000000000000000000000000000000000000000000000000000000000001"
206 );
207 Ok(())
208 }
209
210 #[test]
211 fn test_step2_bob() -> anyhow::Result<()> {
212 let dhke = Dhke::new();
213 let blinding_factor =
214 hex_to_string("0000000000000000000000000000000000000000000000000000000000000001");
215 let (pub_key, _) = dhke.step1_alice("test_message", Some(blinding_factor.as_bytes()))?;
216
217 let a = pk_from_hex("0000000000000000000000000000000000000000000000000000000000000001");
218
219 let c = dhke.step2_bob(pub_key, &a)?;
220 let c_str = c.to_string();
221 assert_eq!(
222 "02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2".to_string(),
223 c_str
224 );
225
226 Ok(())
227 }
228
229 #[test]
230 fn test_step3_alice() -> anyhow::Result<()> {
231 let dhke = Dhke::new();
232 let c_ = public_key_from_hex(
233 "02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2",
234 );
235
236 let r = pk_from_hex("0000000000000000000000000000000000000000000000000000000000000001");
237
238 let a = public_key_from_hex(
239 "020000000000000000000000000000000000000000000000000000000000000001",
240 );
241
242 let result = dhke.step3_alice(c_, r, a)?;
243 assert_eq!(
244 "03c724d7e6a5443b39ac8acf11f40420adc4f99a02e7cc1b57703d9391f6d129cd".to_string(),
245 result.to_string()
246 );
247 Ok(())
248 }
249
250 #[test]
251 #[allow(non_snake_case)]
252 fn test_verify() -> anyhow::Result<()> {
253 let dhke = Dhke::new();
265
266 let a = pk_from_hex("0000000000000000000000000000000000000000000000000000000000000001");
268 let A = a.public_key(&dhke.secp);
269
270 let blinding_factor =
271 hex_to_string("0000000000000000000000000000000000000000000000000000000000000002");
272
273 let secret_msg = "test";
275 let (B_, r) = dhke.step1_alice(secret_msg, Some(blinding_factor.as_bytes()))?;
276 let C_ = dhke.step2_bob(B_, &a)?;
277 let C = dhke.step3_alice(C_, r, A)?;
278
279 assert!(dhke.verify(a, C, secret_msg)?);
281 assert!(!dhke.verify(a, C.combine(&C)?, secret_msg)?); assert!(!dhke.verify(a, A, secret_msg)?); Ok(())
285 }
286}