blueprint_auth/
types.rs

1use serde::{Deserialize, Serialize};
2use std::collections::BTreeMap;
3
4/// Common headers used in the authentication process.
5pub mod headers {
6    pub const AUTHORIZATION: &str = "Authorization";
7    pub const X_SERVICE_ID: &str = "X-Service-Id";
8    pub const X_TENANT_ID: &str = "X-Tenant-Id";
9    pub const X_TENANT_NAME: &str = "X-Tenant-Name";
10}
11
12/// Represents the ID a service in the authentication process.
13///
14/// The `ServiceId` is a tuple of two `u64` values, which can be used to uniquely identify a service.
15/// The first `u64` represents the main service ID, while the second `u64` represents a sub-service or a specific instance of the service.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
17pub struct ServiceId(pub u64, pub u64);
18
19impl From<(u64, u64)> for ServiceId {
20    fn from(value: (u64, u64)) -> Self {
21        ServiceId(value.0, value.1)
22    }
23}
24
25impl ServiceId {
26    /// Creates a new `ServiceId` instance with the given main service ID.
27    pub fn new(main: u64) -> Self {
28        ServiceId(main, 0)
29    }
30
31    /// Creates a new `ServiceId` instance with the given main service ID and sub-service ID.
32    pub fn with_subservice(self, sub: u64) -> Self {
33        ServiceId(self.0, sub)
34    }
35
36    /// The main service ID.
37    pub fn id(&self) -> u64 {
38        self.0
39    }
40
41    /// The sub-service ID.
42    pub fn sub_id(&self) -> u64 {
43        self.1
44    }
45
46    /// Checks if the `ServiceId` has a sub-service ID.
47    ///
48    /// Returns `true` if the sub-service ID is not zero, indicating that it is a specific instance of the service.
49    pub fn has_sub_id(&self) -> bool {
50        self.1 != 0
51    }
52
53    /// Converts the `ServiceId` to a big-endian byte array.
54    pub const fn to_be_bytes(&self) -> [u8; 16] {
55        let mut bytes = [0u8; 16];
56        let hi = self.0.to_be_bytes();
57        let lo = self.1.to_be_bytes();
58        let mut i = 0;
59        while i < 8 {
60            bytes[i] = hi[i];
61            bytes[i + 8] = lo[i];
62            i += 1;
63        }
64        bytes
65    }
66
67    /// Creates a `ServiceId` from a big-endian byte array.
68    pub const fn from_be_bytes(bytes: [u8; 16]) -> Self {
69        let mut hi = [0u8; 8];
70        let mut lo = [0u8; 8];
71        let mut i = 0;
72        while i < 8 {
73            hi[i] = bytes[i];
74            lo[i] = bytes[i + 8];
75            i += 1;
76        }
77        ServiceId(u64::from_be_bytes(hi), u64::from_be_bytes(lo))
78    }
79}
80
81#[derive(Debug, Clone, thiserror::Error)]
82pub enum ServiceIdParseError {
83    /// Error parsing the main or sub-service ID as a `u64`.
84    #[error(transparent)]
85    ParseInt(#[from] core::num::ParseIntError),
86    /// Error parsing the `ServiceId` from a string.
87    #[error("Invalid ServiceId format, expected <main_id>[:<sub_id>]")]
88    Malformed,
89}
90
91impl std::str::FromStr for ServiceId {
92    type Err = ServiceIdParseError;
93
94    fn from_str(s: &str) -> Result<Self, Self::Err> {
95        let mut parts = s.split(':');
96        if let Some(main_str) = parts.next() {
97            if let Some(sub_str) = parts.next() {
98                if parts.next().is_none() {
99                    let main = main_str.parse::<u64>()?;
100                    let sub = sub_str.parse::<u64>()?;
101                    return Ok(ServiceId(main, sub));
102                }
103            } else {
104                let main = main_str.parse::<u64>()?;
105                return Ok(ServiceId::new(main));
106            }
107        }
108        Err(ServiceIdParseError::Malformed)
109    }
110}
111
112impl core::fmt::Display for ServiceId {
113    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
114        if self.has_sub_id() {
115            write!(f, "{}:{}", self.0, self.1)
116        } else {
117            write!(f, "{}:0", self.0)
118        }
119    }
120}
121
122impl<S> axum::extract::FromRequestParts<S> for ServiceId
123where
124    S: Send + Sync,
125{
126    type Rejection = axum::response::Response;
127
128    async fn from_request_parts(
129        parts: &mut axum::http::request::Parts,
130        _state: &S,
131    ) -> Result<Self, Self::Rejection> {
132        use axum::http::StatusCode;
133        use axum::response::IntoResponse;
134
135        let header = match parts.headers.get(crate::types::headers::X_SERVICE_ID) {
136            Some(header) => header,
137            None => {
138                return Err((
139                    StatusCode::PRECONDITION_REQUIRED,
140                    "Missing X-Service-Id header",
141                )
142                    .into_response());
143            }
144        };
145
146        let header_str = match header.to_str() {
147            Ok(header_str) => header_str,
148            Err(_) => {
149                return Err((
150                    StatusCode::BAD_REQUEST,
151                    "Invalid X-Service-Id header; not a string",
152                )
153                    .into_response());
154            }
155        };
156
157        match header_str.parse::<ServiceId>() {
158            Ok(service_id) => Ok(service_id),
159            Err(_) => Err((
160                StatusCode::BAD_REQUEST,
161                "Invalid X-Service-Id header; not a valid ServiceId",
162            )
163                .into_response()),
164        }
165    }
166}
167
168/// Represents the different types of cryptographic keys used in the authentication process.
169#[derive(
170    Debug,
171    Clone,
172    Copy,
173    PartialEq,
174    Eq,
175    PartialOrd,
176    Ord,
177    Hash,
178    Serialize,
179    Deserialize,
180    prost::Enumeration,
181)]
182#[repr(i32)]
183pub enum KeyType {
184    Unknown = 0,
185    /// Ecdsa key type
186    Ecdsa = 1,
187    /// Sr25519 key type
188    Sr25519 = 2,
189}
190
191/// Represents the challenge request sent from the client to the server to request a challenge.
192#[derive(Serialize, Deserialize, Debug, Clone)]
193pub struct ChallengeRequest {
194    /// The public key representing the user in hex format
195    #[serde(with = "hex")]
196    pub pub_key: Vec<u8>,
197    /// The type of the public key
198    pub key_type: KeyType,
199}
200
201/// Represents the challenge response sent from the server to the client after a successful challenge request.
202#[derive(Serialize, Deserialize, Debug, Clone)]
203pub struct ChallengeResponse {
204    /// The challenge string sent from the server to the client to be signed by the user
205    #[serde(with = "hex")]
206    pub challenge: [u8; 32],
207    /// Expires at timestamp in milliseconds since epoch
208    /// the time when the challenge will expire and should not be used anymore
209    pub expires_at: u64,
210}
211
212/// Represents the challenge solution sent from the client to the server after signing the challenge string.
213#[derive(Serialize, Deserialize, Debug, Clone)]
214pub struct VerifyChallengeRequest {
215    /// The original challenge request sent from the server to the client in the first step
216    #[serde(flatten)]
217    pub challenge_request: ChallengeRequest,
218    /// The challenge string sent from the server to the client to be signed by the user
219    #[serde(with = "hex")]
220    pub challenge: [u8; 32],
221    /// The signed challenge string sent from the client to the server
222    #[serde(with = "hex")]
223    pub signature: [u8; 64],
224    /// The timestamp in seconds since epoch at which the token will expire
225    pub expires_at: u64,
226    /// Additional headers to be forwarded to the upstream service
227    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
228    pub additional_headers: BTreeMap<String, String>,
229}
230
231/// Represents the response sent from the server to the client after verifying the challenge solution.
232#[derive(Serialize, Deserialize, Debug, Clone)]
233#[serde(tag = "status", content = "data")]
234pub enum VerifyChallengeResponse {
235    /// The challenge was verified successfully - returns an API key
236    Verified {
237        /// The long-lived API key to be used for token exchange
238        api_key: String,
239        /// A UNIX timestamp in seconds since epoch at which the API key will expire
240        expires_at: u64,
241    },
242    /// The challenge was not verified because the challenge has expired
243    Expired,
244    /// The challenge was not verified because the signature is invalid
245    InvalidSignature,
246
247    /// The challenge was not verified because the service ID is not found
248    ServiceNotFound,
249
250    /// The challenge was not verified because the service ID is not authorized
251    Unauthorized,
252
253    /// An unexpected error occurred during verification
254    UnexpectedError {
255        /// The error message
256        message: String,
257    },
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    #[test]
265    fn test_service_id_creation() {
266        // Create with just main ID
267        let service_id = ServiceId::new(42);
268        assert_eq!(service_id.0, 42);
269        assert_eq!(service_id.1, 0);
270
271        // Create with main ID and add subservice
272        let service_id = ServiceId::new(42).with_subservice(7);
273        assert_eq!(service_id.0, 42);
274        assert_eq!(service_id.1, 7);
275
276        // Create from tuple
277        let service_id = ServiceId::from((42, 7));
278        assert_eq!(service_id.0, 42);
279        assert_eq!(service_id.1, 7);
280    }
281
282    #[test]
283    fn test_service_id_accessors() {
284        let service_id = ServiceId(42, 7);
285
286        assert_eq!(service_id.id(), 42);
287        assert_eq!(service_id.sub_id(), 7);
288        assert!(service_id.has_sub_id());
289
290        let service_id = ServiceId(42, 0);
291        assert!(!service_id.has_sub_id());
292    }
293
294    #[test]
295    fn test_service_id_bytes_conversion() {
296        let service_id = ServiceId(42, 7);
297
298        let bytes = service_id.to_be_bytes();
299        assert_eq!(bytes.len(), 16);
300
301        let reconstructed = ServiceId::from_be_bytes(bytes);
302        assert_eq!(reconstructed, service_id);
303
304        // Test with different values
305        let service_id = ServiceId(0xDEADBEEF, 0xCAFEBABE);
306        let bytes = service_id.to_be_bytes();
307        let reconstructed = ServiceId::from_be_bytes(bytes);
308        assert_eq!(reconstructed, service_id);
309    }
310
311    #[test]
312    fn test_service_id_parsing() {
313        // Valid formats
314        assert_eq!("42".parse::<ServiceId>().unwrap(), ServiceId(42, 0));
315        assert_eq!("42:7".parse::<ServiceId>().unwrap(), ServiceId(42, 7));
316
317        // Invalid formats
318        let empty_result = "".parse::<ServiceId>();
319        assert!(empty_result.is_err());
320
321        assert!(matches!(
322            "abc".parse::<ServiceId>(),
323            Err(ServiceIdParseError::ParseInt(_))
324        ));
325        assert!(matches!(
326            "42:7:9".parse::<ServiceId>(),
327            Err(ServiceIdParseError::Malformed)
328        ));
329        assert!(matches!(
330            "42:abc".parse::<ServiceId>(),
331            Err(ServiceIdParseError::ParseInt(_))
332        ));
333    }
334
335    #[test]
336    fn test_service_id_display() {
337        assert_eq!(ServiceId(42, 0).to_string(), "42:0");
338        assert_eq!(ServiceId(42, 7).to_string(), "42:7");
339    }
340
341    #[test]
342    fn test_key_type_conversion() {
343        // Test KeyType to i32 conversion (as used in the ServiceOwnerModel)
344        assert_eq!(KeyType::Unknown as i32, 0);
345        assert_eq!(KeyType::Ecdsa as i32, 1);
346        assert_eq!(KeyType::Sr25519 as i32, 2);
347
348        // Test i32 to KeyType conversion (using transmute for simplicity in tests)
349        let key_type: KeyType = unsafe { std::mem::transmute(1i32) };
350        assert_eq!(key_type, KeyType::Ecdsa);
351    }
352
353    #[test]
354    fn test_headers_constants() {
355        assert_eq!(headers::AUTHORIZATION, "Authorization");
356        assert_eq!(headers::X_SERVICE_ID, "X-Service-Id");
357    }
358}