use serde::{Deserialize, Serialize};
use crate::pagination::PaginationLinks;
use crate::rate_limit::{RateLimitHeaders, RateLimitResource, RateLimitState};
use crate::token::TokenScope;
use crate::user::UserId;
#[derive(Debug, Clone)]
pub struct AuthContext {
pub user_id: Option<UserId>,
pub username: Option<String>,
pub scopes: Vec<TokenScope>,
pub authenticated: bool,
pub client_ip: String,
}
impl AuthContext {
pub fn anonymous(client_ip: String) -> Self {
Self {
user_id: None,
username: None,
scopes: Vec::new(),
authenticated: false,
client_ip,
}
}
pub fn authenticated(
user_id: UserId,
username: String,
scopes: Vec<TokenScope>,
client_ip: String,
) -> Self {
Self {
user_id: Some(user_id),
username: Some(username),
scopes,
authenticated: true,
client_ip,
}
}
pub fn has_scope(&self, scope: TokenScope) -> bool {
if !self.authenticated {
return false;
}
if self.scopes.contains(&TokenScope::Admin) {
return true;
}
self.scopes.contains(&scope)
}
pub fn rate_limit_key(&self) -> String {
if let Some(id) = self.user_id {
format!("user:{}", id)
} else {
format!("ip:{}", self.client_ip)
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorResponse {
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub documentation_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub errors: Option<Vec<ValidationError>>,
}
impl ErrorResponse {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
documentation_url: None,
errors: None,
}
}
pub fn with_docs(message: impl Into<String>, docs_url: impl Into<String>) -> Self {
Self {
message: message.into(),
documentation_url: Some(docs_url.into()),
errors: None,
}
}
pub fn validation(message: impl Into<String>, errors: Vec<ValidationError>) -> Self {
Self {
message: message.into(),
documentation_url: None,
errors: Some(errors),
}
}
pub fn not_found() -> Self {
Self::new("Not Found")
}
pub fn bad_credentials() -> Self {
Self::new("Bad credentials")
}
pub fn forbidden() -> Self {
Self::new("Forbidden")
}
pub fn rate_limited(reset: u64) -> Self {
Self::with_docs(
format!(
"API rate limit exceeded. Rate limit will reset at {}",
reset
),
"https://docs.guts.network/rest/rate-limiting",
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationError {
pub resource: String,
pub field: String,
pub code: ValidationErrorCode,
#[serde(skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
}
impl ValidationError {
pub fn new(
resource: impl Into<String>,
field: impl Into<String>,
code: ValidationErrorCode,
) -> Self {
Self {
resource: resource.into(),
field: field.into(),
code,
message: None,
}
}
pub fn with_message(
resource: impl Into<String>,
field: impl Into<String>,
code: ValidationErrorCode,
message: impl Into<String>,
) -> Self {
Self {
resource: resource.into(),
field: field.into(),
code,
message: Some(message.into()),
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ValidationErrorCode {
Missing,
MissingField,
Invalid,
AlreadyExists,
NotUnique,
TooLong,
TooShort,
Custom,
}
#[derive(Debug, Clone, Default)]
pub struct ResponseHeaders {
pub rate_limit: Option<RateLimitHeaders>,
pub link: Option<String>,
pub etag: Option<String>,
pub last_modified: Option<String>,
pub cache_control: Option<String>,
}
impl ResponseHeaders {
pub fn new() -> Self {
Self::default()
}
pub fn with_rate_limit(mut self, state: &RateLimitState) -> Self {
self.rate_limit = Some(RateLimitHeaders::from(state));
self
}
pub fn with_pagination(mut self, links: &PaginationLinks) -> Self {
self.link = links.to_header_value();
self
}
pub fn with_etag(mut self, etag: impl Into<String>) -> Self {
self.etag = Some(format!("\"{}\"", etag.into()));
self
}
pub fn with_cache_control(mut self, value: impl Into<String>) -> Self {
self.cache_control = Some(value.into());
self
}
pub fn no_cache(mut self) -> Self {
self.cache_control = Some("private, max-age=60, s-maxage=60".to_string());
self
}
}
pub fn parse_authorization_header(header: &str) -> Option<AuthorizationValue> {
let header = header.trim();
if let Some(token) = header.strip_prefix("Bearer ") {
return Some(AuthorizationValue::Bearer(token.trim().to_string()));
}
if let Some(token) = header.strip_prefix("token ") {
return Some(AuthorizationValue::Token(token.trim().to_string()));
}
if let Some(encoded) = header.strip_prefix("Basic ") {
if let Some((username, password)) = decode_basic_auth(encoded.trim()) {
return Some(AuthorizationValue::Basic { username, password });
}
}
None
}
#[derive(Debug, Clone)]
pub enum AuthorizationValue {
Bearer(String),
Token(String),
Basic { username: String, password: String },
}
impl AuthorizationValue {
pub fn token(&self) -> Option<&str> {
match self {
Self::Bearer(t) | Self::Token(t) => Some(t),
Self::Basic { password, .. } => {
if password.starts_with("guts_") {
Some(password)
} else {
None
}
}
}
}
pub fn username(&self) -> Option<&str> {
match self {
Self::Basic { username, .. } => Some(username),
_ => None,
}
}
}
fn decode_basic_auth(encoded: &str) -> Option<(String, String)> {
let decoded = base64_decode(encoded)?;
let text = String::from_utf8(decoded).ok()?;
let (username, password) = text.split_once(':')?;
Some((username.to_string(), password.to_string()))
}
fn base64_decode(input: &str) -> Option<Vec<u8>> {
const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
fn char_to_value(c: u8) -> Option<u8> {
if let Some(pos) = ALPHABET.iter().position(|&x| x == c) {
Some(pos as u8)
} else if c == b'=' {
Some(0)
} else {
None
}
}
let input = input.trim();
if input.is_empty() || !input.len().is_multiple_of(4) {
return None;
}
let bytes: Vec<u8> = input.bytes().collect();
let mut result = Vec::with_capacity(bytes.len() * 3 / 4);
for chunk in bytes.chunks(4) {
let a = char_to_value(chunk[0])?;
let b = char_to_value(chunk[1])?;
let c = char_to_value(chunk[2])?;
let d = char_to_value(chunk[3])?;
result.push((a << 2) | (b >> 4));
if chunk[2] != b'=' {
result.push((b << 4) | (c >> 2));
}
if chunk[3] != b'=' {
result.push((c << 6) | d);
}
}
Some(result)
}
pub fn resource_from_path(path: &str) -> RateLimitResource {
if path.contains("/search") {
RateLimitResource::Search
} else if path.contains("/graphql") {
RateLimitResource::Graphql
} else if path.starts_with("/git/") {
RateLimitResource::Git
} else {
RateLimitResource::Core
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_auth_context_anonymous() {
let ctx = AuthContext::anonymous("127.0.0.1".to_string());
assert!(!ctx.authenticated);
assert!(ctx.user_id.is_none());
assert!(!ctx.has_scope(TokenScope::RepoRead));
}
#[test]
fn test_auth_context_authenticated() {
let ctx = AuthContext::authenticated(
1,
"alice".to_string(),
vec![TokenScope::RepoRead],
"127.0.0.1".to_string(),
);
assert!(ctx.authenticated);
assert_eq!(ctx.user_id, Some(1));
assert!(ctx.has_scope(TokenScope::RepoRead));
assert!(!ctx.has_scope(TokenScope::RepoWrite));
}
#[test]
fn test_rate_limit_key() {
let anon = AuthContext::anonymous("10.0.0.1".to_string());
assert_eq!(anon.rate_limit_key(), "ip:10.0.0.1");
let auth =
AuthContext::authenticated(42, "bob".to_string(), vec![], "10.0.0.1".to_string());
assert_eq!(auth.rate_limit_key(), "user:42");
}
#[test]
fn test_parse_authorization_bearer() {
let auth = parse_authorization_header("Bearer guts_abc12345_secret").unwrap();
match auth {
AuthorizationValue::Bearer(token) => {
assert_eq!(token, "guts_abc12345_secret");
}
_ => panic!("Expected Bearer"),
}
}
#[test]
fn test_parse_authorization_token() {
let auth = parse_authorization_header("token guts_abc12345_secret").unwrap();
match auth {
AuthorizationValue::Token(token) => {
assert_eq!(token, "guts_abc12345_secret");
}
_ => panic!("Expected Token"),
}
}
#[test]
fn test_parse_authorization_basic() {
let auth = parse_authorization_header("Basic dXNlcjpwYXNz").unwrap();
match auth {
AuthorizationValue::Basic { username, password } => {
assert_eq!(username, "user");
assert_eq!(password, "pass");
}
_ => panic!("Expected Basic"),
}
}
#[test]
fn test_error_response() {
let err = ErrorResponse::not_found();
assert_eq!(err.message, "Not Found");
assert!(err.errors.is_none());
}
#[test]
fn test_validation_error() {
let err = ValidationError::new("User", "username", ValidationErrorCode::AlreadyExists);
assert_eq!(err.resource, "User");
assert_eq!(err.field, "username");
}
#[test]
fn test_resource_from_path() {
assert_eq!(
resource_from_path("/api/search/repositories"),
RateLimitResource::Search
);
assert_eq!(
resource_from_path("/git/owner/repo/info/refs"),
RateLimitResource::Git
);
assert_eq!(
resource_from_path("/api/repos/owner/repo"),
RateLimitResource::Core
);
}
}