use async_trait::async_trait;
use cloudllm::client_wrapper;
use cloudllm::client_wrapper::{ClientWrapper, Message, Role, TokenUsage, ToolDefinition};
use cloudllm::cloudllm::llm_session;
use cloudllm::cloudllm::llm_session::estimate_message_token_count;
use cloudllm::LLMSession;
use std::sync::Arc;
use tokio::sync::Mutex;
struct MockClient {
usage: Mutex<Option<TokenUsage>>,
response_content: String,
last_message_count: Mutex<usize>,
}
impl MockClient {
fn new(response_content: String) -> Self {
Self {
usage: Mutex::new(None),
response_content,
last_message_count: Mutex::new(0),
}
}
async fn get_last_message_count(&self) -> usize {
*self.last_message_count.lock().await
}
async fn set_usage(&self, input: usize, output: usize, total: usize) {
let mut usage = self.usage.lock().await;
*usage = Some(client_wrapper::TokenUsage {
input_tokens: input,
output_tokens: output,
total_tokens: total,
});
}
}
#[async_trait]
impl ClientWrapper for MockClient {
async fn send_message(
&self,
messages: &[Message],
_tools: Option<Vec<ToolDefinition>>,
) -> Result<Message, Box<dyn std::error::Error>> {
let mut count_guard = self.last_message_count.lock().await;
*count_guard = messages.len();
let mut input_tokens = 0;
for msg in messages {
input_tokens += estimate_message_token_count(msg);
}
let output_tokens = estimate_message_token_count(&Message {
role: Role::Assistant,
content: self.response_content.clone().into(),
tool_calls: vec![],
});
let computed_usage = TokenUsage {
input_tokens,
output_tokens,
total_tokens: input_tokens + output_tokens,
};
let mut usage_guard = self.usage.lock().await;
if usage_guard.is_none() {
*usage_guard = Some(computed_usage);
}
Ok(Message {
role: Role::Assistant,
content: self.response_content.clone().into(),
tool_calls: vec![],
})
}
fn model_name(&self) -> &str {
"mock-model"
}
fn provider_name(&self) -> &str {
"mock"
}
fn usage_slot(&self) -> Option<&Mutex<Option<TokenUsage>>> {
Some(&self.usage)
}
}
#[tokio::test]
async fn test_token_caching() {
let mock_client = Arc::new(MockClient::new("Response".to_string()));
let mut session = LLMSession::new(mock_client.clone(), "System prompt".to_string(), 1000);
let user_message = "Hello, this is a test message";
mock_client.set_usage(100, 50, 150).await;
let _ = session
.send_message(Role::User, user_message.to_string(), None)
.await;
assert_eq!(session.get_conversation_history().len(), 2); assert_eq!(session.get_cached_token_counts().len(), 2);
let expected_user_tokens = llm_session::estimate_message_token_count(&Message {
role: Role::User,
content: user_message.to_string().into(),
tool_calls: vec![],
});
let expected_response_tokens = llm_session::estimate_message_token_count(&Message {
role: Role::Assistant,
content: "Response".to_string().into(),
tool_calls: vec![],
});
assert_eq!(session.get_cached_token_counts()[0], expected_user_tokens);
assert_eq!(
session.get_cached_token_counts()[1],
expected_response_tokens
);
}
#[tokio::test]
async fn test_token_caching_with_trimming() {
let mock_client = Arc::new(MockClient::new("Response".to_string()));
let mut session = LLMSession::new(
mock_client.clone(),
"System prompt".to_string(),
100, );
mock_client.set_usage(50, 25, 75).await;
let _ = session
.send_message(Role::User, "First message".to_string(), None)
.await;
assert_eq!(session.get_conversation_history().len(), 2);
assert_eq!(session.get_cached_token_counts().len(), 2);
mock_client.set_usage(80, 40, 120).await; let _ = session
.send_message(Role::User, "Second message".to_string(), None)
.await;
assert!(session.get_conversation_history().len() < 4); assert_eq!(
session.get_conversation_history().len(),
session.get_cached_token_counts().len()
);
}
#[test]
fn test_estimate_token_count() {
assert_eq!(llm_session::estimate_token_count("test"), 1);
assert_eq!(
llm_session::estimate_token_count("this is a longer test"),
5
);
assert_eq!(llm_session::estimate_token_count(""), 1); }
#[test]
fn test_estimate_message_token_count() {
let message = Message {
role: Role::User,
content: "test message".to_string().into(),
tool_calls: vec![],
};
assert_eq!(llm_session::estimate_message_token_count(&message), 4);
}
#[tokio::test]
async fn test_pre_transmission_trimming() {
let client = Arc::new(MockClient::new("Response".to_string()));
let mut session = LLMSession::new(
client.clone(),
"System".to_string(),
20, );
let _ = session
.send_message(Role::User, "Msg1".to_string(), None)
.await;
let _ = session
.send_message(Role::User, "Msg2".to_string(), None)
.await;
let _ = session
.send_message(Role::User, "Msg3".to_string(), None)
.await;
let large_msg = "0123456789012345678901234567890123456789"; let _ = session
.send_message(Role::User, large_msg.to_string(), None)
.await;
let message_count = client.get_last_message_count().await;
assert!(
message_count > 0,
"Should have sent at least the system prompt and new message"
);
assert!(
message_count < 6,
"Should have trimmed some messages (system + 4 user + 4 assistant = 9 total before trim)"
);
assert!(
!session.get_conversation_history().is_empty(),
"Conversation history should not be empty"
);
}
#[tokio::test]
async fn test_no_trimming_when_under_limit() {
let client = Arc::new(MockClient::new("OK".to_string()));
let mut session = LLMSession::new(client.clone(), "System".to_string(), 10000);
let _ = session
.send_message(Role::User, "Hi".to_string(), None)
.await;
let _ = session
.send_message(Role::User, "Hello".to_string(), None)
.await;
let message_count = client.get_last_message_count().await;
assert_eq!(
message_count, 4,
"Should have sent all messages without trimming"
);
}
#[tokio::test]
async fn test_request_buffer_reuse() {
let client = Arc::new(MockClient::new("Response".to_string()));
let mut session = LLMSession::new(
client.clone() as Arc<dyn ClientWrapper>,
"System prompt".to_string(),
10_000,
);
let _ = session
.send_message(Role::User, "First".to_string(), None)
.await;
let count1 = client.get_last_message_count().await;
assert_eq!(count1, 2);
let _ = session
.send_message(Role::User, "Second".to_string(), None)
.await;
let count2 = client.get_last_message_count().await;
assert_eq!(count2, 4);
let _ = session
.send_message(Role::User, "Third".to_string(), None)
.await;
let count3 = client.get_last_message_count().await;
assert_eq!(count3, 6);
}