Skip to main content

blueprint_auth/
types.rs

1use serde::{Deserialize, Serialize};
2use std::collections::BTreeMap;
3
4/// Common headers used in the authentication process.
5pub mod headers {
6    pub const AUTHORIZATION: &str = "Authorization";
7    pub const X_SERVICE_ID: &str = "X-Service-Id";
8    pub const X_TENANT_ID: &str = "X-Tenant-Id";
9    pub const X_TENANT_NAME: &str = "X-Tenant-Name";
10}
11
12/// Represents the ID a service in the authentication process.
13///
14/// The `ServiceId` is a tuple of two `u64` values, which can be used to uniquely identify a service.
15/// The first `u64` represents the main service ID, while the second `u64` represents a sub-service or a specific instance of the service.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
17pub struct ServiceId(pub u64, pub u64);
18
19impl From<(u64, u64)> for ServiceId {
20    fn from(value: (u64, u64)) -> Self {
21        ServiceId(value.0, value.1)
22    }
23}
24
25impl ServiceId {
26    /// Creates a new `ServiceId` instance with the given main service ID.
27    pub fn new(main: u64) -> Self {
28        ServiceId(main, 0)
29    }
30
31    /// Creates a new `ServiceId` instance with the given main service ID and sub-service ID.
32    pub fn with_subservice(self, sub: u64) -> Self {
33        ServiceId(self.0, sub)
34    }
35
36    /// The main service ID.
37    pub fn id(&self) -> u64 {
38        self.0
39    }
40
41    /// The sub-service ID.
42    pub fn sub_id(&self) -> u64 {
43        self.1
44    }
45
46    /// Checks if the `ServiceId` has a sub-service ID.
47    ///
48    /// Returns `true` if the sub-service ID is not zero, indicating that it is a specific instance of the service.
49    pub fn has_sub_id(&self) -> bool {
50        self.1 != 0
51    }
52
53    /// Converts the `ServiceId` to a big-endian byte array.
54    pub const fn to_be_bytes(&self) -> [u8; 16] {
55        let mut bytes = [0u8; 16];
56        let hi = self.0.to_be_bytes();
57        let lo = self.1.to_be_bytes();
58        let mut i = 0;
59        while i < 8 {
60            bytes[i] = hi[i];
61            bytes[i + 8] = lo[i];
62            i += 1;
63        }
64        bytes
65    }
66
67    /// Creates a `ServiceId` from a big-endian byte array.
68    pub const fn from_be_bytes(bytes: [u8; 16]) -> Self {
69        let mut hi = [0u8; 8];
70        let mut lo = [0u8; 8];
71        let mut i = 0;
72        while i < 8 {
73            hi[i] = bytes[i];
74            lo[i] = bytes[i + 8];
75            i += 1;
76        }
77        ServiceId(u64::from_be_bytes(hi), u64::from_be_bytes(lo))
78    }
79}
80
81#[derive(Debug, Clone, thiserror::Error)]
82pub enum ServiceIdParseError {
83    /// Error parsing the main or sub-service ID as a `u64`.
84    #[error(transparent)]
85    ParseInt(#[from] core::num::ParseIntError),
86    /// Error parsing the `ServiceId` from a string.
87    #[error("Invalid ServiceId format, expected <main_id>[:<sub_id>]")]
88    Malformed,
89}
90
91impl std::str::FromStr for ServiceId {
92    type Err = ServiceIdParseError;
93
94    fn from_str(s: &str) -> Result<Self, Self::Err> {
95        let mut parts = s.split(':');
96        if let Some(main_str) = parts.next() {
97            if let Some(sub_str) = parts.next() {
98                if parts.next().is_none() {
99                    let main = main_str.parse::<u64>()?;
100                    let sub = sub_str.parse::<u64>()?;
101                    return Ok(ServiceId(main, sub));
102                }
103            } else {
104                let main = main_str.parse::<u64>()?;
105                return Ok(ServiceId::new(main));
106            }
107        }
108        Err(ServiceIdParseError::Malformed)
109    }
110}
111
112impl core::fmt::Display for ServiceId {
113    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
114        if self.has_sub_id() {
115            write!(f, "{}:{}", self.0, self.1)
116        } else {
117            write!(f, "{}:0", self.0)
118        }
119    }
120}
121
122impl<S> axum::extract::FromRequestParts<S> for ServiceId
123where
124    S: Send + Sync,
125{
126    type Rejection = axum::response::Response;
127
128    async fn from_request_parts(
129        parts: &mut axum::http::request::Parts,
130        _state: &S,
131    ) -> Result<Self, Self::Rejection> {
132        use axum::http::StatusCode;
133        use axum::response::IntoResponse;
134
135        let header = match parts.headers.get(crate::types::headers::X_SERVICE_ID) {
136            Some(header) => header,
137            None => {
138                return Err((
139                    StatusCode::PRECONDITION_REQUIRED,
140                    "Missing X-Service-Id header",
141                )
142                    .into_response());
143            }
144        };
145
146        let header_str = match header.to_str() {
147            Ok(header_str) => header_str,
148            Err(_) => {
149                return Err((
150                    StatusCode::BAD_REQUEST,
151                    "Invalid X-Service-Id header; not a string",
152                )
153                    .into_response());
154            }
155        };
156
157        match header_str.parse::<ServiceId>() {
158            Ok(service_id) => Ok(service_id),
159            Err(_) => Err((
160                StatusCode::BAD_REQUEST,
161                "Invalid X-Service-Id header; not a valid ServiceId",
162            )
163                .into_response()),
164        }
165    }
166}
167
168/// Represents the different types of cryptographic keys used in the authentication process.
169#[derive(
170    Debug,
171    Clone,
172    Copy,
173    PartialEq,
174    Eq,
175    PartialOrd,
176    Ord,
177    Hash,
178    Serialize,
179    Deserialize,
180    prost::Enumeration,
181)]
182#[repr(i32)]
183pub enum KeyType {
184    Unknown = 0,
185    /// ECDSA (secp256k1) key type
186    Ecdsa = 1,
187    /// Sr25519 (Schnorrkel) key type
188    Sr25519 = 2,
189    /// BN254 BLS key type
190    Bn254Bls = 3,
191}
192
193/// Represents the challenge request sent from the client to the server to request a challenge.
194#[derive(Serialize, Deserialize, Debug, Clone)]
195pub struct ChallengeRequest {
196    /// The public key representing the user in hex format
197    #[serde(with = "hex")]
198    pub pub_key: Vec<u8>,
199    /// The type of the public key
200    pub key_type: KeyType,
201}
202
203/// Represents the challenge response sent from the server to the client after a successful challenge request.
204#[derive(Serialize, Deserialize, Debug, Clone)]
205pub struct ChallengeResponse {
206    /// The challenge string sent from the server to the client to be signed by the user
207    #[serde(with = "hex")]
208    pub challenge: [u8; 32],
209    /// Expires at timestamp in milliseconds since epoch
210    /// the time when the challenge will expire and should not be used anymore
211    pub expires_at: u64,
212}
213
214/// Represents the challenge solution sent from the client to the server after signing the challenge string.
215#[derive(Serialize, Deserialize, Debug, Clone)]
216pub struct VerifyChallengeRequest {
217    /// The original challenge request sent from the server to the client in the first step
218    #[serde(flatten)]
219    pub challenge_request: ChallengeRequest,
220    /// The challenge string sent from the server to the client to be signed by the user
221    #[serde(with = "hex")]
222    pub challenge: [u8; 32],
223    /// The signed challenge string sent from the client to the server
224    #[serde(with = "hex")]
225    pub signature: [u8; 64],
226    /// The timestamp in seconds since epoch at which the token will expire
227    pub expires_at: u64,
228    /// Additional headers to be forwarded to the upstream service
229    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
230    pub additional_headers: BTreeMap<String, String>,
231}
232
233/// Represents the response sent from the server to the client after verifying the challenge solution.
234#[derive(Serialize, Deserialize, Debug, Clone)]
235#[serde(tag = "status", content = "data")]
236pub enum VerifyChallengeResponse {
237    /// The challenge was verified successfully - returns an API key
238    Verified {
239        /// The long-lived API key to be used for token exchange
240        api_key: String,
241        /// A UNIX timestamp in seconds since epoch at which the API key will expire
242        expires_at: u64,
243    },
244    /// The challenge was not verified because the challenge has expired
245    Expired,
246    /// The challenge was not verified because the signature is invalid
247    InvalidSignature,
248
249    /// The challenge was not verified because the service ID is not found
250    ServiceNotFound,
251
252    /// The challenge was not verified because the service ID is not authorized
253    Unauthorized,
254
255    /// An unexpected error occurred during verification
256    UnexpectedError {
257        /// The error message
258        message: String,
259    },
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    #[test]
267    fn test_service_id_creation() {
268        // Create with just main ID
269        let service_id = ServiceId::new(42);
270        assert_eq!(service_id.0, 42);
271        assert_eq!(service_id.1, 0);
272
273        // Create with main ID and add subservice
274        let service_id = ServiceId::new(42).with_subservice(7);
275        assert_eq!(service_id.0, 42);
276        assert_eq!(service_id.1, 7);
277
278        // Create from tuple
279        let service_id = ServiceId::from((42, 7));
280        assert_eq!(service_id.0, 42);
281        assert_eq!(service_id.1, 7);
282    }
283
284    #[test]
285    fn test_service_id_accessors() {
286        let service_id = ServiceId(42, 7);
287
288        assert_eq!(service_id.id(), 42);
289        assert_eq!(service_id.sub_id(), 7);
290        assert!(service_id.has_sub_id());
291
292        let service_id = ServiceId(42, 0);
293        assert!(!service_id.has_sub_id());
294    }
295
296    #[test]
297    fn test_service_id_bytes_conversion() {
298        let service_id = ServiceId(42, 7);
299
300        let bytes = service_id.to_be_bytes();
301        assert_eq!(bytes.len(), 16);
302
303        let reconstructed = ServiceId::from_be_bytes(bytes);
304        assert_eq!(reconstructed, service_id);
305
306        // Test with different values
307        let service_id = ServiceId(0xDEADBEEF, 0xCAFEBABE);
308        let bytes = service_id.to_be_bytes();
309        let reconstructed = ServiceId::from_be_bytes(bytes);
310        assert_eq!(reconstructed, service_id);
311    }
312
313    #[test]
314    fn test_service_id_parsing() {
315        // Valid formats
316        assert_eq!("42".parse::<ServiceId>().unwrap(), ServiceId(42, 0));
317        assert_eq!("42:7".parse::<ServiceId>().unwrap(), ServiceId(42, 7));
318
319        // Invalid formats
320        let empty_result = "".parse::<ServiceId>();
321        assert!(empty_result.is_err());
322
323        assert!(matches!(
324            "abc".parse::<ServiceId>(),
325            Err(ServiceIdParseError::ParseInt(_))
326        ));
327        assert!(matches!(
328            "42:7:9".parse::<ServiceId>(),
329            Err(ServiceIdParseError::Malformed)
330        ));
331        assert!(matches!(
332            "42:abc".parse::<ServiceId>(),
333            Err(ServiceIdParseError::ParseInt(_))
334        ));
335    }
336
337    #[test]
338    fn test_service_id_display() {
339        assert_eq!(ServiceId(42, 0).to_string(), "42:0");
340        assert_eq!(ServiceId(42, 7).to_string(), "42:7");
341    }
342
343    #[test]
344    fn test_key_type_conversion() {
345        // Test KeyType to i32 conversion (as used in the ServiceOwnerModel)
346        assert_eq!(KeyType::Unknown as i32, 0);
347        assert_eq!(KeyType::Ecdsa as i32, 1);
348        assert_eq!(KeyType::Sr25519 as i32, 2);
349        assert_eq!(KeyType::Bn254Bls as i32, 3);
350
351        // Test i32 to KeyType conversion (using transmute for simplicity in tests)
352        let key_type: KeyType = unsafe { std::mem::transmute(1i32) };
353        assert_eq!(key_type, KeyType::Ecdsa);
354        let key_type: KeyType = unsafe { std::mem::transmute(3i32) };
355        assert_eq!(key_type, KeyType::Bn254Bls);
356    }
357
358    #[test]
359    fn test_headers_constants() {
360        assert_eq!(headers::AUTHORIZATION, "Authorization");
361        assert_eq!(headers::X_SERVICE_ID, "X-Service-Id");
362    }
363}