use std::io::Cursor;
use sha2::{Digest, Sha256};
use tiny_http::{Header, Request, Response, StatusCode};
use crate::types::{json_header, ErrorResponse, RecallFileSection};
pub fn unauthorized_response(
ctx: &ReqCtx,
req: &Request,
error_msg: &str,
) -> Response<Cursor<Vec<u8>>> {
let mut hdrs = ctx.cors_headers_for(req);
hdrs.push(Header::from_bytes("WWW-Authenticate", "Bearer realm=\"uteke\"").unwrap());
hdrs.push(json_header());
let body = ErrorResponse {
error: error_msg.to_string(),
};
let data = serde_json::to_string(&body).unwrap();
Response::new(
StatusCode::from(401),
hdrs,
Cursor::new(data.into_bytes()),
None,
None,
)
}
pub fn check_auth(req: &Request, ctx: &ReqCtx) -> Result<AuthResult, Response<Cursor<Vec<u8>>>> {
if ctx.auth_token_hash.is_none() && ctx.read_only_token_hash.is_none() {
return Ok(AuthResult::Disabled);
}
let auth_header = req
.headers()
.iter()
.find(|h| h.field.equiv("Authorization"));
let token = match auth_header {
Some(h) => {
let val = h.value.as_str().trim();
let parts: Vec<&str> = val.split_whitespace().collect();
if parts.len() != 2 {
return Err(unauthorized_response(
ctx,
req,
"Invalid auth header format. Use: Authorization: Bearer ***",
));
}
if !parts[0].eq_ignore_ascii_case("Bearer") {
return Err(unauthorized_response(
ctx,
req,
"Invalid auth scheme. Use: Authorization: Bearer ***",
));
}
parts[1]
}
None => {
return Err(unauthorized_response(
ctx,
req,
"Authentication required. Provide Authorization: Bearer ***",
));
}
};
let provided_hash: [u8; 32] = Sha256::digest(token.as_bytes()).into();
if let Some(admin_hash) = &ctx.auth_token_hash {
if constant_time_eq_digest(&provided_hash, admin_hash) {
return Ok(AuthResult::Authenticated(ApiRole::Admin));
}
}
if let Some(ro_hash) = &ctx.read_only_token_hash {
if constant_time_eq_digest(&provided_hash, ro_hash) {
return Ok(AuthResult::Authenticated(ApiRole::ReadOnly));
}
}
Err(unauthorized_response(ctx, req, "Invalid or expired token"))
}
pub fn constant_time_eq_digest(a: &[u8; 32], b: &[u8; 32]) -> bool {
let mut result: u8 = 0;
for (x, y) in a.iter().zip(b.iter()) {
result |= x ^ y;
}
result == 0
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum ApiRole {
Admin,
ReadOnly,
}
pub enum AuthResult {
Disabled,
Authenticated(ApiRole),
}
#[derive(Clone)]
pub struct ReqCtx {
pub auth_token_hash: Option<[u8; 32]>,
pub read_only_token_hash: Option<[u8; 32]>,
pub cors_origins: Vec<String>,
pub recall_config: Option<RecallFileSection>,
}
impl ReqCtx {
pub fn resolve_origin_for(&self, req: &Request) -> String {
if self.cors_origins.is_empty() {
return "*".to_string();
}
if let Some(origin_header) = req.headers().iter().find(|h| h.field.equiv("Origin")) {
let origin = origin_header.value.as_str();
if self.cors_origins.iter().any(|o| o == origin) {
return origin.to_string();
}
}
String::new()
}
pub fn cors_headers_for(&self, req: &Request) -> Vec<Header> {
let origin = self.resolve_origin_for(req);
if origin.is_empty() {
return vec![];
}
let allowed_headers = if self.auth_token_hash.is_some() && self.cors_origins.is_empty() {
"Content-Type"
} else {
"Content-Type, Authorization"
};
vec![
Header::from_bytes("Access-Control-Allow-Origin", origin).unwrap(),
Header::from_bytes("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
.unwrap(),
Header::from_bytes("Access-Control-Allow-Headers", allowed_headers).unwrap(),
]
}
pub fn preflight_headers(&self, req: &Request) -> Vec<Header> {
let origin = self.resolve_origin_for(req);
if origin.is_empty() {
return vec![];
}
let allowed_headers_set: &[&str] =
if self.auth_token_hash.is_some() && self.cors_origins.is_empty() {
&["Content-Type", "Accept", "X-Requested-With"]
} else {
&[
"Content-Type",
"Authorization",
"Accept",
"X-Requested-With",
]
};
let allow_headers = req
.headers()
.iter()
.find(|h| h.field.equiv("Access-Control-Request-Headers"))
.map(|h| {
h.value
.as_str()
.split(',')
.map(|s| s.trim())
.filter(|s| {
allowed_headers_set
.iter()
.any(|a| a.eq_ignore_ascii_case(s))
})
.collect::<Vec<_>>()
.join(", ")
})
.unwrap_or_else(String::new);
vec![
Header::from_bytes("Access-Control-Allow-Origin", origin).unwrap(),
Header::from_bytes("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
.unwrap(),
Header::from_bytes("Access-Control-Allow-Headers", allow_headers).unwrap(),
]
}
pub fn error_response_for(
&self,
req: &Request,
status: u16,
msg: impl Into<String>,
) -> Response<Cursor<Vec<u8>>> {
let body = ErrorResponse { error: msg.into() };
let data = serde_json::to_string(&body).unwrap();
let mut headers = self.cors_headers_for(req);
headers.push(json_header());
Response::new(
StatusCode::from(status),
headers,
Cursor::new(data.into_bytes()),
None,
None,
)
}
pub fn ok_response_for<T: serde::Serialize>(
&self,
req: &Request,
body: &T,
) -> Response<Cursor<Vec<u8>>> {
let data = serde_json::to_string(body)
.unwrap_or_else(|e| serde_json::json!({"error": e.to_string()}).to_string());
let mut headers = self.cors_headers_for(req);
headers.push(json_header());
Response::new(
StatusCode::from(200),
headers,
Cursor::new(data.into_bytes()),
None,
None,
)
}
}