pake_cpace/
lib.rs

1#![no_std]
2#![forbid(unsafe_code)]
3
4use core::fmt;
5use curve25519_dalek::{
6    ristretto::{CompressedRistretto, RistrettoPoint},
7    scalar::Scalar,
8    traits::IsIdentity,
9};
10use getrandom::getrandom;
11use hmac_sha512::{Hash, BLOCKBYTES};
12
13pub const SESSION_ID_BYTES: usize = 16;
14pub const STEP1_PACKET_BYTES: usize = 16 + 32;
15pub const STEP2_PACKET_BYTES: usize = 32;
16pub const SHARED_KEY_BYTES: usize = 32;
17
18const DSI1: &str = "CPaceRistretto255-1";
19const DSI2: &str = "CPaceRistretto255-2";
20
21#[derive(Debug)]
22pub enum Error {
23    Overflow(&'static str),
24    Random(getrandom::Error),
25    InvalidPublicKey,
26}
27
28impl fmt::Display for Error {
29    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
30        write!(f, "{:?}", &self)
31    }
32}
33
34impl From<getrandom::Error> for Error {
35    fn from(e: getrandom::Error) -> Self {
36        Error::Random(e)
37    }
38}
39
40#[derive(Debug, Copy, Clone)]
41pub struct SharedKeys {
42    pub k1: [u8; SHARED_KEY_BYTES],
43    pub k2: [u8; SHARED_KEY_BYTES],
44}
45
46#[derive(Debug, Clone)]
47pub struct CPace {
48    session_id: [u8; SESSION_ID_BYTES],
49    p: RistrettoPoint,
50    r: Scalar,
51}
52
53pub struct Step1Out {
54    ctx: CPace,
55    step1_packet: [u8; STEP1_PACKET_BYTES],
56}
57
58impl Step1Out {
59    pub fn packet(&self) -> [u8; STEP1_PACKET_BYTES] {
60        self.step1_packet
61    }
62
63    pub fn step3(&self, step2_packet: &[u8; STEP2_PACKET_BYTES]) -> Result<SharedKeys, Error> {
64        self.ctx.step3(step2_packet)
65    }
66}
67
68pub struct Step2Out {
69    shared_keys: SharedKeys,
70    step2_packet: [u8; STEP2_PACKET_BYTES],
71}
72
73impl Step2Out {
74    pub fn shared_keys(&self) -> SharedKeys {
75        self.shared_keys
76    }
77
78    pub fn packet(&self) -> [u8; STEP2_PACKET_BYTES] {
79        self.step2_packet
80    }
81}
82
83impl CPace {
84    fn new<T: AsRef<[u8]>>(
85        session_id: [u8; SESSION_ID_BYTES],
86        password: &str,
87        id_a: &str,
88        id_b: &str,
89        ad: Option<T>,
90    ) -> Result<Self, Error> {
91        if id_a.len() > 0xff || id_b.len() > 0xff {
92            return Err(Error::Overflow(
93                "Identifiers must be at most 255 bytes long",
94            ));
95        }
96        let zpad = [0u8; BLOCKBYTES];
97        let pad_len = zpad.len().wrapping_sub(DSI1.len() + password.len()) & (zpad.len() - 1);
98        let mut st = Hash::new();
99        st.update(DSI1);
100        st.update(password);
101        st.update(&zpad[..pad_len]);
102        st.update(session_id);
103        st.update([id_a.len() as u8]);
104        st.update(id_a);
105        st.update([id_b.len() as u8]);
106        st.update(id_b);
107        st.update(ad.as_ref().map(|ad| ad.as_ref()).unwrap_or_default());
108        let h = st.finalize();
109        let mut p = RistrettoPoint::from_uniform_bytes(&h);
110        let mut r = [0u8; 64];
111        getrandom(&mut r)?;
112        let r = Scalar::from_bytes_mod_order_wide(&r);
113        p *= r;
114        Ok(CPace { session_id, p, r })
115    }
116
117    fn finalize(
118        &self,
119        op: RistrettoPoint,
120        ya: RistrettoPoint,
121        yb: RistrettoPoint,
122    ) -> Result<SharedKeys, Error> {
123        if op.is_identity() {
124            return Err(Error::InvalidPublicKey);
125        }
126        let p = op * self.r;
127        let mut st = Hash::new();
128        st.update(DSI2);
129        st.update(p.compress().as_bytes());
130        st.update(ya.compress().as_bytes());
131        st.update(yb.compress().as_bytes());
132        let h = st.finalize();
133        let (mut k1, mut k2) = ([0u8; SHARED_KEY_BYTES], [0u8; SHARED_KEY_BYTES]);
134        k1.copy_from_slice(&h[..SHARED_KEY_BYTES]);
135        k2.copy_from_slice(&h[SHARED_KEY_BYTES..]);
136        Ok(SharedKeys { k1, k2 })
137    }
138
139    pub fn step1<T: AsRef<[u8]>>(
140        password: &str,
141        id_a: &str,
142        id_b: &str,
143        ad: Option<T>,
144    ) -> Result<Step1Out, Error> {
145        let mut session_id = [0u8; SESSION_ID_BYTES];
146        getrandom(&mut session_id)?;
147        let ctx = CPace::new(session_id, password, id_a, id_b, ad)?;
148        let mut step1_packet = [0u8; STEP1_PACKET_BYTES];
149        step1_packet[..SESSION_ID_BYTES].copy_from_slice(&ctx.session_id);
150        step1_packet[SESSION_ID_BYTES..].copy_from_slice(ctx.p.compress().as_bytes());
151        Ok(Step1Out { ctx, step1_packet })
152    }
153
154    pub fn step2<T: AsRef<[u8]>>(
155        step1_packet: &[u8; STEP1_PACKET_BYTES],
156        password: &str,
157        id_a: &str,
158        id_b: &str,
159        ad: Option<T>,
160    ) -> Result<Step2Out, Error> {
161        let mut session_id = [0u8; SESSION_ID_BYTES];
162        session_id.copy_from_slice(&step1_packet[..SESSION_ID_BYTES]);
163        let ya = &step1_packet[SESSION_ID_BYTES..];
164        let ctx = CPace::new(session_id, password, id_a, id_b, ad)?;
165        let mut step2_packet = [0u8; STEP2_PACKET_BYTES];
166        step2_packet.copy_from_slice(ctx.p.compress().as_bytes());
167        let ya = CompressedRistretto::from_slice(ya)
168            .map_err(|_| Error::InvalidPublicKey)?
169            .decompress()
170            .ok_or(Error::InvalidPublicKey)?;
171        let shared_keys = ctx.finalize(ya, ya, ctx.p)?;
172        Ok(Step2Out {
173            shared_keys,
174            step2_packet,
175        })
176    }
177
178    pub fn step3(&self, step2_packet: &[u8; STEP2_PACKET_BYTES]) -> Result<SharedKeys, Error> {
179        let yb = CompressedRistretto::from_slice(step2_packet)
180            .map_err(|_| Error::InvalidPublicKey)?
181            .decompress()
182            .ok_or(Error::InvalidPublicKey)?;
183        self.finalize(yb, self.p, yb)
184    }
185}
186
187#[test]
188fn test_cpace() {
189    let client = CPace::step1("password", "client", "server", Some("ad")).unwrap();
190
191    let step2 = CPace::step2(&client.packet(), "password", "client", "server", Some("ad")).unwrap();
192
193    let shared_keys = client.step3(&step2.packet()).unwrap();
194
195    assert_eq!(shared_keys.k1, step2.shared_keys.k1);
196    assert_eq!(shared_keys.k2, step2.shared_keys.k2);
197}