1use crate::{
2 device_pubkey_from_secret_bytes, kdf, random_secret_key_bytes, secret_key_from_bytes,
3 DevicePubkey, DomainError, ProtocolContext, Result, UnixSeconds, MAX_SKIP,
4};
5use base64::Engine;
6use nostr::nips::nip44::{self, Version};
7use nostr::PublicKey;
8use rand::rngs::OsRng;
9use rand::{CryptoRng, RngCore};
10use serde::{Deserialize, Serialize};
11use std::collections::BTreeMap;
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
14#[serde(rename_all = "camelCase")]
15pub struct Header {
16 pub number: u32,
17 pub previous_chain_length: u32,
18 pub next_public_key: DevicePubkey,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
22pub struct SerializableKeyPair {
23 pub public_key: DevicePubkey,
24 #[serde(with = "serde_bytes_array")]
25 pub private_key: [u8; 32],
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
29pub struct SkippedKeysEntry {
30 #[serde(with = "serde_btreemap_u32_bytes")]
31 pub message_keys: BTreeMap<u32, [u8; 32]>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
35pub struct SessionState {
36 #[serde(with = "serde_bytes_array")]
37 pub root_key: [u8; 32],
38 pub their_current_nostr_public_key: Option<DevicePubkey>,
39 pub their_next_nostr_public_key: Option<DevicePubkey>,
40 #[serde(default, skip_serializing_if = "Option::is_none")]
41 pub our_previous_nostr_key: Option<SerializableKeyPair>,
42 pub our_current_nostr_key: Option<SerializableKeyPair>,
43 pub our_next_nostr_key: SerializableKeyPair,
44 #[serde(default, with = "serde_option_bytes_array")]
45 pub receiving_chain_key: Option<[u8; 32]>,
46 #[serde(default, with = "serde_option_bytes_array")]
47 pub sending_chain_key: Option<[u8; 32]>,
48 pub sending_chain_message_number: u32,
49 pub receiving_chain_message_number: u32,
50 pub previous_sending_chain_message_count: u32,
51 pub skipped_keys: BTreeMap<DevicePubkey, SkippedKeysEntry>,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
55pub struct MessageEnvelope {
56 pub sender: DevicePubkey,
57 pub signer_secret_key: [u8; 32],
58 pub created_at: UnixSeconds,
59 pub encrypted_header: String,
60 pub ciphertext: String,
61}
62
63#[derive(Debug, Clone)]
64pub struct SendPlan {
65 pub next_state: SessionState,
66 pub envelope: MessageEnvelope,
67 pub payload: Vec<u8>,
68}
69
70#[derive(Debug, Clone)]
71pub struct SendOutcome {
72 pub envelope: MessageEnvelope,
73 pub payload: Vec<u8>,
74}
75
76#[derive(Debug, Clone)]
77pub struct ReceivePlan {
78 pub next_state: SessionState,
79 pub payload: Vec<u8>,
80 pub sender: DevicePubkey,
81}
82
83#[derive(Debug, Clone)]
84pub struct ReceiveOutcome {
85 pub payload: Vec<u8>,
86 pub sender: DevicePubkey,
87}
88
89#[derive(Debug, Clone)]
90pub struct Session {
91 pub state: SessionState,
92 pub name: String,
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq)]
96enum HeaderDecryptionTarget {
97 Current,
98 Next,
99 Previous,
100}
101
102impl Session {
103 pub fn from_state(state: SessionState) -> Self {
104 Self {
105 state,
106 name: String::new(),
107 }
108 }
109
110 pub fn new(state: SessionState, name: String) -> Self {
111 Self { state, name }
112 }
113
114 pub fn init(
115 their_ephemeral_nostr_public_key: PublicKey,
116 our_ephemeral_nostr_private_key: [u8; 32],
117 is_initiator: bool,
118 shared_secret: [u8; 32],
119 _name: Option<String>,
120 ) -> Result<Self> {
121 let mut rng = OsRng;
122 let mut ctx = ProtocolContext::new(
123 UnixSeconds(
124 std::time::SystemTime::now()
125 .duration_since(std::time::UNIX_EPOCH)
126 .unwrap()
127 .as_secs(),
128 ),
129 &mut rng,
130 );
131 let peer = DevicePubkey::from_bytes(their_ephemeral_nostr_public_key.to_bytes());
132 let mut session = if is_initiator {
133 Self::new_initiator(
134 &mut ctx,
135 peer,
136 our_ephemeral_nostr_private_key,
137 shared_secret,
138 )
139 } else {
140 Self::new_responder(
141 &mut ctx,
142 peer,
143 our_ephemeral_nostr_private_key,
144 shared_secret,
145 )
146 }?;
147 session.name = _name.unwrap_or_default();
148 Ok(session)
149 }
150
151 pub fn new_initiator<R>(
152 ctx: &mut ProtocolContext<'_, R>,
153 their_ephemeral_public_key: DevicePubkey,
154 our_ephemeral_private_key: [u8; 32],
155 shared_secret: [u8; 32],
156 ) -> Result<Self>
157 where
158 R: RngCore + CryptoRng,
159 {
160 Self::init_with_context(
161 ctx,
162 their_ephemeral_public_key,
163 our_ephemeral_private_key,
164 true,
165 shared_secret,
166 )
167 }
168
169 pub fn new_responder<R>(
170 ctx: &mut ProtocolContext<'_, R>,
171 their_ephemeral_public_key: DevicePubkey,
172 our_ephemeral_private_key: [u8; 32],
173 shared_secret: [u8; 32],
174 ) -> Result<Self>
175 where
176 R: RngCore + CryptoRng,
177 {
178 Self::init_with_context(
179 ctx,
180 their_ephemeral_public_key,
181 our_ephemeral_private_key,
182 false,
183 shared_secret,
184 )
185 }
186
187 fn init_with_context<R>(
188 ctx: &mut ProtocolContext<'_, R>,
189 their_ephemeral_public_key: DevicePubkey,
190 our_ephemeral_private_key: [u8; 32],
191 is_initiator: bool,
192 shared_secret: [u8; 32],
193 ) -> Result<Self>
194 where
195 R: RngCore + CryptoRng,
196 {
197 let our_keys = nostr::Keys::new(secret_key_from_bytes(&our_ephemeral_private_key)?);
198 let our_next_private_key = random_secret_key_bytes(ctx.rng)?;
199 let our_next_keys = nostr::Keys::new(secret_key_from_bytes(&our_next_private_key)?);
200
201 let (root_key, sending_chain_key, our_current_nostr_key, our_next_nostr_key) =
202 if is_initiator {
203 let our_current_pubkey = DevicePubkey::from_nostr(our_keys.public_key());
204 let conversation_key = nip44::v2::ConversationKey::derive(
205 our_next_keys.secret_key(),
206 &their_ephemeral_public_key.to_nostr()?,
207 )?;
208 let kdf_outputs = kdf(&shared_secret, conversation_key.as_bytes(), 2);
209 (
210 kdf_outputs[0],
211 Some(kdf_outputs[1]),
212 Some(SerializableKeyPair {
213 public_key: our_current_pubkey,
214 private_key: our_ephemeral_private_key,
215 }),
216 SerializableKeyPair {
217 public_key: DevicePubkey::from_nostr(our_next_keys.public_key()),
218 private_key: our_next_private_key,
219 },
220 )
221 } else {
222 (
223 shared_secret,
224 None,
225 None,
226 SerializableKeyPair {
227 public_key: DevicePubkey::from_nostr(our_keys.public_key()),
228 private_key: our_ephemeral_private_key,
229 },
230 )
231 };
232
233 Ok(Self {
234 state: SessionState {
235 root_key,
236 their_current_nostr_public_key: None,
237 their_next_nostr_public_key: Some(their_ephemeral_public_key),
238 our_previous_nostr_key: None,
239 our_current_nostr_key,
240 our_next_nostr_key,
241 receiving_chain_key: None,
242 sending_chain_key,
243 sending_chain_message_number: 0,
244 receiving_chain_message_number: 0,
245 previous_sending_chain_message_count: 0,
246 skipped_keys: BTreeMap::new(),
247 },
248 name: String::new(),
249 })
250 }
251
252 pub fn can_send(&self) -> bool {
253 self.state.their_next_nostr_public_key.is_some()
254 && self.state.our_current_nostr_key.is_some()
255 }
256
257 pub fn matches_sender(&self, sender: DevicePubkey) -> bool {
258 self.state.their_current_nostr_public_key == Some(sender)
259 || self.state.their_next_nostr_public_key == Some(sender)
260 || self.state.skipped_keys.contains_key(&sender)
261 }
262
263 pub fn plan_send(&self, payload: &[u8], now: UnixSeconds) -> Result<SendPlan> {
264 if !self.can_send() {
265 return Err(DomainError::CannotSendYet.into());
266 }
267
268 let mut next_state = self.state.clone();
269 let (header, ciphertext) = ratchet_encrypt(&mut next_state, payload)?;
270 let our_current = self
271 .state
272 .our_current_nostr_key
273 .as_ref()
274 .ok_or(DomainError::SessionNotReady)?;
275 let our_secret = secret_key_from_bytes(&our_current.private_key)?;
276 let their_next = self
277 .state
278 .their_next_nostr_public_key
279 .ok_or(DomainError::SessionNotReady)?;
280 let encrypted_header = nip44::encrypt(
281 &our_secret,
282 &their_next.to_nostr()?,
283 &serde_json::to_string(&header)?,
284 Version::V2,
285 )?;
286
287 Ok(SendPlan {
288 next_state,
289 envelope: MessageEnvelope {
290 sender: our_current.public_key,
291 signer_secret_key: our_current.private_key,
292 created_at: now,
293 encrypted_header,
294 ciphertext,
295 },
296 payload: payload.to_vec(),
297 })
298 }
299
300 pub fn apply_send(&mut self, plan: SendPlan) -> SendOutcome {
301 self.state = plan.next_state;
302 SendOutcome {
303 envelope: plan.envelope,
304 payload: plan.payload,
305 }
306 }
307
308 pub fn plan_receive<R>(
309 &self,
310 ctx: &mut ProtocolContext<'_, R>,
311 envelope: &MessageEnvelope,
312 ) -> Result<ReceivePlan>
313 where
314 R: RngCore + CryptoRng,
315 {
316 if !self.matches_sender(envelope.sender) {
317 return Err(DomainError::UnexpectedSender.into());
318 }
319
320 let mut next_state = self.state.clone();
321 let previous_chain_sender = next_state
322 .their_current_nostr_public_key
323 .or(next_state.their_next_nostr_public_key);
324 let (header, decryption_target) =
325 decrypt_header(&next_state, &envelope.encrypted_header, envelope.sender)?;
326 let should_ratchet = decryption_target == HeaderDecryptionTarget::Next;
327
328 let expected_next = next_state.their_next_nostr_public_key;
329 if should_ratchet && expected_next != Some(header.next_public_key) {
330 next_state.their_current_nostr_public_key = next_state.their_next_nostr_public_key;
331 next_state.their_next_nostr_public_key = Some(header.next_public_key);
332 }
333
334 if should_ratchet {
335 if next_state.receiving_chain_key.is_some() {
336 let skipped_sender = previous_chain_sender.ok_or(DomainError::SessionNotReady)?;
337 skip_message_keys(
338 &mut next_state,
339 header.previous_chain_length,
340 skipped_sender,
341 )?;
342 }
343 ratchet_step(&mut next_state, ctx.rng)?;
344 }
345
346 let payload = ratchet_decrypt(
347 &mut next_state,
348 &header,
349 &envelope.ciphertext,
350 envelope.sender,
351 )?;
352
353 Ok(ReceivePlan {
354 next_state,
355 payload,
356 sender: envelope.sender,
357 })
358 }
359
360 pub fn apply_receive(&mut self, plan: ReceivePlan) -> ReceiveOutcome {
361 self.state = plan.next_state;
362 ReceiveOutcome {
363 payload: plan.payload,
364 sender: plan.sender,
365 }
366 }
367
368 pub fn close(&self) {}
369}
370
371fn ratchet_encrypt(state: &mut SessionState, plaintext: &[u8]) -> Result<(Header, String)> {
372 let sending_chain_key = state
373 .sending_chain_key
374 .ok_or(DomainError::SessionNotReady)?;
375
376 let kdf_outputs = kdf(&sending_chain_key, &[1u8], 2);
377 state.sending_chain_key = Some(kdf_outputs[0]);
378 let message_key = kdf_outputs[1];
379
380 let header = Header {
381 number: state.sending_chain_message_number,
382 next_public_key: state.our_next_nostr_key.public_key,
383 previous_chain_length: state.previous_sending_chain_message_count,
384 };
385
386 state.sending_chain_message_number += 1;
387
388 let conversation_key = nip44::v2::ConversationKey::new(message_key);
389 let encrypted_bytes = nip44::v2::encrypt_to_bytes(&conversation_key, plaintext)?;
390 let ciphertext = base64::engine::general_purpose::STANDARD.encode(encrypted_bytes);
391 Ok((header, ciphertext))
392}
393
394fn ratchet_decrypt(
395 state: &mut SessionState,
396 header: &Header,
397 ciphertext: &str,
398 sender: DevicePubkey,
399) -> Result<Vec<u8>> {
400 if let Some(plaintext) = try_skipped_message_keys(state, header, ciphertext, sender)? {
401 return Ok(plaintext);
402 }
403
404 if state.receiving_chain_key.is_none() {
405 return Err(DomainError::SessionNotReady.into());
406 }
407
408 skip_message_keys(state, header.number, sender)?;
409
410 let receiving_chain_key = state
411 .receiving_chain_key
412 .ok_or(DomainError::SessionNotReady)?;
413
414 let kdf_outputs = kdf(&receiving_chain_key, &[1u8], 2);
415 state.receiving_chain_key = Some(kdf_outputs[0]);
416 let message_key = kdf_outputs[1];
417 state.receiving_chain_message_number += 1;
418
419 let conversation_key = nip44::v2::ConversationKey::new(message_key);
420 let ciphertext_bytes = base64::engine::general_purpose::STANDARD
421 .decode(ciphertext)
422 .map_err(|e| crate::Error::Decryption(e.to_string()))?;
423
424 nip44::v2::decrypt_to_bytes(&conversation_key, &ciphertext_bytes).map_err(Into::into)
425}
426
427fn ratchet_step<R>(state: &mut SessionState, rng: &mut R) -> Result<()>
428where
429 R: RngCore + CryptoRng,
430{
431 state.previous_sending_chain_message_count = state.sending_chain_message_number;
432 state.sending_chain_message_number = 0;
433 state.receiving_chain_message_number = 0;
434
435 let our_next_sk = secret_key_from_bytes(&state.our_next_nostr_key.private_key)?;
436 let their_next_pk = state
437 .their_next_nostr_public_key
438 .ok_or(DomainError::SessionNotReady)?;
439
440 let conversation_key1 =
441 nip44::v2::ConversationKey::derive(&our_next_sk, &their_next_pk.to_nostr()?)?;
442 let kdf_outputs = kdf(&state.root_key, conversation_key1.as_bytes(), 2);
443 state.receiving_chain_key = Some(kdf_outputs[1]);
444 state.our_previous_nostr_key = state.our_current_nostr_key.clone();
445 state.our_current_nostr_key = Some(state.our_next_nostr_key.clone());
446
447 let our_next_private_key = random_secret_key_bytes(rng)?;
448 state.our_next_nostr_key = SerializableKeyPair {
449 public_key: device_pubkey_from_secret_bytes(&our_next_private_key)?,
450 private_key: our_next_private_key,
451 };
452
453 let our_next_sk2 = secret_key_from_bytes(&our_next_private_key)?;
454 let conversation_key2 =
455 nip44::v2::ConversationKey::derive(&our_next_sk2, &their_next_pk.to_nostr()?)?;
456 let kdf_outputs2 = kdf(&kdf_outputs[0], conversation_key2.as_bytes(), 2);
457 state.root_key = kdf_outputs2[0];
458 state.sending_chain_key = Some(kdf_outputs2[1]);
459 Ok(())
460}
461
462fn skip_message_keys(state: &mut SessionState, until: u32, sender: DevicePubkey) -> Result<()> {
463 if until <= state.receiving_chain_message_number {
464 return Ok(());
465 }
466
467 if (until - state.receiving_chain_message_number) as usize > MAX_SKIP {
468 return Err(DomainError::TooManySkippedMessages.into());
469 }
470
471 let entry = state.skipped_keys.entry(sender).or_default();
472
473 while state.receiving_chain_message_number < until {
474 let receiving_chain_key = state
475 .receiving_chain_key
476 .ok_or(DomainError::SessionNotReady)?;
477 let kdf_outputs = kdf(&receiving_chain_key, &[1u8], 2);
478 state.receiving_chain_key = Some(kdf_outputs[0]);
479 entry
480 .message_keys
481 .insert(state.receiving_chain_message_number, kdf_outputs[1]);
482 state.receiving_chain_message_number += 1;
483 }
484
485 prune_skipped_message_keys(&mut entry.message_keys);
486 Ok(())
487}
488
489fn try_skipped_message_keys(
490 state: &mut SessionState,
491 header: &Header,
492 ciphertext: &str,
493 sender: DevicePubkey,
494) -> Result<Option<Vec<u8>>> {
495 if let Some(entry) = state.skipped_keys.get_mut(&sender) {
496 if let Some(message_key) = entry.message_keys.remove(&header.number) {
497 let conversation_key = nip44::v2::ConversationKey::new(message_key);
498 let ciphertext_bytes = base64::engine::general_purpose::STANDARD
499 .decode(ciphertext)
500 .map_err(|e| crate::Error::Decryption(e.to_string()))?;
501 let plaintext = nip44::v2::decrypt_to_bytes(&conversation_key, &ciphertext_bytes)?;
502 if entry.message_keys.is_empty() {
503 state.skipped_keys.remove(&sender);
504 }
505 return Ok(Some(plaintext));
506 }
507 }
508
509 Ok(None)
510}
511
512fn decrypt_header(
513 state: &SessionState,
514 encrypted_header: &str,
515 sender: DevicePubkey,
516) -> Result<(Header, HeaderDecryptionTarget)> {
517 if let Some(current) = &state.our_current_nostr_key {
518 let current_sk = secret_key_from_bytes(¤t.private_key)?;
519 if let Ok(decrypted) = nip44::decrypt(¤t_sk, &sender.to_nostr()?, encrypted_header) {
520 let header: Header = serde_json::from_str(&decrypted)?;
521 return Ok((header, HeaderDecryptionTarget::Current));
522 }
523 }
524
525 let next_sk = secret_key_from_bytes(&state.our_next_nostr_key.private_key)?;
526 if let Ok(decrypted) = nip44::decrypt(&next_sk, &sender.to_nostr()?, encrypted_header) {
527 let header: Header = serde_json::from_str(&decrypted)?;
528 return Ok((header, HeaderDecryptionTarget::Next));
529 }
530
531 if let Some(previous) = &state.our_previous_nostr_key {
532 let previous_sk = secret_key_from_bytes(&previous.private_key)?;
533 if let Ok(decrypted) = nip44::decrypt(&previous_sk, &sender.to_nostr()?, encrypted_header) {
534 let header: Header = serde_json::from_str(&decrypted)?;
535 return Ok((header, HeaderDecryptionTarget::Previous));
536 }
537 }
538
539 Err(crate::Error::Parse("invalid header".to_string()))
540}
541
542fn prune_skipped_message_keys(map: &mut BTreeMap<u32, [u8; 32]>) {
543 while map.len() > MAX_SKIP {
544 let Some(first) = map.keys().next().copied() else {
545 break;
546 };
547 map.remove(&first);
548 }
549}
550
551mod serde_bytes_array {
552 use serde::{Deserialize, Deserializer, Serializer};
553
554 pub fn serialize<S>(bytes: &[u8; 32], serializer: S) -> Result<S::Ok, S::Error>
555 where
556 S: Serializer,
557 {
558 serializer.serialize_str(&hex::encode(bytes))
559 }
560
561 pub fn deserialize<'de, D>(deserializer: D) -> Result<[u8; 32], D::Error>
562 where
563 D: Deserializer<'de>,
564 {
565 let s = String::deserialize(deserializer)?;
566 super::decode_hex_32(&s).map_err(serde::de::Error::custom)
567 }
568}
569
570mod serde_option_bytes_array {
571 use serde::{Deserialize, Deserializer, Serializer};
572
573 pub fn serialize<S>(bytes: &Option<[u8; 32]>, serializer: S) -> Result<S::Ok, S::Error>
574 where
575 S: Serializer,
576 {
577 match bytes {
578 Some(b) => serializer.serialize_str(&hex::encode(b)),
579 None => serializer.serialize_none(),
580 }
581 }
582
583 pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<[u8; 32]>, D::Error>
584 where
585 D: Deserializer<'de>,
586 {
587 let opt: Option<String> = Option::deserialize(deserializer)?;
588 match opt {
589 Some(s) => super::decode_hex_32(&s)
590 .map(Some)
591 .map_err(serde::de::Error::custom),
592 None => Ok(None),
593 }
594 }
595}
596
597mod serde_btreemap_u32_bytes {
598 use serde::{Deserialize, Deserializer, Serialize, Serializer};
599 use std::collections::BTreeMap;
600
601 pub fn serialize<S>(map: &BTreeMap<u32, [u8; 32]>, serializer: S) -> Result<S::Ok, S::Error>
602 where
603 S: Serializer,
604 {
605 let string_map: BTreeMap<String, String> = map
606 .iter()
607 .map(|(k, v)| (k.to_string(), hex::encode(v)))
608 .collect();
609 string_map.serialize(serializer)
610 }
611
612 pub fn deserialize<'de, D>(deserializer: D) -> Result<BTreeMap<u32, [u8; 32]>, D::Error>
613 where
614 D: Deserializer<'de>,
615 {
616 let string_map: BTreeMap<String, String> = BTreeMap::deserialize(deserializer)?;
617 let mut out = BTreeMap::new();
618 for (k, v) in string_map {
619 let idx: u32 = k.parse().map_err(serde::de::Error::custom)?;
620 out.insert(
621 idx,
622 super::decode_hex_32(&v).map_err(serde::de::Error::custom)?,
623 );
624 }
625 Ok(out)
626 }
627}
628
629fn decode_hex_32(value: &str) -> std::result::Result<[u8; 32], String> {
630 let bytes = hex::decode(value).map_err(|e| e.to_string())?;
631 <[u8; 32]>::try_from(bytes.as_slice()).map_err(|_| "invalid 32-byte hex".to_string())
632}
633
634#[cfg(test)]
635mod tests {
636 use super::*;
637 use rand::{rngs::StdRng, SeedableRng};
638
639 fn context(seed: u64) -> ProtocolContext<'static, StdRng> {
640 let rng = Box::new(StdRng::seed_from_u64(seed));
641 let rng = Box::leak(rng);
642 ProtocolContext::new(UnixSeconds(1_700_000_000), rng)
643 }
644
645 #[test]
646 fn header_json_uses_camel_case_wire_fields() {
647 let header = Header {
648 number: 3,
649 previous_chain_length: 2,
650 next_public_key: DevicePubkey::from_bytes([9u8; 32]),
651 };
652
653 let json = serde_json::to_value(&header).unwrap();
654 assert_eq!(json["number"], serde_json::json!(3));
655 assert_eq!(json["previousChainLength"], serde_json::json!(2));
656 assert_eq!(
657 json["nextPublicKey"],
658 serde_json::json!(header.next_public_key.to_string())
659 );
660 assert!(json.get("previous_chain_length").is_none());
661 assert!(json.get("next_public_key").is_none());
662
663 let decoded: Header = serde_json::from_value(json).unwrap();
664 assert_eq!(decoded, header);
665 }
666
667 #[test]
668 fn header_json_rejects_snake_case_wire_fields() {
669 let old_header = serde_json::json!({
670 "number": 3,
671 "previous_chain_length": 2,
672 "next_public_key": DevicePubkey::from_bytes([9u8; 32]).to_string(),
673 });
674
675 assert!(serde_json::from_value::<Header>(old_header).is_err());
676 }
677
678 #[test]
679 fn plan_send_and_apply_receive_roundtrip() {
680 let alice_secret = [1u8; 32];
681 let bob_secret = [2u8; 32];
682 let alice_pub = device_pubkey_from_secret_bytes(&alice_secret).unwrap();
683 let bob_pub = device_pubkey_from_secret_bytes(&bob_secret).unwrap();
684 let shared_secret = [7u8; 32];
685
686 let mut init_ctx_alice = context(1);
687 let alice =
688 Session::new_initiator(&mut init_ctx_alice, bob_pub, alice_secret, shared_secret)
689 .unwrap();
690 let mut init_ctx_bob = context(2);
691 let mut bob =
692 Session::new_responder(&mut init_ctx_bob, alice_pub, bob_secret, shared_secret)
693 .unwrap();
694
695 let payload = b"hello".to_vec();
696 let send_plan = alice
697 .plan_send(&payload, UnixSeconds(1_700_000_010))
698 .unwrap();
699 let send_outcome = alice.clone().apply_send(send_plan.clone());
700
701 let mut recv_ctx = context(10);
702 let receive_plan = bob
703 .plan_receive(&mut recv_ctx, &send_outcome.envelope)
704 .unwrap();
705 let outcome = bob.apply_receive(receive_plan);
706 assert_eq!(outcome.payload, payload);
707 }
708
709 #[test]
710 fn plan_receive_does_not_mutate_original_session() {
711 let alice_secret = [3u8; 32];
712 let bob_secret = [4u8; 32];
713 let alice_pub = device_pubkey_from_secret_bytes(&alice_secret).unwrap();
714 let bob_pub = device_pubkey_from_secret_bytes(&bob_secret).unwrap();
715 let shared_secret = [8u8; 32];
716
717 let mut init_ctx_alice = context(3);
718 let alice =
719 Session::new_initiator(&mut init_ctx_alice, bob_pub, alice_secret, shared_secret)
720 .unwrap();
721 let mut init_ctx_bob = context(4);
722 let bob = Session::new_responder(&mut init_ctx_bob, alice_pub, bob_secret, shared_secret)
723 .unwrap();
724 let bob_before = bob.state.clone();
725
726 let payload = b"typing".to_vec();
727 let send_plan = alice
728 .plan_send(&payload, UnixSeconds(1_700_000_011))
729 .unwrap();
730
731 let mut recv_ctx = context(13);
732 let _ = bob
733 .plan_receive(&mut recv_ctx, &send_plan.envelope)
734 .unwrap();
735
736 assert_eq!(bob.state, bob_before);
737 }
738
739 #[test]
740 fn duplicate_receive_fails_without_corrupting_state() {
741 let alice_secret = [5u8; 32];
742 let bob_secret = [6u8; 32];
743 let alice_pub = device_pubkey_from_secret_bytes(&alice_secret).unwrap();
744 let bob_pub = device_pubkey_from_secret_bytes(&bob_secret).unwrap();
745 let shared_secret = [9u8; 32];
746
747 let mut init_ctx_alice = context(5);
748 let alice =
749 Session::new_initiator(&mut init_ctx_alice, bob_pub, alice_secret, shared_secret)
750 .unwrap();
751 let mut init_ctx_bob = context(6);
752 let mut bob =
753 Session::new_responder(&mut init_ctx_bob, alice_pub, bob_secret, shared_secret)
754 .unwrap();
755
756 let payload = b"hello".to_vec();
757 let send_plan = alice
758 .plan_send(&payload, UnixSeconds(1_700_000_012))
759 .unwrap();
760 let envelope = alice.clone().apply_send(send_plan).envelope;
761
762 let mut recv_ctx = context(15);
763 let first_plan = bob.plan_receive(&mut recv_ctx, &envelope).unwrap();
764 let _ = bob.apply_receive(first_plan);
765 let after_first = bob.state.clone();
766
767 let mut replay_ctx = context(16);
768 let replay = bob.plan_receive(&mut replay_ctx, &envelope);
769 assert!(replay.is_err());
770 assert_eq!(bob.state, after_first);
771 }
772
773 #[test]
774 fn invalid_sender_is_rejected() {
775 let alice_secret = [7u8; 32];
776 let bob_secret = [8u8; 32];
777 let alice_pub = device_pubkey_from_secret_bytes(&alice_secret).unwrap();
778 let shared_secret = [10u8; 32];
779
780 let mut init_ctx_bob = context(7);
781 let bob = Session::new_responder(&mut init_ctx_bob, alice_pub, bob_secret, shared_secret)
782 .unwrap();
783
784 let mut recv_ctx = context(17);
785 let err = bob
786 .plan_receive(
787 &mut recv_ctx,
788 &MessageEnvelope {
789 sender: device_pubkey_from_secret_bytes(&bob_secret).unwrap(),
790 signer_secret_key: bob_secret,
791 created_at: UnixSeconds(1),
792 encrypted_header: "bad".to_string(),
793 ciphertext: "bad".to_string(),
794 },
795 )
796 .unwrap_err();
797 assert!(matches!(
798 err,
799 crate::Error::Domain(DomainError::UnexpectedSender)
800 ));
801 }
802}