use std::sync::Arc;
use tokio::task_local;
use juncture::llm::{ChatModel, MockChatModel};
use juncture_core::pregel::{
BudgetConfig, BudgetReportError, BudgetTracker, try_report_model_call,
};
use juncture_core::state::messages::Message;
#[test]
fn test_budget_report_outside_context_returns_error() {
let result = try_report_model_call(100, 200);
assert!(matches!(result, Err(BudgetReportError::NoTracker)));
}
#[test]
fn test_mock_chat_model_sets_usage() {
let model = MockChatModel::new("gpt-4").with_response("Test response");
let messages = vec![Message::human("Hello")];
let response =
futures::executor::block_on(model.invoke(&messages, None)).expect("Invoke should succeed");
assert!(
!response.content_text().is_empty(),
"Response should have content"
);
}
#[test]
fn test_budget_tracker_reports_model_calls() {
let config = BudgetConfig::new().with_max_tokens(1000);
let tracker = BudgetTracker::new(config);
tracker.report_model_call(100, 200);
tracker.report_model_call(50, 150);
let usage = tracker.current_usage();
assert_eq!(usage.tokens_used, 500); }
#[test]
fn test_budget_tracker_enforces_limits() {
let config = BudgetConfig::new().with_max_tokens(100);
let tracker = BudgetTracker::new(config);
assert!(tracker.check().is_none());
tracker.report_model_call(60, 50);
let result = tracker.check();
assert!(result.is_some());
if let Some(reason) = result {
assert_eq!(reason.to_string(), "Token budget exceeded: 110 > 100");
}
}
#[test]
fn test_task_local_budget_tracker_scope() {
task_local! {
static TEST_TRACKER: Arc<BudgetTracker>;
}
let config = BudgetConfig::new().with_max_tokens(1000);
let tracker = Arc::new(BudgetTracker::new(config));
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
TEST_TRACKER
.scope(Arc::clone(&tracker), async {
let result = TEST_TRACKER.try_with(|t| {
t.report_model_call(100, 200);
t.current_usage().tokens_used
});
assert_eq!(result.unwrap(), 300);
})
.await;
});
rt.block_on(async {
let result = TEST_TRACKER.try_with(|_| 0i64);
assert!(
result.is_err(),
"Should return error when accessed outside scope"
);
});
}