use std::sync::{Arc, Mutex};
use serde::{Deserialize, Serialize};
use crate::adapters::schemas::{ToolChoice, ToolsSchema};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub function_name: String,
pub arguments: String,
}
#[derive(Debug, Clone)]
pub enum Message {
System { content: String },
User { content: String },
Assistant {
content: Option<String>,
tool_calls: Option<Vec<ToolCall>>,
},
ToolResult {
tool_call_id: String,
content: String,
},
}
#[derive(Debug, Clone)]
pub struct LLMContext {
pub system_prompt: Option<String>,
pub messages: Vec<Message>,
pub tools: Option<ToolsSchema>,
pub tool_choice: Option<ToolChoice>,
staged: Vec<Message>,
epoch: u64,
}
impl LLMContext {
pub fn new(system_prompt: Option<String>) -> Self {
Self {
system_prompt,
messages: Vec::new(),
tools: None,
tool_choice: None,
staged: Vec::new(),
epoch: 0,
}
}
pub fn with_tools(
system_prompt: Option<String>,
tools: ToolsSchema,
tool_choice: Option<ToolChoice>,
) -> Self {
Self {
system_prompt,
messages: Vec::new(),
tools: Some(tools),
tool_choice,
staged: Vec::new(),
epoch: 0,
}
}
pub fn push_message(&mut self, msg: Message) {
self.messages.push(msg);
}
pub fn add_user_message(&mut self, content: impl Into<String>) {
self.messages.push(Message::User {
content: content.into(),
});
}
pub fn add_assistant_message(&mut self, content: impl Into<String>) {
self.messages.push(Message::Assistant {
content: Some(content.into()),
tool_calls: None,
});
}
pub fn add_assistant_tool_calls(
&mut self,
content: Option<String>,
tool_calls: Vec<ToolCall>,
) {
self.messages.push(Message::Assistant {
content,
tool_calls: Some(tool_calls),
});
}
pub fn add_tool_result(
&mut self,
tool_call_id: impl Into<String>,
content: impl Into<String>,
) {
self.messages.push(Message::ToolResult {
tool_call_id: tool_call_id.into(),
content: content.into(),
});
}
pub fn begin_turn(&mut self) -> u64 {
self.epoch = self.epoch.wrapping_add(1);
self.staged.clear();
self.epoch
}
pub fn epoch(&self) -> u64 {
self.epoch
}
pub fn staged_len(&self) -> usize {
self.staged.len()
}
pub fn stage_message(&mut self, msg: Message) {
self.staged.push(msg);
}
pub fn stage_assistant_tool_calls(
&mut self,
content: Option<String>,
tool_calls: Vec<ToolCall>,
) {
self.staged.push(Message::Assistant {
content,
tool_calls: Some(tool_calls),
});
}
pub fn stage_tool_result(
&mut self,
tool_call_id: impl Into<String>,
content: impl Into<String>,
) {
self.staged.push(Message::ToolResult {
tool_call_id: tool_call_id.into(),
content: content.into(),
});
}
pub fn commit(&mut self) -> usize {
if self.staged.is_empty() {
return 0;
}
let mut staged = std::mem::take(&mut self.staged);
Self::repair_orphan_tool_calls(&mut staged);
let n = staged.len();
self.messages.append(&mut staged);
n
}
pub fn rollback(&mut self) {
if !self.staged.is_empty() {
log::debug!(
"LLMContext: rolling back {} staged message(s)",
self.staged.len()
);
self.staged.clear();
}
}
fn repair_orphan_tool_calls(staged: &mut Vec<Message>) {
use std::collections::HashSet;
let answered: HashSet<&str> = staged
.iter()
.filter_map(|m| match m {
Message::ToolResult { tool_call_id, .. } => Some(tool_call_id.as_str()),
_ => None,
})
.collect();
let mut kept_call_ids: HashSet<String> = HashSet::new();
let mut keep: Vec<bool> = Vec::with_capacity(staged.len());
for m in staged.iter() {
let k = match m {
Message::Assistant { tool_calls: Some(tcs), .. } => {
let ok = tcs.iter().all(|tc| answered.contains(tc.id.as_str()));
if ok {
for tc in tcs {
kept_call_ids.insert(tc.id.clone());
}
} else {
log::warn!(
"LLMContext: dropping orphaned assistant tool_calls at commit \
(unanswered tool call)"
);
}
ok
}
_ => true,
};
keep.push(k);
}
let mut i = 0;
staged.retain(|m| {
let k = keep[i];
i += 1;
match m {
Message::ToolResult { tool_call_id, .. } if k => {
kept_call_ids.contains(tool_call_id.as_str())
}
_ => k,
}
});
}
pub fn to_api_messages(&self) -> Vec<Message> {
let mut result = Vec::new();
if let Some(sys) = &self.system_prompt {
result.push(Message::System {
content: sys.clone(),
});
}
result.extend(self.messages.clone());
result
}
pub fn estimate_tokens(&self) -> usize {
let mut chars: usize = self.system_prompt.as_deref().map_or(0, |s| s.len());
for msg in &self.messages {
chars += match msg {
Message::System { content } => content.len(),
Message::User { content } => content.len(),
Message::Assistant { content, tool_calls } => {
content.as_deref().map_or(0, |c| c.len())
+ tool_calls.as_ref().map_or(0, |tcs| {
tcs.iter()
.map(|tc| tc.function_name.len() + tc.arguments.len() + 20)
.sum()
})
}
Message::ToolResult { content, .. } => content.len(),
};
}
chars.saturating_div(4)
}
pub fn trim_to_context_budget(&mut self, context_window_tokens: usize) {
let budget = (context_window_tokens as f64 * 0.8) as usize;
loop {
if self.estimate_tokens() <= budget {
break;
}
let first_user = self
.messages
.iter()
.position(|m| matches!(m, Message::User { .. }));
let next_user = first_user.and_then(|i| {
self.messages[i + 1..]
.iter()
.position(|m| matches!(m, Message::User { .. }))
.map(|j| i + 1 + j)
});
match (first_user, next_user) {
(Some(start), Some(end)) => {
let dropped = end - start;
self.messages.drain(start..end);
log::warn!(
"LLMContext: trimmed {} messages to fit {}-token budget",
dropped,
context_window_tokens
);
}
_ => {
log::warn!(
"LLMContext: context near limit ({} estimated tokens) but cannot safely trim further",
self.estimate_tokens()
);
break;
}
}
}
}
}
pub fn shared_context(system_prompt: Option<String>) -> Arc<Mutex<LLMContext>> {
Arc::new(Mutex::new(LLMContext::new(system_prompt)))
}
pub fn shared_context_with_tools(
system_prompt: Option<String>,
tools: ToolsSchema,
tool_choice: Option<ToolChoice>,
) -> Arc<Mutex<LLMContext>> {
Arc::new(Mutex::new(LLMContext::with_tools(
system_prompt,
tools,
tool_choice,
)))
}
#[cfg(test)]
mod tests {
use super::*;
fn tc(id: &str, name: &str) -> ToolCall {
ToolCall {
id: id.into(),
function_name: name.into(),
arguments: "{}".into(),
}
}
fn assistant_text(messages: &[Message]) -> Vec<&str> {
messages
.iter()
.filter_map(|m| match m {
Message::Assistant { content: Some(c), tool_calls: None } => Some(c.as_str()),
_ => None,
})
.collect()
}
#[test]
fn staged_is_invisible_until_commit() {
let mut ctx = LLMContext::new(None);
ctx.add_user_message("hello");
ctx.stage_assistant_tool_calls(None, vec![tc("call_1", "lookup")]);
ctx.stage_tool_result("call_1", "ok");
assert_eq!(ctx.staged_len(), 2);
assert_eq!(ctx.to_api_messages().len(), 1); }
#[test]
fn commit_splices_full_round() {
let mut ctx = LLMContext::new(None);
ctx.add_user_message("status of 4471?");
ctx.stage_assistant_tool_calls(None, vec![tc("call_1", "lookup")]);
ctx.stage_tool_result("call_1", "shipped");
let n = ctx.commit();
assert_eq!(n, 2);
assert_eq!(ctx.staged_len(), 0);
assert_eq!(ctx.messages.len(), 3);
assert!(matches!(ctx.messages[1], Message::Assistant { tool_calls: Some(_), .. }));
assert!(matches!(ctx.messages[2], Message::ToolResult { .. }));
}
#[test]
fn rollback_discards_orphaned_round() {
let mut ctx = LLMContext::new(None);
ctx.add_user_message("status of 4471?");
ctx.stage_assistant_tool_calls(None, vec![tc("call_1", "lookup")]);
ctx.rollback();
assert_eq!(ctx.staged_len(), 0);
assert_eq!(ctx.messages.len(), 1);
assert!(matches!(ctx.messages[0], Message::User { .. }));
}
#[test]
fn commit_drops_orphan_tool_calls_for_consistency() {
let mut ctx = LLMContext::new(None);
ctx.stage_assistant_tool_calls(None, vec![tc("call_1", "a"), tc("call_2", "b")]);
ctx.stage_tool_result("call_1", "done");
let n = ctx.commit();
assert_eq!(n, 0, "orphaned round must be dropped entirely");
assert!(ctx.messages.is_empty());
}
#[test]
fn commit_keeps_plain_text_assistant() {
let mut ctx = LLMContext::new(None);
ctx.stage_message(Message::Assistant {
content: Some("hi there".into()),
tool_calls: None,
});
assert_eq!(ctx.commit(), 1);
assert_eq!(assistant_text(&ctx.messages), vec!["hi there"]);
}
#[test]
fn begin_turn_bumps_epoch_and_clears_stale_staged() {
let mut ctx = LLMContext::new(None);
let e0 = ctx.epoch();
ctx.stage_assistant_tool_calls(None, vec![tc("call_1", "lookup")]);
let e1 = ctx.begin_turn();
assert_eq!(e1, e0 + 1);
assert_eq!(ctx.staged_len(), 0, "begin_turn discards a prior interrupted round");
}
}