agent_id_handshake/
protocol.rs1use crate::error::{HandshakeError, Result};
4use crate::messages::{Challenge, CounterChallenge, CounterProof, Hello, Proof, ProofAccepted};
5use agent_id_core::{signing, Did, RootKey};
6use chrono::Utc;
7use std::collections::HashSet;
8use std::sync::Mutex;
9
10pub const DEFAULT_TIMESTAMP_TOLERANCE_MS: i64 = 5 * 60 * 1000;
12
13pub const DEFAULT_SESSION_DURATION_MS: i64 = 24 * 60 * 60 * 1000;
15
16pub struct NonceCache {
18 seen: Mutex<HashSet<String>>,
19 #[allow(dead_code)]
20 max_age_ms: i64,
21}
22
23impl NonceCache {
24 pub fn new(max_age_ms: i64) -> Self {
25 Self {
26 seen: Mutex::new(HashSet::new()),
27 max_age_ms,
28 }
29 }
30
31 pub fn check_and_insert(&self, nonce: &str) -> bool {
34 let mut seen = self.seen.lock().unwrap();
35 if seen.contains(nonce) {
36 return false;
37 }
38 seen.insert(nonce.to_string());
39 true
40 }
41
42 pub fn clear(&self) {
44 let mut seen = self.seen.lock().unwrap();
45 seen.clear();
46 }
47}
48
49impl Default for NonceCache {
50 fn default() -> Self {
51 Self::new(DEFAULT_TIMESTAMP_TOLERANCE_MS * 2)
52 }
53}
54
55pub struct Verifier {
57 pub my_did: Did,
58 pub timestamp_tolerance_ms: i64,
59 pub nonce_cache: NonceCache,
60}
61
62impl Verifier {
63 pub fn new(my_did: Did) -> Self {
64 Self {
65 my_did,
66 timestamp_tolerance_ms: DEFAULT_TIMESTAMP_TOLERANCE_MS,
67 nonce_cache: NonceCache::default(),
68 }
69 }
70
71 pub fn handle_hello(&self, hello: &Hello) -> Result<Challenge> {
73 self.verify_timestamp(hello.timestamp)?;
75
76 if hello.version != "1.0" {
78 return Err(HandshakeError::UnsupportedVersion(hello.version.clone()));
79 }
80
81 Ok(Challenge::new(self.my_did.to_string(), hello.did.clone()))
83 }
84
85 pub fn verify_proof(&self, proof: &Proof, original_challenge: &Challenge) -> Result<()> {
87 let expected_hash = hash_challenge(original_challenge)?;
89 if proof.challenge_hash != expected_hash {
90 return Err(HandshakeError::InvalidSignature);
91 }
92
93 if let Some(ref counter) = proof.counter_challenge {
95 self.verify_timestamp(counter.timestamp)?;
96
97 if counter.audience != self.my_did.to_string() {
99 return Err(HandshakeError::AudienceMismatch {
100 expected: self.my_did.to_string(),
101 got: counter.audience.clone(),
102 });
103 }
104
105 if !self.nonce_cache.check_and_insert(&counter.nonce) {
107 return Err(HandshakeError::NonceReplay);
108 }
109 }
110
111 let responder_did: Did = proof.responder_did.parse()?;
113
114 let public_key = responder_did.public_key()?;
117
118 let sig_bytes =
120 base64::Engine::decode(&base64::engine::general_purpose::STANDARD, &proof.signature)
121 .map_err(|_| HandshakeError::InvalidSignature)?;
122
123 let signature = ed25519_dalek::Signature::from_bytes(
124 &sig_bytes
125 .try_into()
126 .map_err(|_| HandshakeError::InvalidSignature)?,
127 );
128
129 agent_id_core::keys::verify(&public_key, proof.challenge_hash.as_bytes(), &signature)?;
130
131 Ok(())
132 }
133
134 pub fn accept_proof(&self, proof: &Proof, my_key: &RootKey) -> Result<ProofAccepted> {
136 let counter_challenge = proof
137 .counter_challenge
138 .as_ref()
139 .ok_or_else(|| HandshakeError::MissingField("counter_challenge".to_string()))?;
140
141 let counter_proof = sign_counter_proof(counter_challenge, my_key)?;
142
143 Ok(ProofAccepted {
144 type_: "ProofAccepted".to_string(),
145 version: "1.0".to_string(),
146 session_id: uuid::Uuid::now_v7().to_string(),
147 counter_proof,
148 session_expires_at: Utc::now().timestamp_millis() + DEFAULT_SESSION_DURATION_MS,
149 })
150 }
151
152 fn verify_timestamp(&self, timestamp: i64) -> Result<()> {
153 let now = Utc::now().timestamp_millis();
154 let diff = (now - timestamp).abs();
155
156 if diff > self.timestamp_tolerance_ms {
157 return Err(HandshakeError::TimestampOutOfRange);
158 }
159
160 Ok(())
161 }
162}
163
164pub fn hash_challenge(challenge: &Challenge) -> Result<String> {
166 let hash = signing::hash(challenge)?;
167 Ok(format!("sha256:{}", hex::encode(hash)))
168}
169
170pub fn hash_counter_challenge(counter: &CounterChallenge) -> Result<String> {
172 let hash = signing::hash(counter)?;
173 Ok(format!("sha256:{}", hex::encode(hash)))
174}
175
176pub fn sign_proof(
178 challenge: &Challenge,
179 my_did: &Did,
180 my_key: &RootKey,
181 counter_audience: Option<String>,
182) -> Result<Proof> {
183 let challenge_hash = hash_challenge(challenge)?;
184
185 let signature = my_key.sign(challenge_hash.as_bytes());
187 let sig_b64 = base64::Engine::encode(
188 &base64::engine::general_purpose::STANDARD,
189 signature.to_bytes(),
190 );
191
192 let mut proof = Proof::new(
193 challenge_hash,
194 my_did.to_string(),
195 format!("{}#root", my_did),
196 );
197 proof.signature = sig_b64;
198
199 if let Some(audience) = counter_audience {
201 proof = proof.with_counter_challenge(CounterChallenge::new(audience));
202 }
203
204 Ok(proof)
205}
206
207pub fn sign_counter_proof(counter: &CounterChallenge, my_key: &RootKey) -> Result<CounterProof> {
209 let challenge_hash = hash_counter_challenge(counter)?;
210
211 let signature = my_key.sign(challenge_hash.as_bytes());
212 let sig_b64 = base64::Engine::encode(
213 &base64::engine::general_purpose::STANDARD,
214 signature.to_bytes(),
215 );
216
217 let my_did = my_key.did();
218
219 Ok(CounterProof {
220 challenge_hash,
221 responder_did: my_did.to_string(),
222 signing_key: format!("{}#root", my_did),
223 signature: sig_b64,
224 })
225}
226
227pub fn verify_counter_proof(
229 counter_proof: &CounterProof,
230 original_counter_challenge: &CounterChallenge,
231) -> Result<()> {
232 let expected_hash = hash_counter_challenge(original_counter_challenge)?;
234 if counter_proof.challenge_hash != expected_hash {
235 return Err(HandshakeError::InvalidSignature);
236 }
237
238 let responder_did: Did = counter_proof.responder_did.parse()?;
240 let public_key = responder_did.public_key()?;
241
242 let sig_bytes = base64::Engine::decode(
243 &base64::engine::general_purpose::STANDARD,
244 &counter_proof.signature,
245 )
246 .map_err(|_| HandshakeError::InvalidSignature)?;
247
248 let signature = ed25519_dalek::Signature::from_bytes(
249 &sig_bytes
250 .try_into()
251 .map_err(|_| HandshakeError::InvalidSignature)?,
252 );
253
254 agent_id_core::keys::verify(
255 &public_key,
256 counter_proof.challenge_hash.as_bytes(),
257 &signature,
258 )?;
259
260 Ok(())
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 #[test]
268 fn test_full_handshake() {
269 let key_a = RootKey::generate();
271 let key_b = RootKey::generate();
272 let did_a = key_a.did();
273 let did_b = key_b.did();
274
275 let hello = Hello::new(did_a.to_string());
277
278 let verifier_b = Verifier::new(did_b.clone());
280 let challenge = verifier_b.handle_hello(&hello).unwrap();
281
282 assert_eq!(challenge.issuer, did_b.to_string());
283 assert_eq!(challenge.audience, did_a.to_string());
284
285 let proof = sign_proof(&challenge, &did_a, &key_a, Some(did_b.to_string())).unwrap();
287
288 assert!(!proof.signature.is_empty());
289 assert!(proof.counter_challenge.is_some());
290
291 verifier_b.verify_proof(&proof, &challenge).unwrap();
293
294 let accepted = verifier_b.accept_proof(&proof, &key_b).unwrap();
296
297 assert!(!accepted.session_id.is_empty());
298 assert!(!accepted.counter_proof.signature.is_empty());
299
300 verify_counter_proof(
302 &accepted.counter_proof,
303 proof.counter_challenge.as_ref().unwrap(),
304 )
305 .unwrap();
306
307 }
309
310 #[test]
311 fn test_replay_protection() {
312 let key_a = RootKey::generate();
313 let key_b = RootKey::generate();
314 let did_a = key_a.did();
315 let did_b = key_b.did();
316
317 let verifier_b = Verifier::new(did_b.clone());
318
319 let hello = Hello::new(did_a.to_string());
321 let challenge = verifier_b.handle_hello(&hello).unwrap();
322 let proof = sign_proof(&challenge, &did_a, &key_a, Some(did_b.to_string())).unwrap();
323
324 verifier_b.verify_proof(&proof, &challenge).unwrap();
325
326 let result = verifier_b.verify_proof(&proof, &challenge);
328 assert!(matches!(result, Err(HandshakeError::NonceReplay)));
329 }
330
331 #[test]
332 fn test_nonce_cache() {
333 let cache = NonceCache::default();
334
335 assert!(cache.check_and_insert("nonce1"));
336 assert!(cache.check_and_insert("nonce2"));
337
338 assert!(!cache.check_and_insert("nonce1"));
340 assert!(!cache.check_and_insert("nonce2"));
341
342 assert!(cache.check_and_insert("nonce3"));
344 }
345}