use crate::{
completion::{CompletionError, GetTokenUsage, Usage},
streaming::{RawStreamingChoice, RawStreamingToolCall, ToolCallDeltaContent},
};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct MockResponse {
usage: Option<Usage>,
}
impl MockResponse {
pub fn new() -> Self {
Self { usage: None }
}
pub fn with_usage(usage: Usage) -> Self {
Self { usage: Some(usage) }
}
pub fn with_total_tokens(total_tokens: u64) -> Self {
let mut usage = Usage::new();
usage.total_tokens = total_tokens;
Self::with_usage(usage)
}
}
impl GetTokenUsage for MockResponse {
fn token_usage(&self) -> Option<Usage> {
self.usage
}
}
#[derive(Clone, Debug)]
pub enum MockStreamEvent {
Text(String),
ToolCall {
id: String,
name: String,
arguments: serde_json::Value,
call_id: Option<String>,
},
ToolCallDelta {
id: String,
internal_call_id: String,
content: ToolCallDeltaContent,
},
MessageId(String),
FinalResponse(MockResponse),
Error(MockError),
}
use super::completion::MockError;
impl MockStreamEvent {
pub fn text(text: impl Into<String>) -> Self {
Self::Text(text.into())
}
pub fn tool_call(
id: impl Into<String>,
name: impl Into<String>,
arguments: serde_json::Value,
) -> Self {
Self::ToolCall {
id: id.into(),
name: name.into(),
arguments,
call_id: None,
}
}
pub fn with_call_id(mut self, call_id: impl Into<String>) -> Self {
if let Self::ToolCall { call_id: id, .. } = &mut self {
*id = Some(call_id.into());
}
self
}
pub fn tool_call_name_delta(
id: impl Into<String>,
internal_call_id: impl Into<String>,
name: impl Into<String>,
) -> Self {
Self::ToolCallDelta {
id: id.into(),
internal_call_id: internal_call_id.into(),
content: ToolCallDeltaContent::Name(name.into()),
}
}
pub fn tool_call_arguments_delta(
id: impl Into<String>,
internal_call_id: impl Into<String>,
arguments: impl Into<String>,
) -> Self {
Self::ToolCallDelta {
id: id.into(),
internal_call_id: internal_call_id.into(),
content: ToolCallDeltaContent::Delta(arguments.into()),
}
}
pub fn message_id(id: impl Into<String>) -> Self {
Self::MessageId(id.into())
}
pub fn final_response(usage: Usage) -> Self {
Self::FinalResponse(MockResponse::with_usage(usage))
}
pub fn final_response_with_default_usage() -> Self {
Self::FinalResponse(MockResponse::with_usage(Usage::new()))
}
pub fn final_response_with_total_tokens(total_tokens: u64) -> Self {
Self::FinalResponse(MockResponse::with_total_tokens(total_tokens))
}
pub fn error(message: impl Into<String>) -> Self {
Self::Error(MockError::provider(message))
}
pub(crate) fn into_raw_choice(
self,
) -> Result<RawStreamingChoice<MockResponse>, CompletionError> {
match self {
Self::Text(text) => Ok(RawStreamingChoice::Message(text)),
Self::ToolCall {
id,
name,
arguments,
call_id,
} => {
let mut tool_call = RawStreamingToolCall::new(id, name, arguments);
if let Some(call_id) = call_id {
tool_call = tool_call.with_call_id(call_id);
}
Ok(RawStreamingChoice::ToolCall(tool_call))
}
Self::ToolCallDelta {
id,
internal_call_id,
content,
} => Ok(RawStreamingChoice::ToolCallDelta {
id,
internal_call_id,
content,
}),
Self::MessageId(id) => Ok(RawStreamingChoice::MessageId(id)),
Self::FinalResponse(response) => Ok(RawStreamingChoice::FinalResponse(response)),
Self::Error(error) => Err(error.into_completion_error()),
}
}
}