from __future__ import annotations
import hashlib
import math
import os
import re
from typing import Protocol
DEFAULT_DIM = 384
_TOKEN_RE = re.compile(r"[A-Za-z0-9]+")
class Embedder(Protocol):
dim: int
def embed(self, text: str) -> list[float]: ...
class HashEmbedder:
def __init__(self, dim: int = DEFAULT_DIM) -> None:
self.dim = dim
def embed(self, text: str) -> list[float]:
vec = [0.0] * self.dim
tokens = _TOKEN_RE.findall(text.lower())
if not tokens:
return vec
for tok in tokens:
h = hashlib.md5(tok.encode("utf-8")).digest()
bucket = int.from_bytes(h[:4], "little") % self.dim
sign = 1.0 if (h[4] & 1) == 0 else -1.0
vec[bucket] += sign
norm = math.sqrt(sum(v * v for v in vec))
if norm > 0:
vec = [v / norm for v in vec]
return vec
class OpenAIEmbedder:
def __init__(self, *, dim: int = DEFAULT_DIM, api_key: str | None = None) -> None:
try:
from openai import OpenAI except ImportError as e: raise RuntimeError(
"install the 'openai' extra to use OpenAIEmbedder: "
"`pip install 'sqlrite-agent[openai]'`"
) from e
self.dim = dim
self._OpenAI = OpenAI
self._client = OpenAI(api_key=api_key or os.environ.get("OPENAI_API_KEY"))
def embed(self, text: str) -> list[float]:
resp = self._client.embeddings.create(
model="text-embedding-3-small",
input=text,
dimensions=self.dim,
)
return list(resp.data[0].embedding)
class LocalEmbedder:
def __init__(self, *, model_name: str = "sentence-transformers/all-MiniLM-L6-v2") -> None:
try:
from sentence_transformers import SentenceTransformer except ImportError as e: raise RuntimeError(
"install the 'local-embeddings' extra to use LocalEmbedder: "
"`pip install 'sqlrite-agent[local-embeddings]'`"
) from e
self._model = SentenceTransformer(model_name)
self.dim = self._model.get_sentence_embedding_dimension()
def embed(self, text: str) -> list[float]:
return [float(x) for x in self._model.encode(text, normalize_embeddings=True)]
def build_embedder(name: str, *, dim: int = DEFAULT_DIM) -> Embedder:
name = name.lower()
if name == "hash":
return HashEmbedder(dim=dim)
if name == "openai":
return OpenAIEmbedder(dim=dim)
if name == "local":
return LocalEmbedder()
raise ValueError(f"unknown embedder: {name!r} (expected 'hash', 'openai', or 'local')")