use super::EngineBackend;
use crate::{discover::Model, engine::EngineConfig};
use anyhow::Result;
use lazy_static::lazy_static;
use serde_json::Value;
use serde_json::json;
use std::collections::HashMap;
use std::io::BufRead;
use std::io::{BufReader, Write};
use std::process::{ChildStdin, Command, Stdio};
use std::sync::{Arc, Mutex};
use std::thread;
use tempfile::NamedTempFile;
use uuid::Uuid;
type ResponseCallback = Box<dyn FnMut(Value) + Send>;
lazy_static! {
pub static ref PYTHON_BACKEND: Mutex<PythonBackend> = {
match PythonBackend::new() {
Ok(backend) => Mutex::new(backend),
Err(e) => {
eprintln!("[FATAL] Can't start Python backend:");
eprintln!("错误: {}", e);
panic!();
}
}
};
}
pub struct PythonBackend {
stdin: Arc<Mutex<ChildStdin>>,
response_senders: Arc<Mutex<HashMap<String, ResponseCallback>>>,
}
impl PythonBackend {
pub fn new() -> Result<Self> {
let mut tmpfile = NamedTempFile::new()?;
write!(tmpfile, "{}", include_str!("../assets/hf_daemon.py"))?;
let (_file, path) = tmpfile.keep()?;
let mut child = Command::new("python")
.arg(&path)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped()) .spawn()
.map_err(|e| {
anyhow::anyhow!("Failed to start Python process: {}. Make sure Python is installed and in PATH.", e)
})?;
let stdin = Arc::new(Mutex::new(child.stdin.take().unwrap()));
let stdout = child.stdout.take().unwrap();
let stderr = child.stderr.take().unwrap();
thread::spawn(move || {
let reader = BufReader::new(stderr);
for line in reader.lines() {
match line {
Ok(line) => eprintln!("[Python STDERR] {}", line),
Err(e) => eprintln!("[PythonBackend] Can't read stderr: {}", e),
}
}
});
let response_senders: Arc<Mutex<HashMap<String, Box<dyn FnMut(Value) + Send + 'static>>>> =
Arc::new(Mutex::new(HashMap::new()));
let response_senders_clone = Arc::clone(&response_senders);
thread::spawn(move || {
let reader = BufReader::new(stdout);
for line in reader.lines() {
let line = match line {
Ok(l) => l,
Err(e) => {
eprintln!("[PythonBackend] 读取 stdout 失败: {}", e);
continue;
}
};
match serde_json::from_str::<Value>(&line) {
Ok(json) => {
if let Some(id) = json["req_id"].as_str() {
let mut senders = match response_senders_clone.lock() {
Ok(guard) => guard,
Err(_) => {
eprintln!("[PythonBackend] 回调锁被污染");
return;
}
};
if json.get("done").is_some() {
senders.remove(id);
continue;
}
if let Some(sender) = senders.get_mut(id) {
sender(json.clone());
}
}
}
Err(e) => {
eprintln!("[PythonBackend] JSON Parse Fault: {}: {}", e, line);
}
}
}
});
let response_senders_for_wait = Arc::clone(&response_senders);
thread::spawn(move || {
let status = match child.wait() {
Ok(s) => s,
Err(e) => {
eprintln!("[PythonBackend] 等待子进程失败: {}", e);
return;
}
};
if !status.success() {
eprintln!("[PythonBackend] Python 进程异常退出,状态: {}", status);
std::process::exit(1);
} else {
eprintln!("[PythonBackend] Python 进程正常退出");
}
let mut senders = match response_senders_for_wait.lock() {
Ok(guard) => guard,
Err(_) => return,
};
senders.clear();
});
Ok(PythonBackend {
stdin,
response_senders,
})
}
pub fn infer_with_callback<F>(
&self,
model_name: &str,
prompt: &str,
args: &EngineConfig,
callback: F,
) -> Result<String>
where
F: FnMut(Value) + Send + 'static,
{
let req_id = Uuid::new_v4().to_string();
{
let mut senders = self
.response_senders
.lock()
.map_err(|e| anyhow::anyhow!("锁冲突: {:?}", e))?;
senders.insert(req_id.clone(), Box::new(callback));
}
let request = json!({
"req_id": req_id,
"model": model_name,
"prompt": prompt,
"args": args,
});
{
let mut stdin = self
.stdin
.lock()
.map_err(|e| anyhow::anyhow!("stdin 锁失败: {:?}", e))?;
writeln!(stdin, "{}", serde_json::to_string(&request)?)?;
stdin.flush()?; }
Ok(req_id)
}
}
pub struct TransformersEngine {
model_info: Model,
args: EngineConfig,
}
impl EngineBackend for TransformersEngine {
fn new(args: &EngineConfig, model_info: &Model) -> Result<Self> {
Ok(Self {
model_info: model_info.clone(),
args: args.clone(),
})
}
fn infer(
&self,
prompt: &str,
option: Option<&EngineConfig>,
callback: Option<Box<dyn FnMut(String) + Send>>,
) -> Result<String> {
let args = option.unwrap_or(&self.args);
let model_path = self
.model_info
.path
.to_str()
.ok_or_else(|| anyhow::anyhow!("模型路径包含非 UTF-8 字符"))?;
let backend = PYTHON_BACKEND
.lock()
.map_err(|e| anyhow::anyhow!("PythonBackend 锁被污染: {:?}", e))?;
let shared_callback: Arc<Mutex<Option<Box<dyn FnMut(String) + Send>>>> =
Arc::new(Mutex::new(callback));
let closure_callback = {
let shared_callback = Arc::clone(&shared_callback);
move |json: Value| {
let token = json["token"].as_str().unwrap_or_default();
let mut guard = shared_callback.lock().unwrap();
if let Some(ref mut cb) = *guard {
cb(token.to_string());
}
}
};
let req_id = backend.infer_with_callback(model_path, prompt, args, closure_callback)?;
Ok(req_id)
}
fn get_model_info(&self) -> Model {
self.model_info.clone()
}
}