use std::io::Write;
use std::path::{Path, PathBuf};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::UnixStream;
use super::error::{RhaiError, RhaiResult};
use super::socket_server::default_socket_path;
fn is_response_end(line: &str) -> bool {
let trimmed = line.trim();
trimmed.len() >= 2 && trimmed.chars().all(|c| c == '=')
}
pub struct SocketClient {
socket_path: PathBuf,
}
impl SocketClient {
pub fn new(socket_path: PathBuf) -> Self {
Self { socket_path }
}
pub fn with_default_path() -> Self {
Self::new(default_socket_path())
}
pub fn socket_path(&self) -> &Path {
&self.socket_path
}
pub async fn ping(&self) -> bool {
super::socket_server::ping(&self.socket_path).await
}
pub async fn send_script_streaming(&self, script: &str) -> RhaiResult<()> {
let stream = UnixStream::connect(&self.socket_path).await.map_err(|e| {
RhaiError::SocketError(format!(
"Failed to connect to {}: {}",
self.socket_path.display(),
e
))
})?;
let (reader, mut writer) = stream.into_split();
let mut reader = BufReader::new(reader);
writer
.write_all(script.as_bytes())
.await
.map_err(|e| RhaiError::SocketError(format!("Failed to send script: {}", e)))?;
if !script.ends_with('\n') {
writer
.write_all(b"\n")
.await
.map_err(|e| RhaiError::SocketError(format!("Failed to send newline: {}", e)))?;
}
writer
.write_all(b"===\n")
.await
.map_err(|e| RhaiError::SocketError(format!("Failed to send delimiter: {}", e)))?;
writer
.flush()
.await
.map_err(|e| RhaiError::SocketError(format!("Failed to flush: {}", e)))?;
let mut status_line = String::new();
reader
.read_line(&mut status_line)
.await
.map_err(|e| RhaiError::SocketError(format!("Failed to read status: {}", e)))?;
let status = status_line.trim();
if status != "ok" {
return Err(RhaiError::SocketError(format!(
"Unexpected status: {}",
status
)));
}
let mut line = String::new();
let mut has_error = false;
let mut error_msg = String::new();
loop {
line.clear();
let bytes_read = reader
.read_line(&mut line)
.await
.map_err(|e| RhaiError::SocketError(format!("Failed to read output: {}", e)))?;
if bytes_read == 0 {
break;
}
if is_response_end(&line) {
break;
}
if line.starts_with("ERROR: ") {
has_error = true;
error_msg = line[7..].trim().to_string();
}
print!("{}", line);
std::io::stdout().flush().ok();
}
if has_error {
Err(RhaiError::ScriptError(error_msg))
} else {
Ok(())
}
}
pub async fn send_script(&self, script: &str) -> RhaiResult<String> {
let stream = UnixStream::connect(&self.socket_path).await.map_err(|e| {
RhaiError::SocketError(format!(
"Failed to connect to {}: {}",
self.socket_path.display(),
e
))
})?;
let (reader, mut writer) = stream.into_split();
let mut reader = BufReader::new(reader);
writer
.write_all(script.as_bytes())
.await
.map_err(|e| RhaiError::SocketError(format!("Failed to send script: {}", e)))?;
if !script.ends_with('\n') {
writer
.write_all(b"\n")
.await
.map_err(|e| RhaiError::SocketError(format!("Failed to send newline: {}", e)))?;
}
writer
.write_all(b"===\n")
.await
.map_err(|e| RhaiError::SocketError(format!("Failed to send delimiter: {}", e)))?;
writer
.flush()
.await
.map_err(|e| RhaiError::SocketError(format!("Failed to flush: {}", e)))?;
let mut status_line = String::new();
reader
.read_line(&mut status_line)
.await
.map_err(|e| RhaiError::SocketError(format!("Failed to read status: {}", e)))?;
let status = status_line.trim();
if status != "ok" {
return Err(RhaiError::SocketError(format!(
"Unexpected status: {}",
status
)));
}
let mut output = String::new();
let mut line = String::new();
let mut has_error = false;
let mut error_msg = String::new();
loop {
line.clear();
let bytes_read = reader
.read_line(&mut line)
.await
.map_err(|e| RhaiError::SocketError(format!("Failed to read output: {}", e)))?;
if bytes_read == 0 {
break;
}
if is_response_end(&line) {
break;
}
if line.starts_with("ERROR: ") {
has_error = true;
error_msg = line[7..].trim().to_string();
}
output.push_str(&line);
}
if has_error {
Err(RhaiError::ScriptError(error_msg))
} else {
Ok(output)
}
}
}
pub async fn send_script(socket_path: &Path, script: &str) -> RhaiResult<String> {
SocketClient::new(socket_path.to_path_buf())
.send_script(script)
.await
}
pub async fn send_script_streaming(socket_path: &Path, script: &str) -> RhaiResult<()> {
SocketClient::new(socket_path.to_path_buf())
.send_script_streaming(script)
.await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_response_end() {
assert!(is_response_end("====="));
assert!(is_response_end("==="));
assert!(is_response_end(" ===== "));
assert!(!is_response_end("="));
assert!(!is_response_end("abc"));
}
}