use mnm_core::types::{Chunk, ChunkStatus};
use pgvector::Vector;
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use time::OffsetDateTime;
use uuid::Uuid;
use crate::error::{Result, StoreError};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct DocumentSummary {
pub id: Uuid,
pub source_path: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub published_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub source_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub language: Option<String>,
pub kind: mnm_core::types::DocumentKind,
pub provenance: serde_json::Value,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SourceSummary {
pub slug: String,
pub display_name: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ChunkWithContext {
#[serde(flatten)]
pub chunk: Chunk,
pub document: DocumentSummary,
pub source: SourceSummary,
}
#[derive(Debug, Clone)]
pub struct NewChunk<'a> {
pub source_version_id: Uuid,
pub document_id: Uuid,
pub node_id: Uuid,
pub chunk_index: i32,
pub total_chunks: i32,
pub content: &'a str,
pub content_hash: &'a str,
pub embedding: Option<Vec<f32>>,
pub embedding_model_id: Uuid,
pub code_embedding: Option<Vec<f32>>,
pub heading_path: &'a [String],
pub symbol_path: &'a [mnm_core::types::SymbolSegment],
pub start_byte: i32,
pub end_byte: i32,
pub token_count: i32,
pub status: ChunkStatus,
}
pub async fn insert(pool: &PgPool, c: NewChunk<'_>) -> Result<Uuid> {
let status_str = match c.status {
ChunkStatus::Ready => "ready",
ChunkStatus::EmbedFailed => "embed_failed",
ChunkStatus::Deprecated => "deprecated",
};
let row: (Uuid,) = sqlx::query_as(
"INSERT INTO chunk ( \
source_version_id, document_id, node_id, chunk_index, total_chunks, content, \
content_hash, embedding, embedding_model_id, code_embedding, heading_path, \
symbol_path, start_byte, end_byte, token_count, status \
) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16) RETURNING id",
)
.bind(c.source_version_id)
.bind(c.document_id)
.bind(c.node_id)
.bind(c.chunk_index)
.bind(c.total_chunks)
.bind(c.content)
.bind(c.content_hash)
.bind(c.embedding.map(Vector::from))
.bind(c.embedding_model_id)
.bind(c.code_embedding.map(Vector::from))
.bind(c.heading_path)
.bind(sqlx::types::Json(c.symbol_path))
.bind(c.start_byte)
.bind(c.end_byte)
.bind(c.token_count)
.bind(status_str)
.fetch_one(pool)
.await?;
Ok(row.0)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EmbedFailedChunk {
pub id: Uuid,
pub content: String,
pub embedding_model_id: Uuid,
}
pub async fn list_embed_failed_batch(
pool: &PgPool,
model_id: Uuid,
source_version_filter: Option<Uuid>,
limit: i64,
) -> Result<Vec<EmbedFailedChunk>> {
let rows: Vec<(Uuid, String, Uuid)> = sqlx::query_as(
"SELECT id, content, embedding_model_id FROM chunk \
WHERE status = 'embed_failed' AND embedding_model_id = $1 \
AND ($2::uuid IS NULL OR source_version_id = $2::uuid) \
ORDER BY created_at ASC \
LIMIT $3",
)
.bind(model_id)
.bind(source_version_filter)
.bind(limit)
.fetch_all(pool)
.await?;
Ok(rows
.into_iter()
.map(|(id, content, embedding_model_id)| EmbedFailedChunk {
id,
content,
embedding_model_id,
})
.collect())
}
pub async fn set_embedding(pool: &PgPool, id: Uuid, vector: Vec<f32>) -> Result<bool> {
let result = sqlx::query(
"UPDATE chunk SET embedding = $1, status = 'ready' \
WHERE id = $2 AND status = 'embed_failed'",
)
.bind(Vector::from(vector))
.bind(id)
.execute(pool)
.await?;
Ok(result.rows_affected() == 1)
}
pub async fn get_by_id_ready(pool: &PgPool, id: Uuid) -> Result<Chunk> {
let row = sqlx::query_as::<_, ChunkRow>(
"SELECT id, source_version_id, document_id, node_id, chunk_index, total_chunks, \
content, content_hash, embedding_model_id, heading_path, symbol_path, \
start_byte, end_byte, token_count, status, created_at \
FROM chunk WHERE id = $1 AND status <> 'embed_failed'",
)
.bind(id)
.fetch_one(pool)
.await?;
row.try_into()
}
pub async fn get_by_id_admin(pool: &PgPool, id: Uuid) -> Result<Chunk> {
let row = sqlx::query_as::<_, ChunkRow>(
"SELECT id, source_version_id, document_id, node_id, chunk_index, total_chunks, \
content, content_hash, embedding_model_id, heading_path, symbol_path, \
start_byte, end_byte, token_count, status, created_at \
FROM chunk WHERE id = $1",
)
.bind(id)
.fetch_one(pool)
.await?;
row.try_into()
}
#[derive(Debug, Clone, PartialEq)]
pub struct CarryForwardChunk {
pub content: String,
pub content_hash: String,
pub embedding: Option<Vec<f32>>,
pub embedding_model_id: Uuid,
pub code_embedding: Option<Vec<f32>>,
pub heading_path: Vec<String>,
pub symbol_path: Vec<mnm_core::types::SymbolSegment>,
pub chunk_index: i32,
pub total_chunks: i32,
pub start_byte: i32,
pub end_byte: i32,
pub token_count: i32,
pub status: ChunkStatus,
}
pub async fn list_for_carry_forward(
pool: &PgPool,
document_id: Uuid,
) -> Result<Vec<CarryForwardChunk>> {
let rows = sqlx::query_as::<_, CarryForwardRow>(
"SELECT content, content_hash, embedding, embedding_model_id, code_embedding, \
heading_path, symbol_path, chunk_index, total_chunks, start_byte, end_byte, \
token_count, status \
FROM chunk WHERE document_id = $1 ORDER BY chunk_index",
)
.bind(document_id)
.fetch_all(pool)
.await?;
rows.into_iter().map(TryInto::try_into).collect()
}
pub async fn list_next(pool: &PgPool, anchor: Uuid, count: usize) -> Result<Vec<ChunkWithContext>> {
let count = i64::try_from(count.clamp(1, 100)).unwrap_or(5);
let rows = sqlx::query_as::<_, ChunkWithContextRow>(
"WITH a AS (SELECT document_id, chunk_index FROM chunk WHERE id = $1) \
SELECT \
c.id, c.source_version_id, c.document_id, c.node_id, c.chunk_index, c.total_chunks, \
c.content, c.content_hash, c.embedding_model_id, c.heading_path, c.symbol_path, \
c.start_byte, c.end_byte, c.token_count, c.status, c.created_at, \
d.source_path AS d_source_path, d.published_url AS d_published_url, \
d.source_url AS d_source_url, d.language AS d_language, d.kind AS d_kind, \
d.provenance AS d_provenance, \
s.slug AS s_slug, s.display_name AS s_display_name \
FROM chunk c \
JOIN document d ON c.document_id = d.id \
JOIN source_version sv ON c.source_version_id = sv.id \
JOIN source s ON sv.source_id = s.id, a \
WHERE c.document_id = a.document_id \
AND c.chunk_index > a.chunk_index \
AND c.status <> 'embed_failed' \
ORDER BY c.chunk_index ASC \
LIMIT $2",
)
.bind(anchor)
.bind(count)
.fetch_all(pool)
.await?;
rows.into_iter().map(TryInto::try_into).collect()
}
pub async fn list_prev(pool: &PgPool, anchor: Uuid, count: usize) -> Result<Vec<ChunkWithContext>> {
let count = i64::try_from(count.clamp(1, 100)).unwrap_or(5);
let mut rows = sqlx::query_as::<_, ChunkWithContextRow>(
"WITH a AS (SELECT document_id, chunk_index FROM chunk WHERE id = $1) \
SELECT \
c.id, c.source_version_id, c.document_id, c.node_id, c.chunk_index, c.total_chunks, \
c.content, c.content_hash, c.embedding_model_id, c.heading_path, c.symbol_path, \
c.start_byte, c.end_byte, c.token_count, c.status, c.created_at, \
d.source_path AS d_source_path, d.published_url AS d_published_url, \
d.source_url AS d_source_url, d.language AS d_language, d.kind AS d_kind, \
d.provenance AS d_provenance, \
s.slug AS s_slug, s.display_name AS s_display_name \
FROM chunk c \
JOIN document d ON c.document_id = d.id \
JOIN source_version sv ON c.source_version_id = sv.id \
JOIN source s ON sv.source_id = s.id, a \
WHERE c.document_id = a.document_id \
AND c.chunk_index < a.chunk_index \
AND c.status <> 'embed_failed' \
ORDER BY c.chunk_index DESC \
LIMIT $2",
)
.bind(anchor)
.bind(count)
.fetch_all(pool)
.await?;
rows.reverse();
rows.into_iter().map(TryInto::try_into).collect()
}
pub async fn get_with_context(pool: &PgPool, id: Uuid) -> Result<ChunkWithContext> {
let row = sqlx::query_as::<_, ChunkWithContextRow>(
"SELECT \
c.id, c.source_version_id, c.document_id, c.node_id, c.chunk_index, c.total_chunks, \
c.content, c.content_hash, c.embedding_model_id, c.heading_path, c.symbol_path, \
c.start_byte, c.end_byte, c.token_count, c.status, c.created_at, \
d.source_path AS d_source_path, d.published_url AS d_published_url, \
d.source_url AS d_source_url, d.language AS d_language, d.kind AS d_kind, \
d.provenance AS d_provenance, \
s.slug AS s_slug, s.display_name AS s_display_name \
FROM chunk c \
JOIN document d ON c.document_id = d.id \
JOIN source_version sv ON c.source_version_id = sv.id \
JOIN source s ON sv.source_id = s.id \
WHERE c.id = $1 AND c.status <> 'embed_failed'",
)
.bind(id)
.fetch_optional(pool)
.await?
.ok_or(StoreError::NotFound)?;
row.try_into()
}
pub async fn get_many_with_context(pool: &PgPool, ids: &[Uuid]) -> Result<Vec<ChunkWithContext>> {
let rows = sqlx::query_as::<_, ChunkWithContextRow>(
"SELECT \
c.id, c.source_version_id, c.document_id, c.node_id, c.chunk_index, c.total_chunks, \
c.content, c.content_hash, c.embedding_model_id, c.heading_path, c.symbol_path, \
c.start_byte, c.end_byte, c.token_count, c.status, c.created_at, \
d.source_path AS d_source_path, d.published_url AS d_published_url, \
d.source_url AS d_source_url, d.language AS d_language, d.kind AS d_kind, \
d.provenance AS d_provenance, \
s.slug AS s_slug, s.display_name AS s_display_name \
FROM chunk c \
JOIN document d ON c.document_id = d.id \
JOIN source_version sv ON c.source_version_id = sv.id \
JOIN source s ON sv.source_id = s.id \
WHERE c.id = ANY($1) AND c.status <> 'embed_failed'",
)
.bind(ids)
.fetch_all(pool)
.await?;
rows.into_iter().map(TryInto::try_into).collect()
}
#[derive(sqlx::FromRow)]
struct ChunkWithContextRow {
id: Uuid,
source_version_id: Uuid,
document_id: Uuid,
node_id: Uuid,
chunk_index: i32,
total_chunks: i32,
content: String,
content_hash: String,
embedding_model_id: Uuid,
heading_path: Vec<String>,
symbol_path: sqlx::types::Json<Vec<mnm_core::types::SymbolSegment>>,
start_byte: i32,
end_byte: i32,
token_count: i32,
status: String,
created_at: time::OffsetDateTime,
d_source_path: String,
d_published_url: Option<String>,
d_source_url: Option<String>,
d_language: Option<String>,
d_kind: String,
d_provenance: serde_json::Value,
s_slug: String,
s_display_name: String,
}
impl TryFrom<ChunkWithContextRow> for ChunkWithContext {
type Error = StoreError;
fn try_from(r: ChunkWithContextRow) -> Result<Self> {
let status: ChunkStatus = serde_json::from_value(serde_json::Value::String(r.status))
.map_err(|e| StoreError::Json(e.to_string()))?;
let doc_kind: mnm_core::types::DocumentKind =
serde_json::from_value(serde_json::Value::String(r.d_kind))
.map_err(|e| StoreError::Json(e.to_string()))?;
let chunk = Chunk {
id: r.id,
source_version_id: r.source_version_id,
document_id: r.document_id,
node_id: r.node_id,
chunk_index: r.chunk_index,
total_chunks: r.total_chunks,
content: r.content,
content_hash: r.content_hash,
embedding_model_id: r.embedding_model_id,
heading_path: r.heading_path,
symbol_path: r.symbol_path.0,
start_byte: r.start_byte,
end_byte: r.end_byte,
token_count: r.token_count,
status,
created_at: r.created_at,
};
Ok(Self {
chunk,
document: DocumentSummary {
id: r.document_id,
source_path: r.d_source_path,
published_url: r.d_published_url,
source_url: r.d_source_url,
language: r.d_language,
kind: doc_kind,
provenance: r.d_provenance,
},
source: SourceSummary {
slug: r.s_slug,
display_name: r.s_display_name,
},
})
}
}
#[derive(sqlx::FromRow)]
struct CarryForwardRow {
content: String,
content_hash: String,
embedding: Option<Vector>,
embedding_model_id: Uuid,
code_embedding: Option<Vector>,
heading_path: Vec<String>,
symbol_path: sqlx::types::Json<Vec<mnm_core::types::SymbolSegment>>,
chunk_index: i32,
total_chunks: i32,
start_byte: i32,
end_byte: i32,
token_count: i32,
status: String,
}
impl TryFrom<CarryForwardRow> for CarryForwardChunk {
type Error = crate::error::StoreError;
fn try_from(r: CarryForwardRow) -> std::result::Result<Self, Self::Error> {
let status: ChunkStatus = serde_json::from_value(serde_json::Value::String(r.status))
.map_err(|e| crate::error::StoreError::Json(e.to_string()))?;
Ok(Self {
content: r.content,
content_hash: r.content_hash,
embedding: r.embedding.map(|v| v.to_vec()),
embedding_model_id: r.embedding_model_id,
code_embedding: r.code_embedding.map(|v| v.to_vec()),
heading_path: r.heading_path,
symbol_path: r.symbol_path.0,
chunk_index: r.chunk_index,
total_chunks: r.total_chunks,
start_byte: r.start_byte,
end_byte: r.end_byte,
token_count: r.token_count,
status,
})
}
}
#[derive(sqlx::FromRow)]
struct ChunkRow {
id: Uuid,
source_version_id: Uuid,
document_id: Uuid,
node_id: Uuid,
chunk_index: i32,
total_chunks: i32,
content: String,
content_hash: String,
embedding_model_id: Uuid,
heading_path: Vec<String>,
symbol_path: sqlx::types::Json<Vec<mnm_core::types::SymbolSegment>>,
start_byte: i32,
end_byte: i32,
token_count: i32,
status: String,
created_at: OffsetDateTime,
}
impl TryFrom<ChunkRow> for Chunk {
type Error = crate::error::StoreError;
fn try_from(r: ChunkRow) -> std::result::Result<Self, Self::Error> {
let status: ChunkStatus = serde_json::from_value(serde_json::Value::String(r.status))
.map_err(|e| crate::error::StoreError::Json(e.to_string()))?;
Ok(Self {
id: r.id,
source_version_id: r.source_version_id,
document_id: r.document_id,
node_id: r.node_id,
chunk_index: r.chunk_index,
total_chunks: r.total_chunks,
content: r.content,
content_hash: r.content_hash,
embedding_model_id: r.embedding_model_id,
heading_path: r.heading_path,
symbol_path: r.symbol_path.0,
start_byte: r.start_byte,
end_byte: r.end_byte,
token_count: r.token_count,
status,
created_at: r.created_at,
})
}
}
pub async fn symbol_path_of(
pool: &PgPool,
id: Uuid,
) -> Result<Vec<mnm_core::types::SymbolSegment>> {
let row: (sqlx::types::Json<Vec<mnm_core::types::SymbolSegment>>,) =
sqlx::query_as("SELECT symbol_path FROM chunk WHERE id = $1")
.bind(id)
.fetch_one(pool)
.await?;
Ok(row.0 .0)
}