use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use crate::{
CapabilitySupport, ChatCapability, ChatProvider, ChatRequest, ChatResponse, ChatStream,
ChatStreamExt, ContentBlock, Error, ExtraMap, FinishReason, ProviderIdentity, ResponseMetadata,
Result, SingleResponseStream, StreamEvent, Usage,
};
#[derive(Debug, Clone)]
pub struct MockProvider {
state: Arc<Mutex<MockChatState>>,
chat_capabilities: HashMap<ChatCapability, CapabilitySupport>,
provider_name: &'static str,
}
impl MockProvider {
#[must_use]
pub fn build(configure: impl FnOnce(MockProviderBuilder) -> MockProviderBuilder) -> Self {
configure(MockProviderBuilder::new()).build()
}
#[must_use]
pub fn empty() -> Self {
Self::new(std::iter::empty::<MockResponse>())
}
#[must_use]
pub fn new<I, R>(responses: I) -> Self
where
I: IntoIterator<Item = R>,
R: Into<MockResponse>,
{
Self {
state: Arc::new(Mutex::new(MockChatState {
responses: responses.into_iter().map(Into::into).collect(),
requests: Vec::new(),
})),
chat_capabilities: HashMap::new(),
provider_name: "mock",
}
}
#[must_use]
pub fn with_response(response: ChatResponse) -> Self {
Self::new([response])
}
#[must_use]
pub fn with_error(error: Error) -> Self {
Self::new([MockResponse::Error(error)])
}
#[must_use]
pub fn with_text(text: impl Into<String>) -> Self {
Self::new([MockResponse::text(text)])
}
#[must_use]
pub fn with_tool_call(
id: impl Into<String>,
name: impl Into<String>,
arguments: serde_json::Value,
) -> Self {
Self::new([MockResponse::tool_call(id, name, arguments)])
}
#[must_use]
pub fn from_responses<I>(responses: I) -> Self
where
I: IntoIterator<Item = ChatResponse>,
{
Self::new(responses)
}
#[must_use]
pub fn tool_round_trip(
id: impl Into<String>,
name: impl Into<String>,
arguments: serde_json::Value,
final_text: impl Into<String>,
) -> Self {
Self::from_tool_round_trips([MockToolRoundTrip::single_tool_call_text(
id, name, arguments, final_text,
)])
}
#[must_use]
pub fn from_tool_round_trips<I>(round_trips: I) -> Self
where
I: IntoIterator<Item = MockToolRoundTrip>,
{
Self::new(
round_trips
.into_iter()
.flat_map(MockToolRoundTrip::into_responses),
)
}
#[must_use]
pub fn with_chat_capability(
mut self,
capability: ChatCapability,
support: CapabilitySupport,
) -> Self {
self.chat_capabilities.insert(capability, support);
self
}
#[must_use]
pub fn with_chat_capabilities<I>(mut self, capabilities: I) -> Self
where
I: IntoIterator<Item = (ChatCapability, CapabilitySupport)>,
{
self.chat_capabilities.extend(capabilities);
self
}
#[must_use]
pub fn with_supported_chat_capabilities<I>(mut self, capabilities: I) -> Self
where
I: IntoIterator<Item = ChatCapability>,
{
for capability in capabilities {
self.chat_capabilities
.insert(capability, CapabilitySupport::Supported);
}
self
}
#[must_use]
pub fn with_provider_name(mut self, provider_name: &'static str) -> Self {
self.provider_name = provider_name;
self
}
pub fn push_response<R>(&self, response: R)
where
R: Into<MockResponse>,
{
self.state
.lock()
.unwrap()
.responses
.push_back(response.into());
}
#[must_use]
pub fn requests(&self) -> Vec<ChatRequest> {
self.state.lock().unwrap().requests.clone()
}
#[must_use]
pub fn last_request(&self) -> Option<ChatRequest> {
self.state.lock().unwrap().requests.last().cloned()
}
#[must_use]
pub fn call_count(&self) -> usize {
self.state.lock().unwrap().requests.len()
}
#[must_use]
pub fn pending_responses(&self) -> usize {
self.state.lock().unwrap().responses.len()
}
}
impl ProviderIdentity for MockProvider {
fn provider_name(&self) -> &'static str {
self.provider_name
}
}
impl ChatProvider for MockProvider {
type Stream = SingleResponseStream;
async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse> {
let response = {
let mut state = self.state.lock().unwrap();
state.take_next_response(request)
};
response.into_result().await
}
async fn chat_stream(&self, request: &ChatRequest) -> Result<Self::Stream> {
Ok(self.chat(request).await?.into())
}
fn chat_capability(&self, _model: &str, capability: ChatCapability) -> CapabilitySupport {
self.chat_capabilities
.get(&capability)
.copied()
.unwrap_or(CapabilitySupport::Unknown)
}
}
#[cfg(feature = "extract")]
impl crate::ExtractExt for MockProvider {}
#[derive(Debug)]
pub enum MockResponse {
Success(ChatResponse),
Error(Error),
Delayed(Duration, Box<MockResponse>),
}
impl MockResponse {
#[must_use]
pub fn text(text: impl Into<String>) -> Self {
ChatResponseBuilder::new().text(text).build().into()
}
#[must_use]
pub fn reasoning_text(reasoning: impl Into<String>, text: impl Into<String>) -> Self {
ChatResponseBuilder::new()
.reasoning(reasoning)
.text(text)
.build()
.into()
}
#[must_use]
pub fn tool_call(
id: impl Into<String>,
name: impl Into<String>,
arguments: serde_json::Value,
) -> Self {
build_tool_call_response(id, name, arguments).into()
}
#[must_use]
pub fn delayed(duration: Duration, response: impl Into<MockResponse>) -> Self {
Self::Delayed(duration, Box::new(response.into()))
}
async fn into_result(self) -> Result<ChatResponse> {
let mut next = self;
loop {
match next {
Self::Success(response) => return Ok(response),
Self::Error(error) => return Err(error),
Self::Delayed(duration, response) => {
futures_timer::Delay::new(duration).await;
next = *response;
}
}
}
}
}
impl From<ChatResponse> for MockResponse {
fn from(response: ChatResponse) -> Self {
Self::Success(response)
}
}
impl From<Error> for MockResponse {
fn from(error: Error) -> Self {
Self::Error(error)
}
}
impl From<Result<ChatResponse>> for MockResponse {
fn from(result: Result<ChatResponse>) -> Self {
match result {
Ok(response) => Self::Success(response),
Err(error) => Self::Error(error),
}
}
}
#[derive(Debug, Clone)]
pub struct MockToolRoundTrip {
first: ChatResponse,
followup: ChatResponse,
}
impl MockToolRoundTrip {
#[must_use]
pub fn new(first: ChatResponse, followup: ChatResponse) -> Self {
Self { first, followup }
}
#[must_use]
pub fn single_tool_call(
id: impl Into<String>,
name: impl Into<String>,
arguments: serde_json::Value,
followup: ChatResponse,
) -> Self {
Self {
first: build_tool_call_response(id, name, arguments),
followup,
}
}
#[must_use]
pub fn single_tool_call_text(
id: impl Into<String>,
name: impl Into<String>,
arguments: serde_json::Value,
followup_text: impl Into<String>,
) -> Self {
Self::single_tool_call(
id,
name,
arguments,
ChatResponseBuilder::new().text(followup_text).build(),
)
}
fn into_responses(self) -> [ChatResponse; 2] {
[self.first, self.followup]
}
}
#[derive(Debug, Default)]
pub struct MockProviderBuilder {
responses: Vec<MockResponse>,
chat_capabilities: HashMap<ChatCapability, CapabilitySupport>,
provider_name: &'static str,
}
impl MockProviderBuilder {
#[must_use]
pub fn new() -> Self {
Self {
responses: Vec::new(),
chat_capabilities: HashMap::new(),
provider_name: "mock",
}
}
#[must_use]
pub fn response<R>(mut self, response: R) -> Self
where
R: Into<MockResponse>,
{
self.responses.push(response.into());
self
}
#[must_use]
pub fn text(self, text: impl Into<String>) -> Self {
self.response(MockResponse::text(text))
}
#[must_use]
pub fn error(self, error: Error) -> Self {
self.response(error)
}
#[must_use]
pub fn tool_call(
self,
id: impl Into<String>,
name: impl Into<String>,
arguments: serde_json::Value,
) -> Self {
self.response(MockResponse::tool_call(id, name, arguments))
}
#[must_use]
pub fn responses<I, R>(mut self, responses: I) -> Self
where
I: IntoIterator<Item = R>,
R: Into<MockResponse>,
{
self.responses.extend(responses.into_iter().map(Into::into));
self
}
#[must_use]
pub fn chat_capability(
mut self,
capability: ChatCapability,
support: CapabilitySupport,
) -> Self {
self.chat_capabilities.insert(capability, support);
self
}
#[must_use]
pub fn chat_capabilities<I>(mut self, capabilities: I) -> Self
where
I: IntoIterator<Item = (ChatCapability, CapabilitySupport)>,
{
self.chat_capabilities.extend(capabilities);
self
}
#[must_use]
pub fn supported_chat_capabilities<I>(mut self, capabilities: I) -> Self
where
I: IntoIterator<Item = ChatCapability>,
{
for capability in capabilities {
self.chat_capabilities
.insert(capability, CapabilitySupport::Supported);
}
self
}
#[must_use]
pub fn provider_name(mut self, provider_name: &'static str) -> Self {
self.provider_name = provider_name;
self
}
#[must_use]
pub fn build(self) -> MockProvider {
MockProvider::new(self.responses)
.with_chat_capabilities(self.chat_capabilities)
.with_provider_name(self.provider_name)
}
}
#[derive(Debug, Default)]
struct MockChatState {
responses: VecDeque<MockResponse>,
requests: Vec<ChatRequest>,
}
impl MockChatState {
fn take_next_response(&mut self, request: &ChatRequest) -> MockResponse {
self.requests.push(request.clone());
self.responses
.pop_front()
.expect("MockProvider: no more responses configured")
}
}
#[derive(Debug, Clone)]
pub struct MockStreamingProvider {
state: Arc<Mutex<MockStreamingState>>,
chat_capabilities: HashMap<ChatCapability, CapabilitySupport>,
provider_name: &'static str,
}
impl MockStreamingProvider {
#[must_use]
pub fn build(
configure: impl FnOnce(MockStreamingProviderBuilder) -> MockStreamingProviderBuilder,
) -> Self {
configure(MockStreamingProviderBuilder::new()).build()
}
#[must_use]
pub fn empty() -> Self {
Self::new(std::iter::empty::<Vec<MockStreamEvent>>())
}
#[must_use]
pub fn new<I, S, E>(streams: I) -> Self
where
I: IntoIterator<Item = S>,
S: IntoIterator<Item = E>,
E: Into<MockStreamEvent>,
{
Self {
state: Arc::new(Mutex::new(MockStreamingState {
streams: streams
.into_iter()
.map(|stream| stream.into_iter().map(Into::into).collect())
.collect(),
requests: Vec::new(),
})),
chat_capabilities: HashMap::from([
(ChatCapability::Streaming, CapabilitySupport::Supported),
(
ChatCapability::NativeStreaming,
CapabilitySupport::Supported,
),
]),
provider_name: "mock_stream",
}
}
#[must_use]
pub fn with_stream<S, E>(stream: S) -> Self
where
S: IntoIterator<Item = E>,
E: Into<MockStreamEvent>,
{
Self::new([stream])
}
#[must_use]
pub fn with_text(text: impl Into<String>) -> Self {
Self::with_stream(MockStreamEvent::text_response(text))
}
#[must_use]
pub fn with_tool_call(
id: impl Into<String>,
name: impl Into<String>,
arguments: serde_json::Value,
) -> Self {
Self::with_stream(MockStreamEvent::tool_call_response(id, name, arguments))
}
#[must_use]
pub fn from_response(response: ChatResponse) -> Self {
Self::with_stream(MockStreamEvent::from_response(response))
}
#[must_use]
pub fn from_responses<I>(responses: I) -> Self
where
I: IntoIterator<Item = ChatResponse>,
{
Self::new(responses.into_iter().map(MockStreamEvent::from_response))
}
#[must_use]
pub fn with_chat_capability(
mut self,
capability: ChatCapability,
support: CapabilitySupport,
) -> Self {
self.chat_capabilities.insert(capability, support);
self
}
#[must_use]
pub fn with_chat_capabilities<I>(mut self, capabilities: I) -> Self
where
I: IntoIterator<Item = (ChatCapability, CapabilitySupport)>,
{
self.chat_capabilities.extend(capabilities);
self
}
#[must_use]
pub fn with_supported_chat_capabilities<I>(mut self, capabilities: I) -> Self
where
I: IntoIterator<Item = ChatCapability>,
{
for capability in capabilities {
self.chat_capabilities
.insert(capability, CapabilitySupport::Supported);
}
self
}
#[must_use]
pub fn with_provider_name(mut self, provider_name: &'static str) -> Self {
self.provider_name = provider_name;
self
}
pub fn push_stream<S, E>(&self, stream: S)
where
S: IntoIterator<Item = E>,
E: Into<MockStreamEvent>,
{
self.state
.lock()
.unwrap()
.streams
.push_back(stream.into_iter().map(Into::into).collect());
}
#[must_use]
pub fn requests(&self) -> Vec<ChatRequest> {
self.state.lock().unwrap().requests.clone()
}
#[must_use]
pub fn last_request(&self) -> Option<ChatRequest> {
self.state.lock().unwrap().requests.last().cloned()
}
#[must_use]
pub fn call_count(&self) -> usize {
self.state.lock().unwrap().requests.len()
}
#[must_use]
pub fn pending_streams(&self) -> usize {
self.state.lock().unwrap().streams.len()
}
}
impl ProviderIdentity for MockStreamingProvider {
fn provider_name(&self) -> &'static str {
self.provider_name
}
}
impl ChatProvider for MockStreamingProvider {
type Stream = ChatStream;
async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse> {
self.chat_stream(request).await?.collect_response().await
}
async fn chat_stream(&self, request: &ChatRequest) -> Result<Self::Stream> {
let stream = {
let mut state = self.state.lock().unwrap();
state.take_next_stream(request)
};
let stream = futures_util::stream::unfold(stream, |mut pending| async move {
loop {
match pending.pop_front() {
Some(MockStreamEvent::Delay(duration)) => {
futures_timer::Delay::new(duration).await;
}
Some(MockStreamEvent::Event(event)) => return Some((Ok(event), pending)),
Some(MockStreamEvent::Error(error)) => return Some((Err(error), pending)),
None => return None,
}
}
});
Ok(Box::pin(stream) as ChatStream)
}
fn chat_capability(&self, _model: &str, capability: ChatCapability) -> CapabilitySupport {
self.chat_capabilities
.get(&capability)
.copied()
.unwrap_or(CapabilitySupport::Unknown)
}
}
#[cfg(feature = "extract")]
impl crate::ExtractExt for MockStreamingProvider {}
#[derive(Debug)]
pub enum MockStreamEvent {
Event(StreamEvent),
Error(Error),
Delay(Duration),
}
impl MockStreamEvent {
#[must_use]
pub fn delayed(duration: Duration) -> Self {
Self::Delay(duration)
}
#[must_use]
pub fn from_response(response: ChatResponse) -> Vec<Self> {
response_to_mock_stream(response)
}
#[must_use]
pub fn text_response(text: impl Into<String>) -> Vec<Self> {
Self::from_response(ChatResponseBuilder::new().text(text).build())
}
#[must_use]
pub fn reasoning_response(reasoning: impl Into<String>, text: impl Into<String>) -> Vec<Self> {
Self::from_response(
ChatResponseBuilder::new()
.reasoning(reasoning)
.text(text)
.build(),
)
}
#[must_use]
pub fn tool_call_response(
id: impl Into<String>,
name: impl Into<String>,
arguments: serde_json::Value,
) -> Vec<Self> {
Self::from_response(build_tool_call_response(id, name, arguments))
}
}
impl From<StreamEvent> for MockStreamEvent {
fn from(event: StreamEvent) -> Self {
Self::Event(event)
}
}
impl From<Error> for MockStreamEvent {
fn from(error: Error) -> Self {
Self::Error(error)
}
}
impl From<Result<StreamEvent>> for MockStreamEvent {
fn from(result: Result<StreamEvent>) -> Self {
match result {
Ok(event) => Self::Event(event),
Err(error) => Self::Error(error),
}
}
}
#[derive(Debug, Default)]
pub struct MockStreamingProviderBuilder {
streams: Vec<Vec<MockStreamEvent>>,
chat_capabilities: HashMap<ChatCapability, CapabilitySupport>,
provider_name: &'static str,
}
impl MockStreamingProviderBuilder {
#[must_use]
pub fn new() -> Self {
Self {
streams: Vec::new(),
chat_capabilities: HashMap::from([
(ChatCapability::Streaming, CapabilitySupport::Supported),
(
ChatCapability::NativeStreaming,
CapabilitySupport::Supported,
),
]),
provider_name: "mock_stream",
}
}
#[must_use]
pub fn stream<S, E>(mut self, stream: S) -> Self
where
S: IntoIterator<Item = E>,
E: Into<MockStreamEvent>,
{
self.streams
.push(stream.into_iter().map(Into::into).collect());
self
}
#[must_use]
pub fn text(self, text: impl Into<String>) -> Self {
self.stream(MockStreamEvent::text_response(text))
}
#[must_use]
pub fn tool_call(
self,
id: impl Into<String>,
name: impl Into<String>,
arguments: serde_json::Value,
) -> Self {
self.stream(MockStreamEvent::tool_call_response(id, name, arguments))
}
#[must_use]
pub fn chat_capability(
mut self,
capability: ChatCapability,
support: CapabilitySupport,
) -> Self {
self.chat_capabilities.insert(capability, support);
self
}
#[must_use]
pub fn chat_capabilities<I>(mut self, capabilities: I) -> Self
where
I: IntoIterator<Item = (ChatCapability, CapabilitySupport)>,
{
self.chat_capabilities.extend(capabilities);
self
}
#[must_use]
pub fn supported_chat_capabilities<I>(mut self, capabilities: I) -> Self
where
I: IntoIterator<Item = ChatCapability>,
{
for capability in capabilities {
self.chat_capabilities
.insert(capability, CapabilitySupport::Supported);
}
self
}
#[must_use]
pub fn provider_name(mut self, provider_name: &'static str) -> Self {
self.provider_name = provider_name;
self
}
#[must_use]
pub fn build(self) -> MockStreamingProvider {
MockStreamingProvider::new(self.streams)
.with_chat_capabilities(self.chat_capabilities)
.with_provider_name(self.provider_name)
}
}
#[derive(Debug, Default)]
struct MockStreamingState {
streams: VecDeque<VecDeque<MockStreamEvent>>,
requests: Vec<ChatRequest>,
}
impl MockStreamingState {
fn take_next_stream(&mut self, request: &ChatRequest) -> VecDeque<MockStreamEvent> {
self.requests.push(request.clone());
self.streams
.pop_front()
.expect("MockStreamingProvider: no more streams configured")
}
}
#[derive(Debug, Clone)]
pub struct ChatResponseBuilder {
content: Vec<ContentBlock>,
finish_reason: Option<FinishReason>,
usage: Option<Usage>,
model: Option<String>,
id: Option<String>,
metadata: ResponseMetadata,
}
impl ChatResponseBuilder {
#[must_use]
pub fn new() -> Self {
Self {
content: Vec::new(),
finish_reason: Some(FinishReason::Stop),
usage: None,
model: None,
id: None,
metadata: ResponseMetadata::new(),
}
}
#[must_use]
pub fn text(mut self, text: impl Into<String>) -> Self {
self.content.push(ContentBlock::Text { text: text.into() });
self
}
#[must_use]
pub fn tool_call(
mut self,
id: impl Into<String>,
name: impl Into<String>,
arguments: serde_json::Value,
) -> Self {
self.content.push(ContentBlock::ToolCall {
id: id.into(),
name: name.into(),
arguments: arguments.to_string(),
});
self
}
#[must_use]
pub fn reasoning(mut self, text: impl Into<String>) -> Self {
self.content.push(ContentBlock::Reasoning {
text: text.into(),
signature: None,
});
self
}
#[must_use]
pub fn reasoning_with_signature(
mut self,
text: impl Into<String>,
signature: impl Into<String>,
) -> Self {
self.content.push(ContentBlock::Reasoning {
text: text.into(),
signature: Some(signature.into()),
});
self
}
#[must_use]
pub fn other(mut self, type_name: impl Into<String>, data: ExtraMap) -> Self {
self.content.push(ContentBlock::Other {
type_name: type_name.into(),
data,
});
self
}
#[must_use]
pub fn finish_reason(mut self, reason: FinishReason) -> Self {
self.finish_reason = Some(reason);
self
}
#[must_use]
pub fn usage(mut self, input_tokens: u64, output_tokens: u64) -> Self {
self.usage = Some(Usage {
input_tokens: Some(input_tokens),
output_tokens: Some(output_tokens),
..Default::default()
});
self
}
#[must_use]
pub fn usage_value(mut self, usage: Usage) -> Self {
self.usage = Some(usage);
self
}
#[must_use]
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
#[must_use]
pub fn id(mut self, id: impl Into<String>) -> Self {
self.id = Some(id.into());
self
}
#[must_use]
pub fn metadata(mut self, metadata: ResponseMetadata) -> Self {
self.metadata = metadata;
self
}
#[must_use]
pub fn build(self) -> ChatResponse {
ChatResponse {
content: self.content,
finish_reason: self.finish_reason,
usage: self.usage,
model: self.model,
id: self.id,
metadata: self.metadata,
}
}
}
impl Default for ChatResponseBuilder {
fn default() -> Self {
Self::new()
}
}
fn build_tool_call_response(
id: impl Into<String>,
name: impl Into<String>,
arguments: serde_json::Value,
) -> ChatResponse {
ChatResponseBuilder::new()
.tool_call(id, name, arguments)
.finish_reason(FinishReason::ToolCalls)
.build()
}
fn response_to_mock_stream(response: ChatResponse) -> Vec<MockStreamEvent> {
response
.stream_events()
.map(MockStreamEvent::Event)
.collect()
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use crate::{ChatRequest, Message, ResponseMetadataType, StreamBlockType};
use super::*;
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize)]
struct DemoMetadata {
request_id: String,
}
impl ResponseMetadataType for DemoMetadata {
const KEY: &'static str = "demo";
}
fn assert_recorded_request(request: &ChatRequest, model: &str, message: &str) {
assert_eq!(request.model, model);
assert_eq!(request.messages, vec![Message::user(message)]);
}
#[tokio::test]
async fn mock_chat_provider_records_requests_and_returns_responses() {
let provider = MockProvider::new([
ChatResponseBuilder::new().text("first").build(),
ChatResponseBuilder::new().text("second").build(),
]);
let request = ChatRequest::new("mock-model").message(Message::user("hi"));
assert_eq!(
provider.chat(&request).await.unwrap().text(),
Some("first".into())
);
assert_eq!(
provider.chat(&request).await.unwrap().text(),
Some("second".into())
);
assert_eq!(provider.call_count(), 2);
let requests = provider.requests();
assert_eq!(requests.len(), 2);
assert_recorded_request(&requests[0], "mock-model", "hi");
assert_recorded_request(&requests[1], "mock-model", "hi");
assert_recorded_request(&provider.last_request().unwrap(), "mock-model", "hi");
assert_eq!(provider.pending_responses(), 0);
}
#[tokio::test]
async fn mock_chat_provider_empty_can_be_queued_incrementally() {
let provider = MockProvider::empty();
provider.push_response(MockResponse::text("first"));
provider.push_response(ChatResponseBuilder::new().text("second").build());
let request = ChatRequest::new("mock-model").message(Message::user("hi"));
assert_eq!(provider.pending_responses(), 2);
assert_eq!(
provider.chat(&request).await.unwrap().text(),
Some("first".into())
);
assert_eq!(
provider.chat(&request).await.unwrap().text(),
Some("second".into())
);
let requests = provider.requests();
assert_eq!(requests.len(), 2);
assert_recorded_request(&requests[0], "mock-model", "hi");
assert_recorded_request(&requests[1], "mock-model", "hi");
}
#[tokio::test]
async fn mock_chat_provider_supports_delayed_errors() {
let provider = MockProvider::new([MockResponse::delayed(
Duration::from_millis(1),
Error::Timeout("slow".into()),
)]);
let request = ChatRequest::new("mock-model").message(Message::user("hi"));
let error = provider.chat(&request).await.unwrap_err();
assert!(matches!(error, Error::Timeout(message) if message == "slow"));
}
#[test]
fn mock_response_tool_call_sets_tool_finish_reason() {
let response =
MockResponse::tool_call("call_1", "search", serde_json::json!({ "q": "rust" }));
match response {
MockResponse::Success(response) => {
assert!(response.has_tool_calls());
assert_eq!(response.finish_reason, Some(FinishReason::ToolCalls));
}
other => panic!("expected success response, got {other:?}"),
}
}
#[test]
fn mock_response_reasoning_text_builds_reasoning_and_text() {
let response = MockResponse::reasoning_text("thinking", "done");
match response {
MockResponse::Success(response) => {
assert_eq!(response.reasoning_text(), Some("thinking".into()));
assert_eq!(response.text(), Some("done".into()));
assert_eq!(response.finish_reason, Some(FinishReason::Stop));
}
other => panic!("expected success response, got {other:?}"),
}
}
#[tokio::test]
async fn mock_chat_provider_tool_round_trip_returns_tool_then_text() {
let provider = MockProvider::tool_round_trip(
"call_1",
"lookup_weather",
serde_json::json!({ "city": "San Francisco" }),
"Cool and foggy.",
);
let request = ChatRequest::new("mock-model").message(Message::user("weather?"));
let first = provider.chat(&request).await.unwrap();
let second = provider.chat(&request).await.unwrap();
assert!(first.has_tool_calls());
assert_eq!(first.finish_reason, Some(FinishReason::ToolCalls));
assert_eq!(second.text(), Some("Cool and foggy.".into()));
}
#[tokio::test]
async fn mock_chat_provider_from_tool_round_trips_flattens_multiple_turns() {
let provider = MockProvider::from_tool_round_trips([
MockToolRoundTrip::single_tool_call_text(
"call_weather_1",
"lookup_weather",
serde_json::json!({ "city": "San Francisco" }),
"Weather answer",
),
MockToolRoundTrip::single_tool_call_text(
"call_time_1",
"lookup_time",
serde_json::json!({ "city": "London" }),
"Time answer",
),
]);
let request = ChatRequest::new("mock-model").message(Message::user("hi"));
assert!(provider.chat(&request).await.unwrap().has_tool_calls());
assert_eq!(
provider.chat(&request).await.unwrap().text(),
Some("Weather answer".into())
);
assert!(provider.chat(&request).await.unwrap().has_tool_calls());
assert_eq!(
provider.chat(&request).await.unwrap().text(),
Some("Time answer".into())
);
}
#[tokio::test]
async fn mock_streaming_provider_collects_streams() {
let provider = MockStreamingProvider::new([vec![
StreamEvent::ResponseStart {
id: Some("resp_1".into()),
model: Some("mock-model".into()),
}
.into(),
StreamEvent::BlockStart {
index: 0,
block_type: StreamBlockType::Text,
id: None,
name: None,
type_name: None,
data: None,
}
.into(),
MockStreamEvent::delayed(Duration::from_millis(1)),
StreamEvent::TextDelta {
index: 0,
text: "hello".into(),
}
.into(),
StreamEvent::BlockStop { index: 0 }.into(),
StreamEvent::ResponseMetadata {
finish_reason: Some(FinishReason::Stop),
usage: Some(Usage {
input_tokens: Some(3),
output_tokens: Some(1),
..Default::default()
}),
usage_mode: crate::UsageMetadataMode::Snapshot,
id: Some("resp_1".into()),
model: Some("mock-model".into()),
metadata: crate::ExtraMap::new(),
}
.into(),
StreamEvent::ResponseStop.into(),
]]);
let request = ChatRequest::new("mock-model").message(Message::user("hi"));
let response = provider.chat(&request).await.unwrap();
assert_eq!(response.text(), Some("hello".into()));
assert_eq!(provider.call_count(), 1);
let requests = provider.requests();
assert_eq!(requests.len(), 1);
assert_recorded_request(&requests[0], "mock-model", "hi");
assert_recorded_request(&provider.last_request().unwrap(), "mock-model", "hi");
assert_eq!(provider.pending_streams(), 0);
}
#[tokio::test]
async fn mock_streaming_provider_empty_can_be_queued_incrementally() {
let provider = MockStreamingProvider::empty();
provider.push_stream(MockStreamEvent::text_response("first"));
provider.push_stream(MockStreamEvent::text_response("second"));
let request = ChatRequest::new("mock-model").message(Message::user("hi"));
assert_eq!(provider.pending_streams(), 2);
assert_eq!(
provider.chat(&request).await.unwrap().text(),
Some("first".into())
);
assert_eq!(
provider.chat(&request).await.unwrap().text(),
Some("second".into())
);
let requests = provider.requests();
assert_eq!(requests.len(), 2);
assert_recorded_request(&requests[0], "mock-model", "hi");
assert_recorded_request(&requests[1], "mock-model", "hi");
}
#[tokio::test]
async fn mock_stream_event_reasoning_response_builds_collectable_transcript() {
let provider = MockStreamingProvider::with_stream(MockStreamEvent::reasoning_response(
"thinking",
"streamed hello",
));
let request = ChatRequest::new("mock-model").message(Message::user("hi"));
let response = provider.chat(&request).await.unwrap();
assert_eq!(response.reasoning_text(), Some("thinking".into()));
assert_eq!(response.text(), Some("streamed hello".into()));
}
#[tokio::test]
async fn mock_stream_event_text_response_builds_collectable_transcript() {
let provider =
MockStreamingProvider::with_stream(MockStreamEvent::text_response("streamed hello"));
let request = ChatRequest::new("mock-model").message(Message::user("hi"));
let response = provider.chat(&request).await.unwrap();
assert_eq!(response.text(), Some("streamed hello".into()));
}
#[tokio::test]
async fn mock_streaming_provider_with_tool_call_convenience_sets_tool_finish_reason() {
let provider = MockStreamingProvider::with_tool_call(
"call_1",
"search",
serde_json::json!({ "q": "rust" }),
);
let request = ChatRequest::new("mock-model").message(Message::user("search"));
let response = provider.chat(&request).await.unwrap();
assert!(response.has_tool_calls());
assert_eq!(response.finish_reason, Some(FinishReason::ToolCalls));
}
#[tokio::test]
async fn mock_chat_provider_chat_stream_propagates_chat_errors() {
let provider = MockProvider::with_error(Error::Timeout("slow".to_string()));
let request = ChatRequest::new("mock-model").message(Message::user("hi"));
match provider.chat_stream(&request).await {
Err(Error::Timeout(message)) => assert_eq!(message, "slow"),
Err(other) => panic!("expected timeout error, got {other:?}"),
Ok(_) => panic!("expected chat_stream to return an error"),
}
}
#[tokio::test]
async fn mock_streaming_provider_from_responses_replays_multiple_transcripts() {
let provider = MockStreamingProvider::from_responses([
ChatResponseBuilder::new().text("first").build(),
ChatResponseBuilder::new().text("second").build(),
]);
let request = ChatRequest::new("mock-model").message(Message::user("hi"));
assert_eq!(
provider.chat(&request).await.unwrap().text(),
Some("first".into())
);
assert_eq!(
provider.chat(&request).await.unwrap().text(),
Some("second".into())
);
}
#[test]
fn mock_streaming_provider_defaults_to_native_streaming_supported() {
let provider = MockStreamingProvider::new(vec![vec![
MockStreamEvent::from(StreamEvent::ResponseStart {
id: None,
model: None,
}),
MockStreamEvent::from(StreamEvent::ResponseStop),
]]);
assert_eq!(
provider.chat_capability("test", ChatCapability::NativeStreaming),
CapabilitySupport::Supported
);
}
#[test]
fn mock_provider_defaults_to_native_streaming_unknown() {
let provider = MockProvider::with_response(ChatResponseBuilder::new().text("test").build());
assert_eq!(
provider.chat_capability("test", ChatCapability::NativeStreaming),
CapabilitySupport::Unknown
);
}
#[tokio::test]
async fn mock_streaming_provider_from_response_replays_normalized_stream() {
let provider = MockStreamingProvider::from_response(
ChatResponseBuilder::new()
.reasoning("thinking")
.text("done")
.usage(7, 2)
.model("mock-model")
.id("resp_stream_1")
.build(),
);
let request = ChatRequest::new("mock-model").message(Message::user("hi"));
let response = provider.chat(&request).await.unwrap();
assert_eq!(response.reasoning_text(), Some("thinking".into()));
assert_eq!(response.text(), Some("done".into()));
assert_eq!(response.id.as_deref(), Some("resp_stream_1"));
}
#[test]
fn response_builder_builds_tool_calls_and_usage() {
let mut metadata = ResponseMetadata::new();
metadata.insert(DemoMetadata {
request_id: "req_123".into(),
});
let mut data = ExtraMap::new();
data.insert("url".into(), serde_json::json!("https://example.com"));
let response = ChatResponseBuilder::new()
.text("Working on it. ")
.reasoning_with_signature("Thinking", "sig_123")
.tool_call("call_1", "search", serde_json::json!({ "q": "rust" }))
.other("citation", data)
.finish_reason(FinishReason::ToolCalls)
.usage_value(Usage {
input_tokens: Some(10),
output_tokens: Some(5),
total_tokens: Some(15),
reasoning_tokens: Some(2),
..Default::default()
})
.model("mock")
.id("resp_1")
.metadata(metadata)
.build();
assert_eq!(response.text(), Some("Working on it. ".into()));
assert_eq!(response.reasoning_text(), Some("Thinking".into()));
assert!(response.has_tool_calls());
assert_eq!(response.finish_reason, Some(FinishReason::ToolCalls));
assert_eq!(response.usage.unwrap().input_tokens, Some(10));
assert!(matches!(
response.content[1],
ContentBlock::Reasoning { ref signature, .. } if signature.as_deref() == Some("sig_123")
));
assert!(matches!(
response.content[3],
ContentBlock::Other { ref type_name, ref data }
if type_name == "citation"
&& data.get("url") == Some(&serde_json::json!("https://example.com"))
));
assert_eq!(
response.metadata.get::<DemoMetadata>(),
Some(&DemoMetadata {
request_id: "req_123".into()
})
);
}
}