use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use serde::de::DeserializeOwned;
use serde_json::Value;
use crate::backend::ModelInfo;
use crate::backend::client::{LLMClient, MediaFile};
use crate::backend::usage::{GenerateResult, MaterializeResult, TokenUsage};
use crate::error::{RStructorError, Result};
use crate::model::Instructor;
use crate::schema::SchemaType;
#[derive(Debug)]
pub enum MockResponse {
Text(String),
Error(RStructorError),
}
impl MockResponse {
pub fn json<T: serde::Serialize>(value: &T) -> Result<Self> {
serde_json::to_string(value)
.map(MockResponse::Text)
.map_err(|e| RStructorError::SerializationError(e.to_string()))
}
pub fn text(s: impl Into<String>) -> Self {
MockResponse::Text(s.into())
}
pub fn error(err: RStructorError) -> Self {
MockResponse::Error(err)
}
}
impl From<&str> for MockResponse {
fn from(s: &str) -> Self {
MockResponse::Text(s.to_string())
}
}
impl From<String> for MockResponse {
fn from(s: String) -> Self {
MockResponse::Text(s)
}
}
fn clone_error(e: &RStructorError) -> RStructorError {
match e {
RStructorError::ApiError { provider, kind } => RStructorError::ApiError {
provider: provider.clone(),
kind: kind.clone(),
},
RStructorError::ValidationError(s) => RStructorError::ValidationError(s.clone()),
RStructorError::SchemaError(s) => RStructorError::SchemaError(s.clone()),
RStructorError::SerializationError(s) => RStructorError::SerializationError(s.clone()),
RStructorError::Timeout => RStructorError::Timeout,
RStructorError::Unsupported(s) => RStructorError::Unsupported(s.clone()),
#[cfg(feature = "_client")]
RStructorError::HttpError(_) => RStructorError::Unsupported(e.to_string()),
RStructorError::JsonError(_) => RStructorError::SerializationError(e.to_string()),
}
}
impl Clone for MockResponse {
fn clone(&self) -> Self {
match self {
MockResponse::Text(s) => MockResponse::Text(s.clone()),
MockResponse::Error(e) => MockResponse::Error(clone_error(e)),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RequestKind {
Materialize,
MaterializeWithMetadata,
MaterializeWithMedia,
Generate,
GenerateWithMedia,
GenerateWithMetadata,
ListModels,
#[cfg(feature = "streaming")]
GenerateStream,
#[cfg(feature = "streaming")]
MaterializeStream,
#[cfg(feature = "streaming")]
MaterializeIter,
#[cfg(feature = "tools")]
RunToolLoop,
}
#[derive(Debug, Clone)]
pub struct RecordedRequest {
pub kind: RequestKind,
pub prompt: String,
pub schema: Option<Value>,
pub schema_name: Option<String>,
pub media: Vec<MediaFile>,
#[cfg(feature = "tools")]
pub tool_names: Vec<String>,
}
pub struct MockRequestView<'a> {
pub kind: RequestKind,
pub prompt: &'a str,
pub schema: Option<&'a Value>,
pub schema_name: Option<&'a str>,
pub media: &'a [MediaFile],
#[cfg(feature = "tools")]
pub tool_names: &'a [String],
}
impl<'a> MockRequestView<'a> {
fn bare(kind: RequestKind, prompt: &'a str) -> Self {
Self {
kind,
prompt,
schema: None,
schema_name: None,
media: &[],
#[cfg(feature = "tools")]
tool_names: &[],
}
}
fn to_recorded(&self) -> RecordedRequest {
RecordedRequest {
kind: self.kind,
prompt: self.prompt.to_string(),
schema: self.schema.cloned(),
schema_name: self.schema_name.map(str::to_string),
media: self.media.to_vec(),
#[cfg(feature = "tools")]
tool_names: self.tool_names.to_vec(),
}
}
}
type Responder = Box<dyn Fn(&MockRequestView) -> Option<MockResponse> + Send + Sync>;
struct MockState {
queue: Mutex<VecDeque<MockResponse>>,
responder: Mutex<Option<Responder>>,
log: Mutex<Vec<RecordedRequest>>,
models: Mutex<Vec<ModelInfo>>,
default_response: Mutex<MockResponse>,
default_usage: Mutex<Option<TokenUsage>>,
retries: Mutex<usize>,
#[cfg(feature = "tools")]
tool_script: Mutex<VecDeque<(String, Value)>>,
}
impl Default for MockState {
fn default() -> Self {
Self {
queue: Mutex::new(VecDeque::new()),
responder: Mutex::new(None),
log: Mutex::new(Vec::new()),
models: Mutex::new(vec![ModelInfo {
id: "mock-model".to_string(),
name: Some("Mock Model".to_string()),
description: Some("In-memory mock model".to_string()),
}]),
default_response: Mutex::new(MockResponse::Error(RStructorError::Unsupported(
"MockClient: no scripted response configured (use .with_response/.with_responder/.with_default_response)"
.to_string(),
))),
default_usage: Mutex::new(None),
retries: Mutex::new(0),
#[cfg(feature = "tools")]
tool_script: Mutex::new(VecDeque::new()),
}
}
}
#[derive(Clone)]
pub struct MockClient {
inner: Arc<MockState>,
}
impl Default for MockClient {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for MockClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MockClient")
.field("queued_responses", &self.inner.queue.lock().unwrap().len())
.field("recorded_requests", &self.inner.log.lock().unwrap().len())
.finish()
}
}
impl MockClient {
#[must_use]
pub fn new() -> Self {
Self {
inner: Arc::new(MockState::default()),
}
}
#[must_use]
pub fn with_response(self, resp: impl Into<MockResponse>) -> Self {
self.inner.queue.lock().unwrap().push_back(resp.into());
self
}
#[must_use]
pub fn with_responses<I>(self, resps: I) -> Self
where
I: IntoIterator,
I::Item: Into<MockResponse>,
{
let mut q = self.inner.queue.lock().unwrap();
for r in resps {
q.push_back(r.into());
}
drop(q);
self
}
pub fn with_json<T: serde::Serialize>(self, value: &T) -> Result<Self> {
let resp = MockResponse::json(value)?;
Ok(self.with_response(resp))
}
#[must_use]
pub fn with_error(self, err: RStructorError) -> Self {
self.with_response(MockResponse::Error(err))
}
pub fn push_response(&self, resp: impl Into<MockResponse>) {
self.inner.queue.lock().unwrap().push_back(resp.into());
}
pub fn push_error(&self, err: RStructorError) {
self.push_response(MockResponse::Error(err));
}
#[must_use]
pub fn with_responder<F>(self, f: F) -> Self
where
F: Fn(&MockRequestView) -> Option<MockResponse> + Send + Sync + 'static,
{
*self.inner.responder.lock().unwrap() = Some(Box::new(f));
self
}
#[must_use]
pub fn with_models(self, models: Vec<ModelInfo>) -> Self {
*self.inner.models.lock().unwrap() = models;
self
}
#[must_use]
pub fn with_default_response(self, resp: impl Into<MockResponse>) -> Self {
*self.inner.default_response.lock().unwrap() = resp.into();
self
}
#[must_use]
pub fn with_usage(self, usage: TokenUsage) -> Self {
*self.inner.default_usage.lock().unwrap() = Some(usage);
self
}
#[must_use]
pub fn with_retries(self, n: usize) -> Self {
*self.inner.retries.lock().unwrap() = n;
self
}
#[cfg(feature = "tools")]
#[must_use]
pub fn with_tool_script<I>(self, calls: I) -> Self
where
I: IntoIterator<Item = (String, Value)>,
{
let mut s = self.inner.tool_script.lock().unwrap();
for c in calls {
s.push_back(c);
}
drop(s);
self
}
#[must_use]
pub fn requests(&self) -> Vec<RecordedRequest> {
self.inner.log.lock().unwrap().clone()
}
#[must_use]
pub fn request_count(&self) -> usize {
self.inner.log.lock().unwrap().len()
}
#[must_use]
pub fn last_request(&self) -> Option<RecordedRequest> {
self.inner.log.lock().unwrap().last().cloned()
}
#[must_use]
pub fn responses_exhausted(&self) -> bool {
self.inner.queue.lock().unwrap().is_empty()
}
pub fn clear_requests(&self) {
self.inner.log.lock().unwrap().clear();
}
fn record(&self, view: &MockRequestView) {
self.inner.log.lock().unwrap().push(view.to_recorded());
}
fn pick_response(&self, view: &MockRequestView) -> MockResponse {
{
let guard = self.inner.responder.lock().unwrap();
if let Some(f) = guard.as_ref()
&& let Some(r) = f(view)
{
return r;
}
}
if let Some(r) = self.inner.queue.lock().unwrap().pop_front() {
return r;
}
self.inner.default_response.lock().unwrap().clone()
}
fn resolve_materialize<T>(&self, view: &MockRequestView) -> Result<T>
where
T: Instructor + DeserializeOwned,
{
let attempts = 1 + *self.inner.retries.lock().unwrap();
let mut last_err: Option<RStructorError> = None;
for _ in 0..attempts {
match self.pick_response(view) {
MockResponse::Text(s) => match parse_and_validate::<T>(&s) {
Ok(v) => return Ok(v),
Err(e) => last_err = Some(e),
},
MockResponse::Error(e) => return Err(e),
}
}
Err(last_err.unwrap_or_else(|| {
RStructorError::Unsupported("MockClient: no scripted response configured".to_string())
}))
}
}
fn parse_and_validate<T>(raw: &str) -> Result<T>
where
T: Instructor + DeserializeOwned,
{
let value: T = serde_json::from_str(raw).map_err(|e| {
RStructorError::ValidationError(format!(
"Failed to parse response as JSON: {e}\nPartial JSON: {raw}"
))
})?;
value.validate()?;
Ok(value)
}
#[async_trait]
impl LLMClient for MockClient {
async fn materialize<T>(&self, prompt: &str) -> Result<T>
where
T: Instructor + DeserializeOwned + Send + 'static,
{
let schema = <T as SchemaType>::schema().to_json();
let schema_name = <T as SchemaType>::schema_name();
let mut view = MockRequestView::bare(RequestKind::Materialize, prompt);
view.schema = Some(&schema);
view.schema_name = schema_name.as_deref();
self.record(&view);
self.resolve_materialize::<T>(&view)
}
async fn materialize_with_media<T>(&self, prompt: &str, media: &[MediaFile]) -> Result<T>
where
T: Instructor + DeserializeOwned + Send + 'static,
{
let schema = <T as SchemaType>::schema().to_json();
let schema_name = <T as SchemaType>::schema_name();
let mut view = MockRequestView::bare(RequestKind::MaterializeWithMedia, prompt);
view.schema = Some(&schema);
view.schema_name = schema_name.as_deref();
view.media = media;
self.record(&view);
self.resolve_materialize::<T>(&view)
}
async fn materialize_with_metadata<T>(&self, prompt: &str) -> Result<MaterializeResult<T>>
where
T: Instructor + DeserializeOwned + Send + 'static,
{
let schema = <T as SchemaType>::schema().to_json();
let schema_name = <T as SchemaType>::schema_name();
let mut view = MockRequestView::bare(RequestKind::MaterializeWithMetadata, prompt);
view.schema = Some(&schema);
view.schema_name = schema_name.as_deref();
self.record(&view);
let data = self.resolve_materialize::<T>(&view)?;
let usage = self.inner.default_usage.lock().unwrap().clone();
Ok(MaterializeResult { data, usage })
}
async fn generate(&self, prompt: &str) -> Result<String> {
let view = MockRequestView::bare(RequestKind::Generate, prompt);
self.record(&view);
match self.pick_response(&view) {
MockResponse::Text(s) => Ok(s),
MockResponse::Error(e) => Err(e),
}
}
async fn generate_with_media(&self, prompt: &str, media: &[MediaFile]) -> Result<String> {
let mut view = MockRequestView::bare(RequestKind::GenerateWithMedia, prompt);
view.media = media;
self.record(&view);
match self.pick_response(&view) {
MockResponse::Text(s) => Ok(s),
MockResponse::Error(e) => Err(e),
}
}
async fn generate_with_metadata(&self, prompt: &str) -> Result<GenerateResult> {
let view = MockRequestView::bare(RequestKind::GenerateWithMetadata, prompt);
self.record(&view);
let text = match self.pick_response(&view) {
MockResponse::Text(s) => s,
MockResponse::Error(e) => return Err(e),
};
let usage = self.inner.default_usage.lock().unwrap().clone();
Ok(GenerateResult { text, usage })
}
#[cfg(feature = "streaming")]
fn generate_stream<'a>(&'a self, prompt: &'a str) -> crate::backend::streaming::TextStream<'a>
where
Self: Sync,
{
let view = MockRequestView::bare(RequestKind::GenerateStream, prompt);
self.record(&view);
let resp = self.pick_response(&view);
Box::pin(async_stream::try_stream! {
let s = match resp {
MockResponse::Text(s) => s,
MockResponse::Error(e) => Err(e)?,
};
yield s;
})
}
#[cfg(feature = "streaming")]
fn materialize_stream<'a, T>(
&'a self,
prompt: &'a str,
) -> crate::backend::streaming::ObjectStream<'a, T>
where
T: Instructor + DeserializeOwned + Send + 'static,
Self: Sync,
{
use crate::backend::streaming::StreamedObject;
let schema = <T as SchemaType>::schema().to_json();
let schema_name = <T as SchemaType>::schema_name();
let mut view = MockRequestView::bare(RequestKind::MaterializeStream, prompt);
view.schema = Some(&schema);
view.schema_name = schema_name.as_deref();
self.record(&view);
let resp = self.pick_response(&view);
Box::pin(async_stream::try_stream! {
let s = match resp {
MockResponse::Text(s) => s,
MockResponse::Error(e) => Err(e)?,
};
let snapshot: Value = serde_json::from_str(&s).map_err(|e| {
RStructorError::ValidationError(format!(
"Failed to parse response as JSON: {e}\nPartial JSON: {s}"
))
})?;
yield StreamedObject::Partial(snapshot);
let value: T = parse_and_validate::<T>(&s)?;
yield StreamedObject::Complete(value);
})
}
#[cfg(feature = "streaming")]
fn materialize_iter<'a, T>(
&'a self,
prompt: &'a str,
) -> crate::backend::streaming::ItemStream<'a, T>
where
T: Instructor + DeserializeOwned + Send + 'static,
Self: Sync,
{
let schema = <T as SchemaType>::schema().to_json();
let schema_name = <T as SchemaType>::schema_name();
let mut view = MockRequestView::bare(RequestKind::MaterializeIter, prompt);
view.schema = Some(&schema);
view.schema_name = schema_name.as_deref();
self.record(&view);
let resp = self.pick_response(&view);
Box::pin(async_stream::try_stream! {
let s = match resp {
MockResponse::Text(s) => s,
MockResponse::Error(e) => Err(e)?,
};
let root: Value = serde_json::from_str(&s).map_err(|e| {
RStructorError::ValidationError(format!(
"Failed to parse response as JSON: {e}\nPartial JSON: {s}"
))
})?;
let items: Vec<Value> = if let Some(arr) = root.as_array() {
arr.clone()
} else if let Some(arr) = root.get("items").and_then(Value::as_array) {
arr.clone()
} else {
Err(RStructorError::ValidationError(
"MockClient::materialize_iter expects a JSON array or {\"items\": [...]}"
.to_string(),
))?
};
for item in items {
let value: T = crate::backend::streaming::finalize_item::<T>(item)?;
yield value;
}
})
}
fn from_env() -> Result<Self>
where
Self: Sized,
{
Ok(Self::new())
}
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
let view = MockRequestView::bare(RequestKind::ListModels, "");
self.record(&view);
Ok(self.inner.models.lock().unwrap().clone())
}
}
#[cfg(feature = "tools")]
#[async_trait]
impl crate::backend::tools::ToolRunner for MockClient {
async fn run_tool_loop(
&self,
_system: Option<&str>,
prompt: &str,
media: &[MediaFile],
toolbox: &crate::backend::tools::Toolbox,
_max_iterations: usize,
) -> Result<String> {
let tool_names = toolbox.tool_names();
let mut view = MockRequestView::bare(RequestKind::RunToolLoop, prompt);
view.media = media;
view.tool_names = &tool_names;
self.record(&view);
let script: Vec<(String, Value)> =
self.inner.tool_script.lock().unwrap().drain(..).collect();
for (name, args) in script {
match toolbox.get(&name) {
Some(tool) => {
tool.invoke_json(args).await?;
}
None => {
return Err(RStructorError::Unsupported(format!(
"MockClient tool script referenced unknown tool: {name}"
)));
}
}
}
match self.pick_response(&view) {
MockResponse::Text(s) => Ok(s),
MockResponse::Error(e) => Err(e),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Instructor;
use serde::{Deserialize, Serialize};
#[derive(Instructor, Serialize, Deserialize, Debug, PartialEq)]
#[llm(validate = "validate_movie")]
struct Movie {
title: String,
year: u16,
}
fn validate_movie(m: &Movie) -> Result<()> {
if m.year < 1888 {
return Err(RStructorError::ValidationError("year too early".into()));
}
Ok(())
}
#[tokio::test]
async fn materialize_returns_scripted_json() {
let client = MockClient::new().with_response(r#"{"title":"Inception","year":2010}"#);
let movie: Movie = client.materialize("p").await.unwrap();
assert_eq!(
movie,
Movie {
title: "Inception".into(),
year: 2010
}
);
}
#[tokio::test]
async fn materialize_runs_validate_and_fails() {
let client = MockClient::new().with_response(r#"{"title":"X","year":1700}"#);
let err = client.materialize::<Movie>("p").await.unwrap_err();
assert!(matches!(err, RStructorError::ValidationError(_)));
}
#[tokio::test]
async fn bad_json_is_validation_error() {
let client = MockClient::new().with_response("not json");
let err = client.materialize::<Movie>("p").await.unwrap_err();
assert!(matches!(err, RStructorError::ValidationError(_)));
}
#[tokio::test]
async fn retries_consume_next_response() {
let client = MockClient::new()
.with_response(r#"{"title":"X","year":1700}"#) .with_response(r#"{"title":"Dune","year":2021}"#) .with_retries(1);
let movie: Movie = client.materialize("p").await.unwrap();
assert_eq!(movie.year, 2021);
}
#[tokio::test]
async fn records_prompt_and_schema() {
let client = MockClient::new().with_response(r#"{"title":"A","year":2000}"#);
let _: Movie = client.materialize("the prompt").await.unwrap();
let req = client.last_request().unwrap();
assert_eq!(req.kind, RequestKind::Materialize);
assert_eq!(req.prompt, "the prompt");
assert_eq!(req.schema_name.as_deref(), Some("Movie"));
assert!(req.schema.is_some());
assert_eq!(client.request_count(), 1);
}
#[tokio::test]
async fn responder_closure_branches_on_prompt() {
let client = MockClient::new().with_responder(|req| {
if req.prompt.contains("movie") {
Some(MockResponse::text(r#"{"title":"Matrix","year":1999}"#))
} else {
None
}
});
let movie: Movie = client.materialize("a movie please").await.unwrap();
assert_eq!(movie.title, "Matrix");
}
#[tokio::test]
async fn error_response_returned_verbatim() {
let err = RStructorError::api_error("OpenAI", crate::ApiErrorKind::AuthenticationFailed);
let client = MockClient::new().with_error(err);
let got = client.generate("p").await.unwrap_err();
assert_eq!(
got,
RStructorError::api_error("OpenAI", crate::ApiErrorKind::AuthenticationFailed)
);
}
#[tokio::test]
async fn clone_shares_state() {
let client = MockClient::new();
let clone = client.clone();
clone.push_response(r#"{"title":"Shared","year":2020}"#);
let movie: Movie = client.materialize("p").await.unwrap();
assert_eq!(movie.title, "Shared");
}
#[tokio::test]
async fn default_after_exhaustion() {
let client = MockClient::new().with_response(r#"{"title":"A","year":2000}"#);
let _: Movie = client.materialize("p").await.unwrap();
assert!(client.responses_exhausted());
let err = client.materialize::<Movie>("p").await.unwrap_err();
assert!(matches!(err, RStructorError::Unsupported(_)));
}
#[tokio::test]
async fn from_env_needs_no_key() {
assert!(MockClient::from_env().is_ok());
}
#[tokio::test]
async fn metadata_carries_usage() {
let client = MockClient::new()
.with_response(r#"{"title":"A","year":2000}"#)
.with_usage(TokenUsage::new("mock-model", 10, 20));
let result = client
.materialize_with_metadata::<Movie>("p")
.await
.unwrap();
assert_eq!(result.usage.unwrap().total_tokens(), 30);
}
}