use serde_json::Value;
use tt_tokenize::estimate_tokens;
use uuid::Uuid;
use crate::embed::EmbeddingClient;
use crate::error::RetrievalError;
use crate::search::top_k;
use crate::store::RetrievalStore;
use crate::tags;
pub const DEFAULT_MIN_SIMILARITY: f32 = 0.6;
const EMBEDDING_PROVIDER: &str = "openai";
pub struct SubstitutionReport {
pub substitutions: u32,
pub low_confidence_skips: u32,
pub size_increase_skips: u32,
pub gross_tokens_saved: i64,
pub embedding_tokens_cost: i64,
pub tokens_saved_estimate: i64,
}
pub async fn substitute_in_messages(
messages: &mut [Value],
org_id: Uuid,
store: &dyn RetrievalStore,
embedder: &EmbeddingClient,
) -> Result<SubstitutionReport, RetrievalError> {
let mut substitutions = 0u32;
let mut low_confidence_skips = 0u32;
let mut size_increase_skips = 0u32;
let mut gross_saved: i64 = 0;
let mut embedding_cost: i64 = 0;
for msg in messages.iter_mut() {
let Some(content) = msg.get_mut("content") else {
continue;
};
let Some(text) = content.as_str() else {
continue;
};
let text = text.to_string();
let tags = tags::parse(&text)?;
if tags.is_empty() {
continue;
}
let mut without_tags = String::new();
let mut last = 0;
for t in &tags {
without_tags.push_str(&text[last..t.span.0]);
last = t.span.1;
}
without_tags.push_str(&text[last..]);
let query_emb = embedder.embed(&without_tags).await?;
let query_tokens = estimate_tokens(EMBEDDING_PROVIDER, &without_tags) as i64;
embedding_cost += query_tokens;
let mut new_text = String::new();
let mut cursor = 0;
for t in &tags {
new_text.push_str(&text[cursor..t.span.0]);
let floor = t.min_similarity.unwrap_or(DEFAULT_MIN_SIMILARITY);
let hits = top_k(
store,
org_id,
&t.corpus,
&query_emb,
t.k as usize,
floor,
&embedder.model,
)
.await?;
if hits.is_empty() {
new_text.push_str(&text[t.span.0..t.span.1]);
low_confidence_skips += 1;
} else {
let original_payload = &text[t.span.0..t.span.1];
let replacement = hits
.iter()
.map(|r| r.text.clone())
.collect::<Vec<_>>()
.join("\n\n---\n\n");
let orig_tokens = estimate_tokens(EMBEDDING_PROVIDER, original_payload) as i64;
let repl_tokens = estimate_tokens(EMBEDDING_PROVIDER, &replacement) as i64;
let delta = orig_tokens - repl_tokens;
if delta <= 0 {
new_text.push_str(original_payload);
size_increase_skips += 1;
} else {
gross_saved += delta;
new_text.push_str(&replacement);
substitutions += 1;
}
}
cursor = t.span.1;
}
new_text.push_str(&text[cursor..]);
*content = Value::String(new_text);
}
let net = gross_saved - embedding_cost;
Ok(SubstitutionReport {
substitutions,
low_confidence_skips,
size_increase_skips,
gross_tokens_saved: gross_saved,
embedding_tokens_cost: embedding_cost,
tokens_saved_estimate: net,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::store::memory::MemoryStore;
use crate::types::Chunk;
use httpmock::prelude::*;
use serde_json::json;
async fn mock_embedder(server: &MockServer, emb: Vec<f64>) -> EmbeddingClient {
server
.mock_async(|when, then| {
when.method(POST).path("/v1/embeddings");
then.status(200)
.json_body(json!({ "data": [{ "embedding": emb }] }));
})
.await;
EmbeddingClient {
api_key: "k".into(),
base_url: server.base_url(),
model: "x".into(),
http: reqwest::Client::new(),
}
}
fn chunk(org: uuid::Uuid, corpus: &str, emb: Vec<f32>, text: &str) -> Chunk {
Chunk {
id: uuid::Uuid::new_v4(),
org_id: org,
corpus: corpus.into(),
doc_id: uuid::Uuid::new_v4(),
chunk_idx: 0,
text: text.into(),
embedding: emb,
embedding_model: "x".into(),
metadata: json!({}),
}
}
#[tokio::test]
async fn low_similarity_leaves_payload_intact() {
let server = MockServer::start_async().await;
let embedder = mock_embedder(&server, vec![1.0, 0.0]).await;
let store = MemoryStore::new();
let org = Uuid::new_v4();
store
.insert(chunk(org, "docs", vec![0.0, 1.0], "IrrelevantChunk"))
.await
.unwrap();
let original =
r#"Hello <retrievable corpus="docs" k="1">original payload</retrievable> world"#;
let mut messages = vec![json!({ "role": "user", "content": original })];
let report = substitute_in_messages(&mut messages, org, &store, &embedder)
.await
.unwrap();
assert_eq!(report.substitutions, 0);
assert_eq!(report.low_confidence_skips, 1);
assert_eq!(report.gross_tokens_saved, 0);
assert!(
report.tokens_saved_estimate <= 0,
"net savings must be <= 0 when nothing was substituted (embedding cost > 0)"
);
let content = messages[0]["content"].as_str().unwrap();
assert_eq!(
content, original,
"content must be unchanged when no chunk clears the floor"
);
}
#[tokio::test]
async fn high_similarity_substitutes_payload() {
let server = MockServer::start_async().await;
let embedder = mock_embedder(&server, vec![1.0, 0.0]).await;
let store = MemoryStore::new();
let org = Uuid::new_v4();
store
.insert(chunk(org, "docs", vec![1.0, 0.0], "Retrieved-A"))
.await
.unwrap();
let mut messages = vec![json!({
"role": "user",
"content": r#"Summarize <retrievable corpus="docs" k="1">raw payload</retrievable> please."#
})];
let report = substitute_in_messages(&mut messages, org, &store, &embedder)
.await
.unwrap();
assert_eq!(report.substitutions, 1);
assert_eq!(report.low_confidence_skips, 0);
let content = messages[0]["content"].as_str().unwrap();
assert!(
content.contains("Retrieved-A"),
"retrieved chunk must appear in content"
);
assert!(
!content.contains("raw payload"),
"original payload must be replaced"
);
}
#[tokio::test]
async fn per_tag_min_similarity_override() {
let server = MockServer::start_async().await;
let embedder = mock_embedder(&server, vec![1.0, 0.0]).await;
let store = MemoryStore::new();
let org = Uuid::new_v4();
let norm = 2f32.sqrt() / 2.0;
store
.insert(chunk(org, "docs", vec![norm, norm], "MidChunk"))
.await
.unwrap();
let original =
r#"Q: <retrievable corpus="docs" k="1" min_similarity="0.8">fallback</retrievable>"#;
let mut messages = vec![json!({ "role": "user", "content": original })];
let report = substitute_in_messages(&mut messages, org, &store, &embedder)
.await
.unwrap();
assert_eq!(
report.substitutions, 0,
"chunk below per-tag floor must not substitute"
);
assert_eq!(report.low_confidence_skips, 1);
let content = messages[0]["content"].as_str().unwrap();
assert_eq!(
content, original,
"content must be unchanged when per-tag floor is not met"
);
}
#[tokio::test]
async fn tokens_saved_only_for_substituted_spans() {
let server = MockServer::start_async().await;
let embedder = mock_embedder(&server, vec![1.0, 0.0]).await;
let store = MemoryStore::new();
let org = Uuid::new_v4();
store
.insert(chunk(org, "good", vec![1.0, 0.0], "Short"))
.await
.unwrap();
store
.insert(chunk(org, "bad", vec![0.0, 1.0], "IrrelevantChunk"))
.await
.unwrap();
let mut messages = vec![json!({
"role": "user",
"content": concat!(
r#"A <retrievable corpus="good" k="1">a very long original payload text here</retrievable>"#,
r#" and B <retrievable corpus="bad" k="1">another long payload that must stay</retrievable>."#
)
})];
let report = substitute_in_messages(&mut messages, org, &store, &embedder)
.await
.unwrap();
assert_eq!(report.substitutions, 1);
assert_eq!(report.low_confidence_skips, 1);
assert!(
report.gross_tokens_saved > 0,
"expected positive gross token savings from substituted span"
);
}
#[tokio::test]
async fn net_savings_subtracts_embedding_cost() {
let server = MockServer::start_async().await;
let embedder = mock_embedder(&server, vec![1.0, 0.0]).await;
let store = MemoryStore::new();
let org = Uuid::new_v4();
store
.insert(chunk(org, "docs", vec![1.0, 0.0], "ok"))
.await
.unwrap();
let long_query = "This is a fairly long surrounding context sentence to ensure the embedding query has a non-trivial token cost that will exceed any tiny gross savings.";
let content = format!(r#"{long_query} <retrievable corpus="docs" k="1">x</retrievable>"#);
let mut messages = vec![json!({ "role": "user", "content": content })];
let report = substitute_in_messages(&mut messages, org, &store, &embedder)
.await
.unwrap();
assert!(
report.embedding_tokens_cost > 0,
"embedding cost must be tracked (got {})",
report.embedding_tokens_cost
);
assert!(
report.tokens_saved_estimate <= 0,
"net savings must be <= 0 when embedding cost dominates (got {})",
report.tokens_saved_estimate
);
}
#[tokio::test]
async fn larger_replacement_is_skipped() {
let server = MockServer::start_async().await;
let embedder = mock_embedder(&server, vec![1.0, 0.0]).await;
let store = MemoryStore::new();
let org = Uuid::new_v4();
let big_chunk = "This is a very long retrieved chunk that contains many many tokens and is definitely much larger than the tiny original placeholder text.";
store
.insert(chunk(org, "docs", vec![1.0, 0.0], big_chunk))
.await
.unwrap();
let original = r#"Q: <retrievable corpus="docs" k="1">tiny</retrievable>"#;
let mut messages = vec![json!({ "role": "user", "content": original })];
let report = substitute_in_messages(&mut messages, org, &store, &embedder)
.await
.unwrap();
assert_eq!(
report.substitutions, 0,
"larger replacement must not be counted as a substitution"
);
assert_eq!(
report.size_increase_skips, 1,
"larger replacement must be counted in size_increase_skips"
);
assert_eq!(
report.gross_tokens_saved, 0,
"no gross savings when replacement is larger"
);
let content = messages[0]["content"].as_str().unwrap();
assert!(
content.contains("tiny"),
"original payload must be preserved when replacement is larger"
);
assert!(
!content.contains(big_chunk),
"large replacement must not be spliced in"
);
}
#[test]
fn estimate_uses_tokenizer_not_char_div_4() {
let text = "café";
let tokenizer_estimate = tt_tokenize::estimate_tokens(EMBEDDING_PROVIDER, text);
let char_div_4 = tt_tokenize::char_count_estimate(text);
assert_ne!(
tokenizer_estimate, char_div_4,
"tiktoken estimate ({tokenizer_estimate}) must differ from chars/4 heuristic ({char_div_4}) for \"café\""
);
assert!(
tokenizer_estimate > 0,
"tokenizer must return > 0 for non-empty text"
);
}
#[tokio::test]
async fn substitution_replaces_payload_with_top_k_chunks() {
let emb_server = MockServer::start_async().await;
let embedder = mock_embedder(&emb_server, vec![1.0, 0.0]).await;
let store = MemoryStore::new();
let org = Uuid::new_v4();
store
.insert(Chunk {
id: Uuid::new_v4(),
org_id: org,
corpus: "docs".into(),
doc_id: Uuid::new_v4(),
chunk_idx: 0,
text: "Retrieved-A".into(),
embedding: vec![1.0, 0.0],
embedding_model: "x".into(),
metadata: json!({}),
})
.await
.unwrap();
let mut messages = vec![json!({
"role": "user",
"content": "Summarize <retrievable corpus=\"docs\" k=\"1\">raw payload that the LLM never sees</retrievable> for the team."
})];
let report = substitute_in_messages(&mut messages, org, &store, &embedder)
.await
.unwrap();
assert_eq!(report.substitutions, 1);
let new_content = messages[0]["content"].as_str().unwrap();
assert!(new_content.contains("Retrieved-A"));
assert!(!new_content.contains("raw payload"));
}
#[tokio::test]
async fn cross_model_chunk_is_not_retrieved() {
let server = MockServer::start_async().await;
let embedder = mock_embedder(&server, vec![1.0, 0.0]).await; let store = MemoryStore::new();
let org = uuid::Uuid::new_v4();
store
.insert(Chunk {
id: uuid::Uuid::new_v4(),
org_id: org,
corpus: "docs".into(),
doc_id: uuid::Uuid::new_v4(),
chunk_idx: 0,
text: "would-be-retrieved".into(),
embedding: vec![1.0, 0.0],
embedding_model: "other".into(),
metadata: json!({}),
})
.await
.unwrap();
let mut messages = vec![json!({
"role": "user",
"content": r#"<retrievable corpus="docs">original-payload</retrievable>"#
})];
let report = substitute_in_messages(&mut messages, org, &store, &embedder)
.await
.unwrap();
assert_eq!(report.substitutions, 0);
assert_eq!(report.low_confidence_skips, 1);
assert!(messages[0]["content"]
.as_str()
.unwrap()
.contains("original-payload"));
}
}