1use std::collections::HashMap;
2use crate::error::{QVError, QVResult};
3
4#[derive(Debug, Clone, Default)]
12pub struct Claims(pub HashMap<String, String>);
13
14impl Claims {
15 pub fn new() -> Self { Claims(HashMap::new()) }
16
17 pub fn insert(&mut self, key: impl Into<String>, value: impl Into<String>) {
18 self.0.insert(key.into(), value.into());
19 }
20
21 pub fn get(&self, key: &str) -> Option<&str> {
22 self.0.get(key).map(String::as_str)
23 }
24
25 pub fn require(&self, key: &str) -> QVResult<&str> {
26 self.get(key).ok_or_else(|| QVError::MissingClaim(key.to_string()))
27 }
28
29 pub fn encode(&self) -> QVResult<Vec<u8>> {
31 if self.0.len() > 15 {
32 return Err(QVError::SerializationError("too many claims (max 15)".into()));
33 }
34 let mut out = Vec::new();
35 out.push(0x80 | self.0.len() as u8); for (k, v) in &self.0 {
37 encode_str(&mut out, k)?;
38 encode_str(&mut out, v)?;
39 }
40 Ok(out)
41 }
42
43 pub fn decode(data: &[u8]) -> QVResult<Self> {
45 if data.is_empty() {
46 return Err(QVError::BufferTooShort { need: 1, have: 0 });
47 }
48 let first = data[0];
49 if first & 0xF0 != 0x80 {
50 return Err(QVError::SerializationError("expected fixmap".into()));
51 }
52 let n = (first & 0x0F) as usize;
53 let mut pos = 1;
54 let mut map = HashMap::new();
55 for _ in 0..n {
56 let (k, adv) = decode_str(&data[pos..])?;
57 pos += adv;
58 let (v, adv) = decode_str(&data[pos..])?;
59 pos += adv;
60 map.insert(k, v);
61 }
62 Ok(Claims(map))
63 }
64}
65
66fn encode_str(out: &mut Vec<u8>, s: &str) -> QVResult<()> {
67 let b = s.as_bytes();
68 if b.len() <= 31 {
69 out.push(0xA0 | b.len() as u8);
70 } else if b.len() <= 255 {
71 out.push(0xd9);
72 out.push(b.len() as u8);
73 } else {
74 return Err(QVError::SerializationError("claim string too long (max 255)".into()));
75 }
76 out.extend_from_slice(b);
77 Ok(())
78}
79
80fn decode_str(data: &[u8]) -> QVResult<(String, usize)> {
81 if data.is_empty() {
82 return Err(QVError::BufferTooShort { need: 1, have: 0 });
83 }
84 let (len, header) = if data[0] & 0xE0 == 0xA0 {
85 ((data[0] & 0x1F) as usize, 1)
86 } else if data[0] == 0xd9 {
87 if data.len() < 2 {
88 return Err(QVError::BufferTooShort { need: 2, have: data.len() });
89 }
90 (data[1] as usize, 2)
91 } else {
92 return Err(QVError::SerializationError(format!("unexpected msgpack byte {:#04x}", data[0])));
93 };
94 if data.len() < header + len {
95 return Err(QVError::BufferTooShort { need: header + len, have: data.len() });
96 }
97 let s = std::str::from_utf8(&data[header..header + len])
98 .map_err(|e| QVError::SerializationError(e.to_string()))?
99 .to_string();
100 Ok((s, header + len))
101}