from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import threading
import json
import sys
from typing import Dict, Optional
import traceback
from dataclasses import dataclass
stdout_lock = threading.Lock()
@dataclass
class Args:
n_ctx: int = 2048
n_len: Optional[int] = None
temperature: float = 0.7
top_k: int = 40
top_p: float = 0.9
repeat_penalty: float = 1.1
class ModelCache:
_cache: Dict[str, Dict] = {}
@classmethod
def get(cls, model_name: str) -> Optional[Dict]:
if model_name not in cls._cache:
try:
print(f"[加载模型] {model_name}", file=sys.stderr)
model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
cls._cache[model_name] = {
"model": model,
"tokenizer": tokenizer,
}
print(f"[成功加载] {model_name}", file=sys.stderr)
except Exception as e:
print(f"[错误] 加载模型失败 {model_name}: {str(e)}", file=sys.stderr)
return None
return cls._cache[model_name]
@classmethod
def remove(cls, model_name: str):
if model_name in cls._cache:
del cls._cache[model_name]
def stream_generation(req_id: str, model_name: str, prompt: str, args: dict):
try:
model_entry = ModelCache.get(model_name)
if not model_entry:
with stdout_lock:
print(
json.dumps(
{"req_id": req_id, "error": f"Model '{model_name}' not found"}
)
)
sys.stdout.flush()
return
model = model_entry["model"]
tokenizer = model_entry["tokenizer"]
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
generation_config = {
"max_length": args.get("n_ctx", 4096),
"max_new_tokens": args.get("n_len", 1e30) or 1e30, "temperature": args.get("temperature", 0.7),
"top_k": args.get("top_k", 40),
"top_p": args.get("top_p", 0.9),
"repetition_penalty": args.get("repeat_penalty", 1.1),
"do_sample": True,
"streamer": streamer,
}
if generation_config["max_length"] < generation_config["max_new_tokens"]:
del generation_config["max_new_tokens"]
def generate():
try:
model.generate(**inputs, **generation_config)
except Exception as e:
print(f"[生成错误] {req_id}: {e}", file=sys.stderr)
finally:
streamer.end()
thread = threading.Thread(target=generate)
thread.start()
for new_text in streamer:
if new_text:
with stdout_lock:
print(
json.dumps(
{"req_id": req_id, "token": new_text}, ensure_ascii=False
)
)
sys.stdout.flush()
thread.join(timeout=2)
if thread.is_alive():
print(f"[警告] 生成线程未正常退出: {req_id}", file=sys.stderr)
with stdout_lock:
print(json.dumps({"req_id": req_id, "done": True}))
sys.stdout.flush()
except Exception as e:
error_tb = traceback.format_exc()
with stdout_lock:
print(
json.dumps({"req_id": req_id, "error": str(e), "traceback": error_tb})
)
sys.stdout.flush()
print("[系统] LLM Daemon 启动,等待请求...", file=sys.stderr)
sys.stderr.flush()
try:
for line in sys.stdin:
line = line.strip()
if not line:
continue
try:
request = json.loads(line)
req_id = request.get("req_id", "unknown")
model_name = request.get("model", "Qwen/Qwen3-0.6B")
prompt = request.get("prompt", "")
args = request.get("args")
if not prompt:
with stdout_lock:
print(
json.dumps(
{
"req_id": req_id,
"error": "Empty prompt",
"prompt": prompt,
}
)
)
sys.stdout.flush()
continue
t = threading.Thread(
target=stream_generation,
kwargs={
"req_id": req_id,
"model_name": model_name,
"prompt": prompt,
"args": args,
},
)
t.start()
except json.JSONDecodeError as e:
with stdout_lock:
print(
json.dumps({"req_id": "system", "error": f"Invalid JSON: {str(e)}"})
)
sys.stdout.flush()
except (KeyboardInterrupt, EOFError):
print("[系统] 收到退出信号,关闭 daemon...", file=sys.stderr)
sys.stderr.flush()