1use crate::DgcContainer;
2use ciborium::{
3 ser::into_writer,
4 value::{Integer, Value},
5};
6use std::iter::FromIterator;
7use std::{
8 convert::{TryFrom, TryInto},
9 ops::Not,
10};
11use thiserror::Error;
12
13const COSE_SIGN1_CBOR_TAG: u64 = 18;
14const CBOR_WEB_TOKEN_TAG: u64 = 61;
15const COSE_HEADER_KEY_KID: i128 = 4;
16const COSE_HEADER_KEY_ALG: i128 = 1;
17const COSE_ES256: i128 = -7;
19const COSE_PS256: i128 = -37;
21
22#[derive(Error, Debug)]
25pub enum CwtParseError {
26 #[error("Cannot parse the data as CBOR: {0}")]
28 CborError(#[from] ciborium::de::Error<std::io::Error>),
29 #[error("The root value is not a tag")]
31 InvalidRootValue,
32 #[error(
34 "Expected COSE_SIGN1_CBOR_TAG ({}) or CBOR_WEB_TOKEN_TAG ({}). Found: {0}",
35 COSE_SIGN1_CBOR_TAG,
36 CBOR_WEB_TOKEN_TAG
37 )]
38 InvalidTag(u64),
39 #[error("The main CBOR object is not an array")]
41 InvalidParts,
42 #[error("The main CBOR array does not contain 4 parts. {0} parts found")]
44 InvalidPartsCount(usize),
45 #[error("The unprotected header section is not a CBOR map or an emtpy sequence of bytes")]
47 MalformedUnProtectedHeader,
48 #[error("The protected header section is not a binary string")]
50 ProtectedHeaderNotBinary,
51 #[error("The protected header section is not valid CBOR-encoded data")]
53 ProtectedHeaderNotValidCbor,
54 #[error("The protected header section does not contain key-value pairs")]
56 ProtectedHeaderNotMap,
57 #[error("The payload section is not a binary string")]
59 PayloadNotBinary,
60 #[error("Cannot deserialize payload: {0}")]
62 InvalidPayload(#[source] ciborium::de::Error<std::io::Error>),
63 #[error("The signature section is not a binary string")]
65 SignatureNotBinary,
66}
67
68#[derive(Debug, PartialEq, Eq)]
70pub enum EcAlg {
71 Es256,
80 Ps256,
89 Unknown(i128),
94}
95
96impl From<Integer> for EcAlg {
97 fn from(i: Integer) -> Self {
98 let u: i128 = i.into();
99 match u {
100 COSE_ES256 => EcAlg::Es256,
101 COSE_PS256 => EcAlg::Ps256,
102 _ => EcAlg::Unknown(u),
103 }
104 }
105}
106
107#[derive(Debug)]
115pub struct CwtHeader {
116 pub kid: Option<Vec<u8>>,
118 pub alg: Option<EcAlg>,
120}
121
122impl CwtHeader {
123 fn new() -> Self {
124 Self {
125 kid: None,
126 alg: None,
127 }
128 }
129
130 fn kid(&mut self, kid: Vec<u8>) {
131 self.kid = Some(kid);
132 }
133
134 fn alg(&mut self, alg: EcAlg) {
135 self.alg = Some(alg);
136 }
137}
138
139impl FromIterator<(Value, Value)> for CwtHeader {
140 fn from_iter<T: IntoIterator<Item = (Value, Value)>>(iter: T) -> Self {
141 let mut header = CwtHeader::new();
143 for (key, val) in iter {
145 if let Value::Integer(k) = key {
146 let k: i128 = k.into();
147 if k == COSE_HEADER_KEY_KID {
148 if let Value::Bytes(kid) = val {
150 header.kid(kid);
151 }
152 } else if k == COSE_HEADER_KEY_ALG {
153 if let Value::Integer(raw_alg) = val {
155 let alg: EcAlg = raw_alg.into();
156 header.alg(alg);
157 }
158 }
159 }
160 }
161 header
162 }
163}
164
165#[derive(Debug)]
171pub struct Cwt {
172 header_protected_raw: Vec<u8>,
173 payload_raw: Vec<u8>,
174 pub header: CwtHeader,
178 pub payload: DgcContainer,
180 pub signature: Vec<u8>,
182}
183
184impl Cwt {
185 pub fn make_sig_structure(&self) -> Vec<u8> {
188 let sig_structure_cbor = Value::Array(vec![
189 Value::Text(String::from("Signature1")), Value::Bytes(self.header_protected_raw.clone()), Value::Bytes(vec![]), Value::Bytes(self.payload_raw.clone()),
193 ]);
194 let mut sig_structure: Vec<u8> = vec![];
195 into_writer(&sig_structure_cbor, &mut sig_structure).unwrap();
196 sig_structure
197 }
198}
199
200trait ValueExt: Sized {
203 fn into_tag(self) -> Result<(u64, Box<Value>), Self>;
204 fn into_array(self) -> Result<Vec<Value>, Self>;
205 fn into_bytes(self) -> Result<Vec<u8>, Self>;
206}
207
208impl ValueExt for Value {
209 fn into_tag(self) -> Result<(u64, Box<Value>), Self> {
210 match self {
211 Self::Tag(tag, content) => Ok((tag, content)),
212 _ => Err(self),
213 }
214 }
215
216 fn into_array(self) -> Result<Vec<Value>, Self> {
217 match self {
218 Self::Array(array) => Ok(array),
219 _ => Err(self),
220 }
221 }
222
223 fn into_bytes(self) -> Result<Vec<u8>, Self> {
224 match self {
225 Self::Bytes(bytes) => Ok(bytes),
226 _ => Err(self),
227 }
228 }
229}
230
231impl TryFrom<&[u8]> for Cwt {
232 type Error = CwtParseError;
233
234 fn try_from(data: &[u8]) -> Result<Self, Self::Error> {
235 use CwtParseError::*;
236
237 let cwt_content = match ciborium::de::from_reader(data)? {
238 Value::Tag(tag_id, content) if tag_id == CBOR_WEB_TOKEN_TAG => *content,
239 cwt => cwt,
240 };
241 let cwt_content = match cwt_content.into_tag() {
242 Ok((COSE_SIGN1_CBOR_TAG, content)) => *content,
243 Ok((tag_id, _)) => return Err(InvalidTag(tag_id)),
244 Err(cwt) => cwt,
245 };
246
247 let parts = cwt_content.into_array().map_err(|_| InvalidParts)?;
248
249 let parts_len = parts.len();
250 let [header_protected_raw, unprotected_header, payload_raw, signature]: [Value; 4] =
251 parts.try_into().map_err(|_| InvalidPartsCount(parts_len))?;
252
253 let header_protected_raw = header_protected_raw
254 .into_bytes()
255 .map_err(|_| ProtectedHeaderNotBinary)?;
256 let payload_raw = payload_raw.into_bytes().map_err(|_| PayloadNotBinary)?;
257 let signature = signature.into_bytes().map_err(|_| SignatureNotBinary)?;
258
259 let unprotected_header = match unprotected_header {
261 Value::Map(values) => Some(values),
262 Value::Bytes(values) if values.is_empty() => Some(Vec::new()),
263 _ => None,
264 }
265 .ok_or(MalformedUnProtectedHeader)?;
266
267 let protected_header_values = header_protected_raw
271 .is_empty()
272 .not()
273 .then(|| {
274 let value = ciborium::de::from_reader(header_protected_raw.as_slice())
275 .map_err(|_| ProtectedHeaderNotValidCbor)?;
276
277 match value {
278 Value::Map(map) => Ok(map),
279 _ => Err(ProtectedHeaderNotMap),
280 }
281 })
282 .transpose()?
283 .unwrap_or_default();
284
285 let header: CwtHeader = unprotected_header
287 .into_iter()
288 .chain(protected_header_values)
289 .collect();
290
291 let payload: DgcContainer =
292 ciborium::de::from_reader(payload_raw.as_slice()).map_err(InvalidPayload)?;
293
294 Ok(Cwt {
295 header_protected_raw,
296 payload_raw,
297 header,
298 payload,
299 signature,
300 })
301 }
302}
303
304impl TryFrom<Vec<u8>> for Cwt {
305 type Error = CwtParseError;
306
307 fn try_from(data: Vec<u8>) -> Result<Self, Self::Error> {
308 data.as_slice().try_into()
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use std::convert::TryInto;
315
316 use super::*;
318
319 #[test]
320 fn it_parses_cose_data() {
321 let raw_hex_cose_data = "d2844da204481c10ebbbc49f78310126a0590111a4041a61657980061a6162d90001624145390103a101a4617481a862736374323032312d31302d30395431323a30333a31325a627474684c50363436342d3462746376416c686f736e204f6e6520446179205375726765727962636f624145626369782955524e3a555643493a56313a41453a384b5354305248303537484938584b57334d384b324e41443036626973781f4d696e6973747279206f66204865616c746820262050726576656e74696f6e6274676938343035333930303662747269323630343135303030636e616da463666e7465424c414b4562666e65424c414b4563676e7466414c53544f4e62676e66414c53544f4e6376657265312e332e3063646f626a313939302d30312d3031584034fc1cee3c4875c18350d24ccd24dd67ce1bda84f5db6b26b4b8a97c8336e159294859924afa7894a45a5af07a8cf536a36be67912d79f5a93540b86bb7377fb";
322 let expected_sig_structure = "846a5369676e6174757265314da204481c10ebbbc49f7831012640590111a4041a61657980061a6162d90001624145390103a101a4617481a862736374323032312d31302d30395431323a30333a31325a627474684c50363436342d3462746376416c686f736e204f6e6520446179205375726765727962636f624145626369782955524e3a555643493a56313a41453a384b5354305248303537484938584b57334d384b324e41443036626973781f4d696e6973747279206f66204865616c746820262050726576656e74696f6e6274676938343035333930303662747269323630343135303030636e616da463666e7465424c414b4562666e65424c414b4563676e7466414c53544f4e62676e66414c53544f4e6376657265312e332e3063646f626a313939302d30312d3031";
323 let expected_kid: Vec<u8> = vec![28, 16, 235, 187, 196, 159, 120, 49];
324 let expected_alg = EcAlg::Es256;
325 let raw_cose_data = hex::decode(raw_hex_cose_data).unwrap();
326
327 let cwt: Cwt = raw_cose_data.as_slice().try_into().unwrap();
328
329 assert_eq!(Some(expected_kid), cwt.header.kid);
330 assert_eq!(Some(expected_alg), cwt.header.alg);
331 assert_eq!(
332 expected_sig_structure,
333 hex::encode(cwt.make_sig_structure())
334 );
335 }
336}