Skip to main content

ic_oss_types/
cose.rs

1use candid::{CandidType, Principal};
2use cose2::{cwt::Claims, iana, tag, CoseMap, Error as CoseError, Label, Sign1Message, Value};
3use ed25519_dalek::{Signature, VerifyingKey};
4use k256::{ecdsa, ecdsa::signature::hazmat::PrehashVerifier};
5use num_traits::ToPrimitive;
6use serde::{Deserialize, Serialize};
7use serde_bytes::{ByteArray, ByteBuf};
8use sha2::Digest;
9
10pub use cose2;
11pub use iana::{AlgorithmES256K as ES256K, AlgorithmEdDSA as EdDSA};
12
13const CLOCK_SKEW: i64 = 5 * 60; // 5 minutes
14const ALG_ED25519: i64 = EdDSA;
15const ALG_SECP256K1: i64 = ES256K;
16
17const SCOPE_NAME: i64 = iana::CWTClaimScope;
18
19pub static BUCKET_TOKEN_AAD: &[u8] = b"ic_oss_bucket";
20
21#[derive(CandidType, Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
22pub struct Token {
23    pub subject: Principal,
24    pub audience: Principal,
25    pub policies: String,
26}
27
28impl Token {
29    pub fn from_sign1(
30        sign1_token: &[u8],
31        secp256k1_pub_keys: &[ByteBuf],
32        ed25519_pub_keys: &[ByteArray<32>],
33        aad: &[u8],
34        now_sec: i64,
35    ) -> Result<Self, String> {
36        let cs1 = Sign1Message::from_slice(sign1_token)
37            .map_err(|err| format!("invalid COSE sign1 token: {}", err))?;
38
39        let tbs_data = Sign1Message::to_be_signed(
40            cs1.protected_raw(),
41            aad,
42            cs1.payload.as_deref().unwrap_or_default(),
43        )
44        .map_err(|err| format!("invalid COSE signing input: {}", err))?;
45        match cs1
46            .protected
47            .alg()
48            .map_err(|err| format!("invalid COSE header: {}", err))?
49        {
50            Some(Label::Int(ALG_SECP256K1)) => {
51                Self::secp256k1_verify(secp256k1_pub_keys, &tbs_data, cs1.signature())?;
52            }
53            Some(Label::Int(ALG_ED25519)) => {
54                Self::ed25519_verify(ed25519_pub_keys, &tbs_data, cs1.signature())?;
55            }
56            alg => {
57                Err(format!("unsupported algorithm: {:?}", alg))?;
58            }
59        }
60
61        Self::from_cwt_bytes(cs1.payload.as_deref().unwrap_or_default(), now_sec)
62    }
63
64    pub fn to_cwt(self, now_sec: i64, expiration_sec: i64) -> Claims {
65        let now = to_cwt_timestamp(now_sec);
66        let expiration = to_cwt_timestamp(now_sec.saturating_add(expiration_sec));
67        let mut extra = CoseMap::new();
68        extra.insert(SCOPE_NAME, self.policies);
69
70        Claims {
71            issuer: None,
72            subject: Some(self.subject.to_text()),
73            audience: Some(self.audience.to_text()),
74            expiration: Some(expiration),
75            not_before: Some(now),
76            issued_at: Some(now),
77            cwt_id: None,
78            extra,
79        }
80    }
81
82    fn secp256k1_verify(
83        pub_keys: &[ByteBuf],
84        tbs_data: &[u8],
85        signature: &[u8],
86    ) -> Result<(), String> {
87        let keys: Vec<ecdsa::VerifyingKey> = pub_keys
88            .iter()
89            .map(|key| {
90                ecdsa::VerifyingKey::from_sec1_bytes(key)
91                    .map_err(|_| "invalid verifying key".to_string())
92            })
93            .collect::<Result<_, _>>()?;
94        let sig = ecdsa::Signature::try_from(signature).map_err(|_| "invalid signature")?;
95        let digest = sha256(tbs_data);
96        match keys
97            .iter()
98            .any(|key| key.verify_prehash(digest.as_slice(), &sig).is_ok())
99        {
100            true => Ok(()),
101            false => Err("signature verification failed".to_string()),
102        }
103    }
104
105    fn ed25519_verify(
106        pub_keys: &[ByteArray<32>],
107        tbs_data: &[u8],
108        signature: &[u8],
109    ) -> Result<(), String> {
110        let keys: Vec<VerifyingKey> = pub_keys
111            .iter()
112            .map(|key| {
113                VerifyingKey::from_bytes(key).map_err(|_| "invalid verifying key".to_string())
114            })
115            .collect::<Result<_, _>>()?;
116        let sig = Signature::from_slice(signature).map_err(|_| "invalid signature")?;
117
118        match keys
119            .iter()
120            .any(|key| key.verify_strict(tbs_data, &sig).is_ok())
121        {
122            true => Ok(()),
123            false => Err("signature verification failed".to_string()),
124        }
125    }
126
127    fn from_cwt_bytes(data: &[u8], now_sec: i64) -> Result<Self, String> {
128        let claims = claims_from_slice(data).map_err(|err| format!("invalid claims: {}", err))?;
129        if let Some(exp) = timestamp_claim(&claims, iana::CWTClaimExp)? {
130            if exp < now_sec - CLOCK_SKEW {
131                return Err("token expired".to_string());
132            }
133        }
134        if let Some(nbf) = timestamp_claim(&claims, iana::CWTClaimNbf)? {
135            if nbf > now_sec + CLOCK_SKEW {
136                return Err("token not yet valid".to_string());
137            }
138        }
139        Self::try_from(claims)
140    }
141}
142
143/// algorithm: EdDSA | ES256K
144pub fn cose_sign1(cs: Claims, alg: i64, key_id: Option<Vec<u8>>) -> Result<Sign1Message, String> {
145    let tagged_payload = cs.to_vec().map_err(|err| err.to_string())?;
146    let payload = tag::skip_tag(tag::CWT_PREFIX, &tagged_payload).to_vec();
147    let mut msg = Sign1Message::new(Some(payload));
148    msg.protected.set_alg(Label::Int(alg));
149    if let Some(key_id) = key_id {
150        msg.unprotected.set_kid(key_id);
151    }
152    Ok(msg)
153}
154
155pub fn cose_sign1_to_vec(sign1: &Sign1Message) -> Result<Vec<u8>, CoseError> {
156    let encoded = sign1.to_vec()?;
157    Ok(tag::skip_tag(tag::SIGN1_PREFIX, &encoded).to_vec())
158}
159
160impl TryFrom<CoseMap> for Token {
161    type Error = String;
162
163    fn try_from(claims: CoseMap) -> Result<Self, Self::Error> {
164        let scope = claims
165            .get_text(SCOPE_NAME)
166            .map_err(|_| "invalid scope text")?
167            .ok_or("missing scope")?;
168        let subject = claims
169            .get_text(iana::CWTClaimSub)
170            .map_err(|_| "invalid subject text")?
171            .ok_or("missing subject")?;
172        let audience = claims
173            .get_text(iana::CWTClaimAud)
174            .map_err(|_| "invalid audience text")?
175            .ok_or("missing audience")?;
176
177        Ok(Token {
178            subject: Principal::from_text(subject)
179                .map_err(|err| format!("invalid subject: {}", err))?,
180            audience: Principal::from_text(audience)
181                .map_err(|err| format!("invalid audience: {}", err))?,
182            policies: scope.to_string(),
183        })
184    }
185}
186
187pub fn sha256(data: &[u8]) -> [u8; 32] {
188    let mut hasher = sha2::Sha256::new();
189    hasher.update(data);
190    hasher.finalize().into()
191}
192
193fn to_cwt_timestamp(value: i64) -> u64 {
194    u64::try_from(value).unwrap_or_default()
195}
196
197fn claims_from_slice(data: &[u8]) -> Result<CoseMap, CoseError> {
198    let data = tag::skip_tag(tag::CBOR_SELF_PREFIX, data);
199    let data = tag::skip_tag(tag::CWT_PREFIX, data);
200    CoseMap::from_slice(data)
201}
202
203fn timestamp_claim(claims: &CoseMap, key: i64) -> Result<Option<i64>, String> {
204    match claims.get(key) {
205        None => Ok(None),
206        Some(Value::Integer(value)) => i64::try_from(*value)
207            .map(Some)
208            .map_err(|_| "invalid timestamp integer".to_string()),
209        Some(Value::Float(value)) => Ok(Some(value.to_i64().unwrap_or_default())),
210        Some(_) => Err("invalid timestamp".to_string()),
211    }
212}
213
214#[cfg(test)]
215mod test {
216    use super::*;
217    use crate::permission::{Operation, Permission, Policies, Policy, Resource, Resources};
218    use ed25519_dalek::Signer;
219
220    #[test]
221    fn test_ed25519_token() {
222        let secret_key = [8u8; 32];
223        let signing_key = ed25519_dalek::SigningKey::from_bytes(&secret_key);
224        let pub_key: &VerifyingKey = signing_key.as_ref();
225        let pub_key = pub_key.to_bytes();
226        let ps = Policies::from([
227            Policy {
228                permission: Permission {
229                    resource: Resource::Bucket,
230                    operation: Operation::Read,
231                    constraint: Some(Resource::All),
232                },
233                resources: Resources::from([]),
234            },
235            Policy {
236                permission: Permission {
237                    resource: Resource::Folder,
238                    operation: Operation::All,
239                    constraint: None,
240                },
241                resources: Resources::from(["1".to_string()]),
242            },
243        ]);
244        let token = Token {
245            subject: Principal::from_text(
246                "z7wjp-v6fe3-kksu5-26f64-dedtw-j7ndj-57onx-qga6c-et5e3-njx53-tae",
247            )
248            .unwrap(),
249            audience: Principal::from_text("mmrxu-fqaaa-aaaap-ahhna-cai").unwrap(),
250            policies: ps.to_string(),
251        };
252        println!("token: {:?}", &token);
253
254        let now_sec = 1720676064;
255        let claims = token.clone().to_cwt(now_sec, 3600);
256        let mut sign1 = cose_sign1(claims, EdDSA, None).unwrap();
257        let tbs_data = sign1
258            .prepare_signature(None, None, Some(BUCKET_TOKEN_AAD))
259            .unwrap();
260        let sig = signing_key.sign(&tbs_data).to_bytes();
261        sign1.set_signature(sig.to_vec()).unwrap();
262        let sign1_token = cose_sign1_to_vec(&sign1).unwrap();
263        println!("principal: {:?}", &Principal::anonymous().to_text());
264        println!("pub_key: {:?}", &pub_key);
265        println!("sign1_token: {:?}", &sign1_token);
266
267        let token2 = Token::from_sign1(
268            &sign1_token,
269            &[],
270            &[pub_key.into()],
271            BUCKET_TOKEN_AAD,
272            now_sec,
273        )
274        .unwrap();
275        assert_eq!(token, token2);
276    }
277}