use modkit_macros::domain_model;
use crate::config::EstimationBudgets;
const SUMMARY_PREAMBLE: &str = "This conversation has earlier messages that have been summarized. \
The summary below covers the earlier portion of the conversation. \
Recent messages follow after.\n\n";
use crate::domain::llm::{
ContentPart, ContextMessage, FileSearchFilter, LlmMessage, LlmTool, Role,
};
#[domain_model]
#[derive(Debug, Clone)]
#[allow(clippy::struct_excessive_bools)]
pub struct TokenBudget {
pub context_window: u32,
pub max_output_tokens_applied: i32,
pub budgets: EstimationBudgets,
pub tools_enabled: bool,
pub web_search_enabled: bool,
pub code_interpreter_enabled: bool,
}
#[domain_model]
#[allow(clippy::struct_excessive_bools)]
pub struct ContextInput<'a> {
pub system_prompt: &'a str,
pub web_search_guard: &'a str,
pub file_search_guard: &'a str,
pub thread_summary: Option<&'a str>,
pub recent_messages: &'a [ContextMessage],
pub user_message: &'a str,
pub web_search_enabled: bool,
pub file_search_enabled: bool,
pub vector_store_ids: &'a [String],
pub file_search_filters: Option<FileSearchFilter>,
pub web_search_context_size: crate::domain::llm::WebSearchContextSize,
pub file_search_max_num_results: u32,
pub code_interpreter_file_ids: Vec<String>,
pub token_budget: Option<TokenBudget>,
pub image_file_ids: &'a [String],
}
#[domain_model]
pub struct AssembledContext {
pub system_instructions: Option<String>,
pub messages: Vec<LlmMessage>,
pub tools: Vec<LlmTool>,
pub estimated_context_tokens: u64,
pub messages_truncated: bool,
}
#[domain_model]
#[derive(Debug)]
pub enum ContextAssemblyError {
BudgetExceeded {
required_tokens: u64,
available_tokens: u64,
},
}
impl std::fmt::Display for ContextAssemblyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::BudgetExceeded {
required_tokens,
available_tokens,
} => write!(
f,
"mandatory context items require {required_tokens} tokens but only {available_tokens} are available"
),
}
}
}
impl std::error::Error for ContextAssemblyError {}
fn build_user_message(text: &str, image_file_ids: &[String]) -> LlmMessage {
if image_file_ids.is_empty() {
LlmMessage::user(text)
} else {
let mut content = vec![ContentPart::Text {
text: text.to_owned(),
}];
for file_id in image_file_ids {
content.push(ContentPart::Image {
file_id: file_id.clone(),
});
}
LlmMessage {
role: Role::User,
content,
}
}
}
pub fn compute_available_budget(budget: &TokenBudget) -> Result<u64, ContextAssemblyError> {
let tool_surcharge = if budget.tools_enabled {
u64::from(budget.budgets.tool_surcharge_tokens)
} else {
0
} + if budget.web_search_enabled {
u64::from(budget.budgets.web_search_surcharge_tokens)
} else {
0
} + if budget.code_interpreter_enabled {
u64::from(budget.budgets.code_interpreter_surcharge_tokens)
} else {
0
};
#[allow(clippy::cast_sign_loss)]
let deductions = budget.max_output_tokens_applied as u64
+ tool_surcharge
+ u64::from(budget.budgets.fixed_overhead_tokens);
let context_window = u64::from(budget.context_window);
if deductions >= context_window {
return Err(ContextAssemblyError::BudgetExceeded {
required_tokens: deductions,
available_tokens: context_window,
});
}
Ok(context_window - deductions)
}
#[must_use]
pub fn estimate_item_tokens(text_bytes: u64, budgets: &EstimationBudgets) -> u64 {
let bpt = u64::from(budgets.bytes_per_token_conservative.max(1));
let base = text_bytes.div_ceil(bpt) + u64::from(budgets.fixed_overhead_tokens);
#[allow(clippy::integer_division)]
{
base * (100 + u64::from(budgets.safety_margin_pct)) / 100
}
}
pub fn assemble_context(
input: &ContextInput<'_>,
) -> Result<AssembledContext, ContextAssemblyError> {
let system_instructions = build_system_instructions(
input.system_prompt,
input.web_search_enabled,
input.web_search_guard,
input.file_search_enabled,
input.file_search_guard,
);
let mut tools = Vec::new();
if input.file_search_enabled && !input.vector_store_ids.is_empty() {
tools.push(LlmTool::FileSearch {
vector_store_ids: input.vector_store_ids.to_vec(),
filters: input.file_search_filters.clone(),
max_num_results: Some(input.file_search_max_num_results),
});
}
if input.web_search_enabled {
tools.push(LlmTool::WebSearch {
search_context_size: input.web_search_context_size,
});
}
if !input.code_interpreter_file_ids.is_empty() {
tools.push(LlmTool::CodeInterpreter {
file_ids: input.code_interpreter_file_ids.clone(),
});
}
if let Some(ref budget) = input.token_budget {
let available = compute_available_budget(budget)?;
let budgets = &budget.budgets;
let sys_tokens = system_instructions
.as_ref()
.map_or(0, |s| estimate_item_tokens(s.len() as u64, budgets));
let user_tokens = estimate_item_tokens(input.user_message.len() as u64, budgets);
let image_tokens = (input.image_file_ids.len() as u64)
.saturating_mul(u64::from(budgets.image_token_budget));
let mandatory = sys_tokens + user_tokens + image_tokens;
if mandatory > available {
return Err(ContextAssemblyError::BudgetExceeded {
required_tokens: mandatory,
available_tokens: available,
});
}
let mut remaining = available - mandatory;
let keep_summary = if let Some(summary) = input.thread_summary {
let cost =
estimate_item_tokens((summary.len() + SUMMARY_PREAMBLE.len()) as u64, budgets);
if cost <= remaining {
remaining -= cost;
true
} else {
false
}
} else {
false
};
let mut keep_from_index = input.recent_messages.len();
for (i, msg) in input.recent_messages.iter().enumerate().rev() {
if matches!(msg.role, Role::System) {
continue; }
let cost = estimate_item_tokens(msg.content.len() as u64, budgets);
if cost <= remaining {
remaining -= cost;
keep_from_index = i;
} else {
break;
}
}
let mut messages = Vec::new();
if keep_summary && let Some(summary) = input.thread_summary {
messages.push(LlmMessage::user(format!("{SUMMARY_PREAMBLE}{summary}")));
}
for msg in &input.recent_messages[keep_from_index..] {
match msg.role {
Role::User => messages.push(LlmMessage::user(&msg.content)),
Role::Assistant => messages.push(LlmMessage::assistant(&msg.content)),
Role::System => {}
}
}
messages.push(build_user_message(input.user_message, input.image_file_ids));
let estimated_context_tokens = available - remaining;
let messages_truncated = input.recent_messages[..keep_from_index]
.iter()
.any(|m| !matches!(m.role, Role::System));
Ok(AssembledContext {
system_instructions,
messages,
tools,
estimated_context_tokens,
messages_truncated,
})
} else {
let mut messages = Vec::new();
if let Some(summary) = input.thread_summary {
messages.push(LlmMessage::user(format!("{SUMMARY_PREAMBLE}{summary}")));
}
for msg in input.recent_messages {
match msg.role {
Role::User => messages.push(LlmMessage::user(&msg.content)),
Role::Assistant => messages.push(LlmMessage::assistant(&msg.content)),
Role::System => {}
}
}
messages.push(build_user_message(input.user_message, input.image_file_ids));
Ok(AssembledContext {
system_instructions,
messages,
tools,
estimated_context_tokens: 0,
messages_truncated: false,
})
}
}
fn build_system_instructions(
system_prompt: &str,
web_search_enabled: bool,
web_search_guard: &str,
file_search_enabled: bool,
file_search_guard: &str,
) -> Option<String> {
let mut parts: Vec<&str> = Vec::new();
if !system_prompt.is_empty() {
parts.push(system_prompt);
}
if web_search_enabled && !web_search_guard.is_empty() {
parts.push(web_search_guard);
}
if file_search_enabled && !file_search_guard.is_empty() {
parts.push(file_search_guard);
}
if parts.is_empty() {
None
} else {
Some(parts.join("\n\n"))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_message(role: Role, content: &str) -> ContextMessage {
ContextMessage {
role,
content: content.to_owned(),
}
}
#[test]
fn empty_system_prompt_no_tools() {
let result = assemble_context(&ContextInput {
system_prompt: "",
web_search_guard: "",
file_search_guard: "",
thread_summary: None,
recent_messages: &[],
user_message: "hello",
web_search_enabled: false,
file_search_enabled: false,
vector_store_ids: &[],
file_search_filters: None,
web_search_context_size: crate::domain::llm::WebSearchContextSize::Low,
file_search_max_num_results: 5,
code_interpreter_file_ids: vec![],
token_budget: None,
image_file_ids: &[],
})
.unwrap();
assert!(result.system_instructions.is_none());
assert!(result.tools.is_empty());
assert_eq!(result.messages.len(), 1);
}
#[test]
fn system_prompt_with_web_search_guard() {
let result = assemble_context(&ContextInput {
system_prompt: "You are helpful.",
web_search_guard: "Use web_search only if needed.",
file_search_guard: "",
thread_summary: None,
recent_messages: &[],
user_message: "hello",
web_search_enabled: true,
file_search_enabled: false,
vector_store_ids: &[],
file_search_filters: None,
web_search_context_size: crate::domain::llm::WebSearchContextSize::Low,
file_search_max_num_results: 5,
code_interpreter_file_ids: vec![],
token_budget: None,
image_file_ids: &[],
})
.unwrap();
let instructions = result.system_instructions.unwrap();
assert!(instructions.contains("You are helpful."));
assert!(instructions.contains("Use web_search only if needed."));
}
#[test]
fn system_prompt_with_file_search_guard() {
let result = assemble_context(&ContextInput {
system_prompt: "You are helpful.",
web_search_guard: "",
file_search_guard: "Use file_search for documents.",
thread_summary: None,
recent_messages: &[],
user_message: "hello",
web_search_enabled: false,
file_search_enabled: true,
vector_store_ids: &["vs-1".to_owned()],
file_search_filters: None,
web_search_context_size: crate::domain::llm::WebSearchContextSize::Low,
file_search_max_num_results: 5,
code_interpreter_file_ids: vec![],
token_budget: None,
image_file_ids: &[],
})
.unwrap();
let instructions = result.system_instructions.unwrap();
assert!(instructions.contains("You are helpful."));
assert!(instructions.contains("Use file_search for documents."));
}
#[test]
fn both_guards_appended() {
let result = assemble_context(&ContextInput {
system_prompt: "Base prompt.",
web_search_guard: "web guard",
file_search_guard: "file guard",
thread_summary: None,
recent_messages: &[],
user_message: "hello",
web_search_enabled: true,
file_search_enabled: true,
vector_store_ids: &["vs-1".to_owned()],
file_search_filters: None,
web_search_context_size: crate::domain::llm::WebSearchContextSize::Low,
file_search_max_num_results: 5,
code_interpreter_file_ids: vec![],
token_budget: None,
image_file_ids: &[],
})
.unwrap();
let instructions = result.system_instructions.unwrap();
assert!(instructions.contains("Base prompt."));
assert!(instructions.contains("web guard"));
assert!(instructions.contains("file guard"));
}
#[test]
fn thread_summary_included_as_first_message() {
let recent = vec![make_message(Role::User, "prior question")];
let result = assemble_context(&ContextInput {
system_prompt: "",
web_search_guard: "",
file_search_guard: "",
thread_summary: Some("Summary of prior conversation."),
recent_messages: &recent,
user_message: "new question",
web_search_enabled: false,
file_search_enabled: false,
vector_store_ids: &[],
file_search_filters: None,
web_search_context_size: crate::domain::llm::WebSearchContextSize::Low,
file_search_max_num_results: 5,
code_interpreter_file_ids: vec![],
token_budget: None,
image_file_ids: &[],
})
.unwrap();
assert_eq!(result.messages.len(), 3); let first_content = &result.messages[0].content;
match &first_content[0] {
crate::domain::llm::ContentPart::Text { text } => {
assert!(text.contains("earlier messages that have been summarized"));
assert!(text.contains("Summary of prior conversation."));
}
crate::domain::llm::ContentPart::Image { .. } => {
panic!("Expected text content")
}
}
}
#[test]
fn no_thread_summary_starts_with_recent() {
let recent = vec![
make_message(Role::User, "first"),
make_message(Role::Assistant, "response"),
];
let result = assemble_context(&ContextInput {
system_prompt: "",
web_search_guard: "",
file_search_guard: "",
thread_summary: None,
recent_messages: &recent,
user_message: "second",
web_search_enabled: false,
file_search_enabled: false,
vector_store_ids: &[],
file_search_filters: None,
web_search_context_size: crate::domain::llm::WebSearchContextSize::Low,
file_search_max_num_results: 5,
code_interpreter_file_ids: vec![],
token_budget: None,
image_file_ids: &[],
})
.unwrap();
assert_eq!(result.messages.len(), 3); }
#[test]
fn system_role_skipped() {
let recent = vec![
make_message(Role::User, "hello"),
make_message(Role::System, "system msg"),
make_message(Role::Assistant, "hi"),
];
let result = assemble_context(&ContextInput {
system_prompt: "",
web_search_guard: "",
file_search_guard: "",
thread_summary: None,
recent_messages: &recent,
user_message: "bye",
web_search_enabled: false,
file_search_enabled: false,
vector_store_ids: &[],
file_search_filters: None,
web_search_context_size: crate::domain::llm::WebSearchContextSize::Low,
file_search_max_num_results: 5,
code_interpreter_file_ids: vec![],
token_budget: None,
image_file_ids: &[],
})
.unwrap();
assert_eq!(result.messages.len(), 3);
}
#[test]
fn current_user_message_is_last() {
let recent = vec![make_message(Role::Assistant, "prior")];
let result = assemble_context(&ContextInput {
system_prompt: "",
web_search_guard: "",
file_search_guard: "",
thread_summary: None,
recent_messages: &recent,
user_message: "current input",
web_search_enabled: false,
file_search_enabled: false,
vector_store_ids: &[],
file_search_filters: None,
web_search_context_size: crate::domain::llm::WebSearchContextSize::Low,
file_search_max_num_results: 5,
code_interpreter_file_ids: vec![],
token_budget: None,
image_file_ids: &[],
})
.unwrap();
let last = result.messages.last().unwrap();
match &last.content[0] {
crate::domain::llm::ContentPart::Text { text } => {
assert_eq!(text, "current input");
}
crate::domain::llm::ContentPart::Image { .. } => {
panic!("Expected text content")
}
}
}
#[test]
fn tools_populated_correctly() {
let result = assemble_context(&ContextInput {
system_prompt: "",
web_search_guard: "",
file_search_guard: "",
thread_summary: None,
recent_messages: &[],
user_message: "hello",
web_search_enabled: true,
file_search_enabled: true,
vector_store_ids: &["vs-123".to_owned()],
file_search_filters: None,
web_search_context_size: crate::domain::llm::WebSearchContextSize::High,
file_search_max_num_results: 7,
code_interpreter_file_ids: vec![],
token_budget: None,
image_file_ids: &[],
})
.unwrap();
assert_eq!(result.tools.len(), 2);
assert!(matches!(
&result.tools[0],
LlmTool::FileSearch {
max_num_results: Some(7),
..
}
));
assert!(matches!(
&result.tools[1],
LlmTool::WebSearch {
search_context_size: crate::domain::llm::WebSearchContextSize::High
}
));
let result = assemble_context(&ContextInput {
system_prompt: "",
web_search_guard: "",
file_search_guard: "",
thread_summary: None,
recent_messages: &[],
user_message: "hello",
web_search_enabled: false,
file_search_enabled: true,
vector_store_ids: &[],
file_search_filters: None,
web_search_context_size: crate::domain::llm::WebSearchContextSize::Low,
file_search_max_num_results: 5,
code_interpreter_file_ids: vec![],
token_budget: None,
image_file_ids: &[],
})
.unwrap();
assert!(result.tools.is_empty());
let result = assemble_context(&ContextInput {
system_prompt: "",
web_search_guard: "",
file_search_guard: "",
thread_summary: None,
recent_messages: &[],
user_message: "hello",
web_search_enabled: true,
file_search_enabled: false,
vector_store_ids: &[],
file_search_filters: None,
web_search_context_size: crate::domain::llm::WebSearchContextSize::Medium,
file_search_max_num_results: 5,
code_interpreter_file_ids: vec![],
token_budget: None,
image_file_ids: &[],
})
.unwrap();
assert_eq!(result.tools.len(), 1);
assert!(matches!(
&result.tools[0],
LlmTool::WebSearch {
search_context_size: crate::domain::llm::WebSearchContextSize::Medium
}
));
}
fn test_budgets() -> EstimationBudgets {
EstimationBudgets {
bytes_per_token_conservative: 4,
fixed_overhead_tokens: 100,
safety_margin_pct: 10,
image_token_budget: 1000,
tool_surcharge_tokens: 500,
web_search_surcharge_tokens: 500,
code_interpreter_surcharge_tokens: 1000,
minimal_generation_floor: 128,
}
}
fn test_budget(context_window: u32, max_output: i32) -> TokenBudget {
TokenBudget {
context_window,
max_output_tokens_applied: max_output,
budgets: test_budgets(),
tools_enabled: false,
web_search_enabled: false,
code_interpreter_enabled: false,
}
}
#[test]
fn budget_no_tools() {
let budget = test_budget(128_000, 4096);
let available = compute_available_budget(&budget).unwrap();
assert_eq!(available, 128_000 - 4096 - 100);
}
#[test]
fn budget_with_tools() {
let budget = TokenBudget {
context_window: 128_000,
max_output_tokens_applied: 4096,
budgets: test_budgets(),
tools_enabled: true,
web_search_enabled: true,
code_interpreter_enabled: false,
};
let available = compute_available_budget(&budget).unwrap();
assert_eq!(available, 128_000 - 4096 - 500 - 500 - 100);
}
#[test]
fn budget_zero_context_window() {
let budget = test_budget(0, 4096);
let result = compute_available_budget(&budget);
assert!(matches!(
result,
Err(ContextAssemblyError::BudgetExceeded { .. })
));
}
#[test]
fn item_estimation() {
let budgets = test_budgets();
assert_eq!(estimate_item_tokens(400, &budgets), 220);
assert_eq!(estimate_item_tokens(0, &budgets), 110);
assert_eq!(estimate_item_tokens(1, &budgets), 111);
}
#[test]
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
fn truncation_drops_thread_summary() {
let budgets = test_budgets();
let sys_cost = estimate_item_tokens(10, &budgets); let user_cost = estimate_item_tokens(5, &budgets); let overhead = 4096 + 100; let context_window = (overhead as u64 + sys_cost + user_cost + 1) as u32;
let result = assemble_context(&ContextInput {
system_prompt: "0123456789", web_search_guard: "",
file_search_guard: "",
thread_summary: Some("A very long summary that should be dropped"),
recent_messages: &[],
user_message: "hello", web_search_enabled: false,
file_search_enabled: false,
vector_store_ids: &[],
file_search_filters: None,
web_search_context_size: crate::domain::llm::WebSearchContextSize::Low,
file_search_max_num_results: 5,
code_interpreter_file_ids: vec![],
token_budget: Some(test_budget(context_window, 4096)),
image_file_ids: &[],
})
.unwrap();
assert_eq!(result.messages.len(), 1);
}
#[test]
#[allow(clippy::cast_possible_truncation)]
fn truncation_drops_oldest_messages() {
let budgets = test_budgets();
let msg_cost = estimate_item_tokens(3, &budgets);
let user_cost = estimate_item_tokens(5, &budgets);
let overhead = 4096u64 + 100;
let context_window = (overhead + user_cost + msg_cost) as u32;
let recent = vec![
make_message(Role::User, "msg"), make_message(Role::Assistant, "msg"), ];
let result = assemble_context(&ContextInput {
system_prompt: "",
web_search_guard: "",
file_search_guard: "",
thread_summary: None,
recent_messages: &recent,
user_message: "hello",
web_search_enabled: false,
file_search_enabled: false,
vector_store_ids: &[],
file_search_filters: None,
web_search_context_size: crate::domain::llm::WebSearchContextSize::Low,
file_search_max_num_results: 5,
code_interpreter_file_ids: vec![],
token_budget: Some(test_budget(context_window, 4096)),
image_file_ids: &[],
})
.unwrap();
assert_eq!(result.messages.len(), 2);
}
#[test]
#[allow(clippy::cast_possible_truncation)]
fn truncation_drops_summary_keeps_messages() {
let budgets = test_budgets();
let msg_cost = estimate_item_tokens(3, &budgets);
let user_cost = estimate_item_tokens(5, &budgets);
let big_summary = "X".repeat(2000);
let summary_cost = estimate_item_tokens(
(big_summary.len() + SUMMARY_PREAMBLE.len()) as u64,
&budgets,
);
let overhead = 4096u64 + 100;
let context_window = (overhead + user_cost + 2 * msg_cost) as u32;
assert!(
summary_cost > 2 * msg_cost,
"summary should be more expensive than 2 messages for this test"
);
let recent = vec![
make_message(Role::User, "msg"),
make_message(Role::Assistant, "msg"),
];
let result = assemble_context(&ContextInput {
system_prompt: "",
web_search_guard: "",
file_search_guard: "",
thread_summary: Some(&big_summary),
recent_messages: &recent,
user_message: "hello",
web_search_enabled: false,
file_search_enabled: false,
vector_store_ids: &[],
file_search_filters: None,
web_search_context_size: crate::domain::llm::WebSearchContextSize::Low,
file_search_max_num_results: 5,
code_interpreter_file_ids: vec![],
token_budget: Some(test_budget(context_window, 4096)),
image_file_ids: &[],
})
.unwrap();
assert_eq!(result.messages.len(), 3);
}
#[test]
fn budget_exceeded_mandatory_too_large() {
let result = assemble_context(&ContextInput {
system_prompt: "A".repeat(100_000).as_str(),
web_search_guard: "",
file_search_guard: "",
thread_summary: None,
recent_messages: &[],
user_message: "hello",
web_search_enabled: false,
file_search_enabled: false,
vector_store_ids: &[],
file_search_filters: None,
web_search_context_size: crate::domain::llm::WebSearchContextSize::Low,
file_search_max_num_results: 5,
code_interpreter_file_ids: vec![],
token_budget: Some(test_budget(5000, 4096)),
image_file_ids: &[],
});
assert!(matches!(
result,
Err(ContextAssemblyError::BudgetExceeded { .. })
));
}
#[test]
fn no_budget_includes_everything() {
let recent = vec![
make_message(Role::User, "A".repeat(50_000).as_str()),
make_message(Role::Assistant, "B".repeat(50_000).as_str()),
];
let result = assemble_context(&ContextInput {
system_prompt: "sys",
web_search_guard: "",
file_search_guard: "",
thread_summary: Some("summary"),
recent_messages: &recent,
user_message: "hello",
web_search_enabled: false,
file_search_enabled: false,
vector_store_ids: &[],
file_search_filters: None,
web_search_context_size: crate::domain::llm::WebSearchContextSize::Low,
file_search_max_num_results: 5,
code_interpreter_file_ids: vec![],
token_budget: None,
image_file_ids: &[],
})
.unwrap();
assert_eq!(result.messages.len(), 4);
}
#[test]
fn budget_exceeded_error_message() {
let err = ContextAssemblyError::BudgetExceeded {
required_tokens: 50_000,
available_tokens: 10_000,
};
let msg = err.to_string();
assert!(msg.contains("50000"));
assert!(msg.contains("10000"));
}
#[test]
fn code_interpreter_tool_added_when_enabled_with_file_ids() {
let result = assemble_context(&ContextInput {
system_prompt: "",
web_search_guard: "",
file_search_guard: "",
thread_summary: None,
recent_messages: &[],
user_message: "analyze this",
web_search_enabled: false,
file_search_enabled: false,
vector_store_ids: &[],
file_search_filters: None,
web_search_context_size: crate::domain::llm::WebSearchContextSize::Low,
file_search_max_num_results: 5,
code_interpreter_file_ids: vec!["file-abc123".to_owned()],
token_budget: None,
image_file_ids: &[],
})
.unwrap();
assert_eq!(result.tools.len(), 1);
assert!(matches!(
&result.tools[0],
LlmTool::CodeInterpreter { file_ids } if file_ids == &["file-abc123"]
));
}
#[test]
fn code_interpreter_tool_not_added_when_no_file_ids() {
let result = assemble_context(&ContextInput {
system_prompt: "",
web_search_guard: "",
file_search_guard: "",
thread_summary: None,
recent_messages: &[],
user_message: "analyze this",
web_search_enabled: false,
file_search_enabled: false,
vector_store_ids: &[],
file_search_filters: None,
web_search_context_size: crate::domain::llm::WebSearchContextSize::Low,
file_search_max_num_results: 5,
code_interpreter_file_ids: vec![],
token_budget: None,
image_file_ids: &[],
})
.unwrap();
assert!(result.tools.is_empty());
}
#[test]
fn budget_with_code_interpreter_surcharge() {
let budget = TokenBudget {
context_window: 128_000,
max_output_tokens_applied: 4096,
budgets: test_budgets(),
tools_enabled: false,
web_search_enabled: false,
code_interpreter_enabled: true,
};
let available = compute_available_budget(&budget).unwrap();
assert_eq!(available, 128_000 - 4096 - 1000 - 100);
}
#[test]
fn single_image_produces_image_content_part() {
let images = vec!["file-abc".to_owned()];
let result = assemble_context(&ContextInput {
system_prompt: "",
web_search_guard: "",
file_search_guard: "",
thread_summary: None,
recent_messages: &[],
user_message: "Describe this",
web_search_enabled: false,
file_search_enabled: false,
vector_store_ids: &[],
file_search_filters: None,
web_search_context_size: crate::domain::llm::WebSearchContextSize::Low,
file_search_max_num_results: 5,
code_interpreter_file_ids: vec![],
token_budget: None,
image_file_ids: &images,
})
.unwrap();
assert_eq!(result.messages.len(), 1);
let msg = &result.messages[0];
assert_eq!(msg.content.len(), 2);
assert!(matches!(&msg.content[0], ContentPart::Text { text } if text == "Describe this"));
assert!(matches!(&msg.content[1], ContentPart::Image { file_id } if file_id == "file-abc"));
}
#[test]
fn multiple_images_produce_multiple_content_parts() {
let images = vec!["file-1".to_owned(), "file-2".to_owned()];
let result = assemble_context(&ContextInput {
system_prompt: "",
web_search_guard: "",
file_search_guard: "",
thread_summary: None,
recent_messages: &[],
user_message: "Compare these",
web_search_enabled: false,
file_search_enabled: false,
vector_store_ids: &[],
file_search_filters: None,
web_search_context_size: crate::domain::llm::WebSearchContextSize::Low,
file_search_max_num_results: 5,
code_interpreter_file_ids: vec![],
token_budget: None,
image_file_ids: &images,
})
.unwrap();
let msg = &result.messages[0];
assert_eq!(msg.content.len(), 3);
assert!(matches!(&msg.content[1], ContentPart::Image { file_id } if file_id == "file-1"));
assert!(matches!(&msg.content[2], ContentPart::Image { file_id } if file_id == "file-2"));
}
#[test]
fn no_images_produces_text_only() {
let result = assemble_context(&ContextInput {
system_prompt: "",
web_search_guard: "",
file_search_guard: "",
thread_summary: None,
recent_messages: &[],
user_message: "hello",
web_search_enabled: false,
file_search_enabled: false,
vector_store_ids: &[],
file_search_filters: None,
web_search_context_size: crate::domain::llm::WebSearchContextSize::Low,
file_search_max_num_results: 5,
code_interpreter_file_ids: vec![],
token_budget: None,
image_file_ids: &[],
})
.unwrap();
let msg = &result.messages[0];
assert_eq!(msg.content.len(), 1);
assert!(matches!(&msg.content[0], ContentPart::Text { .. }));
}
#[test]
fn image_tokens_included_in_budget_mandatory() {
let images = vec!["file-1".to_owned(), "file-2".to_owned()];
let result = assemble_context(&ContextInput {
system_prompt: "",
web_search_guard: "",
file_search_guard: "",
thread_summary: None,
recent_messages: &[],
user_message: "hi",
web_search_enabled: false,
file_search_enabled: false,
vector_store_ids: &[],
file_search_filters: None,
web_search_context_size: crate::domain::llm::WebSearchContextSize::Low,
file_search_max_num_results: 5,
code_interpreter_file_ids: vec![],
token_budget: Some(test_budget(10_000, 4096)),
image_file_ids: &images,
});
assert!(result.is_ok());
}
#[test]
fn image_tokens_cause_budget_exceeded() {
let images = vec!["file-1".to_owned(), "file-2".to_owned()];
let result = assemble_context(&ContextInput {
system_prompt: "",
web_search_guard: "",
file_search_guard: "",
thread_summary: None,
recent_messages: &[],
user_message: "hi",
web_search_enabled: false,
file_search_enabled: false,
vector_store_ids: &[],
file_search_filters: None,
web_search_context_size: crate::domain::llm::WebSearchContextSize::Low,
file_search_max_num_results: 5,
code_interpreter_file_ids: vec![],
token_budget: Some(test_budget(5100, 4096)),
image_file_ids: &images,
});
assert!(matches!(
result,
Err(ContextAssemblyError::BudgetExceeded { .. })
));
}
#[test]
fn build_user_message_helper_text_only() {
let msg = super::build_user_message("hello", &[]);
assert_eq!(msg.content.len(), 1);
assert!(matches!(&msg.content[0], ContentPart::Text { text } if text == "hello"));
}
#[test]
fn build_user_message_helper_with_images() {
let ids = vec!["f1".to_owned(), "f2".to_owned()];
let msg = super::build_user_message("look", &ids);
assert_eq!(msg.content.len(), 3);
assert!(matches!(&msg.content[0], ContentPart::Text { text } if text == "look"));
assert!(matches!(&msg.content[1], ContentPart::Image { file_id } if file_id == "f1"));
assert!(matches!(&msg.content[2], ContentPart::Image { file_id } if file_id == "f2"));
}
}