heroforge-core 0.2.2

Pure Rust core library for reading and writing Fossil SCM repositories
Documentation
//! Socket Server for Heroforge Daemon
//!
//! This module provides a Unix socket server that accepts Rhai scripts
//! and streams output back to clients.

use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::{UnixListener, UnixStream};
use tokio::sync::RwLock;

use super::engine::ForgeEngine;
use super::error::RhaiResult;

/// Default socket path.
pub fn default_socket_path() -> PathBuf {
    let home = dirs::home_dir().expect("No home directory");
    home.join("hero/var/heroforge.sock")
}

/// Ensure the socket directory exists.
fn ensure_socket_dir(socket_path: &Path) -> std::io::Result<()> {
    if let Some(parent) = socket_path.parent() {
        std::fs::create_dir_all(parent)?;
    }
    Ok(())
}

/// Check if a marker line indicates end of script.
fn is_execute_marker(line: &str) -> bool {
    let trimmed = line.trim();
    trimmed.len() >= 2 && trimmed.chars().all(|c| c == '=')
}

/// Socket server for handling Rhai script execution.
pub struct SocketServer {
    socket_path: PathBuf,
    engine: Arc<RwLock<ForgeEngine>>,
}

impl SocketServer {
    /// Create a new socket server.
    pub fn new(socket_path: PathBuf) -> RhaiResult<Self> {
        let engine = ForgeEngine::new().map_err(|e| {
            super::error::RhaiError::OperationError(format!("Failed to create engine: {}", e))
        })?;

        Ok(Self {
            socket_path,
            engine: Arc::new(RwLock::new(engine)),
        })
    }

    /// Create with default socket path.
    pub fn with_default_path() -> RhaiResult<Self> {
        Self::new(default_socket_path())
    }

    /// Get the socket path.
    pub fn socket_path(&self) -> &Path {
        &self.socket_path
    }

    /// Remove existing socket file if present.
    fn cleanup_socket(&self) -> std::io::Result<()> {
        if self.socket_path.exists() {
            std::fs::remove_file(&self.socket_path)?;
        }
        Ok(())
    }

    /// Run the server (blocking).
    pub async fn run(&self) -> RhaiResult<()> {
        ensure_socket_dir(&self.socket_path)?;
        self.cleanup_socket()?;

        let listener = UnixListener::bind(&self.socket_path).map_err(|e| {
            super::error::RhaiError::SocketError(format!(
                "Failed to bind socket {}: {}",
                self.socket_path.display(),
                e
            ))
        })?;

        eprintln!(
            "Heroforge daemon listening on {}",
            self.socket_path.display()
        );

        loop {
            match listener.accept().await {
                Ok((stream, _addr)) => {
                    let engine = Arc::clone(&self.engine);
                    tokio::spawn(async move {
                        if let Err(e) = handle_connection(stream, engine).await {
                            eprintln!("Connection error: {}", e);
                        }
                    });
                }
                Err(e) => {
                    eprintln!("Accept error: {}", e);
                }
            }
        }
    }

    /// Run the server with graceful shutdown.
    pub async fn run_with_shutdown(
        &self,
        mut shutdown: tokio::sync::broadcast::Receiver<()>,
    ) -> RhaiResult<()> {
        ensure_socket_dir(&self.socket_path)?;
        self.cleanup_socket()?;

        let listener = UnixListener::bind(&self.socket_path).map_err(|e| {
            super::error::RhaiError::SocketError(format!(
                "Failed to bind socket {}: {}",
                self.socket_path.display(),
                e
            ))
        })?;

        eprintln!(
            "Heroforge daemon listening on {}",
            self.socket_path.display()
        );

        loop {
            tokio::select! {
                result = listener.accept() => {
                    match result {
                        Ok((stream, _addr)) => {
                            let engine = Arc::clone(&self.engine);
                            tokio::spawn(async move {
                                if let Err(e) = handle_connection(stream, engine).await {
                                    eprintln!("Connection error: {}", e);
                                }
                            });
                        }
                        Err(e) => {
                            eprintln!("Accept error: {}", e);
                        }
                    }
                }
                _ = shutdown.recv() => {
                    eprintln!("Shutting down server...");
                    break;
                }
            }
        }

        self.cleanup_socket()?;
        Ok(())
    }
}

