use async_trait::async_trait;
use crate::types::agent_state::AgentState;
use crate::types::message::Message;
#[async_trait]
pub trait ContextManager: Send + Sync {
async fn prepare(
&self,
messages: &mut Vec<Message>,
context_window: usize,
state: &mut AgentState,
);
fn estimate_tokens(&self, messages: &[Message]) -> usize {
messages.iter().map(|m| m.content.len() / 4 + 1).sum()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn test_context_manager_is_object_safe() {
struct Dummy;
#[async_trait]
impl ContextManager for Dummy {
async fn prepare(
&self,
_messages: &mut Vec<Message>,
_context_window: usize,
_state: &mut AgentState,
) {
}
}
let _: Arc<dyn ContextManager> = Arc::new(Dummy);
}
#[test]
fn test_default_estimate_tokens() {
struct Dummy;
#[async_trait]
impl ContextManager for Dummy {
async fn prepare(
&self,
_messages: &mut Vec<Message>,
_context_window: usize,
_state: &mut AgentState,
) {
}
}
let cm = Dummy;
let messages = vec![
Message {
role: crate::types::message::MessageRole::User,
content: "a".repeat(400), tool_call_id: None,
},
Message {
role: crate::types::message::MessageRole::Assistant,
content: "b".repeat(800), tool_call_id: None,
},
];
let tokens = cm.estimate_tokens(&messages);
assert_eq!(
tokens, 302,
"4-chars ≈ 1-token: (400/4+1) + (800/4+1) = 302"
);
}
}