hexput-runtime 0.1.3

WebSocket runtime server for Hexput AST processing
use crate::error::RuntimeError;
use crate::messages::{FunctionCallResponse, FunctionExistsResponse, WebSocketMessage, WebSocketRequest, WebSocketResponse};
use futures_util::{SinkExt, StreamExt};
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{mpsc, Mutex as TokioMutex, oneshot};
use tokio_tungstenite::tungstenite::Message;
use tracing::{debug, error, info};
use std::collections::HashMap;
use serde_json::json;

pub struct ServerConfig {
    pub address: String,
}

pub async fn run_server(config: ServerConfig) -> Result<(), RuntimeError> {
    let addr = config.address.parse::<SocketAddr>().map_err(|_| {
        std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid server address")
    })?;

    let listener = TcpListener::bind(&addr).await?;
    info!("WebSocket server listening on: {}", addr);

    let active_connections = Arc::new(TokioMutex::new(0));

    while let Ok((stream, peer_addr)) = listener.accept().await {
        info!("New connection from: {}", peer_addr);

        let connections = active_connections.clone();

        {
            let mut count = connections.lock().await;
            *count += 1;
            info!("Active connections: {}", *count);
        }

        tokio::spawn(async move {
            match handle_connection(stream, peer_addr).await {
                Ok(_) => info!("Connection from {} closed gracefully", peer_addr),
                Err(e) => error!("Error handling connection from {}: {}", peer_addr, e),
            }

            let mut count = connections.lock().await;
            *count -= 1;
            info!("Connection closed. Active connections: {}", *count);
        });
    }

    Ok(())
}

enum SenderMessage {
    Text(String),
    Pong(Vec<u8>),
    Close,
}

async fn handle_connection(stream: TcpStream, peer_addr: SocketAddr) -> Result<(), RuntimeError> {
    debug!("Starting WebSocket handshake with: {}", peer_addr);
    let ws_stream = tokio_tungstenite::accept_async(stream).await?;
    info!("WebSocket connection established with: {}", peer_addr);

    let (ws_sender, mut ws_receiver) = ws_stream.split();
    
    let (sender_tx, mut sender_rx) = mpsc::channel::<SenderMessage>(100);
    
    let function_calls = Arc::new(Mutex::new(HashMap::<String, oneshot::Sender<FunctionCallResponse>>::new()));
    let function_validations = Arc::new(Mutex::new(HashMap::<String, oneshot::Sender<FunctionExistsResponse>>::new()));
    
    let sender_task = tokio::spawn(async move {
        let mut sender = ws_sender;
        
        while let Some(msg) = sender_rx.recv().await {
            match msg {
                SenderMessage::Text(text) => {
                    if let Err(e) = sender.send(Message::Text(text)).await {
                        error!("Error sending message: {}", e);
                        break;
                    }
                },
                SenderMessage::Pong(data) => {
                    if let Err(e) = sender.send(Message::Pong(data)).await {
                        error!("Error sending pong: {}", e);
                        break;
                    }
                },
                SenderMessage::Close => {
                    break;
                }
            }
        }
        
        let _ = sender.close().await;
    });
    
    let welcome_sender = sender_tx.clone();
    
    if let Err(e) = welcome_sender.send(SenderMessage::Text(
        r#"{"type":"connection","status":"connected"}"#.to_string()
    )).await {
        error!("Failed to send welcome message: {}", e);
        return Err(RuntimeError::ConnectionError("Failed to send welcome message".to_string()));
    }

    let mut task_set: tokio::task::JoinSet<()> = tokio::task::JoinSet::new();

    let create_message_sender = |tx: mpsc::Sender<SenderMessage>| {
        move |message: String| -> futures_util::future::BoxFuture<'static, Result<(), RuntimeError>> {
            let sender = tx.clone();
            Box::pin(async move {
                sender.send(SenderMessage::Text(message)).await
                    .map_err(|_| RuntimeError::ConnectionError("Failed to send message".to_string()))
            })
        }
    };

    while let Some(msg) = ws_receiver.next().await {
        match msg {
            Ok(Message::Text(text)) => {
                debug!("Received text message from {}: {}", peer_addr, text);
                
                let function_calls_clone = function_calls.clone();
                let function_validations_clone = function_validations.clone();
                let sender_clone = sender_tx.clone();
                let message_sender = create_message_sender(sender_clone.clone());
                
                match serde_json::from_str::<WebSocketMessage>(&text) {
                    Ok(WebSocketMessage::FunctionResponse(response)) => {
                        debug!("Received function response for ID: {}", response.id);
                        
                        if let Err(e) = handle_function_response_directly(response, function_calls_clone).await {
                            error!("Error processing function response: {}", e);
                        }
                    },
                    Ok(WebSocketMessage::FunctionExistsResponse(response)) => {
                        debug!("Received function exists response for ID: {}, function exists: {}", response.id, response.exists);
                        
                        if let Err(e) = handle_function_exists_response(response, function_validations_clone).await {
                            error!("Error processing function exists response: {}", e);
                        }
                    },
                    Ok(WebSocketMessage::Request(request)) => {
                        debug!("Processing request with ID: {}", request.id);
                        let req_id = request.id.clone();
                        
                        task_set.spawn(async move {
                            match process_request(request, function_calls_clone, function_validations_clone, message_sender).await {
                                Ok(_) => debug!("Request {} processed successfully", req_id),
                                Err(e) => {
                                    error!("Error processing request {}: {}", req_id, e);
                                    
                                    let error_response = WebSocketResponse {
                                        id: req_id,
                                        success: false,
                                        result: None,
                                        error: Some(format!("Internal error: {}", e)),
                                    };
                                    
                                    if let Ok(json) = serde_json::to_string(&error_response) {
                                        if let Err(send_err) = sender_clone.send(SenderMessage::Text(json)).await {
                                            error!("Failed to send error response: {}", send_err);
                                        }
                                    }
                                }
                            }
                        });
                    },
                    Ok(WebSocketMessage::Unknown(value)) => {
                        error!("Received unknown message type: {}", value);
                        
                        let error_msg = json!({
                            "error": "Unknown message format",
                            "details": value
                        }).to_string();
                        
                        if let Err(e) = sender_clone.send(SenderMessage::Text(error_msg)).await {
                            error!("Failed to send error message: {}", e);
                        }
                    },
                    Err(e) => {
                        error!("Failed to parse message: {}", e);
                        
                        let error_msg = json!({
                            "error": "Failed to parse message",
                            "details": e.to_string()
                        }).to_string();
                        
                        if let Err(e) = sender_clone.send(SenderMessage::Text(error_msg)).await {
                            error!("Failed to send error message: {}", e);
                        }
                    }
                }
            },
            Ok(Message::Ping(data)) => {
                debug!("Received ping from {}", peer_addr);
                let pong_sender = sender_tx.clone();
                
                if let Err(e) = pong_sender.send(SenderMessage::Pong(data)).await {
                    error!("Error sending pong to {}: {}", peer_addr, e);
                }
            },
            Ok(Message::Close(_)) => {
                info!("Received close message from {}", peer_addr);
                break;
            },
            Err(e) => {
                error!("Error reading message from {}: {}", peer_addr, e);
                break;
            },
            _ => {
                debug!("Received other message type from {}", peer_addr);
            }
        }
    }

    let _ = sender_tx.send(SenderMessage::Close).await;
    
    if let Err(e) = sender_task.await {
        error!("Error awaiting sender task: {}", e);
    }

    debug!("Cleaning up tasks for connection {}", peer_addr);
    while task_set.join_next().await.is_some() { }

    info!("Closing connection with: {}", peer_addr);
    Ok(())
}

