use crate::api::AppState;
use crate::api::schemas::gateway::{TicketResponse, WsParams};
use axum::{
extract::{Query, State, ws::WebSocketUpgrade},
http::Extensions,
response::IntoResponse,
};
use tower_http::request_id::RequestId;
pub(crate) async fn generate_ticket(
auth_user: crate::api::middleware::AuthUser,
State(state): State<AppState>,
) -> Result<impl IntoResponse, crate::error::AppError> {
let device_id = auth_user
.device_id
.ok_or_else(|| crate::error::AppError::Forbidden("Device-scoped token required".to_string()))?;
let ticket = uuid::Uuid::new_v4().to_string();
state.ws_ticket_cache.set(&ticket, device_id.to_string().as_bytes()).await.map_err(|e| {
tracing::error!(error = %e, "Failed to cache websocket ticket");
crate::error::AppError::InternalMsg("Failed to generate ticket".to_string())
})?;
Ok((axum::http::StatusCode::CREATED, axum::Json(TicketResponse { ticket })))
}
pub(crate) async fn websocket_handler(
ws: WebSocketUpgrade,
Query(params): Query<WsParams>,
extensions: Extensions,
State(state): State<AppState>,
) -> impl IntoResponse {
let request_id = extensions
.get::<RequestId>()
.map_or_else(|| "unknown".to_string(), |id| id.header_value().to_str().unwrap_or_default().to_string());
let device_id_res = match state.ws_ticket_cache.get(¶ms.ticket).await {
Ok(Some(bytes)) => match String::from_utf8(bytes) {
Ok(id_str) => match uuid::Uuid::parse_str(&id_str) {
Ok(id) => {
let _ = state.ws_ticket_cache.delete(¶ms.ticket).await;
Ok(id)
}
Err(_) => Err("Invalid device ID format in cache".to_string()),
},
Err(_) => Err("Invalid UTF-8 in cache".to_string()),
},
Ok(None) => Err("Ticket not found or expired".to_string()),
Err(e) => {
tracing::error!(error = %e, "Failed to read ticket from cache");
Err("Internal server error".to_string())
}
};
match device_id_res {
Ok(device_id) => ws.on_upgrade(move |socket| {
let service = state.gateway_service.clone();
let shutdown = state.shutdown_rx.clone();
async move {
service.handle_socket(socket, device_id, request_id, shutdown).await;
}
}),
Err(e) => {
tracing::warn!(error = %e, "WebSocket handshake failed: invalid ticket");
axum::http::StatusCode::UNAUTHORIZED.into_response()
}
}
}