use rusqlite::Connection;
use time::OffsetDateTime;
use crate::TalonError;
use crate::inference::{EmbedChunkedDataItem, EmbeddingClient, InferenceError};
use super::diagnostics::{
EmbedDiagnostics, EmbedRunContext, align_embedding_dimensions, mark_note_chunks_failed,
};
use super::pending::{NoteWithChunks, get_pending_chunks};
use super::persist::{first_non_empty_batch, persist_chunk_vector};
#[derive(Debug, Clone, Default)]
pub struct EmbedPassOptions {
pub force: bool,
pub restrict_paths: Vec<String>,
pub chunk_embedding_model: String,
pub document_embedding_model: String,
}
impl EmbedPassOptions {
#[must_use]
pub fn defaults() -> Self {
Self {
force: false,
restrict_paths: Vec::new(),
chunk_embedding_model: "embed".to_string(),
document_embedding_model: "embed_chunked".to_string(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct EmbedPassStats {
pub processed: u32,
pub succeeded: u32,
pub failed: u32,
pub dimension_mismatch: bool,
pub remediation: Option<String>,
pub diagnostics: Vec<String>,
}
pub const DIMENSION_MISMATCH_REMEDIATION: &str = "embedding model changed mid-pass — run `talon embed --force` to drop the existing vec_chunks index and re-embed every chunk at the new dimensionality";
impl From<EmbedDiagnostics> for EmbedPassStats {
fn from(value: EmbedDiagnostics) -> Self {
let remediation = value
.dimension_mismatch
.then(|| DIMENSION_MISMATCH_REMEDIATION.to_string());
Self {
processed: value.processed,
succeeded: value.succeeded,
failed: value.failed,
dimension_mismatch: value.dimension_mismatch,
remediation,
diagnostics: value.diagnostics,
}
}
}
pub fn run_embed_pass(
conn: &Connection,
client: &EmbeddingClient,
options: &EmbedPassOptions,
) -> Result<EmbedPassStats, TalonError> {
let pending = get_pending_chunks(conn, options.force, &options.restrict_paths)?;
let mut ctx = EmbedRunContext::default();
for note in &pending {
if note.chunks.len() == 1 {
embed_single_chunk(conn, client, options, note, &mut ctx)?;
} else {
embed_multi_chunk(conn, client, options, note, &mut ctx)?;
}
}
Ok(ctx.snapshot().into())
}
fn now_ms() -> i64 {
let nanos = OffsetDateTime::now_utc().unix_timestamp_nanos();
i64::try_from(nanos / 1_000_000).unwrap_or(i64::MAX)
}
fn fail_note(conn: &Connection, note: &NoteWithChunks, ctx: &mut EmbedRunContext, detail: &str) {
ctx.failed = ctx.failed.saturating_add(1);
ctx.record_diagnostic(¬e.vault_path, detail);
if let Err(err) = mark_note_chunks_failed(conn, note) {
tracing::error!(
target: "talon::embed",
vault_path = note.vault_path,
error = %err,
"could not mark chunks failed"
);
}
}
fn format_inference_failure(err: &InferenceError) -> String {
err.to_string()
}
fn fatal_endpoint_failure(err: &InferenceError) -> Option<TalonError> {
let InferenceError::Http { status, .. } = err else {
return None;
};
let status = *status;
if !matches!(status, None | Some(404 | 405 | 501)) {
return None;
}
Some(TalonError::InvalidInput {
field: "embedding.base_url",
message: format!(
"embedding endpoint unavailable or misconfigured: {}; check embedding.base_url and the sidecar embedding routes",
format_inference_failure(err)
),
})
}
fn embed_single_chunk(
conn: &Connection,
client: &EmbeddingClient,
_options: &EmbedPassOptions,
note: &NoteWithChunks,
ctx: &mut EmbedRunContext,
) -> Result<(), TalonError> {
ctx.processed = ctx.processed.saturating_add(1);
let Some(chunk) = note.chunks.first() else {
return Ok(());
};
let response = match client.embed(std::slice::from_ref(&chunk.embedding_text)) {
Ok(rows) => rows,
Err(err) => {
if let Some(fatal) = fatal_endpoint_failure(&err) {
return Err(fatal);
}
fail_note(conn, note, ctx, &format_inference_failure(&err));
return Ok(());
}
};
let Some(row) = response.into_iter().next() else {
fail_note(conn, note, ctx, "sidecar returned no embedding rows");
return Ok(());
};
let dims = match u32::try_from(row.len()) {
Ok(d) if d > 0 => d,
_ => {
fail_note(conn, note, ctx, "sidecar returned empty embedding vector");
return Ok(());
}
};
if let Err(err) = align_embedding_dimensions(conn, ctx, dims) {
fail_note(conn, note, ctx, &err.to_string());
return Ok(());
}
if ctx.dimension_mismatch {
fail_note(
conn,
note,
ctx,
&format!(
"embedding dimension mismatch (expected {expected}, got {dims}); semantic search disabled — run `talon embed --force` to rebuild at the new dimensionality",
expected = ctx.current_dimensions.unwrap_or(0)
),
);
return Ok(());
}
if let Err(err) = persist_chunk_vector(
conn,
chunk.chunk_id,
client.chunk_model(),
dims,
now_ms(),
&row,
) {
fail_note(conn, note, ctx, &err.to_string());
return Ok(());
}
ctx.succeeded = ctx.succeeded.saturating_add(1);
Ok(())
}
fn embed_multi_chunk(
conn: &Connection,
client: &EmbeddingClient,
_options: &EmbedPassOptions,
note: &NoteWithChunks,
ctx: &mut EmbedRunContext,
) -> Result<(), TalonError> {
ctx.processed = ctx.processed.saturating_add(1);
let texts: Vec<String> = note
.chunks
.iter()
.map(|c| c.embedding_text.clone())
.collect();
let response = match client.embed_chunked(&[texts]) {
Ok(r) => r,
Err(err) => {
if let Some(fatal) = fatal_endpoint_failure(&err) {
return Err(fatal);
}
fail_note(conn, note, ctx, &format_inference_failure(&err));
return Ok(());
}
};
let Some((dims, batch)) = first_non_empty_batch(&response) else {
fail_note(
conn,
note,
ctx,
"sidecar returned no usable chunked embeddings",
);
return Ok(());
};
if let Err(err) = align_embedding_dimensions(conn, ctx, dims) {
fail_note(conn, note, ctx, &err.to_string());
return Ok(());
}
if ctx.dimension_mismatch {
fail_note(
conn,
note,
ctx,
&format!(
"embedding dimension mismatch (expected {expected}, got {dims}); semantic search disabled — run `talon embed --force` to rebuild at the new dimensionality",
expected = ctx.current_dimensions.unwrap_or(0)
),
);
return Ok(());
}
if let Err(err) = persist_multi_chunk(conn, client, note, batch, dims) {
fail_note(conn, note, ctx, &err.to_string());
return Ok(());
}
ctx.succeeded = ctx.succeeded.saturating_add(1);
Ok(())
}
fn persist_multi_chunk(
conn: &Connection,
client: &EmbeddingClient,
note: &NoteWithChunks,
batch: &EmbedChunkedDataItem,
dims: u32,
) -> Result<(), TalonError> {
if batch.embeddings.len() != note.chunks.len() {
return Err(TalonError::Internal {
message: format!(
"chunked response length {got} != note chunks {expected}",
got = batch.embeddings.len(),
expected = note.chunks.len()
),
});
}
let now = now_ms();
for (chunk, embedding) in note.chunks.iter().zip(batch.embeddings.iter()) {
persist_chunk_vector(
conn,
chunk.chunk_id,
client.document_model(),
dims,
now,
embedding,
)?;
}
Ok(())
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests;
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod wiremock_tests;