use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
pub mod headers {
pub const AUTHORIZATION: &str = "Authorization";
pub const X_SERVICE_ID: &str = "X-Service-Id";
pub const X_TENANT_ID: &str = "X-Tenant-Id";
pub const X_TENANT_NAME: &str = "X-Tenant-Name";
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct ServiceId(pub u64, pub u64);
impl From<(u64, u64)> for ServiceId {
fn from(value: (u64, u64)) -> Self {
ServiceId(value.0, value.1)
}
}
impl ServiceId {
pub fn new(main: u64) -> Self {
ServiceId(main, 0)
}
pub fn with_subservice(self, sub: u64) -> Self {
ServiceId(self.0, sub)
}
pub fn id(&self) -> u64 {
self.0
}
pub fn sub_id(&self) -> u64 {
self.1
}
pub fn has_sub_id(&self) -> bool {
self.1 != 0
}
pub const fn to_be_bytes(&self) -> [u8; 16] {
let mut bytes = [0u8; 16];
let hi = self.0.to_be_bytes();
let lo = self.1.to_be_bytes();
let mut i = 0;
while i < 8 {
bytes[i] = hi[i];
bytes[i + 8] = lo[i];
i += 1;
}
bytes
}
pub const fn from_be_bytes(bytes: [u8; 16]) -> Self {
let mut hi = [0u8; 8];
let mut lo = [0u8; 8];
let mut i = 0;
while i < 8 {
hi[i] = bytes[i];
lo[i] = bytes[i + 8];
i += 1;
}
ServiceId(u64::from_be_bytes(hi), u64::from_be_bytes(lo))
}
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum ServiceIdParseError {
#[error(transparent)]
ParseInt(#[from] core::num::ParseIntError),
#[error("Invalid ServiceId format, expected <main_id>[:<sub_id>]")]
Malformed,
}
impl std::str::FromStr for ServiceId {
type Err = ServiceIdParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut parts = s.split(':');
if let Some(main_str) = parts.next() {
if let Some(sub_str) = parts.next() {
if parts.next().is_none() {
let main = main_str.parse::<u64>()?;
let sub = sub_str.parse::<u64>()?;
return Ok(ServiceId(main, sub));
}
} else {
let main = main_str.parse::<u64>()?;
return Ok(ServiceId::new(main));
}
}
Err(ServiceIdParseError::Malformed)
}
}
impl core::fmt::Display for ServiceId {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
if self.has_sub_id() {
write!(f, "{}:{}", self.0, self.1)
} else {
write!(f, "{}:0", self.0)
}
}
}
impl<S> axum::extract::FromRequestParts<S> for ServiceId
where
S: Send + Sync,
{
type Rejection = axum::response::Response;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
use axum::http::StatusCode;
use axum::response::IntoResponse;
let header = match parts.headers.get(crate::types::headers::X_SERVICE_ID) {
Some(header) => header,
None => {
return Err((
StatusCode::PRECONDITION_REQUIRED,
"Missing X-Service-Id header",
)
.into_response());
}
};
let header_str = match header.to_str() {
Ok(header_str) => header_str,
Err(_) => {
return Err((
StatusCode::BAD_REQUEST,
"Invalid X-Service-Id header; not a string",
)
.into_response());
}
};
match header_str.parse::<ServiceId>() {
Ok(service_id) => Ok(service_id),
Err(_) => Err((
StatusCode::BAD_REQUEST,
"Invalid X-Service-Id header; not a valid ServiceId",
)
.into_response()),
}
}
}
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
Serialize,
Deserialize,
prost::Enumeration,
)]
#[repr(i32)]
pub enum KeyType {
Unknown = 0,
Ecdsa = 1,
Sr25519 = 2,
Bn254Bls = 3,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ChallengeRequest {
#[serde(with = "hex")]
pub pub_key: Vec<u8>,
pub key_type: KeyType,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ChallengeResponse {
#[serde(with = "hex")]
pub challenge: [u8; 32],
pub expires_at: u64,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct VerifyChallengeRequest {
#[serde(flatten)]
pub challenge_request: ChallengeRequest,
#[serde(with = "hex")]
pub challenge: [u8; 32],
#[serde(with = "hex")]
pub signature: [u8; 64],
pub expires_at: u64,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub additional_headers: BTreeMap<String, String>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "status", content = "data")]
pub enum VerifyChallengeResponse {
Verified {
api_key: String,
expires_at: u64,
},
Expired,
InvalidSignature,
ServiceNotFound,
Unauthorized,
UnexpectedError {
message: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_service_id_creation() {
let service_id = ServiceId::new(42);
assert_eq!(service_id.0, 42);
assert_eq!(service_id.1, 0);
let service_id = ServiceId::new(42).with_subservice(7);
assert_eq!(service_id.0, 42);
assert_eq!(service_id.1, 7);
let service_id = ServiceId::from((42, 7));
assert_eq!(service_id.0, 42);
assert_eq!(service_id.1, 7);
}
#[test]
fn test_service_id_accessors() {
let service_id = ServiceId(42, 7);
assert_eq!(service_id.id(), 42);
assert_eq!(service_id.sub_id(), 7);
assert!(service_id.has_sub_id());
let service_id = ServiceId(42, 0);
assert!(!service_id.has_sub_id());
}
#[test]
fn test_service_id_bytes_conversion() {
let service_id = ServiceId(42, 7);
let bytes = service_id.to_be_bytes();
assert_eq!(bytes.len(), 16);
let reconstructed = ServiceId::from_be_bytes(bytes);
assert_eq!(reconstructed, service_id);
let service_id = ServiceId(0xDEADBEEF, 0xCAFEBABE);
let bytes = service_id.to_be_bytes();
let reconstructed = ServiceId::from_be_bytes(bytes);
assert_eq!(reconstructed, service_id);
}
#[test]
fn test_service_id_parsing() {
assert_eq!("42".parse::<ServiceId>().unwrap(), ServiceId(42, 0));
assert_eq!("42:7".parse::<ServiceId>().unwrap(), ServiceId(42, 7));
let empty_result = "".parse::<ServiceId>();
assert!(empty_result.is_err());
assert!(matches!(
"abc".parse::<ServiceId>(),
Err(ServiceIdParseError::ParseInt(_))
));
assert!(matches!(
"42:7:9".parse::<ServiceId>(),
Err(ServiceIdParseError::Malformed)
));
assert!(matches!(
"42:abc".parse::<ServiceId>(),
Err(ServiceIdParseError::ParseInt(_))
));
}
#[test]
fn test_service_id_display() {
assert_eq!(ServiceId(42, 0).to_string(), "42:0");
assert_eq!(ServiceId(42, 7).to_string(), "42:7");
}
#[test]
fn test_key_type_conversion() {
assert_eq!(KeyType::Unknown as i32, 0);
assert_eq!(KeyType::Ecdsa as i32, 1);
assert_eq!(KeyType::Sr25519 as i32, 2);
assert_eq!(KeyType::Bn254Bls as i32, 3);
let key_type: KeyType = unsafe { std::mem::transmute(1i32) };
assert_eq!(key_type, KeyType::Ecdsa);
let key_type: KeyType = unsafe { std::mem::transmute(3i32) };
assert_eq!(key_type, KeyType::Bn254Bls);
}
#[test]
fn test_headers_constants() {
assert_eq!(headers::AUTHORIZATION, "Authorization");
assert_eq!(headers::X_SERVICE_ID, "X-Service-Id");
}
}