1use crate::algs;
115use crate::cose_struct;
116use crate::errors::{CoseError, CoseField, CoseResult, CoseResultWithRet};
117use crate::headers::{CoseHeader, COUNTER_SIG};
118use crate::keys;
119use cbor::{Decoder, Encoder};
120use std::io::Cursor;
121
122#[derive(Clone)]
124pub struct CoseAgent {
125 pub header: CoseHeader,
127 pub payload: Vec<u8>,
129 pub(crate) ph_bstr: Vec<u8>,
130 pub pub_key: Vec<u8>,
132 pub s_key: Vec<u8>,
134 pub(crate) context: String,
135 pub(crate) crv: Option<i32>,
136 pub(crate) key_ops: Vec<i32>,
137 pub(crate) base_iv: Option<Vec<u8>>,
138 pub(crate) enc: bool,
139}
140const KEY_OPS_SKEY: [i32; 8] = [
141 keys::KEY_OPS_DERIVE_BITS,
142 keys::KEY_OPS_DERIVE,
143 keys::KEY_OPS_DECRYPT,
144 keys::KEY_OPS_ENCRYPT,
145 keys::KEY_OPS_WRAP,
146 keys::KEY_OPS_UNWRAP,
147 keys::KEY_OPS_MAC_VERIFY,
148 keys::KEY_OPS_MAC,
149];
150
151const SIZE: usize = 3;
152
153impl CoseAgent {
154 pub fn new() -> CoseAgent {
156 CoseAgent {
157 header: CoseHeader::new(),
158 payload: Vec::new(),
159 ph_bstr: Vec::new(),
160 pub_key: Vec::new(),
161 key_ops: Vec::new(),
162 s_key: Vec::new(),
163 crv: None,
164 base_iv: None,
165 context: String::new(),
166 enc: false,
167 }
168 }
169
170 pub fn new_counter_sig() -> CoseAgent {
172 CoseAgent {
173 header: CoseHeader::new(),
174 payload: Vec::new(),
175 ph_bstr: Vec::new(),
176 pub_key: Vec::new(),
177 key_ops: Vec::new(),
178 s_key: Vec::new(),
179 crv: None,
180 base_iv: None,
181 context: cose_struct::COUNTER_SIGNATURE.to_string(),
182 enc: false,
183 }
184 }
185
186 pub fn add_header(&mut self, header: CoseHeader) {
188 self.header = header;
189 }
190
191 pub fn key(&mut self, key: &keys::CoseKey) -> CoseResult {
193 let alg = self.header.alg.ok_or(CoseError::Missing(CoseField::Alg))?;
194 key.verify_kty()?;
195 if algs::ECDH_ALGS.contains(&alg) {
196 if !keys::ECDH_KTY.contains(key.kty.as_ref().ok_or(CoseError::Missing(CoseField::Kty))?)
197 {
198 return Err(CoseError::Invalid(CoseField::Kty));
199 }
200 if key.alg.is_some() && key.alg.unwrap() != alg {
201 return Err(CoseError::AlgMismatch());
202 }
203 } else if (alg != algs::DIRECT
204 && !algs::A_KW.contains(&alg)
205 && !algs::RSA_OAEP.contains(&alg))
206 && key.alg.is_some()
207 && key.alg.unwrap() != alg
208 {
209 return Err(CoseError::AlgMismatch());
210 }
211 if algs::SIGNING_ALGS.contains(&alg) {
212 if key.key_ops.contains(&keys::KEY_OPS_SIGN) {
213 self.s_key = key.get_s_key()?;
214 }
215 if key.key_ops.contains(&keys::KEY_OPS_VERIFY) {
216 self.pub_key = key.get_pub_key()?;
217 }
218 if key.key_ops.is_empty() {
219 self.s_key = match key.get_s_key() {
220 Ok(v) => v,
221 Err(_) => Vec::new(),
222 };
223 self.pub_key = match key.get_pub_key() {
224 Ok(v) => v,
225 Err(_) => Vec::new(),
226 };
227 }
228 } else if algs::KEY_DISTRIBUTION_ALGS.contains(&alg) || algs::ENCRYPT_ALGS.contains(&alg) {
229 if KEY_OPS_SKEY.iter().any(|i| key.key_ops.contains(i)) {
230 self.s_key = key.get_s_key()?;
231 }
232 if key.key_ops.is_empty() {
233 self.s_key = match key.get_s_key() {
234 Ok(v) => v,
235 Err(_) => Vec::new(),
236 };
237 }
238 if (algs::ECDH_ALGS.contains(&alg) || algs::OAEP_ALGS.contains(&alg))
239 && key.key_ops.is_empty()
240 {
241 self.pub_key = key.get_pub_key()?;
242 }
243 }
244 self.crv = key.crv;
245 self.base_iv = key.base_iv.clone();
246 self.key_ops = key.key_ops.clone();
247 Ok(())
248 }
249
250 pub(crate) fn sign(
251 &mut self,
252 content: &Vec<u8>,
253 external_aad: &Vec<u8>,
254 body_protected: &Vec<u8>,
255 ) -> CoseResult {
256 if !self.key_ops.is_empty() && !self.key_ops.contains(&keys::KEY_OPS_SIGN) {
257 return Err(CoseError::Invalid(CoseField::KeyOp));
258 }
259 self.ph_bstr = self.header.get_protected_bstr(false)?;
260 self.payload = cose_struct::gen_sig(
261 &self.s_key,
262 &self.header.alg.ok_or(CoseError::Missing(CoseField::Alg))?,
263 &self.crv,
264 &external_aad,
265 &self.context,
266 &body_protected,
267 &self.ph_bstr,
268 &content,
269 )?;
270 Ok(())
271 }
272 pub(crate) fn verify(
273 &self,
274 content: &Vec<u8>,
275 external_aad: &Vec<u8>,
276 body_protected: &Vec<u8>,
277 ) -> CoseResultWithRet<bool> {
278 if !self.key_ops.is_empty() && !self.key_ops.contains(&keys::KEY_OPS_VERIFY) {
279 return Err(CoseError::Invalid(CoseField::KeyOp));
280 }
281 Ok(cose_struct::verify_sig(
282 &self.pub_key,
283 &self.header.alg.ok_or(CoseError::Missing(CoseField::Alg))?,
284 &self.crv,
285 &external_aad,
286 &self.context,
287 &body_protected,
288 &self.ph_bstr,
289 &content,
290 &self.payload,
291 )?)
292 }
293
294 pub fn add_signature(&mut self, signature: Vec<u8>) -> CoseResult {
300 if self.context != cose_struct::COUNTER_SIGNATURE {
301 return Err(CoseError::InvalidContext(self.context.clone()));
302 }
303 self.payload = signature;
304 Ok(())
305 }
306
307 pub(crate) fn get_sign_content(
308 &mut self,
309 content: &Vec<u8>,
310 external_aad: &Vec<u8>,
311 body_protected: &Vec<u8>,
312 ) -> CoseResultWithRet<Vec<u8>> {
313 if self.context != cose_struct::COUNTER_SIGNATURE {
314 return Err(CoseError::InvalidContext(self.context.clone()));
315 }
316 self.ph_bstr = self.header.get_protected_bstr(false)?;
317 cose_struct::get_to_sign(
318 &external_aad,
319 cose_struct::COUNTER_SIGNATURE,
320 &body_protected,
321 &self.ph_bstr,
322 &content,
323 )
324 }
325
326 pub fn counter_sig(
328 &self,
329 external_aad: Option<Vec<u8>>,
330 counter: &mut CoseAgent,
331 ) -> CoseResult {
332 if !self.enc {
333 Err(CoseError::Missing(CoseField::Payload))
334 } else {
335 let aead = match external_aad {
336 None => Vec::new(),
337 Some(v) => v,
338 };
339 counter.sign(&self.payload, &aead, &self.ph_bstr)?;
340 Ok(())
341 }
342 }
343
344 pub fn get_to_sign(
349 &self,
350 external_aad: Option<Vec<u8>>,
351 counter: &mut CoseAgent,
352 ) -> CoseResultWithRet<Vec<u8>> {
353 if !self.enc {
354 Err(CoseError::Missing(CoseField::Payload))
355 } else {
356 let aead = match external_aad {
357 None => Vec::new(),
358 Some(v) => v,
359 };
360 counter.get_sign_content(&self.payload, &aead, &self.ph_bstr)
361 }
362 }
363
364 pub fn get_to_verify(
369 &mut self,
370 external_aad: Option<Vec<u8>>,
371 counter: &usize,
372 ) -> CoseResultWithRet<Vec<u8>> {
373 if !self.enc {
374 Err(CoseError::Missing(CoseField::Payload))
375 } else {
376 let aead = match external_aad {
377 None => Vec::new(),
378 Some(v) => v,
379 };
380 self.header.counters[*counter].get_sign_content(&self.payload, &aead, &self.ph_bstr)
381 }
382 }
383
384 pub fn counters_verify(&mut self, external_aad: Option<Vec<u8>>, counter: usize) -> CoseResult {
386 if !self.enc {
387 Err(CoseError::Missing(CoseField::Payload))
388 } else {
389 let aead = match external_aad {
390 None => Vec::new(),
391 Some(v) => v,
392 };
393 if self.header.counters[counter].verify(&self.payload, &aead, &self.ph_bstr)? {
394 Ok(())
395 } else {
396 Err(CoseError::Invalid(CoseField::CounterSignature))
397 }
398 }
399 }
400
401 pub fn add_counter_sig(&mut self, counter: CoseAgent) -> CoseResult {
404 if !algs::SIGNING_ALGS.contains(
405 &counter
406 .header
407 .alg
408 .ok_or(CoseError::Missing(CoseField::Alg))?,
409 ) {
410 return Err(CoseError::Invalid(CoseField::Alg));
411 }
412 if counter.context != cose_struct::COUNTER_SIGNATURE {
413 return Err(CoseError::InvalidContext(counter.context));
414 }
415 if self.header.unprotected.contains(&COUNTER_SIG) {
416 self.header.counters.push(counter);
417 Ok(())
418 } else {
419 self.header.counters.push(counter);
420 self.header.remove_label(COUNTER_SIG);
421 self.header.unprotected.push(COUNTER_SIG);
422 Ok(())
423 }
424 }
425
426 pub(crate) fn derive_key(
427 &mut self,
428 cek: &Vec<u8>,
429 size: usize,
430 sender: bool,
431 true_alg: &i32,
432 ) -> CoseResultWithRet<Vec<u8>> {
433 if self.ph_bstr.is_empty() {
434 self.ph_bstr = self.header.get_protected_bstr(false)?;
435 }
436 let alg = self
437 .header
438 .alg
439 .as_ref()
440 .ok_or(CoseError::Missing(CoseField::Alg))?;
441 if algs::A_KW.contains(alg) {
442 if sender {
443 self.payload = algs::aes_key_wrap(&self.s_key, size, &cek)?;
444 } else {
445 return Ok(algs::aes_key_unwrap(&self.s_key, size, &cek)?);
446 }
447 return Ok(cek.to_vec());
448 } else if algs::RSA_OAEP.contains(alg) {
449 if sender {
450 self.payload = algs::rsa_oaep_enc(&self.pub_key, &cek, alg)?;
451 } else {
452 return Ok(algs::rsa_oaep_dec(&self.s_key, size, &cek, alg)?);
453 }
454 return Ok(cek.to_vec());
455 } else if algs::D_HA.contains(alg) || algs::D_HS.contains(alg) {
456 let mut kdf_context = cose_struct::gen_kdf(
457 true_alg,
458 &self.header.party_u_identity,
459 &self.header.party_u_nonce,
460 &self.header.party_u_other,
461 &self.header.party_v_identity,
462 &self.header.party_v_nonce,
463 &self.header.party_v_other,
464 size as u16 * 8,
465 &self.ph_bstr,
466 &self.header.pub_other,
467 &self.header.priv_info,
468 )?;
469 return Ok(algs::hkdf(
470 size,
471 &self.s_key,
472 self.header.salt.as_ref(),
473 &mut kdf_context,
474 self.header.alg.unwrap(),
475 )?);
476 } else if algs::ECDH_H.contains(alg) || algs::ECDH_A.contains(alg) {
477 let (receiver_key, sender_key, crv_rec, crv_send);
478 if sender {
479 if self.pub_key.is_empty() {
480 return Err(CoseError::MissingKey());
481 }
482 receiver_key = self.pub_key.clone();
483 if !self.header.x5_private.is_empty() {
484 sender_key = self.header.x5_private.clone();
485 crv_send = None;
486 } else {
487 sender_key = self.header.ecdh_key.get_s_key()?;
488 crv_send = Some(self.header.ecdh_key.crv.unwrap());
489 }
490 crv_rec = Some(self.crv.unwrap());
491 } else {
492 if self.s_key.is_empty() {
493 return Err(CoseError::MissingKey());
494 }
495 if self.header.x5chain_sender.is_some() {
496 algs::verify_chain(self.header.x5chain_sender.as_ref().unwrap())?;
497 receiver_key = self.header.x5chain_sender.as_ref().unwrap()[0].clone();
498 crv_rec = None;
499 } else {
500 receiver_key = self.header.ecdh_key.get_pub_key()?;
501 crv_rec = Some(self.crv.unwrap());
502 }
503 sender_key = self.s_key.clone();
504 crv_send = Some(self.crv.unwrap());
505 }
506 let shared = algs::ecdh_derive_key(crv_rec, crv_send, &receiver_key, &sender_key)?;
507
508 if algs::ECDH_H.contains(alg) {
509 let mut kdf_context = cose_struct::gen_kdf(
510 true_alg,
511 &self.header.party_u_identity,
512 &self.header.party_u_nonce,
513 &self.header.party_u_other,
514 &self.header.party_v_identity,
515 &self.header.party_v_nonce,
516 &self.header.party_v_other,
517 size as u16 * 8,
518 &self.ph_bstr,
519 &self.header.pub_other,
520 &self.header.priv_info,
521 )?;
522 return Ok(algs::hkdf(
523 size,
524 &shared,
525 self.header.salt.as_ref(),
526 &mut kdf_context,
527 self.header.alg.unwrap(),
528 )?);
529 } else {
530 let size_akw = algs::get_cek_size(&alg)?;
531
532 let alg_akw;
533 if [algs::ECDH_ES_A128KW, algs::ECDH_SS_A128KW].contains(alg) {
534 alg_akw = algs::A128KW;
535 } else if [algs::ECDH_ES_A192KW, algs::ECDH_SS_A192KW].contains(alg) {
536 alg_akw = algs::A192KW;
537 } else {
538 alg_akw = algs::A256KW;
539 }
540
541 let mut kdf_context = cose_struct::gen_kdf(
542 &alg_akw,
543 &self.header.party_u_identity,
544 &self.header.party_u_nonce,
545 &self.header.party_u_other,
546 &self.header.party_v_identity,
547 &self.header.party_v_nonce,
548 &self.header.party_v_other,
549 size_akw as u16 * 8,
550 &self.ph_bstr,
551 &self.header.pub_other,
552 &self.header.priv_info,
553 )?;
554 let kek = algs::hkdf(
555 size_akw,
556 &shared,
557 self.header.salt.as_ref(),
558 &mut kdf_context,
559 self.header.alg.unwrap(),
560 )?;
561 if sender {
562 self.payload = algs::aes_key_wrap(&kek, size, &cek)?;
563 } else {
564 return Ok(algs::aes_key_unwrap(&kek, size, &cek)?);
565 }
566 return Ok(cek.to_vec());
567 }
568 } else {
569 return Err(CoseError::Invalid(CoseField::Alg));
570 }
571 }
572
573 pub(crate) fn decode(&mut self, d: &mut Decoder<Cursor<Vec<u8>>>) -> CoseResult {
574 if !self.ph_bstr.is_empty() {
575 self.header.decode_protected_bstr(&self.ph_bstr)?;
576 }
577 self.header
578 .decode_unprotected(d, self.context == cose_struct::COUNTER_SIGNATURE)?;
579 self.payload = d.bytes()?;
580 self.header.labels_found = Vec::new();
581 Ok(())
582 }
583
584 pub(crate) fn encode(&mut self, e: &mut Encoder<Vec<u8>>) -> CoseResult {
585 e.array(SIZE)?;
586 e.bytes(&self.ph_bstr)?;
587 self.header.encode_unprotected(e)?;
588 e.bytes(&self.payload)?;
589 self.header.labels_found = Vec::new();
590 Ok(())
591 }
592}