use crate::error::AybError;
use crate::hosted_db::sandbox::build_daemon_command;
use crate::hosted_db::{QueryMode, QueryResult};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs::canonicalize;
use std::path::PathBuf;
use std::process::Stdio;
use std::sync::Arc;
use tokio::io::BufReader;
use tokio::process::{Child, ChildStdin, ChildStdout};
use tokio::sync::Mutex;
#[derive(Serialize, Deserialize, Debug)]
struct QueryRequest {
query: String,
query_mode: i16,
}
pub struct DaemonHandle {
child: Child,
stdin: Option<ChildStdin>,
stdout: BufReader<ChildStdout>,
}
impl DaemonHandle {
pub async fn execute_query(
&mut self,
query: &str,
query_mode: QueryMode,
) -> Result<String, AybError> {
let stdin = self.stdin.as_mut().ok_or(AybError::Other {
message: "Daemon stdin has been closed".to_string(),
})?;
let request = QueryRequest {
query: query.to_string(),
query_mode: query_mode as i16,
};
let request_json = serde_json::to_string(&request)?;
use tokio::io::AsyncWriteExt;
stdin.write_all(request_json.as_bytes()).await?;
stdin.write_all(b"\n").await?;
stdin.flush().await?;
use tokio::io::AsyncBufReadExt;
let mut response_line = String::new();
self.stdout.read_line(&mut response_line).await?;
Ok(response_line)
}
pub async fn shut_down(&mut self) {
self.stdin.take();
let _ = self.child.kill().await;
}
}
pub struct DaemonRegistry {
daemons: Arc<Mutex<HashMap<PathBuf, Arc<Mutex<DaemonHandle>>>>>,
}
impl Default for DaemonRegistry {
fn default() -> Self {
Self::new()
}
}
impl DaemonRegistry {
pub fn new() -> Self {
Self {
daemons: Arc::new(Mutex::new(HashMap::new())),
}
}
async fn get_or_create_daemon(
&self,
db_path: &PathBuf,
) -> Result<Arc<Mutex<DaemonHandle>>, AybError> {
let canonical_path = canonicalize(db_path)?;
let mut daemons = self.daemons.lock().await;
if let Some(daemon) = daemons.get(&canonical_path) {
return Ok(daemon.clone());
}
let daemon_handle = self.spawn_daemon(&canonical_path).await?;
let daemon_arc = Arc::new(Mutex::new(daemon_handle));
daemons.insert(canonical_path, daemon_arc.clone());
Ok(daemon_arc)
}
pub async fn execute_query(
&self,
db_path: &PathBuf,
query: &str,
query_mode: QueryMode,
) -> Result<QueryResult, AybError> {
let daemon_arc = self.get_or_create_daemon(db_path).await?;
let mut daemon = daemon_arc.lock().await;
let response = daemon.execute_query(query, query_mode).await?;
parse_response(&response)
}
async fn spawn_daemon(&self, db_path: &PathBuf) -> Result<DaemonHandle, AybError> {
let mut cmd = build_daemon_command(db_path)?;
let mut child = cmd
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::null())
.spawn()?;
let stdin = child.stdin.take().ok_or(AybError::Other {
message: "Failed to get daemon stdin".to_string(),
})?;
let stdout = child.stdout.take().ok_or(AybError::Other {
message: "Failed to get daemon stdout".to_string(),
})?;
Ok(DaemonHandle {
child,
stdin: Some(stdin),
stdout: BufReader::new(stdout),
})
}
pub async fn shut_down_daemon(&self, db_path: &PathBuf) -> Result<(), AybError> {
let canonical_path = canonicalize(db_path)?;
let mut daemons = self.daemons.lock().await;
if let Some(daemon_arc) = daemons.remove(&canonical_path) {
if let Ok(mut daemon) = daemon_arc.try_lock() {
daemon.shut_down().await;
}
}
Ok(())
}
pub async fn shut_down_all(&self) {
let mut daemons = self.daemons.lock().await;
for (_path, daemon_arc) in daemons.drain() {
if let Ok(mut daemon) = daemon_arc.try_lock() {
daemon.shut_down().await;
}
}
}
}
fn parse_response(response: &str) -> Result<QueryResult, AybError> {
if let Ok(result) = serde_json::from_str::<QueryResult>(response) {
return Ok(result);
}
if let Ok(error) = serde_json::from_str::<AybError>(response) {
return Err(error);
}
Err(AybError::QueryError {
message: format!("Invalid response: {response}"),
})
}
impl Clone for DaemonRegistry {
fn clone(&self) -> Self {
Self {
daemons: self.daemons.clone(),
}
}
}