async fn process_request(
    request: WebSocketRequest,
    function_calls: Arc<Mutex<HashMap<String, oneshot::Sender<FunctionCallResponse>>>>,
    function_validations: Arc<Mutex<HashMap<String, oneshot::Sender<FunctionExistsResponse>>>>,
    message_sender: impl Fn(String) -> futures_util::future::BoxFuture<'static, Result<(), RuntimeError>> + Send + Clone + 'static,
) -> Result<(), RuntimeError> {
    handle_request(request, function_calls, function_validations, message_sender).await?;
    
    Ok(())
}

async fn handle_function_response_directly(
    response: FunctionCallResponse,
    function_calls: Arc<Mutex<HashMap<String, oneshot::Sender<FunctionCallResponse>>>>,
) -> Result<(), RuntimeError> {
    debug!("Processing function response for call ID: {}", response.id);
    
    let sender = {
        let mut calls = function_calls.lock().unwrap();
        calls.remove(&response.id)
    };
    
    if let Some(sender) = sender {
        if sender.send(response).is_err() {
            error!("Failed to send response through channel - receiver likely dropped");
        } else {
            debug!("Successfully sent function response through channel");
        }
    } else {
        error!("Received function response for unknown call ID: {}", response.id);
    }
    
    Ok(())
}

async fn handle_function_exists_response(
    response: crate::messages::FunctionExistsResponse,
    function_validations: Arc<Mutex<HashMap<String, oneshot::Sender<FunctionExistsResponse>>>>,
) -> Result<(), RuntimeError> {
    debug!("Processing function exists response for ID: {}", response.id);
    
    let sender = {
        let mut validations = function_validations.lock().unwrap();
        validations.remove(&response.id)
    };
    
    if let Some(sender) = sender {
        if sender.send(response).is_err() {
            error!("Failed to send function exists response through channel - receiver likely dropped");
        } else {
            debug!("Successfully sent function exists response through channel");
        }
    } else {
        error!("Received function exists response for unknown ID: {}", response.id);
    }
    
    Ok(())
}

async fn handle_request(
    request: WebSocketRequest,
    function_calls: Arc<Mutex<HashMap<String, oneshot::Sender<FunctionCallResponse>>>>,
    function_validations: Arc<Mutex<HashMap<String, oneshot::Sender<FunctionExistsResponse>>>>,
    message_sender: impl Fn(String) -> futures_util::future::BoxFuture<'static, Result<(), RuntimeError>> + Send + Clone + 'static,
) -> Result<(), RuntimeError> {
    let result = crate::handler::handle_request(request, function_calls, function_validations, message_sender.clone()).await?;
    
    if !result.is_empty() {
        message_sender(result).await?;
    }
    
    Ok(())
}