from __future__ import annotations
import time
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Iterator, Optional
import sqlrite
from sqlrite_agent.sqlutil import q
DEFAULT_DIM = 384
SCHEMA_VERSION = 1
@dataclass(frozen=True)
class Message:
id: int
conversation_id: str
role: str
content: str
ts: int
@dataclass(frozen=True)
class Summary:
id: int
conversation_id: str
start_ts: int
end_ts: int
content: str
@dataclass(frozen=True)
class Fact:
id: int
subject: str
predicate: str
object: str
source_message_id: Optional[int]
ts: int
class AgentDB:
def __init__(self, path: str, *, dim: int = DEFAULT_DIM) -> None:
self.path = path
self.dim = dim
self._conn = sqlrite.connect(path)
self._migrate()
def _migrate(self) -> None:
cur = self._conn.cursor()
try:
cur.execute("SELECT version FROM schema_version")
row = cur.fetchone()
current = int(row[0]) if row else 0
except sqlrite.SQLRiteError:
cur.execute("CREATE TABLE schema_version (version INTEGER PRIMARY KEY)")
current = 0
if current < 1:
self._apply_v1()
cur.execute(f"INSERT INTO schema_version (version) VALUES ({SCHEMA_VERSION})")
def _apply_v1(self) -> None:
cur = self._conn.cursor()
dim = self.dim
cur.execute(
f"""
CREATE TABLE messages (
id INTEGER PRIMARY KEY,
conversation_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
ts INTEGER NOT NULL,
embedding VECTOR({dim})
)
"""
)
cur.execute(
f"""
CREATE TABLE summaries (
id INTEGER PRIMARY KEY,
conversation_id TEXT NOT NULL,
start_ts INTEGER NOT NULL,
end_ts INTEGER NOT NULL,
content TEXT NOT NULL,
embedding VECTOR({dim})
)
"""
)
cur.execute(
"""
CREATE TABLE facts (
id INTEGER PRIMARY KEY,
subject TEXT NOT NULL,
predicate TEXT NOT NULL,
object TEXT NOT NULL,
source_message_id INTEGER,
ts INTEGER NOT NULL
)
"""
)
cur.execute("CREATE INDEX idx_messages_emb ON messages USING hnsw (embedding)")
cur.execute("CREATE INDEX idx_summaries_emb ON summaries USING hnsw (embedding)")
cur.execute("CREATE INDEX idx_messages_fts ON messages USING fts (content)")
cur.execute("CREATE INDEX idx_summaries_fts ON summaries USING fts (content)")
def insert_message(
self,
*,
conversation_id: str,
role: str,
content: str,
embedding: list[float],
ts: Optional[int] = None,
) -> int:
ts = ts or int(time.time())
sql = (
"INSERT INTO messages (conversation_id, role, content, ts, embedding) "
f"VALUES ({q(conversation_id)}, {q(role)}, {q(content)}, {q(ts)}, {q(embedding)})"
)
cur = self._conn.cursor()
cur.execute(sql)
return self._last_rowid("messages")
def insert_summary(
self,
*,
conversation_id: str,
start_ts: int,
end_ts: int,
content: str,
embedding: list[float],
) -> int:
sql = (
"INSERT INTO summaries (conversation_id, start_ts, end_ts, content, embedding) "
f"VALUES ({q(conversation_id)}, {q(start_ts)}, {q(end_ts)}, {q(content)}, {q(embedding)})"
)
cur = self._conn.cursor()
cur.execute(sql)
return self._last_rowid("summaries")
def insert_fact(
self,
*,
subject: str,
predicate: str,
object_: str,
source_message_id: Optional[int] = None,
ts: Optional[int] = None,
) -> int:
ts = ts or int(time.time())
src = q(source_message_id) if source_message_id is not None else "NULL"
sql = (
"INSERT INTO facts (subject, predicate, object, source_message_id, ts) "
f"VALUES ({q(subject)}, {q(predicate)}, {q(object_)}, {src}, {q(ts)})"
)
cur = self._conn.cursor()
cur.execute(sql)
return self._last_rowid("facts")
def _last_rowid(self, table: str) -> int:
cur = self._conn.cursor()
cur.execute(f"SELECT id FROM {table} ORDER BY id DESC LIMIT 1")
row = cur.fetchone()
return int(row[0]) if row else -1
def recent_messages(self, *, conversation_id: str, limit: int) -> list[Message]:
sql = (
"SELECT id, conversation_id, role, content, ts FROM messages "
f"WHERE conversation_id = {q(conversation_id)} "
f"ORDER BY id DESC LIMIT {int(limit)}"
)
cur = self._conn.cursor()
cur.execute(sql)
rows = list(cur.fetchall())
rows.reverse() return [Message(*r) for r in rows]
def vector_search_messages(
self,
*,
embedding: list[float],
k: int,
conversation_id: Optional[str] = None,
) -> list[Message]:
where = (
f"WHERE conversation_id = {q(conversation_id)} "
if conversation_id is not None
else ""
)
sql = (
"SELECT id, conversation_id, role, content, ts FROM messages "
f"{where}"
f"ORDER BY vec_distance_cosine(embedding, {q(embedding)}) "
f"LIMIT {int(k)}"
)
cur = self._conn.cursor()
cur.execute(sql)
return [Message(*r) for r in cur.fetchall()]
def vector_search_summaries(
self,
*,
embedding: list[float],
k: int,
conversation_id: Optional[str] = None,
) -> list[Summary]:
where = (
f"WHERE conversation_id = {q(conversation_id)} "
if conversation_id is not None
else ""
)
sql = (
"SELECT id, conversation_id, start_ts, end_ts, content FROM summaries "
f"{where}"
f"ORDER BY vec_distance_cosine(embedding, {q(embedding)}) "
f"LIMIT {int(k)}"
)
cur = self._conn.cursor()
cur.execute(sql)
return [Summary(*r) for r in cur.fetchall()]
def lexical_search_messages(
self,
*,
keywords: list[str],
k: int,
conversation_id: Optional[str] = None,
) -> list[Message]:
if not keywords:
return []
query = " ".join(keywords)
conv_clause = (
f"AND conversation_id = {q(conversation_id)} " if conversation_id else ""
)
sql = (
"SELECT id, conversation_id, role, content, ts FROM messages "
f"WHERE fts_match(content, {q(query)}) {conv_clause}"
f"ORDER BY bm25_score(content, {q(query)}) DESC LIMIT {int(k)}"
)
cur = self._conn.cursor()
cur.execute(sql)
return [Message(*r) for r in cur.fetchall()]
def search_facts(self, *, keywords: list[str], k: int = 20) -> list[Fact]:
if not keywords:
return []
keyword_clauses = " OR ".join(
f"subject LIKE {q('%' + kw + '%')} OR "
f"predicate LIKE {q('%' + kw + '%')} OR "
f"object LIKE {q('%' + kw + '%')}"
for kw in keywords
)
sql = (
"SELECT id, subject, predicate, object, source_message_id, ts FROM facts "
f"WHERE {keyword_clauses} "
f"ORDER BY ts DESC LIMIT {int(k)}"
)
cur = self._conn.cursor()
cur.execute(sql)
return [Fact(*r) for r in cur.fetchall()]
def all_facts(self, limit: int = 100) -> list[Fact]:
sql = (
"SELECT id, subject, predicate, object, source_message_id, ts FROM facts "
f"ORDER BY ts DESC LIMIT {int(limit)}"
)
cur = self._conn.cursor()
cur.execute(sql)
return [Fact(*r) for r in cur.fetchall()]
def messages_in_window(
self, *, conversation_id: str, start_ts: int, end_ts: int
) -> list[Message]:
sql = (
"SELECT id, conversation_id, role, content, ts FROM messages "
f"WHERE conversation_id = {q(conversation_id)} "
f"AND ts >= {q(start_ts)} AND ts <= {q(end_ts)} "
"ORDER BY id ASC"
)
cur = self._conn.cursor()
cur.execute(sql)
return [Message(*r) for r in cur.fetchall()]
def count(self, table: str) -> int:
cur = self._conn.cursor()
cur.execute(f"SELECT COUNT(*) FROM {table}")
row = cur.fetchone()
return int(row[0]) if row else 0
def close(self) -> None:
self._conn.close()
@contextmanager
def transaction(self) -> Iterator[None]:
cur = self._conn.cursor()
cur.execute("BEGIN")
try:
yield
self._conn.commit()
except Exception:
self._conn.rollback()
raise