use std::pin::Pin;
use std::sync::Mutex;
use std::task::{Context, Poll};
use std::time::Duration;
use async_trait::async_trait;
use futures::Stream;
use futures::stream;
use crate::capabilities::{
Capabilities, ModelInfo, PromptCachingCapability, SystemPromptCapability, ToolUseCapability,
};
use crate::error::{Error, Result};
use crate::provider::Provider;
use crate::request::CompletionRequest;
use crate::response::{CompletionResponse, StopReason, Usage};
use crate::stream::{MessageStream, StreamEvent, StreamingContentType, StreamingDelta};
#[derive(Default)]
pub struct MockProvider {
inner: Mutex<MockState>,
}
enum MockEntry {
Events(Vec<Result<StreamEvent>>),
Error(Error),
Silent,
}
#[derive(Default)]
struct MockState {
complete_queue: Vec<Result<CompletionResponse>>,
stream_queue: Vec<MockEntry>,
capabilities: Option<Capabilities>,
models: Vec<ModelInfo>,
}
impl MockProvider {
pub fn new() -> Self {
Self::default()
}
pub fn enqueue_complete(&self, resp: Result<CompletionResponse>) {
self.inner
.lock()
.expect("MockProvider lock poisoned")
.complete_queue
.push(resp);
}
pub fn enqueue_stream(&self, events: Vec<Result<StreamEvent>>) {
self.inner
.lock()
.expect("MockProvider lock poisoned")
.stream_queue
.push(MockEntry::Events(events));
}
pub fn enqueue_stream_error(&self, err: Error) {
self.inner
.lock()
.expect("MockProvider lock poisoned")
.stream_queue
.push(MockEntry::Error(err));
}
pub fn enqueue_silent_stream(&self) {
self.inner
.lock()
.expect("MockProvider lock poisoned")
.stream_queue
.push(MockEntry::Silent);
}
#[must_use]
pub fn builder() -> MockProviderBuilder {
MockProviderBuilder::default()
}
pub fn set_capabilities(&self, caps: Capabilities) {
self.inner
.lock()
.expect("MockProvider lock poisoned")
.capabilities = Some(caps);
}
pub fn set_models(&self, models: Vec<ModelInfo>) {
self.inner
.lock()
.expect("MockProvider lock poisoned")
.models = models;
}
#[must_use]
pub fn for_tests_with_models(ids: &[&str]) -> Self {
let p = Self::new();
let caps = default_capabilities();
let models: Vec<ModelInfo> = ids
.iter()
.map(|id| ModelInfo {
id: (*id).to_string(),
native_id: (*id).to_string(),
display_name: (*id).to_string(),
capabilities: caps,
})
.collect();
p.set_models(models);
p
}
}
fn default_capabilities() -> Capabilities {
Capabilities {
max_input_tokens: 100_000,
max_output_tokens: 4_096,
vision: false,
tool_use: ToolUseCapability::Basic,
thinking: false,
prompt_caching: PromptCachingCapability::None,
json_mode: false,
streaming: true,
stop_sequences: true,
top_k: false,
system_prompt: SystemPromptCapability::SeparateField,
refusal_field: false,
}
}
#[async_trait]
impl Provider for MockProvider {
async fn complete(&self, _req: CompletionRequest) -> Result<CompletionResponse> {
let mut s = self.inner.lock().expect("MockProvider lock poisoned");
if s.complete_queue.is_empty() {
return Err(Error::InvalidRequest(
"MockProvider: complete queue empty".into(),
));
}
s.complete_queue.remove(0)
}
async fn stream(&self, _req: CompletionRequest) -> Result<MessageStream> {
let mut s = self.inner.lock().expect("MockProvider lock poisoned");
if s.stream_queue.is_empty() {
return Err(Error::InvalidRequest(
"MockProvider: stream queue empty".into(),
));
}
match s.stream_queue.remove(0) {
MockEntry::Error(e) => Err(e),
MockEntry::Events(events) => Ok(Box::pin(stream::iter(events))),
MockEntry::Silent => Ok(Box::pin(SilentStream)),
}
}
fn capabilities(&self, _model: &str) -> Capabilities {
self.inner
.lock()
.expect("MockProvider lock poisoned")
.capabilities
.unwrap_or_else(default_capabilities)
}
fn list_models(&self) -> Vec<ModelInfo> {
self.inner
.lock()
.expect("MockProvider lock poisoned")
.models
.clone()
}
fn name(&self) -> &'static str {
"mock"
}
}
struct SilentStream;
impl Stream for SilentStream {
type Item = Result<StreamEvent>;
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Poll::Pending
}
}
#[derive(Default)]
pub struct MockProviderBuilder {
entries: Vec<MockEntry>,
capabilities: Option<Capabilities>,
models: Vec<ModelInfo>,
}
impl MockProviderBuilder {
#[must_use]
pub fn with_capabilities(mut self, caps: Capabilities) -> Self {
self.capabilities = Some(caps);
self
}
#[must_use]
pub fn with_models(mut self, models: Vec<ModelInfo>) -> Self {
self.models = models;
self
}
#[must_use]
pub fn with_response_max_tokens(self, output_tokens: u32) -> Self {
self.with_response_stop_reason(StopReason::MaxTokens, "")
.with_output_tokens(output_tokens)
}
#[must_use]
pub fn with_response_end_turn(self, text: &str) -> Self {
self.with_response_stop_reason(StopReason::EndTurn, text)
}
#[must_use]
pub fn with_response_stop_reason(mut self, stop: StopReason, text: &str) -> Self {
let events = build_text_events(text, stop, 1);
self.entries.push(MockEntry::Events(events));
self
}
#[must_use]
pub fn with_output_tokens(mut self, output_tokens: u32) -> Self {
if let Some(MockEntry::Events(events)) = self.entries.last_mut() {
for evt in events.iter_mut() {
if let Ok(StreamEvent::MessageDelta {
usage_delta: Some(u),
..
}) = evt
{
u.output_tokens = output_tokens;
}
}
}
self
}
#[must_use]
pub fn with_error_once(mut self, err: Error) -> Self {
self.entries.push(MockEntry::Error(err));
self
}
#[must_use]
pub fn with_silent_stream(mut self, _min_silence: Duration) -> Self {
self.entries.push(MockEntry::Silent);
self
}
#[must_use]
pub fn build(self) -> MockProvider {
let provider = MockProvider::new();
{
let mut s = provider.inner.lock().expect("MockProvider lock poisoned");
s.stream_queue = self.entries;
s.capabilities = self.capabilities;
s.models = self.models;
}
provider
}
}
fn build_text_events(text: &str, stop: StopReason, output_tokens: u32) -> Vec<Result<StreamEvent>> {
let mut out = Vec::with_capacity(6);
out.push(Ok(StreamEvent::MessageStart {
id: "msg_mock".into(),
model: "mock".into(),
}));
out.push(Ok(StreamEvent::ContentBlockStart {
index: 0,
content_type: StreamingContentType::Text,
}));
if !text.is_empty() {
out.push(Ok(StreamEvent::Delta {
index: 0,
delta: StreamingDelta::Text(text.to_owned()),
}));
}
out.push(Ok(StreamEvent::ContentBlockStop { index: 0 }));
out.push(Ok(StreamEvent::MessageDelta {
stop_reason: Some(stop),
usage_delta: Some(Usage {
input_tokens: 1,
output_tokens,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
}),
}));
out.push(Ok(StreamEvent::MessageStop));
out
}