from __future__ import annotations
import asyncio
import sys
import time
from pathlib import Path
from typing import AsyncIterator, Callable
import grpc
from grpc_tools import protoc
import structlog
from claw_vector_svc.config import Settings
from claw_vector_svc.embedder import EmbedderService
from claw_vector_svc.metrics import record_embed_request
log = structlog.get_logger(__name__)
PROTO_ROOT = Path(__file__).resolve().parents[1] / "proto"
PROTO_FILE = PROTO_ROOT / "vector.proto"
def _ensure_proto_generated() -> None:
vector_pb2 = PROTO_ROOT / "vector_pb2.py"
vector_pb2_grpc = PROTO_ROOT / "vector_pb2_grpc.py"
if vector_pb2.exists() and vector_pb2_grpc.exists():
if vector_pb2.stat().st_mtime >= PROTO_FILE.stat().st_mtime and vector_pb2_grpc.stat().st_mtime >= PROTO_FILE.stat().st_mtime:
return
result = protoc.main(
[
"grpc_tools.protoc",
f"-I{PROTO_ROOT}",
f"--python_out={PROTO_ROOT}",
f"--grpc_python_out={PROTO_ROOT}",
str(PROTO_FILE),
]
)
if result != 0:
raise RuntimeError("failed to generate Python gRPC stubs from vector.proto")
def _import_proto():
_ensure_proto_generated()
proto_path = str(PROTO_ROOT)
if proto_path not in sys.path:
sys.path.insert(0, proto_path)
import vector_pb2 as pb2 import vector_pb2_grpc as pb2_grpc
return pb2, pb2_grpc
vector_pb2, vector_pb2_grpc = _import_proto()
class EmbeddingServicer(vector_pb2_grpc.EmbeddingServiceServicer):
def __init__(
self,
embedder: EmbedderService,
allowed_keys: set[str],
is_warmup_complete: Callable[[], bool],
) -> None:
self._embedder = embedder
self._allowed_keys = allowed_keys
self._is_warmup_complete = is_warmup_complete
async def _authorize(self, context) -> None:
if not self._allowed_keys:
return
metadata = dict(context.invocation_metadata())
raw = metadata.get("authorization", "")
api_key = raw.removeprefix("Bearer ").strip()
if api_key not in self._allowed_keys:
await context.abort(grpc.StatusCode.UNAUTHENTICATED, "invalid API key")
async def _ensure_ready(self, context) -> None:
if not self._embedder.is_ready or not self._is_warmup_complete():
await context.abort(grpc.StatusCode.UNAVAILABLE, "embedding model is not ready")
async def Embed(self, request, context):
await self._authorize(context)
await self._ensure_ready(context)
t0 = time.monotonic()
texts = list(request.texts)
try:
vectors = await asyncio.to_thread(
self._embedder.embed,
texts,
request.normalize,
)
latency_ms = int((time.monotonic() - t0) * 1000)
record_embed_request("grpc", "ok", len(texts), time.monotonic() - t0)
embed_vectors = [
vector_pb2.EmbedVector(values=vector.tolist(), dimensions=len(vector))
for vector in vectors
]
return vector_pb2.EmbedResponse(
vectors=embed_vectors,
model_name=self._embedder.model_name,
latency_ms=latency_ms,
)
except Exception as exc:
log.error("embed failed", error=str(exc))
record_embed_request("grpc", "error", len(texts), time.monotonic() - t0)
await context.abort(grpc.StatusCode.INTERNAL, str(exc))
async def Health(self, request, context):
await self._authorize(context)
return vector_pb2.HealthResponse(
ready=self._embedder.is_ready and self._is_warmup_complete(),
model_name=self._embedder.model_name,
model_load_time_ms=self._embedder.load_time_ms,
)
async def ModelInfo(self, request, context):
await self._authorize(context)
await self._ensure_ready(context)
return vector_pb2.ModelInfoResponse(
model_name=self._embedder.model_name,
dimensions=self._embedder.dimensions,
max_sequence_length=self._embedder._settings.max_sequence_length,
device=self._embedder._settings.device,
)
async def EmbedStream(self, request_iterator, context) -> AsyncIterator[object]:
await self._authorize(context)
await self._ensure_ready(context)
async for req in request_iterator:
t0 = time.monotonic()
texts = list(req.texts)
vectors = await asyncio.to_thread(self._embedder.embed, texts, req.normalize)
latency_ms = int((time.monotonic() - t0) * 1000)
embed_vectors = [
vector_pb2.EmbedVector(values=vector.tolist(), dimensions=len(vector))
for vector in vectors
]
record_embed_request("grpc", "ok", len(texts), time.monotonic() - t0)
yield vector_pb2.EmbedResponse(
vectors=embed_vectors,
model_name=self._embedder.model_name,
latency_ms=latency_ms,
)
async def start_grpc_server(
embedder: EmbedderService,
settings: Settings,
allowed_keys: set[str],
is_warmup_complete: Callable[[], bool],
) -> grpc.aio.Server:
server = grpc.aio.server()
vector_pb2_grpc.add_EmbeddingServiceServicer_to_server(
EmbeddingServicer(embedder, allowed_keys, is_warmup_complete),
server,
)
address = f"{settings.grpc_host}:{settings.grpc_port}"
server.add_insecure_port(address)
await server.start()
log.info("gRPC server started", address=address)
return server