tllama 0.1.1

Lightweight Local LLM Inference Engine
Documentation
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import threading
import json
import sys
from typing import Dict, Optional
import traceback
from dataclasses import dataclass

# ==============================
# 模型缓存类
# ==============================

stdout_lock = threading.Lock()


@dataclass
class Args:
    n_ctx: int = 2048
    n_len: Optional[int] = None
    temperature: float = 0.7
    top_k: int = 40
    top_p: float = 0.9
    repeat_penalty: float = 1.1


class ModelCache:
    _cache: Dict[str, Dict] = {}

    @classmethod
    def get(cls, model_name: str) -> Optional[Dict]:
        if model_name not in cls._cache:
            try:
                print(f"[加载模型] {model_name}", file=sys.stderr)
                model = AutoModelForCausalLM.from_pretrained(
                    model_name, device_map="auto", dtype="auto"
                )
                tokenizer = AutoTokenizer.from_pretrained(model_name)
                cls._cache[model_name] = {
                    "model": model,
                    "tokenizer": tokenizer,
                }
                print(f"[成功加载] {model_name}", file=sys.stderr)
            except Exception as e:
                print(f"[错误] 加载模型失败 {model_name}: {str(e)}", file=sys.stderr)
                return None
        return cls._cache[model_name]

    @classmethod
    def remove(cls, model_name: str):
        if model_name in cls._cache:
            del cls._cache[model_name]


# ==============================
# 流式生成函数(每个请求一个线程)
# ==============================


def stream_generation(req_id: str, model_name: str, prompt: str, args: dict):
    try:
        # 获取模型
        model_entry = ModelCache.get(model_name)
        if not model_entry:
            with stdout_lock:
                print(
                    json.dumps(
                        {"req_id": req_id, "error": f"Model '{model_name}' not found"}
                    )
                )
                sys.stdout.flush()
            return

        model = model_entry["model"]
        tokenizer = model_entry["tokenizer"]

        # 编码输入
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)

        # 设置生成参数
        generation_config = {
            "max_length": args.get("n_ctx", 4096),
            "max_new_tokens": args.get("n_len", 1e30) or 1e30,  # 默认值256
            "temperature": args.get("temperature", 0.7),
            "top_k": args.get("top_k", 40),
            "top_p": args.get("top_p", 0.9),
            "repetition_penalty": args.get("repeat_penalty", 1.1),
            "do_sample": True,
            "streamer": streamer,
        }

        if generation_config["max_length"] < generation_config["max_new_tokens"]:
            del generation_config["max_new_tokens"]

        # 启动生成线程
        def generate():
            try:
                model.generate(**inputs, **generation_config)
            except Exception as e:
                print(f"[生成错误] {req_id}: {e}", file=sys.stderr)
            finally:
                streamer.end()  # 必须调用

        thread = threading.Thread(target=generate)
        thread.start()

        # 实时发送 token
        for new_text in streamer:
            if new_text:
                with stdout_lock:
                    print(
                        json.dumps(
                            {"req_id": req_id, "token": new_text}, ensure_ascii=False
                        )
                    )
                    sys.stdout.flush()

        thread.join(timeout=2)
        if thread.is_alive():
            print(f"[警告] 生成线程未正常退出: {req_id}", file=sys.stderr)

        # 完成信号
        with stdout_lock:
            print(json.dumps({"req_id": req_id, "done": True}))
            sys.stdout.flush()

    except Exception as e:
        error_tb = traceback.format_exc()
        with stdout_lock:
            print(
                json.dumps({"req_id": req_id, "error": str(e), "traceback": error_tb})
            )
            sys.stdout.flush()


# ==============================
# 主循环:读取 stdin JSON 请求
# ==============================

print("[系统] LLM Daemon 启动,等待请求...", file=sys.stderr)
sys.stderr.flush()

try:
    for line in sys.stdin:
        line = line.strip()
        if not line:
            continue

        try:
            request = json.loads(line)
            req_id = request.get("req_id", "unknown")
            model_name = request.get("model", "Qwen/Qwen3-0.6B")
            prompt = request.get("prompt", "")
            args = request.get("args")

            if not prompt:
                with stdout_lock:
                    print(
                        json.dumps(
                            {
                                "req_id": req_id,
                                "error": "Empty prompt",
                                "prompt": prompt,
                            }
                        )
                    )
                    sys.stdout.flush()
                continue

            # 启动独立线程处理请求(支持并发)
            t = threading.Thread(
                target=stream_generation,
                kwargs={
                    "req_id": req_id,
                    "model_name": model_name,
                    "prompt": prompt,
                    "args": args,
                },
            )
            t.start()

        except json.JSONDecodeError as e:
            with stdout_lock:
                print(
                    json.dumps({"req_id": "system", "error": f"Invalid JSON: {str(e)}"})
                )
                sys.stdout.flush()

except (KeyboardInterrupt, EOFError):
    print("[系统] 收到退出信号,关闭 daemon...", file=sys.stderr)
    sys.stderr.flush()