Skip to main content

llm_stack/
test_helpers.rs

1//! Pre-built helpers for testing code that uses `llm-core` types.
2//!
3//! Available when the `test-utils` feature is enabled, allowing
4//! downstream crates to reuse these utilities in their own test
5//! suites. Also compiled during `#[cfg(test)]` for this crate's
6//! own tests. Provides sample responses, message shorthands, stream
7//! collectors, and a quick [`MockProvider`] factory.
8
9use std::collections::{HashMap, HashSet};
10
11use futures::StreamExt;
12
13use crate::chat::{ChatMessage, ChatResponse, ContentBlock, StopReason, ToolCall};
14use crate::error::LlmError;
15use crate::mock::MockProvider;
16use crate::provider::{Capability, ProviderMetadata};
17use crate::stream::{ChatStream, StreamEvent};
18use crate::usage::Usage;
19
20/// Builds a [`ChatResponse`] with a single text block and default usage.
21pub fn sample_response(text: &str) -> ChatResponse {
22    ChatResponse {
23        content: vec![ContentBlock::Text(text.into())],
24        usage: sample_usage(),
25        stop_reason: StopReason::EndTurn,
26        model: "test-model".into(),
27        metadata: HashMap::new(),
28    }
29}
30
31/// Builds a [`ChatResponse`] containing the given tool calls.
32pub fn sample_tool_response(calls: Vec<ToolCall>) -> ChatResponse {
33    ChatResponse {
34        content: calls.into_iter().map(ContentBlock::ToolCall).collect(),
35        usage: sample_usage(),
36        stop_reason: StopReason::ToolUse,
37        model: "test-model".into(),
38        metadata: HashMap::new(),
39    }
40}
41
42/// Returns a [`Usage`] with 100 input / 50 output tokens.
43pub fn sample_usage() -> Usage {
44    Usage {
45        input_tokens: 100,
46        output_tokens: 50,
47        reasoning_tokens: None,
48        cache_read_tokens: None,
49        cache_write_tokens: None,
50    }
51}
52
53/// Shorthand for [`ChatMessage::user`].
54pub fn user_msg(text: &str) -> ChatMessage {
55    ChatMessage::user(text)
56}
57
58/// Shorthand for [`ChatMessage::assistant`].
59pub fn assistant_msg(text: &str) -> ChatMessage {
60    ChatMessage::assistant(text)
61}
62
63/// Shorthand for [`ChatMessage::system`].
64pub fn system_msg(text: &str) -> ChatMessage {
65    ChatMessage::system(text)
66}
67
68/// Shorthand for [`ChatMessage::tool_result`].
69pub fn tool_result_msg(tool_call_id: &str, content: &str) -> ChatMessage {
70    ChatMessage::tool_result(tool_call_id, content)
71}
72
73/// Collect stream events, returning results including errors.
74pub async fn collect_stream_results(stream: ChatStream) -> Vec<Result<StreamEvent, LlmError>> {
75    stream.collect::<Vec<_>>().await
76}
77
78/// Collect stream events, panicking on any error.
79/// Use `collect_stream_results` when testing error scenarios.
80pub async fn collect_stream(stream: ChatStream) -> Vec<StreamEvent> {
81    stream
82        .collect::<Vec<_>>()
83        .await
84        .into_iter()
85        .map(|r| r.expect("stream event should be Ok"))
86        .collect()
87}
88
89/// Creates a [`MockProvider`] with the given name, model, and [`Capability::Tools`].
90pub fn mock_for(provider_name: &str, model: &str) -> MockProvider {
91    MockProvider::new(ProviderMetadata {
92        name: provider_name.to_owned().into(),
93        model: model.into(),
94        context_window: 128_000,
95        capabilities: HashSet::from([Capability::Tools]),
96    })
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102    use crate::chat::ChatRole;
103
104    #[test]
105    fn test_sample_response_is_valid() {
106        let r = sample_response("hello");
107        assert_eq!(r.content, vec![ContentBlock::Text("hello".into())]);
108        assert_eq!(r.stop_reason, StopReason::EndTurn);
109    }
110
111    #[test]
112    fn test_sample_tool_response() {
113        let calls = vec![ToolCall {
114            id: "tc_1".into(),
115            name: "search".into(),
116            arguments: serde_json::json!({"q": "rust"}),
117        }];
118        let r = sample_tool_response(calls);
119        assert_eq!(r.stop_reason, StopReason::ToolUse);
120        assert!(!r.content.is_empty());
121    }
122
123    #[test]
124    fn test_sample_usage_fields() {
125        let u = sample_usage();
126        assert!(u.input_tokens > 0);
127        assert!(u.output_tokens > 0);
128    }
129
130    #[test]
131    fn test_helper_messages() {
132        assert_eq!(user_msg("hi").role, ChatRole::User);
133        assert_eq!(assistant_msg("hello").role, ChatRole::Assistant);
134        assert_eq!(system_msg("be nice").role, ChatRole::System);
135        assert_eq!(tool_result_msg("tc_1", "42").role, ChatRole::Tool);
136    }
137
138    #[tokio::test]
139    async fn test_collect_stream_happy() {
140        let events = vec![
141            Ok(StreamEvent::TextDelta("hello".into())),
142            Ok(StreamEvent::Done {
143                stop_reason: StopReason::EndTurn,
144            }),
145        ];
146        let stream: ChatStream = Box::pin(futures::stream::iter(events));
147        let collected = collect_stream(stream).await;
148        assert_eq!(collected.len(), 2);
149    }
150
151    #[tokio::test]
152    async fn test_collect_stream_empty() {
153        let stream: ChatStream = Box::pin(futures::stream::iter(Vec::<
154            Result<StreamEvent, LlmError>,
155        >::new()));
156        let collected = collect_stream(stream).await;
157        assert!(collected.is_empty());
158    }
159
160    #[tokio::test]
161    async fn test_collect_stream_results_with_errors() {
162        let events = vec![
163            Ok(StreamEvent::TextDelta("hello".into())),
164            Err(LlmError::Http {
165                status: Some(http::StatusCode::INTERNAL_SERVER_ERROR),
166                message: "server error".into(),
167                retryable: true,
168            }),
169        ];
170        let stream: ChatStream = Box::pin(futures::stream::iter(events));
171        let collected = collect_stream_results(stream).await;
172        assert_eq!(collected.len(), 2);
173        assert!(collected[0].is_ok());
174        assert!(collected[1].is_err());
175    }
176
177    #[test]
178    fn test_mock_for_helper() {
179        let mock = mock_for("anthropic", "claude-sonnet-4");
180        let meta = crate::provider::Provider::metadata(&mock);
181        assert_eq!(meta.name, "anthropic");
182        assert_eq!(meta.model, "claude-sonnet-4");
183    }
184
185    #[test]
186    fn test_mock_for_custom_name() {
187        let mock = mock_for("my-custom-provider", "gpt-4");
188        let meta = crate::provider::Provider::metadata(&mock);
189        assert_eq!(meta.name, "my-custom-provider");
190    }
191}