khive-pack-knowledge 0.2.10

Knowledge verb pack — lore corpus (atoms/domains), TF-IDF retrieval, concept registration
Documentation
//! Section embedding pass (ADR-051).
//!
//! Embeds `knowledge_sections` into the inline `embedding` column with the
//! default embedder. Vectors are unit-normalised little-endian `f32` so a stored
//! dot product equals cosine similarity. Embed text is breadcrumb-enriched:
//! `atom_name \n heading \n\n content`.

use khive_runtime::{KhiveRuntime, NamespaceToken, RuntimeError};
use khive_storage::types::{SqlStatement, SqlValue};

use super::util::{now_us, row_str, sql_err, EMBED_BATCH, MAX_EMBED_BYTES};

fn unit_normalize(v: &mut [f32]) {
    let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
    if norm > 1e-12 {
        for x in v.iter_mut() {
            *x /= norm;
        }
    }
}

fn f32_to_le_bytes(v: &[f32]) -> Vec<u8> {
    let mut out = Vec::with_capacity(v.len() * 4);
    for x in v {
        out.extend_from_slice(&x.to_le_bytes());
    }
    out
}

fn truncate_bytes(t: &str) -> String {
    if t.len() <= MAX_EMBED_BYTES {
        return t.to_string();
    }
    let mut end = MAX_EMBED_BYTES;
    while !t.is_char_boundary(end) {
        end -= 1;
    }
    t[..end].to_string()
}

/// Embed sections in `token`'s namespace into `knowledge_sections.embedding`.
///
/// With `drop_existing`, every section is re-embedded; otherwise only sections
/// whose `embedding` is currently NULL are filled. Returns `(indexed, skipped,
/// failed)`. Genuine skips (blank section text) go to `skipped`; embed errors
/// and vector-count mismatches go to `failed` (fail-closed contract).
pub(crate) async fn embed_sections(
    runtime: &KhiveRuntime,
    token: &NamespaceToken,
    drop_existing: bool,
    batch_size: usize,
    on_progress: Option<&(dyn Fn(u64, u64) + Send + Sync)>,
) -> Result<(usize, usize, usize), RuntimeError> {
    if runtime.default_embedder_name().is_empty() {
        return Ok((0, 0, 0));
    }
    let ns = token.namespace().as_str().to_owned();
    let sql = runtime.sql();
    let page = batch_size.clamp(1, 1000) as i64;

    let total: u64 = {
        let filter = if drop_existing {
            ""
        } else {
            " AND embedding IS NULL"
        };
        let mut reader = sql
            .reader()
            .await
            .map_err(|e| sql_err("section count reader", e))?;
        let row = reader
            .query_row(SqlStatement {
                sql: format!(
                    "SELECT count(*) AS cnt FROM knowledge_sections \
                     WHERE namespace = ?1{filter}"
                ),
                params: vec![SqlValue::Text(ns.clone())],
                label: None,
            })
            .await
            .map_err(|e| sql_err("section count", e))?;
        match row {
            Some(r) => match r.get("cnt") {
                Some(SqlValue::Integer(n)) => *n as u64,
                _ => 0,
            },
            None => 0,
        }
    };

    if let Some(cb) = on_progress {
        cb(0, total);
    }

    let mut indexed = 0usize;
    let mut skipped = 0usize;
    let mut failed = 0usize;
    let mut offset = 0i64;

    loop {
        let skipped_before = skipped;
        let failed_before = failed;
        // When keeping existing vectors we filter on `embedding IS NULL`. Embedded
        // rows leave the set; rows that fail or are skipped stay NULL, so we must
        // advance the offset past THEM each pass (see the offset update below) —
        // otherwise a full page of persistent failures would re-select forever.
        // A full re-embed has a stable result set, so we paginate by row count.
        let (filter, query_offset) = if drop_existing {
            ("", offset)
        } else {
            (" AND s.embedding IS NULL", offset)
        };
        let query = format!(
            "SELECT s.id AS id, s.heading AS heading, s.content AS content, \
                    a.name AS atom_name \
             FROM knowledge_sections s \
             JOIN knowledge_atoms a ON a.id = s.atom_id \
             WHERE s.namespace = ?1{filter} \
             ORDER BY s.id LIMIT ?2 OFFSET ?3"
        );
        let mut reader = sql
            .reader()
            .await
            .map_err(|e| sql_err("section index reader", e))?;
        let rows = reader
            .query_all(SqlStatement {
                sql: query,
                params: vec![
                    SqlValue::Text(ns.clone()),
                    SqlValue::Integer(page),
                    SqlValue::Integer(query_offset),
                ],
                label: None,
            })
            .await
            .map_err(|e| sql_err("section index page", e))?;
        let n = rows.len();
        if n == 0 {
            break;
        }

        let mut staged: Vec<(String, String)> = Vec::with_capacity(n);
        for r in &rows {
            let Some(id) = row_str(r, "id") else {
                continue;
            };
            let heading = row_str(r, "heading").unwrap_or_default();
            let content = row_str(r, "content").unwrap_or_default();
            let atom_name = row_str(r, "atom_name").unwrap_or_default();
            let text = format!("{atom_name}\n{heading}\n\n{content}");
            if text.trim().is_empty() {
                skipped += 1;
                continue;
            }
            staged.push((id, text));
        }

        for chunk in staged.chunks(EMBED_BATCH) {
            let texts: Vec<String> = chunk.iter().map(|(_, t)| truncate_bytes(t)).collect();
            let embeddings = match runtime.embed_document_batch(&texts).await {
                Ok(e) if e.len() == chunk.len() => e,
                Ok(_) => {
                    tracing::warn!(
                        batch = chunk.len(),
                        "section embed_batch returned wrong vector count; counting as failed"
                    );
                    failed += chunk.len();
                    continue;
                }
                Err(e) => {
                    tracing::warn!(error = %e, batch = chunk.len(), "section embed_batch failed; counting as failed");
                    failed += chunk.len();
                    continue;
                }
            };
            let mut writer = sql
                .writer()
                .await
                .map_err(|e| sql_err("section index writer", e))?;
            let now = now_us();
            for ((id, _), mut emb) in chunk.iter().zip(embeddings.into_iter()) {
                unit_normalize(&mut emb);
                if let Err(e) = writer
                    .execute(SqlStatement {
                        sql: "UPDATE knowledge_sections SET embedding = ?1, updated_at = ?2 \
                              WHERE id = ?3"
                            .into(),
                        params: vec![
                            SqlValue::Blob(f32_to_le_bytes(&emb)),
                            SqlValue::Integer(now),
                            SqlValue::Text(id.clone()),
                        ],
                        label: None,
                    })
                    .await
                {
                    tracing::warn!(id = %id, error = %e, "section embedding UPDATE failed; counting as failed");
                    failed += 1;
                } else {
                    indexed += 1;
                }
            }
        }

        if let Some(cb) = on_progress {
            cb(indexed as u64, total);
        }

        if n < page as usize {
            break;
        }
        // drop_existing: stable result set → paginate by full page.
        // keep-existing: embedded rows leave the `embedding IS NULL` set, so only
        // the rows still NULL after this pass (failed + skipped) need to be
        // stepped over; advancing by exactly that count attempts each row once
        // and guarantees termination.
        offset += if drop_existing {
            n as i64
        } else {
            (skipped - skipped_before + failed - failed_before) as i64
        };
    }

    Ok((indexed, skipped, failed))
}