impl Drop for SocketServer {
    fn drop(&mut self) {
        let _ = self.cleanup_socket();
    }
}

/// Handle a single client connection.
async fn handle_connection(
    stream: UnixStream,
    engine: Arc<RwLock<ForgeEngine>>,
) -> std::io::Result<()> {
    let (reader, mut writer) = stream.into_split();
    let mut reader = BufReader::new(reader);

    // Read script until we see the execute marker (===)
    let mut script = String::new();
    let mut line = String::new();

    loop {
        line.clear();
        let bytes_read = reader.read_line(&mut line).await?;
        if bytes_read == 0 {
            // EOF without execute marker
            break;
        }

        if is_execute_marker(&line) {
            break;
        }

        script.push_str(&line);
    }

    if script.trim().is_empty() {
        writer.write_all(b"error\nEmpty script\n=====\n").await?;
        return Ok(());
    }

    // Send "ok" header
    writer.write_all(b"ok\n").await?;
    writer.flush().await?;

    // Create channel for streaming output
    let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(100);

    // Spawn task to write output as it arrives
    let writer = Arc::new(tokio::sync::Mutex::new(writer));
    let writer_clone = Arc::clone(&writer);
    let output_task = tokio::spawn(async move {
        while let Some(output) = rx.recv().await {
            let mut w = writer_clone.lock().await;
            if w.write_all(output.as_bytes()).await.is_err() {
                break;
            }
            let _ = w.flush().await;
        }
    });

    // Execute script with streaming
    // We convert the result to a String inside spawn_blocking to avoid Send issues
    let script_clone = script.clone();
    let result: Result<Result<(), String>, _> = tokio::task::spawn_blocking(move || {
        // Set up streaming callback
        let tx_clone = tx.clone();
        super::engine::set_output_sender(move |s: &str| {
            let _ = tx_clone.blocking_send(format!("{}\n", s));
        });

        // Run the script
        let engine_guard = futures::executor::block_on(engine.read());
        let result = engine_guard.run(&script_clone);

        // Clean up
        super::engine::clear_output_sender();
        drop(tx);

        // Convert error to string before returning (to make it Send)
        result.map_err(|e| e.to_string())
    })
    .await;

    // Wait for output task to complete
    let _ = output_task.await;

    // Write result/error and delimiter
    let mut w = writer.lock().await;
    match result {
        Ok(Ok(())) => {
            // Script completed successfully
        }
        Ok(Err(e)) => {
            w.write_all(format!("ERROR: {}\n", e).as_bytes()).await?;
        }
        Err(e) => {
            w.write_all(format!("ERROR: Task panicked: {}\n", e).as_bytes())
                .await?;
        }
    }

    w.write_all(b"=====\n").await?;
    w.flush().await?;

    Ok(())
}

/// Check if the daemon is running by pinging the socket.
pub async fn ping(socket_path: &Path) -> bool {
    if !socket_path.exists() {
        return false;
    }

    match UnixStream::connect(socket_path).await {
        Ok(stream) => {
            let (_, mut writer) = stream.into_split();
            // Send a simple ping script
            let result = writer.write_all(b"true\n===\n").await;
            result.is_ok()
        }
        Err(_) => false,
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_is_execute_marker() {
        assert!(is_execute_marker("==="));
        assert!(is_execute_marker("====="));
        assert!(is_execute_marker("  ===  "));
        assert!(!is_execute_marker("="));
        assert!(!is_execute_marker("abc"));
        assert!(!is_execute_marker("=a="));
    }

    #[test]
    fn test_default_socket_path() {
        let path = default_socket_path();
        assert!(path.to_string_lossy().contains("hero/var/heroforge.sock"));
    }
}