1use serde::{Deserialize, Serialize};
2use std::collections::BTreeMap;
3
4pub mod headers {
6 pub const AUTHORIZATION: &str = "Authorization";
7 pub const X_SERVICE_ID: &str = "X-Service-Id";
8 pub const X_TENANT_ID: &str = "X-Tenant-Id";
9 pub const X_TENANT_NAME: &str = "X-Tenant-Name";
10}
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
17pub struct ServiceId(pub u64, pub u64);
18
19impl From<(u64, u64)> for ServiceId {
20 fn from(value: (u64, u64)) -> Self {
21 ServiceId(value.0, value.1)
22 }
23}
24
25impl ServiceId {
26 pub fn new(main: u64) -> Self {
28 ServiceId(main, 0)
29 }
30
31 pub fn with_subservice(self, sub: u64) -> Self {
33 ServiceId(self.0, sub)
34 }
35
36 pub fn id(&self) -> u64 {
38 self.0
39 }
40
41 pub fn sub_id(&self) -> u64 {
43 self.1
44 }
45
46 pub fn has_sub_id(&self) -> bool {
50 self.1 != 0
51 }
52
53 pub const fn to_be_bytes(&self) -> [u8; 16] {
55 let mut bytes = [0u8; 16];
56 let hi = self.0.to_be_bytes();
57 let lo = self.1.to_be_bytes();
58 let mut i = 0;
59 while i < 8 {
60 bytes[i] = hi[i];
61 bytes[i + 8] = lo[i];
62 i += 1;
63 }
64 bytes
65 }
66
67 pub const fn from_be_bytes(bytes: [u8; 16]) -> Self {
69 let mut hi = [0u8; 8];
70 let mut lo = [0u8; 8];
71 let mut i = 0;
72 while i < 8 {
73 hi[i] = bytes[i];
74 lo[i] = bytes[i + 8];
75 i += 1;
76 }
77 ServiceId(u64::from_be_bytes(hi), u64::from_be_bytes(lo))
78 }
79}
80
81#[derive(Debug, Clone, thiserror::Error)]
82pub enum ServiceIdParseError {
83 #[error(transparent)]
85 ParseInt(#[from] core::num::ParseIntError),
86 #[error("Invalid ServiceId format, expected <main_id>[:<sub_id>]")]
88 Malformed,
89}
90
91impl std::str::FromStr for ServiceId {
92 type Err = ServiceIdParseError;
93
94 fn from_str(s: &str) -> Result<Self, Self::Err> {
95 let mut parts = s.split(':');
96 if let Some(main_str) = parts.next() {
97 if let Some(sub_str) = parts.next() {
98 if parts.next().is_none() {
99 let main = main_str.parse::<u64>()?;
100 let sub = sub_str.parse::<u64>()?;
101 return Ok(ServiceId(main, sub));
102 }
103 } else {
104 let main = main_str.parse::<u64>()?;
105 return Ok(ServiceId::new(main));
106 }
107 }
108 Err(ServiceIdParseError::Malformed)
109 }
110}
111
112impl core::fmt::Display for ServiceId {
113 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
114 if self.has_sub_id() {
115 write!(f, "{}:{}", self.0, self.1)
116 } else {
117 write!(f, "{}:0", self.0)
118 }
119 }
120}
121
122impl<S> axum::extract::FromRequestParts<S> for ServiceId
123where
124 S: Send + Sync,
125{
126 type Rejection = axum::response::Response;
127
128 async fn from_request_parts(
129 parts: &mut axum::http::request::Parts,
130 _state: &S,
131 ) -> Result<Self, Self::Rejection> {
132 use axum::http::StatusCode;
133 use axum::response::IntoResponse;
134
135 let header = match parts.headers.get(crate::types::headers::X_SERVICE_ID) {
136 Some(header) => header,
137 None => {
138 return Err((
139 StatusCode::PRECONDITION_REQUIRED,
140 "Missing X-Service-Id header",
141 )
142 .into_response());
143 }
144 };
145
146 let header_str = match header.to_str() {
147 Ok(header_str) => header_str,
148 Err(_) => {
149 return Err((
150 StatusCode::BAD_REQUEST,
151 "Invalid X-Service-Id header; not a string",
152 )
153 .into_response());
154 }
155 };
156
157 match header_str.parse::<ServiceId>() {
158 Ok(service_id) => Ok(service_id),
159 Err(_) => Err((
160 StatusCode::BAD_REQUEST,
161 "Invalid X-Service-Id header; not a valid ServiceId",
162 )
163 .into_response()),
164 }
165 }
166}
167
168#[derive(
170 Debug,
171 Clone,
172 Copy,
173 PartialEq,
174 Eq,
175 PartialOrd,
176 Ord,
177 Hash,
178 Serialize,
179 Deserialize,
180 prost::Enumeration,
181)]
182#[repr(i32)]
183pub enum KeyType {
184 Unknown = 0,
185 Ecdsa = 1,
187 Sr25519 = 2,
189 Bn254Bls = 3,
191}
192
193#[derive(Serialize, Deserialize, Debug, Clone)]
195pub struct ChallengeRequest {
196 #[serde(with = "hex")]
198 pub pub_key: Vec<u8>,
199 pub key_type: KeyType,
201}
202
203#[derive(Serialize, Deserialize, Debug, Clone)]
205pub struct ChallengeResponse {
206 #[serde(with = "hex")]
208 pub challenge: [u8; 32],
209 pub expires_at: u64,
212}
213
214#[derive(Serialize, Deserialize, Debug, Clone)]
216pub struct VerifyChallengeRequest {
217 #[serde(flatten)]
219 pub challenge_request: ChallengeRequest,
220 #[serde(with = "hex")]
222 pub challenge: [u8; 32],
223 #[serde(with = "hex")]
225 pub signature: [u8; 64],
226 pub expires_at: u64,
228 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
230 pub additional_headers: BTreeMap<String, String>,
231}
232
233#[derive(Serialize, Deserialize, Debug, Clone)]
235#[serde(tag = "status", content = "data")]
236pub enum VerifyChallengeResponse {
237 Verified {
239 api_key: String,
241 expires_at: u64,
243 },
244 Expired,
246 InvalidSignature,
248
249 ServiceNotFound,
251
252 Unauthorized,
254
255 UnexpectedError {
257 message: String,
259 },
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
267 fn test_service_id_creation() {
268 let service_id = ServiceId::new(42);
270 assert_eq!(service_id.0, 42);
271 assert_eq!(service_id.1, 0);
272
273 let service_id = ServiceId::new(42).with_subservice(7);
275 assert_eq!(service_id.0, 42);
276 assert_eq!(service_id.1, 7);
277
278 let service_id = ServiceId::from((42, 7));
280 assert_eq!(service_id.0, 42);
281 assert_eq!(service_id.1, 7);
282 }
283
284 #[test]
285 fn test_service_id_accessors() {
286 let service_id = ServiceId(42, 7);
287
288 assert_eq!(service_id.id(), 42);
289 assert_eq!(service_id.sub_id(), 7);
290 assert!(service_id.has_sub_id());
291
292 let service_id = ServiceId(42, 0);
293 assert!(!service_id.has_sub_id());
294 }
295
296 #[test]
297 fn test_service_id_bytes_conversion() {
298 let service_id = ServiceId(42, 7);
299
300 let bytes = service_id.to_be_bytes();
301 assert_eq!(bytes.len(), 16);
302
303 let reconstructed = ServiceId::from_be_bytes(bytes);
304 assert_eq!(reconstructed, service_id);
305
306 let service_id = ServiceId(0xDEADBEEF, 0xCAFEBABE);
308 let bytes = service_id.to_be_bytes();
309 let reconstructed = ServiceId::from_be_bytes(bytes);
310 assert_eq!(reconstructed, service_id);
311 }
312
313 #[test]
314 fn test_service_id_parsing() {
315 assert_eq!("42".parse::<ServiceId>().unwrap(), ServiceId(42, 0));
317 assert_eq!("42:7".parse::<ServiceId>().unwrap(), ServiceId(42, 7));
318
319 let empty_result = "".parse::<ServiceId>();
321 assert!(empty_result.is_err());
322
323 assert!(matches!(
324 "abc".parse::<ServiceId>(),
325 Err(ServiceIdParseError::ParseInt(_))
326 ));
327 assert!(matches!(
328 "42:7:9".parse::<ServiceId>(),
329 Err(ServiceIdParseError::Malformed)
330 ));
331 assert!(matches!(
332 "42:abc".parse::<ServiceId>(),
333 Err(ServiceIdParseError::ParseInt(_))
334 ));
335 }
336
337 #[test]
338 fn test_service_id_display() {
339 assert_eq!(ServiceId(42, 0).to_string(), "42:0");
340 assert_eq!(ServiceId(42, 7).to_string(), "42:7");
341 }
342
343 #[test]
344 fn test_key_type_conversion() {
345 assert_eq!(KeyType::Unknown as i32, 0);
347 assert_eq!(KeyType::Ecdsa as i32, 1);
348 assert_eq!(KeyType::Sr25519 as i32, 2);
349 assert_eq!(KeyType::Bn254Bls as i32, 3);
350
351 let key_type: KeyType = unsafe { std::mem::transmute(1i32) };
353 assert_eq!(key_type, KeyType::Ecdsa);
354 let key_type: KeyType = unsafe { std::mem::transmute(3i32) };
355 assert_eq!(key_type, KeyType::Bn254Bls);
356 }
357
358 #[test]
359 fn test_headers_constants() {
360 assert_eq!(headers::AUTHORIZATION, "Authorization");
361 assert_eq!(headers::X_SERVICE_ID, "X-Service-Id");
362 }
363}