1use serde::{Deserialize, Serialize};
2use std::collections::BTreeMap;
3
4pub mod headers {
6 pub const AUTHORIZATION: &str = "Authorization";
7 pub const X_SERVICE_ID: &str = "X-Service-Id";
8 pub const X_TENANT_ID: &str = "X-Tenant-Id";
9 pub const X_TENANT_NAME: &str = "X-Tenant-Name";
10}
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
17pub struct ServiceId(pub u64, pub u64);
18
19impl From<(u64, u64)> for ServiceId {
20 fn from(value: (u64, u64)) -> Self {
21 ServiceId(value.0, value.1)
22 }
23}
24
25impl ServiceId {
26 pub fn new(main: u64) -> Self {
28 ServiceId(main, 0)
29 }
30
31 pub fn with_subservice(self, sub: u64) -> Self {
33 ServiceId(self.0, sub)
34 }
35
36 pub fn id(&self) -> u64 {
38 self.0
39 }
40
41 pub fn sub_id(&self) -> u64 {
43 self.1
44 }
45
46 pub fn has_sub_id(&self) -> bool {
50 self.1 != 0
51 }
52
53 pub const fn to_be_bytes(&self) -> [u8; 16] {
55 let mut bytes = [0u8; 16];
56 let hi = self.0.to_be_bytes();
57 let lo = self.1.to_be_bytes();
58 let mut i = 0;
59 while i < 8 {
60 bytes[i] = hi[i];
61 bytes[i + 8] = lo[i];
62 i += 1;
63 }
64 bytes
65 }
66
67 pub const fn from_be_bytes(bytes: [u8; 16]) -> Self {
69 let mut hi = [0u8; 8];
70 let mut lo = [0u8; 8];
71 let mut i = 0;
72 while i < 8 {
73 hi[i] = bytes[i];
74 lo[i] = bytes[i + 8];
75 i += 1;
76 }
77 ServiceId(u64::from_be_bytes(hi), u64::from_be_bytes(lo))
78 }
79}
80
81#[derive(Debug, Clone, thiserror::Error)]
82pub enum ServiceIdParseError {
83 #[error(transparent)]
85 ParseInt(#[from] core::num::ParseIntError),
86 #[error("Invalid ServiceId format, expected <main_id>[:<sub_id>]")]
88 Malformed,
89}
90
91impl std::str::FromStr for ServiceId {
92 type Err = ServiceIdParseError;
93
94 fn from_str(s: &str) -> Result<Self, Self::Err> {
95 let mut parts = s.split(':');
96 if let Some(main_str) = parts.next() {
97 if let Some(sub_str) = parts.next() {
98 if parts.next().is_none() {
99 let main = main_str.parse::<u64>()?;
100 let sub = sub_str.parse::<u64>()?;
101 return Ok(ServiceId(main, sub));
102 }
103 } else {
104 let main = main_str.parse::<u64>()?;
105 return Ok(ServiceId::new(main));
106 }
107 }
108 Err(ServiceIdParseError::Malformed)
109 }
110}
111
112impl core::fmt::Display for ServiceId {
113 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
114 if self.has_sub_id() {
115 write!(f, "{}:{}", self.0, self.1)
116 } else {
117 write!(f, "{}:0", self.0)
118 }
119 }
120}
121
122impl<S> axum::extract::FromRequestParts<S> for ServiceId
123where
124 S: Send + Sync,
125{
126 type Rejection = axum::response::Response;
127
128 async fn from_request_parts(
129 parts: &mut axum::http::request::Parts,
130 _state: &S,
131 ) -> Result<Self, Self::Rejection> {
132 use axum::http::StatusCode;
133 use axum::response::IntoResponse;
134
135 let header = match parts.headers.get(crate::types::headers::X_SERVICE_ID) {
136 Some(header) => header,
137 None => {
138 return Err((
139 StatusCode::PRECONDITION_REQUIRED,
140 "Missing X-Service-Id header",
141 )
142 .into_response());
143 }
144 };
145
146 let header_str = match header.to_str() {
147 Ok(header_str) => header_str,
148 Err(_) => {
149 return Err((
150 StatusCode::BAD_REQUEST,
151 "Invalid X-Service-Id header; not a string",
152 )
153 .into_response());
154 }
155 };
156
157 match header_str.parse::<ServiceId>() {
158 Ok(service_id) => Ok(service_id),
159 Err(_) => Err((
160 StatusCode::BAD_REQUEST,
161 "Invalid X-Service-Id header; not a valid ServiceId",
162 )
163 .into_response()),
164 }
165 }
166}
167
168#[derive(
170 Debug,
171 Clone,
172 Copy,
173 PartialEq,
174 Eq,
175 PartialOrd,
176 Ord,
177 Hash,
178 Serialize,
179 Deserialize,
180 prost::Enumeration,
181)]
182#[repr(i32)]
183pub enum KeyType {
184 Unknown = 0,
185 Ecdsa = 1,
187 Sr25519 = 2,
189}
190
191#[derive(Serialize, Deserialize, Debug, Clone)]
193pub struct ChallengeRequest {
194 #[serde(with = "hex")]
196 pub pub_key: Vec<u8>,
197 pub key_type: KeyType,
199}
200
201#[derive(Serialize, Deserialize, Debug, Clone)]
203pub struct ChallengeResponse {
204 #[serde(with = "hex")]
206 pub challenge: [u8; 32],
207 pub expires_at: u64,
210}
211
212#[derive(Serialize, Deserialize, Debug, Clone)]
214pub struct VerifyChallengeRequest {
215 #[serde(flatten)]
217 pub challenge_request: ChallengeRequest,
218 #[serde(with = "hex")]
220 pub challenge: [u8; 32],
221 #[serde(with = "hex")]
223 pub signature: [u8; 64],
224 pub expires_at: u64,
226 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
228 pub additional_headers: BTreeMap<String, String>,
229}
230
231#[derive(Serialize, Deserialize, Debug, Clone)]
233#[serde(tag = "status", content = "data")]
234pub enum VerifyChallengeResponse {
235 Verified {
237 api_key: String,
239 expires_at: u64,
241 },
242 Expired,
244 InvalidSignature,
246
247 ServiceNotFound,
249
250 Unauthorized,
252
253 UnexpectedError {
255 message: String,
257 },
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 #[test]
265 fn test_service_id_creation() {
266 let service_id = ServiceId::new(42);
268 assert_eq!(service_id.0, 42);
269 assert_eq!(service_id.1, 0);
270
271 let service_id = ServiceId::new(42).with_subservice(7);
273 assert_eq!(service_id.0, 42);
274 assert_eq!(service_id.1, 7);
275
276 let service_id = ServiceId::from((42, 7));
278 assert_eq!(service_id.0, 42);
279 assert_eq!(service_id.1, 7);
280 }
281
282 #[test]
283 fn test_service_id_accessors() {
284 let service_id = ServiceId(42, 7);
285
286 assert_eq!(service_id.id(), 42);
287 assert_eq!(service_id.sub_id(), 7);
288 assert!(service_id.has_sub_id());
289
290 let service_id = ServiceId(42, 0);
291 assert!(!service_id.has_sub_id());
292 }
293
294 #[test]
295 fn test_service_id_bytes_conversion() {
296 let service_id = ServiceId(42, 7);
297
298 let bytes = service_id.to_be_bytes();
299 assert_eq!(bytes.len(), 16);
300
301 let reconstructed = ServiceId::from_be_bytes(bytes);
302 assert_eq!(reconstructed, service_id);
303
304 let service_id = ServiceId(0xDEADBEEF, 0xCAFEBABE);
306 let bytes = service_id.to_be_bytes();
307 let reconstructed = ServiceId::from_be_bytes(bytes);
308 assert_eq!(reconstructed, service_id);
309 }
310
311 #[test]
312 fn test_service_id_parsing() {
313 assert_eq!("42".parse::<ServiceId>().unwrap(), ServiceId(42, 0));
315 assert_eq!("42:7".parse::<ServiceId>().unwrap(), ServiceId(42, 7));
316
317 let empty_result = "".parse::<ServiceId>();
319 assert!(empty_result.is_err());
320
321 assert!(matches!(
322 "abc".parse::<ServiceId>(),
323 Err(ServiceIdParseError::ParseInt(_))
324 ));
325 assert!(matches!(
326 "42:7:9".parse::<ServiceId>(),
327 Err(ServiceIdParseError::Malformed)
328 ));
329 assert!(matches!(
330 "42:abc".parse::<ServiceId>(),
331 Err(ServiceIdParseError::ParseInt(_))
332 ));
333 }
334
335 #[test]
336 fn test_service_id_display() {
337 assert_eq!(ServiceId(42, 0).to_string(), "42:0");
338 assert_eq!(ServiceId(42, 7).to_string(), "42:7");
339 }
340
341 #[test]
342 fn test_key_type_conversion() {
343 assert_eq!(KeyType::Unknown as i32, 0);
345 assert_eq!(KeyType::Ecdsa as i32, 1);
346 assert_eq!(KeyType::Sr25519 as i32, 2);
347
348 let key_type: KeyType = unsafe { std::mem::transmute(1i32) };
350 assert_eq!(key_type, KeyType::Ecdsa);
351 }
352
353 #[test]
354 fn test_headers_constants() {
355 assert_eq!(headers::AUTHORIZATION, "Authorization");
356 assert_eq!(headers::X_SERVICE_ID, "X-Service-Id");
357 }
358}