use std::collections::HashMap;
use std::sync::Arc;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientCapabilities {
pub sampling: bool,
pub roots: bool,
pub elicitation: bool,
pub max_concurrent_requests: usize,
pub experimental: HashMap<String, bool>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ClientId {
Header(String),
Token(String),
Session(String),
QueryParam(String),
UserAgent(String),
Anonymous,
}
impl ClientId {
#[must_use]
pub fn as_str(&self) -> &str {
match self {
Self::Header(id)
| Self::Token(id)
| Self::Session(id)
| Self::QueryParam(id)
| Self::UserAgent(id) => id,
Self::Anonymous => "anonymous",
}
}
#[must_use]
pub const fn is_authenticated(&self) -> bool {
matches!(self, Self::Token(_) | Self::Session(_))
}
#[must_use]
pub const fn auth_method(&self) -> &'static str {
match self {
Self::Header(_) => "header",
Self::Token(_) => "bearer_token",
Self::Session(_) => "session_cookie",
Self::QueryParam(_) => "query_param",
Self::UserAgent(_) => "user_agent",
Self::Anonymous => "anonymous",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientSession {
pub client_id: String,
pub client_name: Option<String>,
pub connected_at: DateTime<Utc>,
pub last_activity: DateTime<Utc>,
pub request_count: usize,
pub transport_type: String,
pub authenticated: bool,
pub capabilities: Option<serde_json::Value>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl ClientSession {
#[must_use]
pub fn new(client_id: String, transport_type: String) -> Self {
let now = Utc::now();
Self {
client_id,
client_name: None,
connected_at: now,
last_activity: now,
request_count: 0,
transport_type,
authenticated: false,
capabilities: None,
metadata: HashMap::new(),
}
}
pub fn update_activity(&mut self) {
self.last_activity = Utc::now();
self.request_count += 1;
}
pub fn authenticate(&mut self, client_name: Option<String>) {
self.authenticated = true;
self.client_name = client_name;
}
pub fn set_capabilities(&mut self, capabilities: serde_json::Value) {
self.capabilities = Some(capabilities);
}
#[must_use]
pub fn session_duration(&self) -> chrono::Duration {
self.last_activity - self.connected_at
}
#[must_use]
pub fn is_idle(&self, idle_threshold: chrono::Duration) -> bool {
Utc::now() - self.last_activity > idle_threshold
}
}
#[derive(Debug)]
pub struct ClientIdExtractor {
auth_tokens: Arc<dashmap::DashMap<String, String>>,
}
impl ClientIdExtractor {
#[must_use]
pub fn new() -> Self {
Self {
auth_tokens: Arc::new(dashmap::DashMap::new()),
}
}
pub fn register_token(&self, token: String, client_id: String) {
self.auth_tokens.insert(token, client_id);
}
pub fn revoke_token(&self, token: &str) {
self.auth_tokens.remove(token);
}
#[must_use]
pub fn list_tokens(&self) -> Vec<(String, String)> {
self.auth_tokens
.iter()
.map(|entry| (entry.key().clone(), entry.value().clone()))
.collect()
}
#[must_use]
#[allow(clippy::significant_drop_tightening)]
pub fn extract_from_http_headers(&self, headers: &HashMap<String, String>) -> ClientId {
if let Some(client_id) = headers.get("x-client-id") {
return ClientId::Header(client_id.clone());
}
if let Some(auth) = headers.get("authorization")
&& let Some(token) = auth.strip_prefix("Bearer ")
{
let token_lookup = self.auth_tokens.iter().find(|e| e.key() == token);
if let Some(entry) = token_lookup {
let client_id = entry.value().clone();
drop(entry); return ClientId::Token(client_id);
}
return ClientId::Token(token.to_string());
}
if let Some(cookie) = headers.get("cookie") {
for cookie_part in cookie.split(';') {
let parts: Vec<&str> = cookie_part.trim().splitn(2, '=').collect();
if parts.len() == 2 && (parts[0] == "session_id" || parts[0] == "sessionid") {
return ClientId::Session(parts[1].to_string());
}
}
}
if let Some(user_agent) = headers.get("user-agent") {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
user_agent.hash(&mut hasher);
return ClientId::UserAgent(format!("ua_{:x}", hasher.finish()));
}
ClientId::Anonymous
}
#[must_use]
pub fn extract_from_query(&self, query_params: &HashMap<String, String>) -> Option<ClientId> {
query_params
.get("client_id")
.map(|client_id| ClientId::QueryParam(client_id.clone()))
}
#[must_use]
pub fn extract_client_id(
&self,
headers: Option<&HashMap<String, String>>,
query_params: Option<&HashMap<String, String>>,
) -> ClientId {
if let Some(params) = query_params
&& let Some(client_id) = self.extract_from_query(params)
{
return client_id;
}
if let Some(headers) = headers {
return self.extract_from_http_headers(headers);
}
ClientId::Anonymous
}
}
impl Default for ClientIdExtractor {
fn default() -> Self {
Self::new()
}
}