llm_stack/
test_helpers.rs1use std::collections::{HashMap, HashSet};
10
11use futures::StreamExt;
12
13use crate::chat::{ChatMessage, ChatResponse, ContentBlock, StopReason, ToolCall};
14use crate::error::LlmError;
15use crate::mock::MockProvider;
16use crate::provider::{Capability, ProviderMetadata};
17use crate::stream::{ChatStream, StreamEvent};
18use crate::usage::Usage;
19
20pub fn sample_response(text: &str) -> ChatResponse {
22 ChatResponse {
23 content: vec![ContentBlock::Text(text.into())],
24 usage: sample_usage(),
25 stop_reason: StopReason::EndTurn,
26 model: "test-model".into(),
27 metadata: HashMap::new(),
28 }
29}
30
31pub fn sample_tool_response(calls: Vec<ToolCall>) -> ChatResponse {
33 ChatResponse {
34 content: calls.into_iter().map(ContentBlock::ToolCall).collect(),
35 usage: sample_usage(),
36 stop_reason: StopReason::ToolUse,
37 model: "test-model".into(),
38 metadata: HashMap::new(),
39 }
40}
41
42pub fn sample_tool_response_with_text(text: &str, calls: Vec<ToolCall>) -> ChatResponse {
47 let mut content: Vec<ContentBlock> = vec![ContentBlock::Text(text.into())];
48 content.extend(calls.into_iter().map(ContentBlock::ToolCall));
49 ChatResponse {
50 content,
51 usage: sample_usage(),
52 stop_reason: StopReason::ToolUse,
53 model: "test-model".into(),
54 metadata: HashMap::new(),
55 }
56}
57
58pub fn sample_usage() -> Usage {
60 Usage {
61 input_tokens: 100,
62 output_tokens: 50,
63 reasoning_tokens: None,
64 cache_read_tokens: None,
65 cache_write_tokens: None,
66 }
67}
68
69pub fn user_msg(text: &str) -> ChatMessage {
71 ChatMessage::user(text)
72}
73
74pub fn assistant_msg(text: &str) -> ChatMessage {
76 ChatMessage::assistant(text)
77}
78
79pub fn system_msg(text: &str) -> ChatMessage {
81 ChatMessage::system(text)
82}
83
84pub fn tool_result_msg(tool_call_id: &str, content: &str) -> ChatMessage {
86 ChatMessage::tool_result(tool_call_id, content)
87}
88
89pub async fn collect_stream_results(stream: ChatStream) -> Vec<Result<StreamEvent, LlmError>> {
91 stream.collect::<Vec<_>>().await
92}
93
94pub async fn collect_stream(stream: ChatStream) -> Vec<StreamEvent> {
97 stream
98 .collect::<Vec<_>>()
99 .await
100 .into_iter()
101 .map(|r| r.expect("stream event should be Ok"))
102 .collect()
103}
104
105pub fn mock_for(provider_name: &str, model: &str) -> MockProvider {
107 MockProvider::new(ProviderMetadata {
108 name: provider_name.to_owned().into(),
109 model: model.into(),
110 context_window: 128_000,
111 capabilities: HashSet::from([Capability::Tools]),
112 })
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118 use crate::chat::ChatRole;
119
120 #[test]
121 fn test_sample_response_is_valid() {
122 let r = sample_response("hello");
123 assert_eq!(r.content, vec![ContentBlock::Text("hello".into())]);
124 assert_eq!(r.stop_reason, StopReason::EndTurn);
125 }
126
127 #[test]
128 fn test_sample_tool_response() {
129 let calls = vec![ToolCall {
130 id: "tc_1".into(),
131 name: "search".into(),
132 arguments: serde_json::json!({"q": "rust"}),
133 }];
134 let r = sample_tool_response(calls);
135 assert_eq!(r.stop_reason, StopReason::ToolUse);
136 assert!(!r.content.is_empty());
137 }
138
139 #[test]
140 fn test_sample_usage_fields() {
141 let u = sample_usage();
142 assert!(u.input_tokens > 0);
143 assert!(u.output_tokens > 0);
144 }
145
146 #[test]
147 fn test_helper_messages() {
148 assert_eq!(user_msg("hi").role, ChatRole::User);
149 assert_eq!(assistant_msg("hello").role, ChatRole::Assistant);
150 assert_eq!(system_msg("be nice").role, ChatRole::System);
151 assert_eq!(tool_result_msg("tc_1", "42").role, ChatRole::Tool);
152 }
153
154 #[tokio::test]
155 async fn test_collect_stream_happy() {
156 let events = vec![
157 Ok(StreamEvent::TextDelta("hello".into())),
158 Ok(StreamEvent::Done {
159 stop_reason: StopReason::EndTurn,
160 }),
161 ];
162 let stream: ChatStream = Box::pin(futures::stream::iter(events));
163 let collected = collect_stream(stream).await;
164 assert_eq!(collected.len(), 2);
165 }
166
167 #[tokio::test]
168 async fn test_collect_stream_empty() {
169 let stream: ChatStream = Box::pin(futures::stream::iter(Vec::<
170 Result<StreamEvent, LlmError>,
171 >::new()));
172 let collected = collect_stream(stream).await;
173 assert!(collected.is_empty());
174 }
175
176 #[tokio::test]
177 async fn test_collect_stream_results_with_errors() {
178 let events = vec![
179 Ok(StreamEvent::TextDelta("hello".into())),
180 Err(LlmError::Http {
181 status: Some(http::StatusCode::INTERNAL_SERVER_ERROR),
182 message: "server error".into(),
183 retryable: true,
184 }),
185 ];
186 let stream: ChatStream = Box::pin(futures::stream::iter(events));
187 let collected = collect_stream_results(stream).await;
188 assert_eq!(collected.len(), 2);
189 assert!(collected[0].is_ok());
190 assert!(collected[1].is_err());
191 }
192
193 #[test]
194 fn test_mock_for_helper() {
195 let mock = mock_for("anthropic", "claude-sonnet-4");
196 let meta = crate::provider::Provider::metadata(&mock);
197 assert_eq!(meta.name, "anthropic");
198 assert_eq!(meta.model, "claude-sonnet-4");
199 }
200
201 #[test]
202 fn test_mock_for_custom_name() {
203 let mock = mock_for("my-custom-provider", "gpt-4");
204 let meta = crate::provider::Provider::metadata(&mock);
205 assert_eq!(meta.name, "my-custom-provider");
206 }
207}