use crate::{
boxed_stream::BoxedStream,
errors::{LanguageModelError, LanguageModelResult},
language_model::{LanguageModel, LanguageModelMetadata, LanguageModelStream},
LanguageModelInput, ModelResponse, PartialModelResponse,
};
use futures::{future::BoxFuture, stream};
use std::{collections::VecDeque, sync::Mutex};
pub enum MockGenerateResult {
Response(ModelResponse),
Error(LanguageModelError),
}
impl MockGenerateResult {
#[must_use]
pub fn response(response: ModelResponse) -> Self {
Self::Response(response)
}
#[must_use]
pub fn error(error: LanguageModelError) -> Self {
Self::Error(error)
}
}
impl From<ModelResponse> for MockGenerateResult {
fn from(response: ModelResponse) -> Self {
Self::response(response)
}
}
impl From<LanguageModelResult<ModelResponse>> for MockGenerateResult {
fn from(result: LanguageModelResult<ModelResponse>) -> Self {
match result {
Ok(response) => Self::Response(response),
Err(error) => Self::Error(error),
}
}
}
pub enum MockStreamResult {
Partials(Vec<PartialModelResponse>),
Error(LanguageModelError),
}
impl MockStreamResult {
#[must_use]
pub fn partials(partials: Vec<PartialModelResponse>) -> Self {
Self::Partials(partials)
}
#[must_use]
pub fn error(error: LanguageModelError) -> Self {
Self::Error(error)
}
}
impl From<Vec<PartialModelResponse>> for MockStreamResult {
fn from(partials: Vec<PartialModelResponse>) -> Self {
Self::partials(partials)
}
}
impl From<PartialModelResponse> for MockStreamResult {
fn from(partial: PartialModelResponse) -> Self {
Self::partials(vec![partial])
}
}
impl From<LanguageModelResult<Vec<PartialModelResponse>>> for MockStreamResult {
fn from(result: LanguageModelResult<Vec<PartialModelResponse>>) -> Self {
match result {
Ok(partials) => Self::Partials(partials),
Err(error) => Self::Error(error),
}
}
}
#[derive(Default)]
struct MockLanguageModelState {
mocked_generate_results: VecDeque<MockGenerateResult>,
mocked_stream_results: VecDeque<MockStreamResult>,
tracked_generate_inputs: Vec<LanguageModelInput>,
tracked_stream_inputs: Vec<LanguageModelInput>,
}
impl MockLanguageModelState {
fn enqueue_generate_result(&mut self, result: MockGenerateResult) {
self.mocked_generate_results.push_back(result);
}
fn enqueue_stream_result(&mut self, result: MockStreamResult) {
self.mocked_stream_results.push_back(result);
}
fn reset(&mut self) {
self.tracked_generate_inputs.clear();
self.tracked_stream_inputs.clear();
}
fn restore(&mut self) {
self.mocked_generate_results.clear();
self.mocked_stream_results.clear();
self.reset();
}
}
pub struct MockLanguageModel {
provider: &'static str,
model_id: String,
metadata: Option<LanguageModelMetadata>,
state: Mutex<MockLanguageModelState>,
}
impl Default for MockLanguageModel {
fn default() -> Self {
Self {
provider: "mock",
model_id: "mock-model".to_string(),
metadata: None,
state: Mutex::new(MockLanguageModelState::default()),
}
}
}
impl MockLanguageModel {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn set_provider(&mut self, provider: &'static str) {
self.provider = provider;
}
pub fn set_model_id<S: Into<String>>(&mut self, model_id: S) {
self.model_id = model_id.into();
}
pub fn set_metadata(&mut self, metadata: Option<LanguageModelMetadata>) {
self.metadata = metadata;
}
pub fn enqueue_generate_results<I>(&self, results: I) -> &Self
where
I: IntoIterator<Item = MockGenerateResult>,
{
let mut state = self.state.lock().expect("mock state poisoned");
for result in results {
state.enqueue_generate_result(result);
}
drop(state);
self
}
pub fn enqueue_generate<R>(&self, result: R) -> &Self
where
R: Into<MockGenerateResult>,
{
self.enqueue_generate_results(std::iter::once(result.into()))
}
pub fn enqueue_stream_results<I>(&self, results: I) -> &Self
where
I: IntoIterator<Item = MockStreamResult>,
{
let mut state = self.state.lock().expect("mock state poisoned");
for result in results {
state.enqueue_stream_result(result);
}
drop(state);
self
}
pub fn enqueue_stream<R>(&self, result: R) -> &Self
where
R: Into<MockStreamResult>,
{
self.enqueue_stream_results(std::iter::once(result.into()))
}
pub fn tracked_generate_inputs(&self) -> Vec<LanguageModelInput> {
let state = self.state.lock().expect("mock state poisoned");
state.tracked_generate_inputs.clone()
}
pub fn tracked_stream_inputs(&self) -> Vec<LanguageModelInput> {
let state = self.state.lock().expect("mock state poisoned");
state.tracked_stream_inputs.clone()
}
pub fn reset(&self) {
let mut state = self.state.lock().expect("mock state poisoned");
state.reset();
}
pub fn restore(&self) {
let mut state = self.state.lock().expect("mock state poisoned");
state.restore();
}
}
impl LanguageModel for MockLanguageModel {
fn provider(&self) -> &'static str {
self.provider
}
fn model_id(&self) -> String {
self.model_id.clone()
}
fn metadata(&self) -> Option<&LanguageModelMetadata> {
self.metadata.as_ref()
}
fn generate(
&self,
input: LanguageModelInput,
) -> BoxFuture<'_, LanguageModelResult<ModelResponse>> {
Box::pin(async move {
let mut state = self.state.lock().expect("mock state poisoned");
state.tracked_generate_inputs.push(input.clone());
let result = state.mocked_generate_results.pop_front().ok_or_else(|| {
LanguageModelError::Invariant(
self.provider,
"no mocked generate results available".into(),
)
})?;
match result {
MockGenerateResult::Response(response) => Ok(response),
MockGenerateResult::Error(error) => Err(error),
}
})
}
fn stream(
&self,
input: LanguageModelInput,
) -> BoxFuture<'_, LanguageModelResult<LanguageModelStream>> {
Box::pin(async move {
let mut state = self.state.lock().expect("mock state poisoned");
let result = state.mocked_stream_results.pop_front().ok_or_else(|| {
LanguageModelError::Invariant(
self.provider,
"no mocked stream results available".into(),
)
})?;
state.tracked_stream_inputs.push(input.clone());
match result {
MockStreamResult::Error(error) => Err(error),
MockStreamResult::Partials(partials) => {
let stream = stream_from_partials(partials);
Ok(stream)
}
}
})
}
}
fn stream_from_partials(partials: Vec<PartialModelResponse>) -> LanguageModelStream {
let iter = stream::iter(partials.into_iter().map(Ok));
BoxedStream::from_stream(iter)
}