talon-core 0.4.1

Core retrieval engine for Talon: hybrid search (BM25 + semantic + reranker), indexing, and graph-aware ranking over markdown corpora.
Documentation
#![allow(clippy::unwrap_used, clippy::expect_used)]

use serde_json::json;
use std::sync::{Arc, Mutex};
use std::time::Instant;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};

use super::super::pre_filter::PreFilter;
use super::test_support::{cleanup, dummy_embed_response, insert_note, runtime, unique_db_path};
use super::*;
use crate::expansion::client::ExpansionClient;

use crate::store::open_database;

fn build_recording_hooks(
    events: &Arc<Mutex<Vec<(&'static str, u128)>>>,
    started: Instant,
) -> SearchHooks {
    SearchHooks {
        on_strong_signal: Some({
            let events = Arc::clone(events);
            Box::new(move |top_score| {
                #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
                let bucket = (top_score * 1000.0) as u128;
                events.lock().unwrap().push(("strong_signal", bucket));
            })
        }),
        on_expand_start: Some({
            let events = Arc::clone(events);
            Box::new(move || {
                events
                    .lock()
                    .unwrap()
                    .push(("expand_start", started.elapsed().as_millis()));
            })
        }),
        on_expand_end: Some({
            let events = Arc::clone(events);
            Box::new(move |elapsed_ms| {
                events
                    .lock()
                    .unwrap()
                    .push(("expand_end", u128::from(elapsed_ms)));
            })
        }),
        on_embed_batch: Some({
            let events = Arc::clone(events);
            Box::new(move |batch_size| {
                events
                    .lock()
                    .unwrap()
                    .push(("embed_batch", batch_size as u128));
            })
        }),
        on_rerank_start: Some({
            let events = Arc::clone(events);
            Box::new(move |candidate_count| {
                events
                    .lock()
                    .unwrap()
                    .push(("rerank_start", candidate_count as u128));
            })
        }),
        on_rerank_end: Some({
            let events = Arc::clone(events);
            Box::new(move |elapsed_ms| {
                events
                    .lock()
                    .unwrap()
                    .push(("rerank_end", u128::from(elapsed_ms)));
            })
        }),
    }
}

#[test]
fn hooks_record_expand_before_rerank() {
    let rt = runtime();
    let server = rt.block_on(MockServer::start());

    rt.block_on(
        Mock::given(method("POST"))
            .and(path("/embed"))
            .respond_with(ResponseTemplate::new(200).set_body_json(dummy_embed_response()))
            .mount(&server),
    );
    rt.block_on(
        Mock::given(method("POST"))
            .and(path("/chat/completions"))
            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
                "choices": [{
                    "message": {
                        "content": "{\"queries\":[\"atomic notes\",\"zettelkasten\"]}"
                    }
                }]
            })))
            .mount(&server),
    );
    rt.block_on(
        Mock::given(method("POST"))
            .and(path("/rerank"))
            .respond_with(ResponseTemplate::new(200).set_body_json(json!([
                {"index": 0, "score": 0.91}
            ])))
            .mount(&server),
    );

    let db_path = unique_db_path();
    let conn = open_database(&db_path).unwrap();
    insert_note(
        &conn,
        "target.md",
        "Zettelkasten Method",
        "atomic notes for thinking and learning",
    );
    insert_note(
        &conn,
        "related.md",
        "Atomic Notes",
        "small notes connected into a knowledge graph",
    );

    let events: Arc<Mutex<Vec<(&'static str, u128)>>> = Arc::new(Mutex::new(Vec::new()));
    let hooks = build_recording_hooks(&events, Instant::now());

    let (embedding, rerank) = test_support::test_clients(server.uri());
    let expansion = ExpansionClient::new(server.uri(), "test-model").unwrap();
    let opts = HybridPipelineOptions {
        limit: 10,
        candidate_limit: 40,
        fast: false,
        skip_expansion: false,
        queries: vec![],
        intent: None,
        hooks,
        pre_filter: PreFilter::none(),
        deadline_at: None,
    };

    let results = run_hybrid_pipeline(
        &conn,
        &embedding,
        &rerank,
        Some(&expansion),
        "atomic notes",
        &opts,
    );

    assert!(!results.is_empty(), "pipeline must still return results");

    let names: Vec<&str> = {
        let events = events.lock().unwrap();
        events.iter().map(|(name, _)| *name).collect()
    };
    let expand_end = names
        .iter()
        .position(|name| *name == "expand_end")
        .expect("expand_end should fire");
    let rerank_start = names
        .iter()
        .position(|name| *name == "rerank_start")
        .expect("rerank_start should fire");

    assert!(
        expand_end < rerank_start,
        "rerank_start must fire after expand_end; events={names:?}"
    );
    assert!(
        names.contains(&"expand_start")
            && names.contains(&"embed_batch")
            && names.contains(&"rerank_end"),
        "expected all hook stages to fire; events={names:?}"
    );

    drop(conn);
    cleanup(&db_path);
}