Skip to main content

llm_stack/
test_helpers.rs

1//! Pre-built helpers for testing code that uses `llm-core` types.
2//!
3//! Available when the `test-utils` feature is enabled, allowing
4//! downstream crates to reuse these utilities in their own test
5//! suites. Also compiled during `#[cfg(test)]` for this crate's
6//! own tests. Provides sample responses, message shorthands, stream
7//! collectors, and a quick [`MockProvider`] factory.
8
9use std::collections::{HashMap, HashSet};
10
11use futures::StreamExt;
12
13use crate::chat::{ChatMessage, ChatResponse, ContentBlock, StopReason, ToolCall};
14use crate::error::LlmError;
15use crate::mock::MockProvider;
16use crate::provider::{Capability, ProviderMetadata};
17use crate::stream::{ChatStream, StreamEvent};
18use crate::usage::Usage;
19
20/// Builds a [`ChatResponse`] with a single text block and default usage.
21pub fn sample_response(text: &str) -> ChatResponse {
22    ChatResponse {
23        content: vec![ContentBlock::Text(text.into())],
24        usage: sample_usage(),
25        stop_reason: StopReason::EndTurn,
26        model: "test-model".into(),
27        metadata: HashMap::new(),
28    }
29}
30
31/// Builds a [`ChatResponse`] containing the given tool calls.
32pub fn sample_tool_response(calls: Vec<ToolCall>) -> ChatResponse {
33    ChatResponse {
34        content: calls.into_iter().map(ContentBlock::ToolCall).collect(),
35        usage: sample_usage(),
36        stop_reason: StopReason::ToolUse,
37        model: "test-model".into(),
38        metadata: HashMap::new(),
39    }
40}
41
42/// Builds a [`ChatResponse`] containing tool calls AND a text block.
43///
44/// Simulates an LLM that says something ("I'll help with that") alongside
45/// requesting tool calls.
46pub fn sample_tool_response_with_text(text: &str, calls: Vec<ToolCall>) -> ChatResponse {
47    let mut content: Vec<ContentBlock> = vec![ContentBlock::Text(text.into())];
48    content.extend(calls.into_iter().map(ContentBlock::ToolCall));
49    ChatResponse {
50        content,
51        usage: sample_usage(),
52        stop_reason: StopReason::ToolUse,
53        model: "test-model".into(),
54        metadata: HashMap::new(),
55    }
56}
57
58/// Returns a [`Usage`] with 100 input / 50 output tokens.
59pub fn sample_usage() -> Usage {
60    Usage {
61        input_tokens: 100,
62        output_tokens: 50,
63        reasoning_tokens: None,
64        cache_read_tokens: None,
65        cache_write_tokens: None,
66    }
67}
68
69/// Shorthand for [`ChatMessage::user`].
70pub fn user_msg(text: &str) -> ChatMessage {
71    ChatMessage::user(text)
72}
73
74/// Shorthand for [`ChatMessage::assistant`].
75pub fn assistant_msg(text: &str) -> ChatMessage {
76    ChatMessage::assistant(text)
77}
78
79/// Shorthand for [`ChatMessage::system`].
80pub fn system_msg(text: &str) -> ChatMessage {
81    ChatMessage::system(text)
82}
83
84/// Shorthand for [`ChatMessage::tool_result`].
85pub fn tool_result_msg(tool_call_id: &str, content: &str) -> ChatMessage {
86    ChatMessage::tool_result(tool_call_id, content)
87}
88
89/// Collect stream events, returning results including errors.
90pub async fn collect_stream_results(stream: ChatStream) -> Vec<Result<StreamEvent, LlmError>> {
91    stream.collect::<Vec<_>>().await
92}
93
94/// Collect stream events, panicking on any error.
95/// Use `collect_stream_results` when testing error scenarios.
96pub async fn collect_stream(stream: ChatStream) -> Vec<StreamEvent> {
97    stream
98        .collect::<Vec<_>>()
99        .await
100        .into_iter()
101        .map(|r| r.expect("stream event should be Ok"))
102        .collect()
103}
104
105/// Creates a [`MockProvider`] with the given name, model, and [`Capability::Tools`].
106pub fn mock_for(provider_name: &str, model: &str) -> MockProvider {
107    MockProvider::new(ProviderMetadata {
108        name: provider_name.to_owned().into(),
109        model: model.into(),
110        context_window: 128_000,
111        capabilities: HashSet::from([Capability::Tools]),
112    })
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use crate::chat::ChatRole;
119
120    #[test]
121    fn test_sample_response_is_valid() {
122        let r = sample_response("hello");
123        assert_eq!(r.content, vec![ContentBlock::Text("hello".into())]);
124        assert_eq!(r.stop_reason, StopReason::EndTurn);
125    }
126
127    #[test]
128    fn test_sample_tool_response() {
129        let calls = vec![ToolCall {
130            id: "tc_1".into(),
131            name: "search".into(),
132            arguments: serde_json::json!({"q": "rust"}),
133        }];
134        let r = sample_tool_response(calls);
135        assert_eq!(r.stop_reason, StopReason::ToolUse);
136        assert!(!r.content.is_empty());
137    }
138
139    #[test]
140    fn test_sample_usage_fields() {
141        let u = sample_usage();
142        assert!(u.input_tokens > 0);
143        assert!(u.output_tokens > 0);
144    }
145
146    #[test]
147    fn test_helper_messages() {
148        assert_eq!(user_msg("hi").role, ChatRole::User);
149        assert_eq!(assistant_msg("hello").role, ChatRole::Assistant);
150        assert_eq!(system_msg("be nice").role, ChatRole::System);
151        assert_eq!(tool_result_msg("tc_1", "42").role, ChatRole::Tool);
152    }
153
154    #[tokio::test]
155    async fn test_collect_stream_happy() {
156        let events = vec![
157            Ok(StreamEvent::TextDelta("hello".into())),
158            Ok(StreamEvent::Done {
159                stop_reason: StopReason::EndTurn,
160            }),
161        ];
162        let stream: ChatStream = Box::pin(futures::stream::iter(events));
163        let collected = collect_stream(stream).await;
164        assert_eq!(collected.len(), 2);
165    }
166
167    #[tokio::test]
168    async fn test_collect_stream_empty() {
169        let stream: ChatStream = Box::pin(futures::stream::iter(Vec::<
170            Result<StreamEvent, LlmError>,
171        >::new()));
172        let collected = collect_stream(stream).await;
173        assert!(collected.is_empty());
174    }
175
176    #[tokio::test]
177    async fn test_collect_stream_results_with_errors() {
178        let events = vec![
179            Ok(StreamEvent::TextDelta("hello".into())),
180            Err(LlmError::Http {
181                status: Some(http::StatusCode::INTERNAL_SERVER_ERROR),
182                message: "server error".into(),
183                retryable: true,
184            }),
185        ];
186        let stream: ChatStream = Box::pin(futures::stream::iter(events));
187        let collected = collect_stream_results(stream).await;
188        assert_eq!(collected.len(), 2);
189        assert!(collected[0].is_ok());
190        assert!(collected[1].is_err());
191    }
192
193    #[test]
194    fn test_mock_for_helper() {
195        let mock = mock_for("anthropic", "claude-sonnet-4");
196        let meta = crate::provider::Provider::metadata(&mock);
197        assert_eq!(meta.name, "anthropic");
198        assert_eq!(meta.model, "claude-sonnet-4");
199    }
200
201    #[test]
202    fn test_mock_for_custom_name() {
203        let mock = mock_for("my-custom-provider", "gpt-4");
204        let meta = crate::provider::Provider::metadata(&mock);
205        assert_eq!(meta.name, "my-custom-provider");
206    }
207}