use std::time::Duration;
use futures::StreamExt as _;
use tracing::info_span;
use zeph_common::memory::TokenCounting;
use zeph_common::{ContextFidelity, PlannedToolHint};
use zeph_llm::LlmProviderDyn;
use zeph_llm::provider::{EmbedFuture, Message, MessageMetadata, MessagePart, Role};
use crate::assembler::CORRECTIONS_PREFIX;
pub use zeph_config::FidelityConfig;
pub async fn embed_prepass<F>(
messages: &[Message],
embed: &F,
config: &FidelityConfig,
inserted_count: usize,
) -> std::collections::HashMap<usize, Vec<f32>>
where
F: Fn(&str) -> EmbedFuture + Send + Sync,
{
let concurrency = if config.embed_concurrency == 0 {
tracing::warn!(
"embed_concurrency is 0, clamping to 1; set a positive value in [context.fidelity]"
);
1
} else {
config.embed_concurrency
};
let tasks = messages.iter().enumerate().filter_map(|(i, msg)| {
if is_exempt(msg, i, inserted_count)
|| msg.content.is_empty()
|| msg.metadata.embedding.is_some()
{
return None;
}
let content = match config.max_embed_input_tokens {
Some(n) => truncate_to_byte_limit(&msg.content, n.saturating_mul(4)),
None => msg.content.clone(),
};
Some((i, content))
});
futures::stream::iter(tasks)
.map(|(i, content)| async move {
let result = tokio::time::timeout(Duration::from_secs(30), embed(&content)).await;
match result {
Ok(Ok(vec)) => Some((i, vec)),
Ok(Err(e)) => {
tracing::debug!(idx = i, err = %e, "embed_prepass: embed failed, skipping");
None
}
Err(_) => {
tracing::warn!(idx = i, "embed_prepass: embed timed out, skipping");
None
}
}
})
.buffer_unordered(concurrency)
.filter_map(|opt| async move { opt })
.collect()
.await
}
fn truncate_to_byte_limit(s: &str, max_bytes: usize) -> String {
if s.len() <= max_bytes {
return s.to_string();
}
let boundary = s.floor_char_boundary(max_bytes);
s[..boundary].to_string()
}
struct FidelityScore {
score: f32,
level: ContextFidelity,
original_tokens: u32,
}
pub struct FidelityScorer;
impl FidelityScorer {
#[allow(clippy::too_many_arguments)]
pub async fn score_and_apply(
&self,
messages: &mut Vec<Message>,
query: &str,
planned_tools: &[PlannedToolHint],
config: &FidelityConfig,
tc: &dyn TokenCounting,
inserted_count: usize,
allow_upgrade: bool,
embed_provider: Option<&dyn LlmProviderDyn>,
compress_provider: Option<&dyn LlmProviderDyn>,
) {
if !config.enabled || messages.is_empty() {
return;
}
let query_embedding: Option<Vec<f32>> = if let (true, Some(p)) =
(config.semantic_scoring_provider.is_some(), embed_provider)
&& p.supports_embeddings()
{
let _span = info_span!("context.fidelity.embed_query").entered();
match tokio::time::timeout(Duration::from_secs(30), p.embed(query)).await {
Ok(Ok(v)) => Some(v),
Ok(Err(e)) => {
tracing::warn!(error = %e, "semantic scoring provider unavailable, falling back to keyword");
None
}
Err(_) => {
tracing::warn!("fidelity query embed timed out, falling back to keyword");
None
}
}
} else {
None
};
if let (Some(q_emb), Some(p)) = (&query_embedding, embed_provider) {
let n = messages.len();
let score_end = if n > config.max_scored_messages {
n.saturating_sub(config.exempt_tail_messages)
} else {
n
};
let concurrency = if config.embed_concurrency == 0 {
1
} else {
config.embed_concurrency
};
let _span = info_span!("context.fidelity.embed_prepass").entered();
let embeddings: std::collections::HashMap<usize, Vec<f32>> =
futures::stream::iter(messages[..score_end].iter().enumerate().filter_map(
|(i, msg)| {
if msg.metadata.embedding.is_none()
&& !is_exempt(msg, i, inserted_count)
&& !msg.content.is_empty()
{
let content = match config.max_embed_input_tokens {
Some(n) => {
truncate_to_byte_limit(&msg.content, n.saturating_mul(4))
}
None => msg.content.clone(),
};
Some((i, content))
} else {
None
}
},
))
.map(|(i, content)| async move {
let result =
tokio::time::timeout(Duration::from_secs(30), p.embed(&content)).await;
match result {
Ok(Ok(v)) => Some((i, v)),
Ok(Err(e)) => {
tracing::warn!(error = %e, "message embed failed, skipping");
None
}
Err(_) => {
tracing::warn!(idx = i, "fidelity message embed timed out, skipping");
None
}
}
})
.buffer_unordered(concurrency)
.filter_map(|opt| async move { opt })
.collect()
.await;
for (i, emb) in embeddings {
messages[i].metadata.embedding = Some(emb);
}
let _ = q_emb; }
let scores = compute_scores(
messages,
query,
planned_tools,
config,
tc,
inserted_count,
allow_upgrade,
query_embedding.as_deref(),
);
apply_scores(messages, &scores, config, tc, compress_provider).await;
let _merge_span = info_span!("context.fidelity.merge").entered();
let merged_count = merge_consecutive_placeholders(messages);
tracing::debug!(merged_count, "fidelity merge complete");
}
}
#[allow(clippy::too_many_arguments)]
fn compute_scores(
messages: &[Message],
query: &str,
planned_tools: &[PlannedToolHint],
config: &FidelityConfig,
tc: &dyn TokenCounting,
inserted_count: usize,
allow_upgrade: bool,
query_embedding: Option<&[f32]>,
) -> Vec<Option<FidelityScore>> {
let n = messages.len();
let score_end = if n > config.max_scored_messages {
n.saturating_sub(config.exempt_tail_messages)
} else {
n
};
let semantic_active = query.len() >= config.min_query_length;
let plan_active = !planned_tools.is_empty();
let query_words: std::collections::HashSet<&str> = if semantic_active {
query.split_whitespace().collect()
} else {
std::collections::HashSet::default()
};
let mut weight_sum = config.w_temporal + config.w_importance;
if semantic_active {
weight_sum += config.w_semantic;
}
if plan_active {
weight_sum += config.w_plan;
}
if weight_sum <= 0.0 {
weight_sum = 1.0;
}
#[allow(clippy::cast_precision_loss)]
let max_dist = score_end.saturating_sub(1) as f32;
let mut scores: Vec<Option<FidelityScore>> = (0..n).map(|_| None).collect();
for (i, msg) in messages.iter().enumerate().take(score_end) {
if is_exempt(msg, i, inserted_count) {
continue;
}
#[allow(clippy::cast_possible_truncation)]
let original_tokens = tc.count_tokens(&msg.content) as u32;
#[allow(clippy::cast_precision_loss)]
let temporal = if max_dist > 0.0 {
let distance_from_end = (score_end - 1 - i) as f32;
1.0 - distance_from_end / max_dist
} else {
1.0
};
let importance = if msg
.parts
.iter()
.any(|p| matches!(p, MessagePart::ToolResult { .. }))
{
0.4
} else {
role_weight(msg.role)
};
let semantic = if semantic_active {
match (query_embedding, msg.metadata.embedding.as_deref()) {
(Some(q_emb), Some(m_emb)) => semantic_overlap(m_emb, q_emb),
_ => keyword_overlap(&msg.content, &query_words),
}
} else {
0.0
};
let plan = if plan_active {
plan_relevance(&msg.content, planned_tools)
} else {
0.0
};
let raw = config.w_temporal * temporal
+ config.w_importance * importance
+ if semantic_active {
config.w_semantic * semantic
} else {
0.0
}
+ if plan_active {
config.w_plan * plan
} else {
0.0
};
let score = (raw / weight_sum).clamp(0.0, 1.0);
let candidate_level = score_to_level(score, config);
let level = if allow_upgrade {
candidate_level
} else {
match msg.metadata.fidelity_tag {
Some(ContextFidelity::Placeholder) => ContextFidelity::Placeholder,
Some(ContextFidelity::Compressed) => {
if candidate_level == ContextFidelity::Full {
ContextFidelity::Compressed
} else {
candidate_level
}
}
_ => candidate_level,
}
};
scores[i] = Some(FidelityScore {
score,
level,
original_tokens,
});
}
apply_tool_pair_atomicity(messages, &mut scores, config);
scores
}
async fn apply_scores(
messages: &mut [Message],
scores: &[Option<FidelityScore>],
config: &FidelityConfig,
tc: &dyn TokenCounting,
provider: Option<&dyn LlmProviderDyn>,
) {
let _apply_span = info_span!("context.fidelity.apply").entered();
let (mut full_count, mut compressed_count, mut placeholder_count, mut tokens_saved) =
(0u32, 0u32, 0u32, 0u32);
for (i, msg) in messages.iter_mut().enumerate() {
let Some(ref fs) = scores[i] else { continue };
match fs.level {
ContextFidelity::Compressed => {
#[allow(clippy::cast_possible_truncation)]
let original_tokens = fs.original_tokens;
render_compressed(msg, config, tc, provider).await;
#[allow(clippy::cast_possible_truncation)]
let new_tokens = tc.count_tokens(&msg.content) as u32;
tokens_saved += original_tokens.saturating_sub(new_tokens);
compressed_count += 1;
}
ContextFidelity::Placeholder => {
render_placeholder(msg, fs.score, fs.original_tokens);
placeholder_count += 1;
}
_ => {
msg.metadata.fidelity_tag = Some(ContextFidelity::Full);
full_count += 1;
}
}
}
tracing::debug!(
full_count,
compressed_count,
placeholder_count,
tokens_saved,
"fidelity apply complete"
);
}
fn is_exempt(msg: &Message, idx: usize, inserted_count: usize) -> bool {
(idx == 0 && msg.role == Role::System)
|| msg.metadata.focus_pinned
|| msg.content.starts_with(CORRECTIONS_PREFIX)
|| (idx >= 1 && idx < 1 + inserted_count)
}
fn role_weight(role: Role) -> f32 {
match role {
Role::System => 1.0,
Role::User => 0.8,
Role::Assistant => 0.6,
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
(dot / (norm_a * norm_b)).clamp(0.0, 1.0)
}
fn semantic_overlap(msg_embedding: &[f32], query_embedding: &[f32]) -> f32 {
cosine_similarity(msg_embedding, query_embedding)
}
fn keyword_overlap(content: &str, query_words: &std::collections::HashSet<&str>) -> f32 {
let content_words: std::collections::HashSet<&str> = content.split_whitespace().collect();
let min_len = content_words.len().min(query_words.len());
if min_len == 0 {
return 0.0;
}
#[allow(clippy::cast_precision_loss)]
let result = content_words.intersection(query_words).count() as f32 / min_len as f32;
result.clamp(0.0, 1.0)
}
fn plan_relevance(content: &str, planned_tools: &[PlannedToolHint]) -> f32 {
if planned_tools.is_empty() {
return 0.0;
}
let content_words: std::collections::HashSet<&str> = content.split_whitespace().collect();
let mut weighted_sum = 0.0f32;
let mut weight_total = 0.0f32;
for hint in planned_tools {
let dist = f32::from(hint.distance_from_current.max(1));
let weight = 1.0 / dist;
weight_total += weight;
let hint_words: std::collections::HashSet<&str> =
hint.keywords.iter().map(String::as_str).collect();
let min_len = content_words.len().min(hint_words.len());
if min_len == 0 {
continue;
}
#[allow(clippy::cast_precision_loss)]
let overlap = content_words.intersection(&hint_words).count() as f32 / min_len as f32;
weighted_sum += weight * overlap.clamp(0.0, 1.0);
}
if weight_total <= 0.0 {
return 0.0;
}
(weighted_sum / weight_total).clamp(0.0, 1.0)
}
fn apply_tool_pair_atomicity(
messages: &[Message],
scores: &mut [Option<FidelityScore>],
config: &FidelityConfig,
) {
let mut tool_result_map: std::collections::HashMap<&str, usize> =
std::collections::HashMap::new();
for (i, msg) in messages.iter().enumerate() {
for part in &msg.parts {
if let MessagePart::ToolResult { tool_use_id, .. } = part {
tool_result_map.insert(tool_use_id.as_str(), i);
}
}
}
for (i, msg) in messages.iter().enumerate().rev() {
for part in &msg.parts {
if let MessagePart::ToolUse { id, .. } = part
&& let Some(&result_idx) = tool_result_map.get(id.as_str())
{
let score_a = scores[i].as_ref().map_or(1.0, |s| s.score);
let score_b = scores[result_idx].as_ref().map_or(1.0, |s| s.score);
let min_score = score_a.min(score_b);
let level_a = scores[i]
.as_ref()
.map_or(ContextFidelity::Full, |s| s.level);
let level_b = scores[result_idx]
.as_ref()
.map_or(ContextFidelity::Full, |s| s.level);
let float_level = score_to_level(min_score, config);
let min_level = more_restrictive(more_restrictive(level_a, level_b), float_level);
let tokens_a = scores[i].as_ref().map_or(0, |s| s.original_tokens);
let tokens_b = scores[result_idx].as_ref().map_or(0, |s| s.original_tokens);
scores[i] = Some(FidelityScore {
score: min_score,
level: min_level,
original_tokens: tokens_a,
});
scores[result_idx] = Some(FidelityScore {
score: min_score,
level: min_level,
original_tokens: tokens_b,
});
}
}
}
}
fn more_restrictive(a: ContextFidelity, b: ContextFidelity) -> ContextFidelity {
use ContextFidelity::{Compressed, Full, Placeholder};
match (a, b) {
(Placeholder, _) | (_, Placeholder) => Placeholder,
(Compressed, _) | (_, Compressed) => Compressed,
_ => Full,
}
}
fn score_to_level(score: f32, config: &FidelityConfig) -> ContextFidelity {
if score >= config.full_threshold {
ContextFidelity::Full
} else if score >= config.compressed_threshold {
ContextFidelity::Compressed
} else {
ContextFidelity::Placeholder
}
}
async fn render_compressed(
msg: &mut Message,
config: &FidelityConfig,
tc: &dyn TokenCounting,
provider: Option<&dyn LlmProviderDyn>,
) {
if let Some(summary) = msg.metadata.deferred_summary.take() {
msg.content = summary;
} else if config.compress_provider.is_some()
&& let Some(p) = provider
{
let input_tokens = tc.count_tokens(&msg.content);
if input_tokens > config.compressed_max_tokens * 2 && input_tokens > 0 {
if let Some(max_in) = config.max_compress_input_tokens {
apply_input_cap(&mut msg.content, max_in);
}
let prompt = format!(
"Summarize in {} tokens or fewer: {}",
config.compressed_max_tokens, msg.content
);
let req = vec![Message {
role: Role::User,
content: prompt,
parts: vec![],
metadata: MessageMetadata::default(),
}];
let span = info_span!(
"context.fidelity.compress_llm",
input_tokens,
cached = false,
);
let result = {
let _enter = span.enter();
tokio::time::timeout(Duration::from_secs(30), p.chat(&req)).await
};
match result {
Ok(Ok(summary)) => {
msg.metadata.deferred_summary = Some(summary.clone());
msg.content = summary;
}
Ok(Err(e)) => {
tracing::debug!(error = %e, "compress_llm failed, falling back to truncation");
}
Err(_) => {
tracing::warn!("compress_llm timed out, falling back to truncation");
}
}
}
} else if let Some(max_in) = config.max_compress_input_tokens {
apply_input_cap(&mut msg.content, max_in);
}
truncate_to_tokens(&mut msg.content, config.compressed_max_tokens, tc);
msg.parts.clear();
msg.metadata.fidelity_tag = Some(ContextFidelity::Compressed);
}
pub fn apply_input_cap(content: &mut String, max_tokens: usize) {
let max_bytes = max_tokens.saturating_mul(4);
if content.len() > max_bytes {
let boundary = content.floor_char_boundary(max_bytes);
content.truncate(boundary);
}
}
fn truncate_to_tokens(content: &mut String, max_tokens: usize, tc: &dyn TokenCounting) {
if tc.count_tokens(content) <= max_tokens {
return;
}
let mut lo: usize = 0;
let mut hi: usize = content.len();
while hi - lo > 1 {
let mid = content.floor_char_boundary(usize::midpoint(lo, hi));
if mid == lo {
hi = mid;
} else if tc.count_tokens(&content[..mid]) <= max_tokens {
lo = mid;
} else {
hi = mid;
}
}
content.truncate(lo);
}
fn render_placeholder(msg: &mut Message, score: f32, original_tokens: u32) {
let role_str = match msg.role {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
};
msg.content = format!(
"[placeholder: role={role_str}, original_tokens={original_tokens}, importance={score:.2}]"
);
msg.parts.clear();
msg.metadata.fidelity_tag = Some(ContextFidelity::Placeholder);
}
fn merge_consecutive_placeholders(messages: &mut Vec<Message>) -> usize {
let mut merged_count = 0usize;
let mut i = 0;
while i < messages.len() {
if messages[i].metadata.fidelity_tag != Some(ContextFidelity::Placeholder)
|| messages[i].role == Role::System
{
i += 1;
continue;
}
let role = messages[i].role;
let mut j = i + 1;
while j < messages.len()
&& messages[j].metadata.fidelity_tag == Some(ContextFidelity::Placeholder)
&& messages[j].role == role
{
j += 1;
}
if j - i <= 1 {
i += 1;
continue;
}
let count = j - i;
let mut total_tokens = 0u32;
let mut importance_sum = 0.0f32;
for msg in &messages[i..j] {
total_tokens += parse_placeholder_tokens(&msg.content);
importance_sum += parse_placeholder_importance(&msg.content);
}
debug_assert!(count >= 2, "placeholder merge triggered with count={count}");
#[allow(clippy::cast_precision_loss)]
let avg_importance = if count > 0 {
importance_sum / count as f32
} else {
0.0
};
let role_str = match role {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
};
let merged_content = format!(
"[placeholder: {count} messages, role={role_str}, total_tokens={total_tokens}, avg_importance={avg_importance:.2}]"
);
let first = messages[i].clone();
messages.drain(i..j);
messages.insert(
i,
Message {
role: first.role,
content: merged_content,
parts: vec![],
metadata: {
let mut m = first.metadata;
m.fidelity_tag = Some(ContextFidelity::Placeholder);
m
},
},
);
merged_count += count - 1;
i += 1;
}
merged_count
}
fn parse_placeholder_tokens(content: &str) -> u32 {
for part in content.split(',') {
let part = part.trim();
for prefix in &["original_tokens=", "total_tokens="] {
if let Some(rest) = part.strip_prefix(prefix)
&& let Ok(n) = rest.trim_end_matches(']').trim().parse::<u32>()
{
return n;
}
}
}
0
}
fn parse_placeholder_importance(content: &str) -> f32 {
for part in content.split(',') {
let part = part.trim();
for prefix in &["importance=", "avg_importance="] {
if let Some(rest) = part.strip_prefix(prefix)
&& let Ok(v) = rest.trim_end_matches(']').trim().parse::<f32>()
{
return v;
}
}
}
0.0
}
#[cfg(test)]
mod tests {
use super::*;
use zeph_llm::provider::{Message, MessageMetadata, MessagePart, Role};
struct FixedTc(usize);
impl TokenCounting for FixedTc {
fn count_tokens(&self, text: &str) -> usize {
text.len() / self.0.max(1)
}
fn count_tool_schema_tokens(&self, _schema: &serde_json::Value) -> usize {
0
}
}
fn make_msg(role: Role, content: &str) -> Message {
Message {
role,
content: content.to_string(),
parts: vec![],
metadata: MessageMetadata::default(),
}
}
fn make_cfg() -> FidelityConfig {
FidelityConfig {
enabled: true,
w_semantic: 0.3,
w_temporal: 0.3,
w_importance: 0.2,
w_plan: 0.2,
full_threshold: 0.7,
compressed_threshold: 0.3,
compressed_max_tokens: 50,
regrade_threshold: 0.6,
min_query_length: 8,
max_scored_messages: 500,
exempt_tail_messages: 0,
compress_provider: None,
semantic_scoring_provider: None,
lookahead_depth: 3,
embed_concurrency: 32,
max_embed_input_tokens: None,
max_compress_input_tokens: None,
}
}
#[tokio::test]
async fn empty_window_no_change() {
let scorer = FidelityScorer;
let cfg = make_cfg();
let tc = FixedTc(4);
let mut messages: Vec<Message> = vec![];
scorer
.score_and_apply(
&mut messages,
"query text",
&[],
&cfg,
&tc,
0,
false,
None,
None,
)
.await;
assert!(messages.is_empty());
}
#[tokio::test]
async fn all_exempt_no_downgrade() {
let scorer = FidelityScorer;
let cfg = make_cfg();
let tc = FixedTc(4);
let mut messages = vec![
make_msg(Role::System, "system prompt"),
make_msg(Role::User, "memory context"),
];
scorer
.score_and_apply(&mut messages, "short", &[], &cfg, &tc, 1, false, None, None)
.await;
for msg in &messages {
assert!(
msg.metadata.fidelity_tag.is_none()
|| msg.metadata.fidelity_tag == Some(ContextFidelity::Full)
);
}
}
#[tokio::test]
async fn tool_pair_atomicity() {
let scorer = FidelityScorer;
let cfg = FidelityConfig {
full_threshold: 0.9,
compressed_threshold: 0.5,
..make_cfg()
};
let tc = FixedTc(4);
let tool_use_id = "abc123".to_string();
let mut tool_use_msg = make_msg(Role::Assistant, "calling tool");
tool_use_msg.parts = vec![MessagePart::ToolUse {
id: tool_use_id.clone(),
name: "shell".to_string(),
input: serde_json::json!({}),
}];
let mut tool_result_msg = make_msg(Role::User, "tool result body");
tool_result_msg.parts = vec![MessagePart::ToolResult {
tool_use_id: tool_use_id.clone(),
content: "result".to_string(),
is_error: false,
}];
let mut messages = vec![
make_msg(Role::System, "system"),
tool_use_msg,
tool_result_msg,
];
scorer
.score_and_apply(
&mut messages,
"completely unrelated query blah",
&[],
&cfg,
&tc,
0,
false,
None,
None,
)
.await;
let tag_a = messages[1].metadata.fidelity_tag;
let tag_b = messages[2].metadata.fidelity_tag;
assert_eq!(tag_a, tag_b, "tool pair must share fidelity level");
}
#[tokio::test]
async fn same_role_placeholder_merge() {
let scorer = FidelityScorer;
let cfg = FidelityConfig {
full_threshold: 2.0, compressed_threshold: 1.5, ..make_cfg()
};
let tc = FixedTc(4);
let mut messages: Vec<Message> = std::iter::once(make_msg(Role::System, "system"))
.chain((0..5).map(|i| make_msg(Role::Assistant, &format!("msg {i}"))))
.collect();
scorer
.score_and_apply(
&mut messages,
"some query here",
&[],
&cfg,
&tc,
0,
false,
None,
None,
)
.await;
assert_eq!(
messages.len(),
2,
"5 assistant placeholders must merge to 1"
);
assert!(messages[1].content.contains("5 messages"));
}
#[tokio::test]
async fn score_normalization_no_panic() {
let scorer = FidelityScorer;
let cfg = make_cfg();
let tc = FixedTc(4);
let mut messages = vec![
make_msg(Role::System, "system"),
make_msg(Role::User, "hello"),
make_msg(Role::Assistant, "world response"),
];
scorer
.score_and_apply(
&mut messages,
"hello world signal",
&[],
&cfg,
&tc,
0,
false,
None,
None,
)
.await;
for msg in &messages {
let _ = msg.metadata.fidelity_tag;
}
}
#[tokio::test]
async fn short_query_fallback() {
let scorer = FidelityScorer;
let cfg = FidelityConfig {
min_query_length: 8,
..make_cfg()
};
let tc = FixedTc(4);
let mut messages = vec![
make_msg(Role::System, "system"),
make_msg(Role::User, "test"),
];
scorer
.score_and_apply(&mut messages, "short", &[], &cfg, &tc, 0, false, None, None)
.await;
}
#[tokio::test]
async fn memory_first_bypass_is_callers_responsibility() {
let scorer = FidelityScorer;
let cfg = FidelityConfig {
enabled: false,
..make_cfg()
};
let tc = FixedTc(4);
let mut messages = vec![
make_msg(Role::System, "system prompt"),
make_msg(Role::User, "memory-injected context"),
make_msg(Role::Assistant, "response"),
];
let before: Vec<_> = messages.iter().map(|m| m.content.clone()).collect();
scorer
.score_and_apply(
&mut messages,
"some user query text here",
&[],
&cfg,
&tc,
2,
false,
None,
None,
)
.await;
for (msg, orig) in messages.iter().zip(&before) {
assert_eq!(msg.content, *orig, "content must be unchanged");
assert!(
msg.metadata.fidelity_tag.is_none(),
"no fidelity tag must be set"
);
}
}
#[tokio::test]
async fn enabled_false_guard() {
let scorer = FidelityScorer;
let cfg = FidelityConfig {
enabled: false,
..make_cfg()
};
let tc = FixedTc(4);
let mut messages = vec![
make_msg(Role::System, "system"),
make_msg(Role::User, "user message that would normally be scored"),
];
let original_contents: Vec<String> = messages.iter().map(|m| m.content.clone()).collect();
scorer
.score_and_apply(
&mut messages,
"query text here",
&[],
&cfg,
&tc,
0,
false,
None,
None,
)
.await;
for (msg, orig) in messages.iter().zip(&original_contents) {
assert_eq!(msg.content, *orig);
assert!(msg.metadata.fidelity_tag.is_none());
}
}
#[tokio::test]
async fn score_always_in_range() {
let scorer = FidelityScorer;
let cfg = FidelityConfig {
enabled: true,
w_semantic: 0.0,
w_temporal: 0.0,
w_importance: 0.0,
w_plan: 0.0,
full_threshold: 0.7,
compressed_threshold: 0.3,
compressed_max_tokens: 50,
regrade_threshold: 0.6,
min_query_length: 0,
max_scored_messages: 500,
exempt_tail_messages: 0,
compress_provider: None,
semantic_scoring_provider: None,
lookahead_depth: 3,
embed_concurrency: 32,
max_embed_input_tokens: None,
max_compress_input_tokens: None,
};
let tc = FixedTc(4);
let mut messages = vec![make_msg(Role::System, ""), make_msg(Role::User, "")];
scorer
.score_and_apply(&mut messages, "", &[], &cfg, &tc, 0, false, None, None)
.await;
}
#[tokio::test]
async fn placeholder_uses_tc_count_tokens() {
let scorer = FidelityScorer;
let cfg = FidelityConfig {
full_threshold: 2.0,
compressed_threshold: 1.5,
..make_cfg()
};
let tc = FixedTc(1); let mut messages = vec![
make_msg(Role::System, "system"),
make_msg(Role::User, "user message content for placeholder rendering"),
];
scorer
.score_and_apply(
&mut messages,
"some query text here",
&[],
&cfg,
&tc,
0,
false,
None,
None,
)
.await;
assert_eq!(
messages[1].metadata.fidelity_tag,
Some(ContextFidelity::Placeholder)
);
assert!(messages[1].content.starts_with("[placeholder:"));
}
#[tokio::test]
async fn exempt_tail_messages_large_window() {
let scorer = FidelityScorer;
let cfg = FidelityConfig {
full_threshold: 2.0,
compressed_threshold: 1.5,
max_scored_messages: 10,
exempt_tail_messages: 5,
..make_cfg()
};
let tc = FixedTc(4);
let mut messages: Vec<Message> = std::iter::once(make_msg(Role::System, "system prompt"))
.chain((1..15).map(|i| make_msg(Role::Assistant, &format!("assistant message {i}"))))
.chain((15..20).map(|i| {
let mut m = make_msg(Role::User, &format!("tail message {i}"));
m.metadata.focus_pinned = true;
m
}))
.collect();
scorer
.score_and_apply(
&mut messages,
"query text here long",
&[],
&cfg,
&tc,
0,
false,
None,
None,
)
.await;
let tail: Vec<_> = messages
.iter()
.filter(|m| m.metadata.focus_pinned)
.collect();
assert_eq!(
tail.len(),
5,
"all 5 tail messages must survive the merge pass"
);
for msg in &tail {
assert!(
msg.metadata.fidelity_tag.is_none(),
"tail message must have no fidelity_tag, got {:?}",
msg.metadata.fidelity_tag
);
}
}
#[tokio::test]
async fn exempt_tail_messages_small_window_no_effect() {
let scorer = FidelityScorer;
let cfg = FidelityConfig {
full_threshold: 2.0,
compressed_threshold: 1.5,
max_scored_messages: 10,
exempt_tail_messages: 5,
..make_cfg()
};
let tc = FixedTc(4);
let roles = [Role::User, Role::Assistant];
let mut messages: Vec<Message> = std::iter::once(make_msg(Role::System, "system prompt"))
.chain((1..8usize).map(|i| make_msg(roles[i % 2], &format!("message {i}"))))
.collect();
scorer
.score_and_apply(
&mut messages,
"query text here long",
&[],
&cfg,
&tc,
0,
false,
None,
None,
)
.await;
let untagged_count = messages[1..]
.iter()
.filter(|m| m.metadata.fidelity_tag.is_none())
.count();
assert_eq!(
untagged_count, 0,
"all non-system messages must be scored when n <= max_scored_messages"
);
}
#[tokio::test]
async fn compressed_uses_deferred_summary() {
let scorer = FidelityScorer;
let cfg = FidelityConfig {
full_threshold: 2.0, compressed_threshold: 0.0, compressed_max_tokens: 5,
..make_cfg()
};
let tc = FixedTc(4);
let mut msg_with_summary =
make_msg(Role::User, "original long content that would be truncated");
msg_with_summary.metadata.deferred_summary = Some("short summary".to_string());
let mut messages = vec![make_msg(Role::System, "system"), msg_with_summary];
scorer
.score_and_apply(
&mut messages,
"query text here long",
&[],
&cfg,
&tc,
0,
false,
None,
None,
)
.await;
assert_eq!(
messages[1].metadata.fidelity_tag,
Some(ContextFidelity::Compressed)
);
assert_eq!(messages[1].content, "short summary");
}
fn make_msg_with_fidelity(role: Role, content: &str, tag: Option<ContextFidelity>) -> Message {
let mut m = make_msg(role, content);
m.metadata.fidelity_tag = tag;
m
}
#[tokio::test]
async fn floor_prevents_compressed_upgrade_to_full() {
let scorer = FidelityScorer;
let cfg = FidelityConfig {
full_threshold: 0.0,
compressed_threshold: -1.0,
..make_cfg()
};
let tc = FixedTc(4);
let mut messages = vec![
make_msg(Role::System, "system"),
make_msg_with_fidelity(
Role::User,
"query text here long keyword",
Some(ContextFidelity::Compressed),
),
];
scorer
.score_and_apply(
&mut messages,
"query text here long keyword",
&[],
&cfg,
&tc,
0,
false,
None,
None,
)
.await;
assert_eq!(
messages[1].metadata.fidelity_tag,
Some(ContextFidelity::Compressed),
"Compressed floor must block upgrade to Full"
);
}
#[tokio::test]
async fn floor_prevents_placeholder_upgrade_to_full() {
let scorer = FidelityScorer;
let cfg = FidelityConfig {
full_threshold: 0.0,
compressed_threshold: -1.0,
..make_cfg()
};
let tc = FixedTc(4);
let mut messages = vec![
make_msg(Role::System, "system"),
make_msg_with_fidelity(
Role::User,
"query text here long keyword",
Some(ContextFidelity::Placeholder),
),
];
scorer
.score_and_apply(
&mut messages,
"query text here long keyword",
&[],
&cfg,
&tc,
0,
false,
None,
None,
)
.await;
assert_eq!(
messages[1].metadata.fidelity_tag,
Some(ContextFidelity::Placeholder),
"Placeholder floor must block upgrade to Full"
);
}
#[tokio::test]
async fn floor_prevents_placeholder_upgrade_to_compressed() {
let scorer = FidelityScorer;
let cfg = FidelityConfig {
full_threshold: 2.0,
compressed_threshold: 0.0,
..make_cfg()
};
let tc = FixedTc(4);
let mut messages = vec![
make_msg(Role::System, "system"),
make_msg_with_fidelity(
Role::User,
"message content",
Some(ContextFidelity::Placeholder),
),
];
scorer
.score_and_apply(
&mut messages,
"query text here long",
&[],
&cfg,
&tc,
0,
false,
None,
None,
)
.await;
assert_eq!(
messages[1].metadata.fidelity_tag,
Some(ContextFidelity::Placeholder),
"Placeholder floor must block upgrade to Compressed"
);
}
#[tokio::test]
async fn floor_allows_further_downgrade() {
let scorer = FidelityScorer;
let cfg = FidelityConfig {
full_threshold: 2.0,
compressed_threshold: 2.0, ..make_cfg()
};
let tc = FixedTc(4);
let mut messages = vec![
make_msg(Role::System, "system"),
make_msg_with_fidelity(
Role::User,
"some content",
Some(ContextFidelity::Compressed),
),
];
scorer
.score_and_apply(
&mut messages,
"query text here long",
&[],
&cfg,
&tc,
0,
false,
None,
None,
)
.await;
assert_eq!(
messages[1].metadata.fidelity_tag,
Some(ContextFidelity::Placeholder),
"downgrade from Compressed to Placeholder must be allowed"
);
}
#[tokio::test]
async fn floor_no_constraint_when_none() {
let scorer = FidelityScorer;
let cfg = FidelityConfig {
full_threshold: 0.0,
compressed_threshold: -1.0,
..make_cfg()
};
let tc = FixedTc(4);
let mut messages = vec![
make_msg(Role::System, "system"),
make_msg_with_fidelity(Role::User, "query text here long keyword", None),
];
scorer
.score_and_apply(
&mut messages,
"query text here long keyword",
&[],
&cfg,
&tc,
0,
false,
None,
None,
)
.await;
assert_eq!(
messages[1].metadata.fidelity_tag,
Some(ContextFidelity::Full),
"None tag must not constrain scoring"
);
}
#[tokio::test]
async fn allow_upgrade_bypasses_floor() {
let scorer = FidelityScorer;
let cfg = FidelityConfig {
full_threshold: 0.0,
compressed_threshold: -1.0,
..make_cfg()
};
let tc = FixedTc(4);
let mut messages = vec![
make_msg(Role::System, "system"),
make_msg_with_fidelity(
Role::User,
"query text here long keyword",
Some(ContextFidelity::Placeholder),
),
];
scorer
.score_and_apply(
&mut messages,
"query text here long keyword",
&[],
&cfg,
&tc,
0,
true,
None,
None,
)
.await;
assert_eq!(
messages[1].metadata.fidelity_tag,
Some(ContextFidelity::Full),
"allow_upgrade=true must bypass the Placeholder floor"
);
}
#[test]
fn truncate_no_op_below_limit() {
let tc = FixedTc(1); let mut s = "hello".to_string(); truncate_to_tokens(&mut s, 10, &tc);
assert_eq!(s, "hello");
}
#[test]
fn truncate_no_op_at_limit() {
let tc = FixedTc(1);
let mut s = "hello".to_string(); truncate_to_tokens(&mut s, 5, &tc);
assert_eq!(s, "hello");
}
#[test]
fn truncate_minimal_one_over_limit() {
let tc = FixedTc(1); let mut s = "abcdef".to_string(); truncate_to_tokens(&mut s, 5, &tc);
assert!(
tc.count_tokens(&s) <= 5,
"result must fit in 5 tokens, got {}",
tc.count_tokens(&s)
);
assert!(!s.is_empty(), "must keep prefix, not empty");
}
#[test]
fn truncate_preserves_90pct_of_limit() {
let tc = FixedTc(1);
let s_orig = "a".repeat(90);
let mut s = s_orig.clone();
truncate_to_tokens(&mut s, 100, &tc);
assert_eq!(s, s_orig, "90% of limit must not be truncated");
}
#[test]
fn truncate_empty_string_no_op() {
let tc = FixedTc(1);
let mut s = String::new();
truncate_to_tokens(&mut s, 5, &tc);
assert!(s.is_empty());
}
#[test]
fn truncate_max_tokens_zero_clears_content() {
let tc = FixedTc(1);
let mut s = "hello world".to_string();
truncate_to_tokens(&mut s, 0, &tc);
assert!(s.is_empty(), "max_tokens=0 must clear content");
}
#[test]
fn truncate_multibyte_stays_on_char_boundary() {
let tc = FixedTc(3);
let mut s = "日本語".to_string();
truncate_to_tokens(&mut s, 2, &tc);
assert!(
s.is_char_boundary(s.len()),
"result must be on a valid char boundary"
);
assert!(tc.count_tokens(&s) <= 2);
assert_eq!(s, "日本");
}
#[tokio::test]
async fn mixed_fidelity_tool_pair_floor_plus_atomicity() {
let scorer = FidelityScorer;
let cfg = FidelityConfig {
full_threshold: 0.0,
compressed_threshold: -1.0,
..make_cfg()
};
let tc = FixedTc(4);
let tool_id = "tool-42".to_string();
let mut tool_use_msg = make_msg_with_fidelity(Role::Assistant, "call tool", None);
tool_use_msg.parts = vec![MessagePart::ToolUse {
id: tool_id.clone(),
name: "shell".to_string(),
input: serde_json::json!({}),
}];
let mut tool_result_msg =
make_msg_with_fidelity(Role::User, "result body", Some(ContextFidelity::Compressed));
tool_result_msg.parts = vec![MessagePart::ToolResult {
tool_use_id: tool_id.clone(),
content: "output".to_string(),
is_error: false,
}];
let mut messages = vec![
make_msg(Role::System, "system"),
tool_use_msg,
tool_result_msg,
];
scorer
.score_and_apply(
&mut messages,
"query text here long",
&[],
&cfg,
&tc,
0,
false,
None,
None,
)
.await;
let tag_use = messages[1].metadata.fidelity_tag;
let tag_result = messages[2].metadata.fidelity_tag;
assert_eq!(
tag_use, tag_result,
"tool pair must share the same fidelity level"
);
assert_eq!(
tag_use,
Some(ContextFidelity::Compressed),
"atomicity must bring the tool-use down to the tool-result floor"
);
}
#[tokio::test]
async fn compress_llm_path_stores_deferred_summary() {
use zeph_llm::LlmError;
use zeph_llm::provider::ChatStream;
#[derive(Debug)]
struct MockProvider;
impl zeph_llm::provider::LlmProvider for MockProvider {
async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
Ok("summary text".to_string())
}
async fn chat_stream(&self, _messages: &[Message]) -> Result<ChatStream, LlmError> {
Err(LlmError::Unavailable)
}
fn supports_streaming(&self) -> bool {
false
}
async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
Err(LlmError::EmbedUnsupported {
provider: "mock".into(),
})
}
fn supports_embeddings(&self) -> bool {
false
}
fn name(&self) -> &'static str {
"mock"
}
}
let scorer = FidelityScorer;
let cfg = FidelityConfig {
enabled: true,
full_threshold: 2.0,
compressed_threshold: 0.0,
compressed_max_tokens: 5,
compress_provider: Some("mock".to_string()),
..make_cfg()
};
let tc = FixedTc(1);
let content = "a".repeat(50); let mut messages = vec![
make_msg(Role::System, "system"),
make_msg(Role::User, &content),
];
let provider = MockProvider;
scorer
.score_and_apply(
&mut messages,
"some query text here",
&[],
&cfg,
&tc,
0,
false,
None,
Some(&provider),
)
.await;
assert_eq!(
messages[1].metadata.fidelity_tag,
Some(ContextFidelity::Compressed),
);
assert!(
tc.count_tokens(&messages[1].content) <= 5,
"content must be capped to compressed_max_tokens after LLM summary"
);
assert_eq!(
messages[1].metadata.deferred_summary,
Some("summary text".to_string()),
);
}
#[tokio::test]
async fn compress_llm_skipped_when_provider_none() {
let scorer = FidelityScorer;
let cfg = FidelityConfig {
enabled: true,
full_threshold: 2.0,
compressed_threshold: 0.0,
compressed_max_tokens: 5,
compress_provider: Some("mock".to_string()),
..make_cfg()
};
let tc = FixedTc(1);
let content = "a".repeat(50);
let mut messages = vec![
make_msg(Role::System, "system"),
make_msg(Role::User, &content),
];
scorer
.score_and_apply(
&mut messages,
"some query text here",
&[],
&cfg,
&tc,
0,
false,
None,
None,
)
.await;
assert_eq!(
messages[1].metadata.fidelity_tag,
Some(ContextFidelity::Compressed),
);
assert!(
messages[1].metadata.deferred_summary.is_none(),
"deferred_summary must not be populated via truncation path"
);
assert!(
messages[1].content.len() <= 5,
"content must be truncated, got len={}",
messages[1].content.len()
);
}
#[test]
fn cosine_similarity_identical() {
let v = vec![1.0f32, 0.0, 0.0];
assert!((cosine_similarity(&v, &v) - 1.0).abs() < 1e-6);
}
#[test]
fn cosine_similarity_orthogonal() {
let a = vec![1.0f32, 0.0, 0.0];
let b = vec![0.0f32, 1.0, 0.0];
assert!(cosine_similarity(&a, &b).abs() < 1e-6);
}
#[test]
fn cosine_similarity_zero_vector() {
let a = vec![0.0f32, 0.0, 0.0];
let b = vec![1.0f32, 0.0, 0.0];
assert!(cosine_similarity(&a, &b).abs() < f32::EPSILON);
}
#[test]
fn cosine_similarity_empty() {
assert!(cosine_similarity(&[], &[]).abs() < f32::EPSILON);
}
#[test]
fn cosine_similarity_dimension_mismatch() {
let a = vec![1.0f32, 0.0];
let b = vec![1.0f32, 0.0, 0.0];
assert!(cosine_similarity(&a, &b).abs() < f32::EPSILON);
}
#[tokio::test]
async fn semantic_scoring_higher_for_similar_messages() {
use zeph_llm::LlmError;
use zeph_llm::provider::ChatStream;
#[derive(Debug)]
struct EmbedMockProvider;
impl zeph_llm::provider::LlmProvider for EmbedMockProvider {
async fn chat(&self, _: &[Message]) -> Result<String, LlmError> {
Err(LlmError::Unavailable)
}
async fn chat_stream(&self, _: &[Message]) -> Result<ChatStream, LlmError> {
Err(LlmError::Unavailable)
}
fn supports_streaming(&self) -> bool {
false
}
async fn embed(&self, text: &str) -> Result<Vec<f32>, LlmError> {
let v = if text.contains("cat")
|| text.contains("mat")
|| text.contains("feline")
|| text.contains("rug")
{
if text.contains("feline") || text.contains("rug") {
vec![0.9f32, 0.1, 0.0]
} else {
vec![1.0f32, 0.0, 0.0]
}
} else {
vec![0.0f32, 0.0, 1.0]
};
Ok(v)
}
fn supports_embeddings(&self) -> bool {
true
}
fn name(&self) -> &'static str {
"embed-mock"
}
}
let provider = EmbedMockProvider;
let scorer = FidelityScorer;
let cfg = FidelityConfig {
enabled: true,
semantic_scoring_provider: Some("embed-mock".to_string()),
full_threshold: 0.0,
compressed_threshold: 0.0,
w_semantic: 1.0,
w_temporal: 0.0,
w_importance: 0.0,
w_plan: 0.0,
..make_cfg()
};
let tc = FixedTc(4);
let cat_msg = make_msg(Role::User, "The cat is on the mat");
let feline_msg = make_msg(Role::User, "A feline rests on the rug");
let stock_msg = make_msg(Role::User, "Stock prices fell today");
let mut messages = vec![
make_msg(Role::System, "system"),
cat_msg,
feline_msg,
stock_msg,
];
scorer
.score_and_apply(
&mut messages,
"cat mat",
&[],
&cfg,
&tc,
0,
false,
Some(&provider),
None,
)
.await;
assert!(
messages[1].metadata.embedding.is_some(),
"cat message must have embedding"
);
assert!(
messages[2].metadata.embedding.is_some(),
"feline message must have embedding"
);
assert!(
messages[3].metadata.embedding.is_some(),
"stock message must have embedding"
);
let query_emb = [1.0f32, 0.0, 0.0];
let cat_emb = messages[1].metadata.embedding.as_ref().unwrap();
let feline_emb = messages[2].metadata.embedding.as_ref().unwrap();
let stock_emb = messages[3].metadata.embedding.as_ref().unwrap();
assert!(
cosine_similarity(cat_emb, &query_emb) > cosine_similarity(stock_emb, &query_emb),
"cat message must be more similar to query than stock message"
);
assert!(
cosine_similarity(feline_emb, &query_emb) > cosine_similarity(stock_emb, &query_emb),
"feline message must be more similar to query than stock message"
);
}
#[tokio::test]
async fn semantic_scoring_falls_back_to_keyword_when_provider_none() {
let scorer = FidelityScorer;
let cfg = FidelityConfig {
enabled: true,
semantic_scoring_provider: None,
..make_cfg()
};
let tc = FixedTc(4);
let mut messages = vec![
make_msg(Role::System, "system"),
make_msg(Role::User, "cat mat keyword test"),
make_msg(Role::User, "something unrelated here"),
];
scorer
.score_and_apply(
&mut messages,
"query text here long",
&[],
&cfg,
&tc,
0,
false,
None,
None,
)
.await;
for msg in &messages {
assert!(
msg.metadata.embedding.is_none(),
"no embedding must be computed when provider is None"
);
}
for msg in &messages[1..] {
assert!(
msg.metadata.fidelity_tag.is_some(),
"all non-system messages must be scored via keyword path"
);
}
}
#[tokio::test]
async fn w_plan_produces_nonzero_score_for_matching_message() {
use zeph_common::PlannedToolHint;
let scorer = FidelityScorer;
let cfg = FidelityConfig {
w_semantic: 0.0,
w_temporal: 0.0,
w_importance: 0.0,
w_plan: 1.0,
full_threshold: 0.5,
compressed_threshold: 0.1,
min_query_length: 100, ..make_cfg()
};
let tc = FixedTc(4);
let hint = PlannedToolHint::new("shell", vec!["cargo".to_string(), "build".to_string()], 1);
let mut messages = vec![
make_msg(Role::System, "system prompt"),
make_msg(Role::User, "run cargo build to compile"),
make_msg(Role::User, "what is the weather today"),
];
scorer
.score_and_apply(
&mut messages,
"q", &[hint],
&cfg,
&tc,
0,
false,
None,
None,
)
.await;
assert_eq!(
messages[1].metadata.fidelity_tag,
Some(ContextFidelity::Full),
"message matching planned tool keywords must reach Full fidelity"
);
assert_ne!(
messages[2].metadata.fidelity_tag,
Some(ContextFidelity::Full),
"message with no keyword overlap must not reach Full fidelity via w_plan"
);
}
#[test]
fn truncate_to_byte_limit_no_op_when_short() {
assert_eq!(truncate_to_byte_limit("hello", 10), "hello");
}
#[test]
fn truncate_to_byte_limit_exact_limit_no_op() {
assert_eq!(truncate_to_byte_limit("hello", 5), "hello");
}
#[test]
fn truncate_to_byte_limit_over_limit() {
let s = truncate_to_byte_limit("abcdefgh", 5);
assert_eq!(s.len(), 5);
assert_eq!(s, "abcde");
}
#[test]
fn truncate_to_byte_limit_multibyte_boundary() {
let s = truncate_to_byte_limit("日本語", 6);
assert!(s.is_char_boundary(s.len()));
assert_eq!(s, "日本");
}
#[test]
fn apply_input_cap_no_op_below_limit() {
let mut s = "hello".to_string();
apply_input_cap(&mut s, 10); assert_eq!(s, "hello");
}
#[test]
fn apply_input_cap_truncates_over_limit() {
let mut s = "abcdefgh".to_string();
apply_input_cap(&mut s, 1);
assert_eq!(s, "abcd");
}
#[test]
fn apply_input_cap_multibyte() {
let mut s = "日本語".to_string();
apply_input_cap(&mut s, 1);
assert!(s.is_char_boundary(s.len()));
assert_eq!(s, "日");
}
#[tokio::test]
async fn embed_prepass_returns_embeddings_for_non_exempt() {
let messages = vec![
make_msg(Role::System, "system prompt"), make_msg(Role::User, "user message"),
make_msg(Role::Assistant, "assistant reply"),
];
let cfg = FidelityConfig::default();
let embed = |_text: &str| -> EmbedFuture { Box::pin(async { Ok(vec![1.0f32, 2.0, 3.0]) }) };
let result = embed_prepass(&messages, &embed, &cfg, 0).await;
assert!(!result.contains_key(&0));
assert_eq!(result[&1], vec![1.0, 2.0, 3.0]);
assert_eq!(result[&2], vec![1.0, 2.0, 3.0]);
}
#[tokio::test]
async fn embed_prepass_skips_empty_content() {
let messages = vec![make_msg(Role::System, "system"), make_msg(Role::User, "")];
let cfg = FidelityConfig::default();
let embed = |_text: &str| -> EmbedFuture { Box::pin(async { Ok(vec![1.0f32]) }) };
let result = embed_prepass(&messages, &embed, &cfg, 0).await;
assert!(!result.contains_key(&1), "empty content must be skipped");
}
#[tokio::test]
async fn embed_prepass_skips_inserted_memory() {
let messages = vec![
make_msg(Role::System, "system"),
make_msg(Role::User, "injected memory"),
make_msg(Role::User, "real user message"),
];
let cfg = FidelityConfig::default();
let embed = |_text: &str| -> EmbedFuture { Box::pin(async { Ok(vec![1.0f32]) }) };
let result = embed_prepass(&messages, &embed, &cfg, 1).await;
assert!(!result.contains_key(&0), "system is exempt");
assert!(!result.contains_key(&1), "inserted memory is exempt");
assert!(
result.contains_key(&2),
"real user message must be embedded"
);
}
#[tokio::test]
async fn embed_prepass_silently_skips_errors() {
let messages = vec![
make_msg(Role::System, "system"),
make_msg(Role::User, "user"),
];
let cfg = FidelityConfig::default();
let embed = |_text: &str| -> EmbedFuture {
Box::pin(async {
Err(zeph_llm::LlmError::EmbedUnsupported {
provider: "mock".to_string(),
})
})
};
let result = embed_prepass(&messages, &embed, &cfg, 0).await;
assert!(result.is_empty(), "errors must be silently skipped");
}
#[tokio::test]
async fn embed_prepass_truncates_content_when_cap_set() {
let long_content = "a".repeat(100);
let messages = vec![
make_msg(Role::System, "system"),
make_msg(Role::User, &long_content),
];
let cfg = FidelityConfig {
max_embed_input_tokens: Some(1), ..FidelityConfig::default()
};
let seen_len = std::sync::Arc::new(std::sync::Mutex::new(0usize));
let seen_len_clone = seen_len.clone();
let embed = move |text: &str| -> EmbedFuture {
let len = text.len();
let seen = seen_len_clone.clone();
Box::pin(async move {
*seen.lock().unwrap() = len;
Ok(vec![1.0f32])
})
};
embed_prepass(&messages, &embed, &cfg, 0).await;
assert_eq!(
*seen_len.lock().unwrap(),
4,
"content must be truncated to max_embed_input_tokens * 4 bytes"
);
}
#[tokio::test]
async fn embed_prepass_concurrency_zero_clamped_to_one() {
let messages = vec![
make_msg(Role::System, "system"),
make_msg(Role::User, "user message"),
];
let cfg = FidelityConfig {
embed_concurrency: 0,
..FidelityConfig::default()
};
let embed = |_text: &str| -> EmbedFuture { Box::pin(async { Ok(vec![1.0f32]) }) };
let result = embed_prepass(&messages, &embed, &cfg, 0).await;
assert!(
result.contains_key(&1),
"result must be produced even with concurrency=0"
);
}
#[tokio::test]
async fn embed_prepass_skips_cached_embeddings() {
let mut msg_with_cache = make_msg(Role::User, "already embedded");
msg_with_cache.metadata.embedding = Some(vec![9.0f32]);
let messages = vec![
make_msg(Role::System, "system"),
msg_with_cache,
make_msg(Role::User, "needs embedding"),
];
let cfg = FidelityConfig::default();
let embed = |_text: &str| -> EmbedFuture { Box::pin(async { Ok(vec![1.0f32]) }) };
let result = embed_prepass(&messages, &embed, &cfg, 0).await;
assert!(
!result.contains_key(&1),
"message with cached embedding must be skipped"
);
assert!(
result.contains_key(&2),
"message without embedding must be processed"
);
}
#[tokio::test(start_paused = true)]
async fn embed_prepass_timeout_skips_message() {
let messages = vec![
make_msg(Role::System, "system"),
make_msg(Role::User, "user"),
];
let cfg = FidelityConfig::default();
let embed = |_text: &str| -> EmbedFuture {
Box::pin(async {
tokio::time::sleep(Duration::from_secs(45)).await;
Ok(vec![1.0f32])
})
};
let result = embed_prepass(&messages, &embed, &cfg, 0).await;
assert!(result.is_empty(), "timed-out embed must be skipped");
}
#[test]
fn fidelity_config_new_fields_defaults() {
let cfg = FidelityConfig::default();
assert_eq!(cfg.embed_concurrency, 32);
assert!(cfg.max_embed_input_tokens.is_none());
assert!(cfg.max_compress_input_tokens.is_none());
}
#[test]
fn fidelity_config_new_fields_custom() {
let cfg = FidelityConfig {
embed_concurrency: 8,
max_embed_input_tokens: Some(512),
max_compress_input_tokens: Some(1024),
..FidelityConfig::default()
};
assert_eq!(cfg.embed_concurrency, 8);
assert_eq!(cfg.max_embed_input_tokens, Some(512));
assert_eq!(cfg.max_compress_input_tokens, Some(1024));
}
#[tokio::test]
async fn render_compressed_truncates_oversized_deferred_summary() {
let scorer = FidelityScorer;
let cfg = FidelityConfig {
full_threshold: 2.0,
compressed_threshold: 0.0,
compressed_max_tokens: 3,
..make_cfg()
};
let tc = FixedTc(1);
let mut msg = make_msg(Role::User, "original long content");
msg.metadata.deferred_summary = Some("ten chars!".to_string());
let mut messages = vec![make_msg(Role::System, "sys"), msg];
scorer
.score_and_apply(
&mut messages,
"query text here long",
&[],
&cfg,
&tc,
0,
false,
None,
None,
)
.await;
let compressed = &messages[1];
assert_eq!(
compressed.metadata.fidelity_tag,
Some(ContextFidelity::Compressed)
);
assert!(
tc.count_tokens(&compressed.content) <= 3,
"deferred_summary result must be truncated to compressed_max_tokens"
);
}
#[tokio::test]
async fn render_compressed_applies_max_compress_input_tokens() {
let scorer = FidelityScorer;
let cfg = FidelityConfig {
full_threshold: 2.0,
compressed_threshold: 0.0,
compressed_max_tokens: 100,
max_compress_input_tokens: Some(2), ..make_cfg()
};
let tc = FixedTc(1);
let content_20 = "a".repeat(20); let mut msg = make_msg(Role::User, &content_20);
let mut messages = vec![make_msg(Role::System, "sys"), msg.clone()];
scorer
.score_and_apply(
&mut messages,
"query text here long",
&[],
&cfg,
&tc,
0,
false,
None,
None,
)
.await;
let compressed = &messages[1];
assert_eq!(
compressed.metadata.fidelity_tag,
Some(ContextFidelity::Compressed)
);
assert_eq!(
compressed.content.len(),
8,
"content must be capped to max_compress_input_tokens * 4 bytes"
);
msg.metadata.deferred_summary = Some("short".to_string());
let mut messages2 = vec![make_msg(Role::System, "sys"), msg];
scorer
.score_and_apply(
&mut messages2,
"query text here long",
&[],
&cfg,
&tc,
0,
false,
None,
None,
)
.await;
assert_eq!(
messages2[1].content, "short",
"deferred_summary must bypass input cap"
);
}
}