rusty-relay-server 0.6.1

The http server for rusty-relay
use crate::{error::HttpError, state::AppState, util::generate_id};
use axum::{
    extract::{
        State, WebSocketUpgrade,
        ws::{Message, WebSocket},
    },
    http::HeaderMap,
    response::IntoResponse,
};
use rusty_relay_messages::RelayMessage;
use std::sync::Arc;
use tokio::time;
use tokio_stream::StreamExt;
use tracing::{debug, error, info};

#[tracing::instrument(skip(ws, state))]
pub async fn connect_handler(
    ws: WebSocketUpgrade,
    headers: HeaderMap,
    state: State<Arc<AppState>>,
) -> impl IntoResponse {
    match headers.get("PRIVATE-TOKEN") {
        Some(token) => match token.to_str() {
            Ok(token) => {
                if token == state.connect_token() {
                    let client_id = generate_id(12);
                    info!(client_id, "👨 client connected");
                    ws.on_upgrade(move |socket| handle_ws(socket, client_id, state))
                } else {
                    debug!("❌ client provided invalid token");
                    HttpError::Unauthorized("Connection token is invalid".to_string())
                        .into_response()
                }
            }
            Err(_) => {
                debug!("❌ client provided invalid token format");
                HttpError::BadRequest("Connection token has invalid format".to_string())
                    .into_response()
            }
        },
        None => {
            debug!("❌ client did not provide token");
            HttpError::Unauthorized("Connection token is missing from header".to_string())
                .into_response()
        }
    }
}

#[tracing::instrument(skip(socket, state))]
async fn handle_ws(mut socket: WebSocket, client_id: String, state: State<Arc<AppState>>) {
    if let Ok(msg) = serde_json::to_string(&RelayMessage::ClientId(client_id.clone())) {
        if socket.send(Message::Text(msg.into())).await.is_err() {
            error!("failed to send message to client");
            return;
        }
    } else {
        error!("failed to serialize into JSON");
        return;
    }

    let mut rx_relay = state.add_client(&client_id).await;
    let mut ping_interval = time::interval(state.ping_interval());

    loop {
        tokio::select! {
            _ = ping_interval.tick() => {
                if socket.send(Message::Ping(Vec::new().into())).await.is_err() {
                    error!("failed to send ping to client");
                    break;
                }
            }
            Ok(relay_message) = rx_relay.recv() => {
                if let Ok(msg) = serde_json::to_string(&relay_message) {
                    if socket
                        .send(Message::Text(msg.into()))
                        .await
                        .is_err()
                    {
                        error!("failed to send message to client");
                        break;
                    }
                } else {
                    error!("failed to serialize into JSON");
                    break;
                }
            }
            Some(result) = socket.next() => {
                match result {
                    Ok(Message::Close(_)) => {
                        debug!("received websocket close message");
                        break;
                    }
                    Ok(Message::Text(message)) => {
                        if let Ok(RelayMessage::ProxyResponse { request_id, body, headers, status }) = serde_json::from_slice::<RelayMessage>(message.as_bytes()) {
                            let body_str = std::str::from_utf8(&body).unwrap_or("<binary>");
                            debug!(
                                request_id,
                                status,
                                ?headers,
                                body_str,
                                "received proxy response from client"
                                );
                            if let Some(tx) = state.remove_proxy_request(&request_id).await {
                                let _ = tx.send(RelayMessage::ProxyResponse { request_id, body, headers, status });
                            }
                       } else {
                            error!("failed to deserialize from bytes: {}", message);
                       }
                    }
                    Ok(_) => {},
                    Err(err) => {
                        debug!("received websocket error: {}", err);
                        break;
                    }
                }
            }
        }
    }

    state.remove_client(&client_id).await;

    info!("👨 client disconnected");
}