use std::collections::HashMap;
use std::sync::Arc;
use std::sync::RwLock;
use axum::extract::State;
use axum::http::{header, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use crate::server::AppState;
#[cfg(feature = "oauth")]
#[derive(Clone)]
pub(crate) struct OAuthIntrospect {
pub url: String,
pub client_id: String,
pub client_secret: String,
pub client: reqwest::Client,
}
#[derive(Debug, PartialEq)]
pub enum TokenStatus {
Valid,
Exhausted,
Unknown,
}
pub struct AuthState {
tokens: RwLock<HashMap<String, Option<u64>>>,
exhausted: RwLock<std::collections::HashSet<String>>,
#[cfg(feature = "oauth")]
oauth_introspect: RwLock<Option<OAuthIntrospect>>,
}
impl Default for AuthState {
fn default() -> Self {
Self::new()
}
}
impl AuthState {
pub fn new() -> Self {
Self {
tokens: RwLock::new(HashMap::new()),
exhausted: RwLock::new(std::collections::HashSet::new()),
#[cfg(feature = "oauth")]
oauth_introspect: RwLock::new(None),
}
}
pub fn add_token(&self, token: &str, max_uses: Option<u64>) {
let mut tokens = self.tokens.write().unwrap_or_else(|e| e.into_inner());
let mut exhausted = self.exhausted.write().unwrap_or_else(|e| e.into_inner());
exhausted.remove(token);
tokens.insert(token.to_string(), max_uses);
}
pub fn check_and_use(&self, token: &str) -> TokenStatus {
if self
.exhausted
.read()
.unwrap_or_else(|e| e.into_inner())
.contains(token)
{
return TokenStatus::Exhausted;
}
let mut tokens = self.tokens.write().unwrap_or_else(|e| e.into_inner());
match tokens.get_mut(token) {
Some(Some(remaining)) if *remaining > 0 => {
*remaining -= 1;
if *remaining == 0 {
tokens.remove(token);
self.exhausted
.write()
.unwrap_or_else(|e| e.into_inner())
.insert(token.to_string());
}
TokenStatus::Valid
}
Some(Some(_)) => {
tokens.remove(token);
self.exhausted
.write()
.unwrap_or_else(|e| e.into_inner())
.insert(token.to_string());
TokenStatus::Exhausted
}
Some(None) => TokenStatus::Valid,
None => {
if self
.exhausted
.read()
.unwrap_or_else(|e| e.into_inner())
.contains(token)
{
TokenStatus::Exhausted
} else {
TokenStatus::Unknown
}
}
}
}
pub fn revoke(&self, token: &str) {
let mut tokens = self.tokens.write().unwrap_or_else(|e| e.into_inner());
let mut exhausted = self.exhausted.write().unwrap_or_else(|e| e.into_inner());
tokens.remove(token);
exhausted.insert(token.to_string());
}
#[cfg(feature = "oauth")]
pub(crate) fn set_oauth_introspect(&self, config: OAuthIntrospect) {
*self
.oauth_introspect
.write()
.unwrap_or_else(|e| e.into_inner()) = Some(config);
}
#[cfg(feature = "oauth")]
pub(crate) async fn check_oauth_token(&self, token: &str) -> bool {
let config = {
let guard = self
.oauth_introspect
.read()
.unwrap_or_else(|e| e.into_inner());
match guard.as_ref() {
Some(c) => c.clone(),
None => return false,
}
};
let resp = config
.client
.post(&config.url)
.basic_auth(&config.client_id, Some(&config.client_secret))
.form(&[("token", token)])
.send()
.await;
match resp {
Ok(r) => {
if let Ok(body) = r.json::<serde_json::Value>().await {
body.get("active")
.and_then(|v| v.as_bool())
.unwrap_or(false)
} else {
eprintln!(
"[llmposter] OAuth introspect: failed to parse response body as JSON"
);
false
}
}
Err(e) => {
eprintln!("[llmposter] OAuth introspect request failed: {e}");
false
}
}
}
}
pub(crate) async fn bearer_auth_check(
State(state): State<Arc<AppState>>,
request: axum::extract::Request,
next: Next,
) -> Response {
let auth = match &state.auth {
Some(a) => a,
None => return next.run(request).await,
};
let path = request.uri().path().to_string();
let is_llm_route = path.starts_with("/v1/") || path.starts_with("/v1beta/");
if !is_llm_route {
return next.run(request).await;
}
let token = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|v| {
if v.len() > 7 && v[..7].eq_ignore_ascii_case("bearer ") {
Some(&v[7..])
} else {
None
}
});
match token {
Some(t) => {
let status = auth.check_and_use(t);
let is_valid = match status {
TokenStatus::Valid => true,
TokenStatus::Exhausted => false,
TokenStatus::Unknown => {
#[cfg(feature = "oauth")]
{
auth.check_oauth_token(t).await
}
#[cfg(not(feature = "oauth"))]
false
}
};
if is_valid {
next.run(request).await
} else {
auth_error_response(&path)
}
}
_ => auth_error_response(&path),
}
}
fn auth_error_response(path: &str) -> Response {
let body = if path.starts_with("/v1/messages") {
serde_json::json!({
"type": "error",
"error": {
"type": "authentication_error",
"message": "Invalid bearer token"
}
})
} else if path.starts_with("/v1beta/models") {
serde_json::json!({
"error": {
"code": 401,
"message": "Invalid bearer token",
"status": "UNAUTHENTICATED"
}
})
} else {
serde_json::json!({
"error": {
"message": "Invalid bearer token",
"type": "authentication_error",
"param": null,
"code": "invalid_api_key"
}
})
};
(
StatusCode::UNAUTHORIZED,
[
(header::CONTENT_TYPE, "application/json"),
(header::WWW_AUTHENTICATE, "Bearer realm=\"api\""),
],
body.to_string(),
)
.into_response()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn should_accept_valid_token() {
let state = AuthState::new();
state.add_token("tok-1", None);
assert_eq!(state.check_and_use("tok-1"), TokenStatus::Valid);
}
#[test]
fn should_reject_unknown_token() {
let state = AuthState::new();
assert_eq!(state.check_and_use("unknown"), TokenStatus::Unknown);
}
#[test]
fn should_expire_after_n_uses() {
let state = AuthState::new();
state.add_token("tok-1", Some(2));
assert_eq!(state.check_and_use("tok-1"), TokenStatus::Valid);
assert_eq!(state.check_and_use("tok-1"), TokenStatus::Valid);
assert_eq!(state.check_and_use("tok-1"), TokenStatus::Exhausted);
}
#[test]
fn should_remove_revoked_token() {
let state = AuthState::new();
state.add_token("tok-1", None);
state.revoke("tok-1");
assert_eq!(state.check_and_use("tok-1"), TokenStatus::Exhausted);
}
#[test]
fn should_accept_unlimited_token_many_times() {
let state = AuthState::new();
state.add_token("unlimited", None);
for _ in 0..100 {
assert_eq!(state.check_and_use("unlimited"), TokenStatus::Valid);
}
}
#[test]
fn should_support_default_trait() {
let state = AuthState::default();
state.add_token("tok", None);
assert_eq!(state.check_and_use("tok"), TokenStatus::Valid);
}
#[test]
fn should_reject_zero_use_token() {
let state = AuthState::new();
state.add_token("zero", Some(0));
assert_eq!(state.check_and_use("zero"), TokenStatus::Exhausted);
}
#[test]
fn should_allow_re_add_after_revoke() {
let state = AuthState::new();
state.add_token("tok", None);
state.revoke("tok");
assert_eq!(state.check_and_use("tok"), TokenStatus::Exhausted);
state.add_token("tok", None);
assert_eq!(state.check_and_use("tok"), TokenStatus::Valid);
}
#[test]
fn should_allow_re_add_after_exhaustion() {
let state = AuthState::new();
state.add_token("tok", Some(1));
assert_eq!(state.check_and_use("tok"), TokenStatus::Valid);
assert_eq!(state.check_and_use("tok"), TokenStatus::Exhausted);
state.add_token("tok", Some(2));
assert_eq!(state.check_and_use("tok"), TokenStatus::Valid);
assert_eq!(state.check_and_use("tok"), TokenStatus::Valid);
assert_eq!(state.check_and_use("tok"), TokenStatus::Exhausted);
}
#[cfg(feature = "oauth")]
#[tokio::test]
async fn should_return_false_when_introspect_unreachable() {
let state = AuthState::new();
state.set_oauth_introspect(OAuthIntrospect {
url: "http://127.0.0.1:1/introspect".to_string(),
client_id: "test".to_string(),
client_secret: "secret".to_string(),
client: reqwest::Client::builder()
.timeout(std::time::Duration::from_millis(100))
.build()
.unwrap(),
});
assert!(!state.check_oauth_token("any").await);
}
#[cfg(feature = "oauth")]
#[tokio::test]
async fn should_return_false_when_introspect_not_configured() {
let state = AuthState::new();
assert!(!state.check_oauth_token("any").await);
}
#[cfg(feature = "oauth")]
#[tokio::test]
async fn should_return_false_when_introspect_returns_non_json() {
use axum::{routing::post, Router};
let app = Router::new().route("/introspect", post(|| async { "this is not json" }));
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let state = AuthState::new();
state.set_oauth_introspect(OAuthIntrospect {
url: format!("http://127.0.0.1:{}/introspect", port),
client_id: "test".to_string(),
client_secret: "secret".to_string(),
client: reqwest::Client::new(),
});
assert!(!state.check_oauth_token("any").await);
}
}