use std::collections::HashMap;
use std::sync::Arc;
use axum::extract::{Path, Query};
use axum::{Extension, Json};
use axum::http::StatusCode;
use axum_core::response::IntoResponse;
use serde_json::Value;
use crate::core::base_structs::ClientsManager;
use crate::core::traits::Update;
use crate::server::subtypes::{AuthResult, QueryParams, VoiceflousionHeadersWrapper};
use crate::server::traits::{BotHandler, ServerClient};
pub(super) async fn main_endpoint<C: ServerClient>(
id: Path<String>,
mut params: Query<QueryParams>,
Json(body): Json<Value>,
headers: VoiceflousionHeadersWrapper,
Extension(clients): Extension<Arc<ClientsManager<C>>>,
Extension(optional_allowed_origins): Extension<Arc<Option<HashMap<&'static str, ()>>>>,
Extension(handler): Extension<Arc<dyn BotHandler<C>>>
) -> impl IntoResponse {
let client = match authenticate_request::<C>(id, &mut params, Some(&body), headers, clients.clone(), optional_allowed_origins.clone()).await{
AuthResult::Client(client) => client,
AuthResult::Response(response) => return response
};
let update = match deserialize_update::<C::ClientUpdate<'static>>(body) {
Ok(update) => update,
Err(err) => {
println!("Error deserializing update: {:?}", err);
return (StatusCode::OK, Json("Invalid update".to_string())).into_response();
}
};
if !client.client_base().is_active() {
println!("Client {} deactivated!", client.client_base().client_id());
return (StatusCode::OK, Json("Access to deactivated client".to_string())).into_response();
}
match handler(update, client).await {
Ok(_) => (StatusCode::OK, Json("Ok".to_string())).into_response(),
Err(_) => (StatusCode::OK, Json("Handler error".to_string())).into_response(),
}
}
pub(super) async fn get_auth_endpoint<C: ServerClient>(
id: Path<String>,
mut params: Query<QueryParams>,
headers: VoiceflousionHeadersWrapper,
Extension(clients): Extension<Arc<ClientsManager<C>>>,
Extension(optional_allowed_origins): Extension< Arc<Option<HashMap<&'static str, ()>>>>
) -> impl IntoResponse{
match authenticate_request::<C>(id, &mut params, None, headers, clients.clone(), optional_allowed_origins.clone()).await{
AuthResult::Client(_client) => {
(StatusCode::OK, Json("Endpoint authentication with GET response not passed".to_string())).into_response()
},
AuthResult::Response(response) => response
}
}
fn deserialize_update<U: Update>(body: Value) -> Result<U, StatusCode> {
match U::from_request_body(body) {
Ok(update) => Ok(update),
Err(err) => {
println!("Error: {:?}", &err);
Err(StatusCode::OK)
}
}
}
async fn authenticate_request<C: ServerClient>(
Path(id): Path<String>,
Query(params): &mut Query<QueryParams>,
body: Option<&Value>,
headers: VoiceflousionHeadersWrapper,
clients: Arc<ClientsManager<C>>,
optional_allowed_origins: Arc<Option<HashMap<&'static str, ()>>>
) -> AuthResult<C> {
if let Some(allowed_origins) = &*optional_allowed_origins {
if let Some(origin) = headers.get_header_str("Origin"){
if !allowed_origins.contains_key(origin) {
println!("Unauthorized origin: {}", origin);
return AuthResult::Response((StatusCode::OK, Json("Unauthorized origin".to_string())).into_response());
}
}
else{
println!("Missing origin header in request! Client: {}\nAdd headers to request or turn of allowed origins in server!", id);
return AuthResult::Response((StatusCode::OK, Json("Missing origin header in request! Add headers to request or turn of allowed origins in server!".to_string())).into_response());
}
}
let client = if let Some(client) = clients.get_client(&id).await {
client
} else {
println!("Invalid client id - {}", id);
return AuthResult::Response((StatusCode::OK, Json("Invalid client id".to_string())).into_response());
};
if let Some(client_token) = client.client_base().bot_auth_token().await {
let bot_auth_token = if let Some(token) = params.extract_bot_auth_token() {
token
} else {
println!("Missing token parameter");
return AuthResult::Response((StatusCode::OK, Json("Missing token parameter".to_string())).into_response());
};
if client_token.token() != bot_auth_token.token() {
println!("Unauthorized access to client {}", client.client_base().client_id());
return AuthResult::Response((StatusCode::OK, Json("Unauthorized access".to_string())).into_response());
}
if let Some(response) = client.authenticate_server_client_request(headers, params, body, Some(bot_auth_token)){
return AuthResult::Response(response);
};
}
else{
if let Some(response) = client.authenticate_server_client_request(headers, params, body, None){
return AuthResult::Response(response);
};
};
AuthResult::Client(client)
}