use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::{UnixListener, UnixStream};
use tokio::sync::RwLock;
use super::engine::ForgeEngine;
use super::error::RhaiResult;
pub fn default_socket_path() -> PathBuf {
let home = dirs::home_dir().expect("No home directory");
home.join("hero/var/heroforge.sock")
}
fn ensure_socket_dir(socket_path: &Path) -> std::io::Result<()> {
if let Some(parent) = socket_path.parent() {
std::fs::create_dir_all(parent)?;
}
Ok(())
}
fn is_execute_marker(line: &str) -> bool {
let trimmed = line.trim();
trimmed.len() >= 2 && trimmed.chars().all(|c| c == '=')
}
pub struct SocketServer {
socket_path: PathBuf,
engine: Arc<RwLock<ForgeEngine>>,
}
impl SocketServer {
pub fn new(socket_path: PathBuf) -> RhaiResult<Self> {
let engine = ForgeEngine::new().map_err(|e| {
super::error::RhaiError::OperationError(format!("Failed to create engine: {}", e))
})?;
Ok(Self {
socket_path,
engine: Arc::new(RwLock::new(engine)),
})
}
pub fn with_default_path() -> RhaiResult<Self> {
Self::new(default_socket_path())
}
pub fn socket_path(&self) -> &Path {
&self.socket_path
}
fn cleanup_socket(&self) -> std::io::Result<()> {
if self.socket_path.exists() {
std::fs::remove_file(&self.socket_path)?;
}
Ok(())
}
pub async fn run(&self) -> RhaiResult<()> {
ensure_socket_dir(&self.socket_path)?;
self.cleanup_socket()?;
let listener = UnixListener::bind(&self.socket_path).map_err(|e| {
super::error::RhaiError::SocketError(format!(
"Failed to bind socket {}: {}",
self.socket_path.display(),
e
))
})?;
eprintln!(
"Heroforge daemon listening on {}",
self.socket_path.display()
);
loop {
match listener.accept().await {
Ok((stream, _addr)) => {
let engine = Arc::clone(&self.engine);
tokio::spawn(async move {
if let Err(e) = handle_connection(stream, engine).await {
eprintln!("Connection error: {}", e);
}
});
}
Err(e) => {
eprintln!("Accept error: {}", e);
}
}
}
}
pub async fn run_with_shutdown(
&self,
mut shutdown: tokio::sync::broadcast::Receiver<()>,
) -> RhaiResult<()> {
ensure_socket_dir(&self.socket_path)?;
self.cleanup_socket()?;
let listener = UnixListener::bind(&self.socket_path).map_err(|e| {
super::error::RhaiError::SocketError(format!(
"Failed to bind socket {}: {}",
self.socket_path.display(),
e
))
})?;
eprintln!(
"Heroforge daemon listening on {}",
self.socket_path.display()
);
loop {
tokio::select! {
result = listener.accept() => {
match result {
Ok((stream, _addr)) => {
let engine = Arc::clone(&self.engine);
tokio::spawn(async move {
if let Err(e) = handle_connection(stream, engine).await {
eprintln!("Connection error: {}", e);
}
});
}
Err(e) => {
eprintln!("Accept error: {}", e);
}
}
}
_ = shutdown.recv() => {
eprintln!("Shutting down server...");
break;
}
}
}
self.cleanup_socket()?;
Ok(())
}
}
impl Drop for SocketServer {
fn drop(&mut self) {
let _ = self.cleanup_socket();
}
}
async fn handle_connection(
stream: UnixStream,
engine: Arc<RwLock<ForgeEngine>>,
) -> std::io::Result<()> {
let (reader, mut writer) = stream.into_split();
let mut reader = BufReader::new(reader);
let mut script = String::new();
let mut line = String::new();
loop {
line.clear();
let bytes_read = reader.read_line(&mut line).await?;
if bytes_read == 0 {
break;
}
if is_execute_marker(&line) {
break;
}
script.push_str(&line);
}
if script.trim().is_empty() {
writer.write_all(b"error\nEmpty script\n=====\n").await?;
return Ok(());
}
writer.write_all(b"ok\n").await?;
writer.flush().await?;
let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(100);
let writer = Arc::new(tokio::sync::Mutex::new(writer));
let writer_clone = Arc::clone(&writer);
let output_task = tokio::spawn(async move {
while let Some(output) = rx.recv().await {
let mut w = writer_clone.lock().await;
if w.write_all(output.as_bytes()).await.is_err() {
break;
}
let _ = w.flush().await;
}
});
let script_clone = script.clone();
let result: Result<Result<(), String>, _> = tokio::task::spawn_blocking(move || {
let tx_clone = tx.clone();
super::engine::set_output_sender(move |s: &str| {
let _ = tx_clone.blocking_send(format!("{}\n", s));
});
let engine_guard = futures::executor::block_on(engine.read());
let result = engine_guard.run(&script_clone);
super::engine::clear_output_sender();
drop(tx);
result.map_err(|e| e.to_string())
})
.await;
let _ = output_task.await;
let mut w = writer.lock().await;
match result {
Ok(Ok(())) => {
}
Ok(Err(e)) => {
w.write_all(format!("ERROR: {}\n", e).as_bytes()).await?;
}
Err(e) => {
w.write_all(format!("ERROR: Task panicked: {}\n", e).as_bytes())
.await?;
}
}
w.write_all(b"=====\n").await?;
w.flush().await?;
Ok(())
}
pub async fn ping(socket_path: &Path) -> bool {
if !socket_path.exists() {
return false;
}
match UnixStream::connect(socket_path).await {
Ok(stream) => {
let (_, mut writer) = stream.into_split();
let result = writer.write_all(b"true\n===\n").await;
result.is_ok()
}
Err(_) => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_execute_marker() {
assert!(is_execute_marker("==="));
assert!(is_execute_marker("====="));
assert!(is_execute_marker(" === "));
assert!(!is_execute_marker("="));
assert!(!is_execute_marker("abc"));
assert!(!is_execute_marker("=a="));
}
#[test]
fn test_default_socket_path() {
let path = default_socket_path();
assert!(path.to_string_lossy().contains("hero/var/heroforge.sock"));
}
}