use std::sync::{Arc, Mutex};
use serde::{Deserialize, Serialize};
use crate::adapters::schemas::{ToolChoice, ToolsSchema};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub function_name: String,
pub arguments: String,
}
#[derive(Debug, Clone)]
pub enum Message {
System { content: String },
User { content: String },
Assistant {
content: Option<String>,
tool_calls: Option<Vec<ToolCall>>,
},
ToolResult {
tool_call_id: String,
content: String,
},
}
#[derive(Debug, Clone)]
pub struct LLMContext {
pub system_prompt: Option<String>,
pub messages: Vec<Message>,
pub tools: Option<ToolsSchema>,
pub tool_choice: Option<ToolChoice>,
}
impl LLMContext {
pub fn new(system_prompt: Option<String>) -> Self {
Self {
system_prompt,
messages: Vec::new(),
tools: None,
tool_choice: None,
}
}
pub fn with_tools(
system_prompt: Option<String>,
tools: ToolsSchema,
tool_choice: Option<ToolChoice>,
) -> Self {
Self {
system_prompt,
messages: Vec::new(),
tools: Some(tools),
tool_choice,
}
}
pub fn push_message(&mut self, msg: Message) {
self.messages.push(msg);
}
pub fn add_user_message(&mut self, content: impl Into<String>) {
self.messages.push(Message::User {
content: content.into(),
});
}
pub fn add_assistant_message(&mut self, content: impl Into<String>) {
self.messages.push(Message::Assistant {
content: Some(content.into()),
tool_calls: None,
});
}
pub fn add_assistant_tool_calls(
&mut self,
content: Option<String>,
tool_calls: Vec<ToolCall>,
) {
self.messages.push(Message::Assistant {
content,
tool_calls: Some(tool_calls),
});
}
pub fn add_tool_result(
&mut self,
tool_call_id: impl Into<String>,
content: impl Into<String>,
) {
self.messages.push(Message::ToolResult {
tool_call_id: tool_call_id.into(),
content: content.into(),
});
}
pub fn to_api_messages(&self) -> Vec<Message> {
let mut result = Vec::new();
if let Some(sys) = &self.system_prompt {
result.push(Message::System {
content: sys.clone(),
});
}
result.extend(self.messages.clone());
result
}
pub fn estimate_tokens(&self) -> usize {
let mut chars: usize = self.system_prompt.as_deref().map_or(0, |s| s.len());
for msg in &self.messages {
chars += match msg {
Message::System { content } => content.len(),
Message::User { content } => content.len(),
Message::Assistant { content, tool_calls } => {
content.as_deref().map_or(0, |c| c.len())
+ tool_calls.as_ref().map_or(0, |tcs| {
tcs.iter()
.map(|tc| tc.function_name.len() + tc.arguments.len() + 20)
.sum()
})
}
Message::ToolResult { content, .. } => content.len(),
};
}
chars.saturating_div(4)
}
pub fn trim_to_context_budget(&mut self, context_window_tokens: usize) {
let budget = (context_window_tokens as f64 * 0.8) as usize;
loop {
if self.estimate_tokens() <= budget {
break;
}
let first_user = self
.messages
.iter()
.position(|m| matches!(m, Message::User { .. }));
let next_user = first_user.and_then(|i| {
self.messages[i + 1..]
.iter()
.position(|m| matches!(m, Message::User { .. }))
.map(|j| i + 1 + j)
});
match (first_user, next_user) {
(Some(start), Some(end)) => {
let dropped = end - start;
self.messages.drain(start..end);
log::warn!(
"LLMContext: trimmed {} messages to fit {}-token budget",
dropped,
context_window_tokens
);
}
_ => {
log::warn!(
"LLMContext: context near limit ({} estimated tokens) but cannot safely trim further",
self.estimate_tokens()
);
break;
}
}
}
}
}
pub fn shared_context(system_prompt: Option<String>) -> Arc<Mutex<LLMContext>> {
Arc::new(Mutex::new(LLMContext::new(system_prompt)))
}
pub fn shared_context_with_tools(
system_prompt: Option<String>,
tools: ToolsSchema,
tool_choice: Option<ToolChoice>,
) -> Arc<Mutex<LLMContext>> {
Arc::new(Mutex::new(LLMContext::with_tools(
system_prompt,
tools,
tool_choice,
)))
}