use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use super::*;
use crate::error::{Result, TinyAgentsError};
use crate::harness::context::{RunConfig, RunContext};
use crate::harness::events::{AgentEvent, RecordingListener};
use crate::harness::message::{AssistantMessage, ContentBlock, Message, UserMessage};
use crate::harness::model::{ModelRequest, ModelResponse, PromptSegment, SegmentRole};
use crate::harness::summarization::{SummarizationPolicy, TrimStrategy};
use crate::harness::usage::Usage;
fn ctx() -> RunContext {
RunContext::new(RunConfig::new("test-run"), ())
}
fn user(text: &str) -> Message {
Message::User(UserMessage {
content: vec![ContentBlock::Text(text.to_string())],
})
}
fn response_with_usage(usage: Usage) -> ModelResponse {
ModelResponse {
message: AssistantMessage {
id: None,
content: vec![ContentBlock::Text("ok".to_string())],
tool_calls: Vec::new(),
usage: None,
},
usage: Some(usage),
finish_reason: None,
raw: None,
resolved_model: None,
}
}
fn segment(id: &str, role: SegmentRole, cacheable: bool) -> PromptSegment {
PromptSegment {
id: id.to_string(),
role,
cacheable,
}
}
struct OrderRecorder {
label: &'static str,
log: Arc<Mutex<Vec<String>>>,
}
#[async_trait]
impl Middleware<()> for OrderRecorder {
fn name(&self) -> &str {
self.label
}
async fn before_model(
&self,
_ctx: &mut RunContext,
_state: &(),
_request: &mut ModelRequest,
) -> Result<()> {
self.log
.lock()
.unwrap()
.push(format!("{}:before", self.label));
Ok(())
}
async fn after_model(
&self,
_ctx: &mut RunContext,
_state: &(),
_response: &mut ModelResponse,
) -> Result<()> {
self.log
.lock()
.unwrap()
.push(format!("{}:after", self.label));
Ok(())
}
}
struct FailingMiddleware;
#[async_trait]
impl Middleware<()> for FailingMiddleware {
fn name(&self) -> &str {
"failing"
}
async fn before_model(
&self,
_ctx: &mut RunContext,
_state: &(),
_request: &mut ModelRequest,
) -> Result<()> {
Err(TinyAgentsError::Middleware("boom".to_string()))
}
}
#[tokio::test]
async fn before_runs_forward_after_runs_reverse() {
let log = Arc::new(Mutex::new(Vec::new()));
let mut stack: MiddlewareStack<()> = MiddlewareStack::new();
stack.push(Arc::new(OrderRecorder {
label: "a",
log: log.clone(),
}));
stack.push(Arc::new(OrderRecorder {
label: "b",
log: log.clone(),
}));
let mut c = ctx();
let mut request = ModelRequest::default();
let mut response = response_with_usage(Usage::new(1, 1));
stack
.run_before_model(&mut c, &(), &mut request)
.await
.unwrap();
stack
.run_after_model(&mut c, &(), &mut response)
.await
.unwrap();
let order = log.lock().unwrap().clone();
assert_eq!(
order,
vec!["a:before", "b:before", "b:after", "a:after"],
"before runs in registration order, after runs reversed"
);
}
#[tokio::test]
async fn error_short_circuits_and_invokes_on_error() {
let logging = Arc::new(LoggingMiddleware::new());
let mut stack: MiddlewareStack<()> = MiddlewareStack::new();
stack.push(logging.clone());
stack.push(Arc::new(FailingMiddleware));
let never = Arc::new(LoggingMiddleware::with_label("never"));
stack.push(never.clone());
let mut c = ctx();
let mut request = ModelRequest::default();
let result = stack.run_before_model(&mut c, &(), &mut request).await;
assert!(matches!(result, Err(TinyAgentsError::Middleware(_))));
assert_eq!(logging.counts().on_error, 1);
assert_eq!(logging.counts().before_model, 1);
assert_eq!(never.counts().before_model, 0);
}
#[tokio::test]
async fn emits_started_and_completed_events() {
let mut stack: MiddlewareStack<()> = MiddlewareStack::new();
stack.push(Arc::new(LoggingMiddleware::new()));
let recorder = Arc::new(RecordingListener::new());
let mut c = ctx();
c.events.subscribe(recorder.clone());
let mut request = ModelRequest::default();
stack
.run_before_model(&mut c, &(), &mut request)
.await
.unwrap();
let kinds: Vec<AgentEvent> = recorder.events().into_iter().map(|r| r.event).collect();
assert_eq!(
kinds,
vec![
AgentEvent::MiddlewareStarted {
name: "logging".to_string()
},
AgentEvent::MiddlewareCompleted {
name: "logging".to_string()
},
]
);
}
#[tokio::test]
async fn message_trim_middleware_shrinks_request() {
let mw = MessageTrimMiddleware::new(TrimStrategy::KeepLast(1));
let mut stack: MiddlewareStack<()> = MiddlewareStack::new();
stack.push(Arc::new(mw));
let mut request = ModelRequest {
messages: vec![user("one"), user("two"), user("three")],
..Default::default()
};
let mut c = ctx();
stack
.run_before_model(&mut c, &(), &mut request)
.await
.unwrap();
assert_eq!(request.messages.len(), 1);
assert_eq!(request.messages[0], user("three"));
}
#[tokio::test]
async fn context_compression_is_noop_below_window_threshold() {
let policy = SummarizationPolicy::default()
.with_context_window(1000)
.with_threshold_fraction(0.9);
let mw = Arc::new(ContextCompressionMiddleware::new(policy));
let mut stack: MiddlewareStack<()> = MiddlewareStack::new();
stack.push(mw.clone());
let recorder = Arc::new(RecordingListener::new());
let mut c = ctx();
c.events.subscribe(recorder.clone());
let before = vec![user("one"), user("two"), user("three")];
let mut request = ModelRequest {
messages: before.clone(),
..Default::default()
};
stack
.run_before_model(&mut c, &(), &mut request)
.await
.unwrap();
assert_eq!(request.messages, before);
assert!(mw.records().is_empty());
let events: Vec<AgentEvent> = recorder.events().into_iter().map(|r| r.event).collect();
assert!(
!events
.iter()
.any(|e| matches!(e, AgentEvent::Compressed { .. })),
);
}
#[tokio::test]
async fn context_compression_compresses_at_or_above_threshold() {
let policy = SummarizationPolicy {
keep_last: 1,
..SummarizationPolicy::default()
}
.with_context_window(100)
.with_threshold_fraction(0.5);
let mw = Arc::new(ContextCompressionMiddleware::new(policy));
let mut stack: MiddlewareStack<()> = MiddlewareStack::new();
stack.push(mw.clone());
let recorder = Arc::new(RecordingListener::new());
let mut c = ctx();
c.events.subscribe(recorder.clone());
let big = "a".repeat(200);
let mut request = ModelRequest {
messages: vec![
user(&format!("{big}-1")),
user(&format!("{big}-2")),
user(&format!("{big}-3")),
],
..Default::default()
};
stack
.run_before_model(&mut c, &(), &mut request)
.await
.unwrap();
assert_eq!(request.messages.len(), 2);
assert!(matches!(request.messages[0], Message::System(_)));
assert_eq!(request.messages[1].text(), format!("{big}-3"));
let records = mw.records();
assert_eq!(records.len(), 1);
assert_eq!(records[0].provenance.source_ids, vec!["msg-0", "msg-1"]);
assert!(records[0].provenance.original_token_estimate > 0);
let compressed: Vec<(u64, u64)> = recorder
.events()
.into_iter()
.filter_map(|r| match r.event {
AgentEvent::Compressed {
from_tokens,
to_tokens,
} => Some((from_tokens, to_tokens)),
_ => None,
})
.collect();
assert_eq!(compressed.len(), 1);
assert!(compressed[0].0 > 0);
assert!(compressed[0].1 > 0);
}
#[tokio::test]
async fn context_compression_none_window_falls_back_to_trigger_tokens() {
let policy = SummarizationPolicy {
trigger_tokens: 2,
keep_last: 1,
..SummarizationPolicy::default()
};
assert_eq!(policy.context_window, None);
let mw = Arc::new(ContextCompressionMiddleware::new(policy));
let mut stack: MiddlewareStack<()> = MiddlewareStack::new();
stack.push(mw.clone());
let mut c = ctx();
let mut request = ModelRequest {
messages: vec![user("aaaaaaaaaaaaaaaa"), user("bbbbbbbbbbbbbbbb")],
..Default::default()
};
stack
.run_before_model(&mut c, &(), &mut request)
.await
.unwrap();
assert_eq!(request.messages.len(), 2);
assert!(matches!(request.messages[0], Message::System(_)));
assert_eq!(request.messages[1].text(), "bbbbbbbbbbbbbbbb");
assert_eq!(mw.records().len(), 1);
}
#[tokio::test]
async fn usage_accounting_accumulates_across_calls() {
let mw = Arc::new(UsageAccountingMiddleware::new());
let mut stack: MiddlewareStack<()> = MiddlewareStack::new();
stack.push(mw.clone());
let mut c = ctx();
let mut r1 = response_with_usage(Usage::new(10, 5));
let mut r2 = response_with_usage(Usage::new(3, 2));
stack.run_after_model(&mut c, &(), &mut r1).await.unwrap();
stack.run_after_model(&mut c, &(), &mut r2).await.unwrap();
let totals = mw.totals();
assert_eq!(totals.calls, 2);
assert_eq!(totals.usage.input_tokens, 13);
assert_eq!(totals.usage.output_tokens, 7);
assert_eq!(totals.usage.total_tokens, 20);
}
#[tokio::test]
async fn prompt_cache_guard_detects_prefix_change() {
let mw = Arc::new(PromptCacheGuardMiddleware::new());
let mut stack: MiddlewareStack<()> = MiddlewareStack::new();
stack.push(mw.clone());
let mut c = ctx();
let mut req1 = ModelRequest {
cache_segments: vec![segment("sys", SegmentRole::System, true)],
..Default::default()
};
stack
.run_before_model(&mut c, &(), &mut req1)
.await
.unwrap();
assert!(mw.layout_events().is_empty(), "no prior layout to compare");
let mut req2 = ModelRequest {
cache_segments: vec![segment("sys2", SegmentRole::System, true)],
..Default::default()
};
stack
.run_before_model(&mut c, &(), &mut req2)
.await
.unwrap();
let events = mw.layout_events();
assert_eq!(events.len(), 1);
assert!(events[0].changed_prefix);
assert_eq!(events[0].segment_ids_before, vec!["sys".to_string()]);
assert_eq!(events[0].segment_ids_after, vec!["sys2".to_string()]);
}
#[tokio::test]
async fn agent_run_text_reflects_final_response() {
let mut run = AgentRun::new();
assert_eq!(run.text(), None);
run.final_response = Some(response_with_usage(Usage::new(1, 1)));
assert_eq!(run.text(), Some("ok".to_string()));
}