llm_stack/
test_helpers.rs1use std::collections::{HashMap, HashSet};
10
11use futures::StreamExt;
12
13use crate::chat::{ChatMessage, ChatResponse, ContentBlock, StopReason, ToolCall};
14use crate::error::LlmError;
15use crate::mock::MockProvider;
16use crate::provider::{Capability, ProviderMetadata};
17use crate::stream::{ChatStream, StreamEvent};
18use crate::usage::Usage;
19
20pub fn sample_response(text: &str) -> ChatResponse {
22 ChatResponse {
23 content: vec![ContentBlock::Text(text.into())],
24 usage: sample_usage(),
25 stop_reason: StopReason::EndTurn,
26 model: "test-model".into(),
27 metadata: HashMap::new(),
28 }
29}
30
31pub fn sample_tool_response(calls: Vec<ToolCall>) -> ChatResponse {
33 ChatResponse {
34 content: calls.into_iter().map(ContentBlock::ToolCall).collect(),
35 usage: sample_usage(),
36 stop_reason: StopReason::ToolUse,
37 model: "test-model".into(),
38 metadata: HashMap::new(),
39 }
40}
41
42pub fn sample_usage() -> Usage {
44 Usage {
45 input_tokens: 100,
46 output_tokens: 50,
47 reasoning_tokens: None,
48 cache_read_tokens: None,
49 cache_write_tokens: None,
50 }
51}
52
53pub fn user_msg(text: &str) -> ChatMessage {
55 ChatMessage::user(text)
56}
57
58pub fn assistant_msg(text: &str) -> ChatMessage {
60 ChatMessage::assistant(text)
61}
62
63pub fn system_msg(text: &str) -> ChatMessage {
65 ChatMessage::system(text)
66}
67
68pub fn tool_result_msg(tool_call_id: &str, content: &str) -> ChatMessage {
70 ChatMessage::tool_result(tool_call_id, content)
71}
72
73pub async fn collect_stream_results(stream: ChatStream) -> Vec<Result<StreamEvent, LlmError>> {
75 stream.collect::<Vec<_>>().await
76}
77
78pub async fn collect_stream(stream: ChatStream) -> Vec<StreamEvent> {
81 stream
82 .collect::<Vec<_>>()
83 .await
84 .into_iter()
85 .map(|r| r.expect("stream event should be Ok"))
86 .collect()
87}
88
89pub fn mock_for(provider_name: &str, model: &str) -> MockProvider {
91 MockProvider::new(ProviderMetadata {
92 name: provider_name.to_owned().into(),
93 model: model.into(),
94 context_window: 128_000,
95 capabilities: HashSet::from([Capability::Tools]),
96 })
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102 use crate::chat::ChatRole;
103
104 #[test]
105 fn test_sample_response_is_valid() {
106 let r = sample_response("hello");
107 assert_eq!(r.content, vec![ContentBlock::Text("hello".into())]);
108 assert_eq!(r.stop_reason, StopReason::EndTurn);
109 }
110
111 #[test]
112 fn test_sample_tool_response() {
113 let calls = vec![ToolCall {
114 id: "tc_1".into(),
115 name: "search".into(),
116 arguments: serde_json::json!({"q": "rust"}),
117 }];
118 let r = sample_tool_response(calls);
119 assert_eq!(r.stop_reason, StopReason::ToolUse);
120 assert!(!r.content.is_empty());
121 }
122
123 #[test]
124 fn test_sample_usage_fields() {
125 let u = sample_usage();
126 assert!(u.input_tokens > 0);
127 assert!(u.output_tokens > 0);
128 }
129
130 #[test]
131 fn test_helper_messages() {
132 assert_eq!(user_msg("hi").role, ChatRole::User);
133 assert_eq!(assistant_msg("hello").role, ChatRole::Assistant);
134 assert_eq!(system_msg("be nice").role, ChatRole::System);
135 assert_eq!(tool_result_msg("tc_1", "42").role, ChatRole::Tool);
136 }
137
138 #[tokio::test]
139 async fn test_collect_stream_happy() {
140 let events = vec![
141 Ok(StreamEvent::TextDelta("hello".into())),
142 Ok(StreamEvent::Done {
143 stop_reason: StopReason::EndTurn,
144 }),
145 ];
146 let stream: ChatStream = Box::pin(futures::stream::iter(events));
147 let collected = collect_stream(stream).await;
148 assert_eq!(collected.len(), 2);
149 }
150
151 #[tokio::test]
152 async fn test_collect_stream_empty() {
153 let stream: ChatStream = Box::pin(futures::stream::iter(Vec::<
154 Result<StreamEvent, LlmError>,
155 >::new()));
156 let collected = collect_stream(stream).await;
157 assert!(collected.is_empty());
158 }
159
160 #[tokio::test]
161 async fn test_collect_stream_results_with_errors() {
162 let events = vec![
163 Ok(StreamEvent::TextDelta("hello".into())),
164 Err(LlmError::Http {
165 status: Some(http::StatusCode::INTERNAL_SERVER_ERROR),
166 message: "server error".into(),
167 retryable: true,
168 }),
169 ];
170 let stream: ChatStream = Box::pin(futures::stream::iter(events));
171 let collected = collect_stream_results(stream).await;
172 assert_eq!(collected.len(), 2);
173 assert!(collected[0].is_ok());
174 assert!(collected[1].is_err());
175 }
176
177 #[test]
178 fn test_mock_for_helper() {
179 let mock = mock_for("anthropic", "claude-sonnet-4");
180 let meta = crate::provider::Provider::metadata(&mock);
181 assert_eq!(meta.name, "anthropic");
182 assert_eq!(meta.model, "claude-sonnet-4");
183 }
184
185 #[test]
186 fn test_mock_for_custom_name() {
187 let mock = mock_for("my-custom-provider", "gpt-4");
188 let meta = crate::provider::Provider::metadata(&mock);
189 assert_eq!(meta.name, "my-custom-provider");
190 }
191}