use crate::{FierrosError, FierrosResult};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum MessageRole {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Message {
pub role: MessageRole,
pub content: String,
}
impl Message {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: MessageRole::System,
content: content.into(),
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: MessageRole::User,
content: content.into(),
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: MessageRole::Assistant,
content: content.into(),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct CompletionRequest {
pub messages: Vec<Message>,
pub temperature: f32,
pub max_tokens: Option<u32>,
}
impl CompletionRequest {
pub fn from_user(content: impl Into<String>) -> Self {
Self {
messages: vec![Message::user(content)],
temperature: 0.0,
max_tokens: None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct TokenUsage {
pub input_tokens: u32,
pub output_tokens: u32,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct CompletionResponse {
pub content: String,
pub usage: Option<TokenUsage>,
}
#[async_trait]
pub trait Llm: Send + Sync {
async fn complete(&self, request: CompletionRequest) -> FierrosResult<CompletionResponse>;
}
#[derive(Debug, Clone)]
pub struct MockLlm {
behavior: MockLlmBehavior,
}
#[derive(Debug, Clone)]
enum MockLlmBehavior {
Success {
response: String,
usage: Option<TokenUsage>,
},
Failure(FierrosError),
}
impl MockLlm {
pub fn new(response: impl Into<String>) -> Self {
Self {
behavior: MockLlmBehavior::Success {
response: response.into(),
usage: Some(TokenUsage {
input_tokens: 0,
output_tokens: 0,
}),
},
}
}
pub fn failing(error: FierrosError) -> Self {
Self {
behavior: MockLlmBehavior::Failure(error),
}
}
pub fn with_usage(mut self, usage: TokenUsage) -> Self {
if let MockLlmBehavior::Success {
usage: current_usage,
..
} = &mut self.behavior
{
*current_usage = Some(usage);
}
self
}
}
#[async_trait]
impl Llm for MockLlm {
async fn complete(&self, _request: CompletionRequest) -> FierrosResult<CompletionResponse> {
match &self.behavior {
MockLlmBehavior::Success { response, usage } => Ok(CompletionResponse {
content: response.clone(),
usage: usage.clone(),
}),
MockLlmBehavior::Failure(error) => Err(error.clone()),
}
}
}
#[cfg(test)]
mod tests {
use super::{CompletionRequest, Llm, Message, MessageRole, MockLlm, TokenUsage};
use crate::FierrosError;
#[test]
fn message_constructors_set_roles() {
assert_eq!(Message::user("x").content, "x");
assert_eq!(Message::system("s").content, "s");
assert_eq!(Message::assistant("a").role, MessageRole::Assistant);
}
#[tokio::test]
async fn mock_llm_returns_configured_response() {
let llm = MockLlm::new("answer");
let response = llm
.complete(CompletionRequest::from_user("question"))
.await
.unwrap();
assert_eq!(response.content, "answer");
assert_eq!(response.usage.unwrap().input_tokens, 0);
}
#[tokio::test]
async fn mock_llm_can_override_usage() {
let llm = MockLlm::new("answer").with_usage(TokenUsage {
input_tokens: 12,
output_tokens: 7,
});
let response = llm
.complete(CompletionRequest::from_user("question"))
.await
.unwrap();
assert_eq!(
response.usage,
Some(TokenUsage {
input_tokens: 12,
output_tokens: 7,
})
);
}
#[tokio::test]
async fn mock_llm_can_return_configured_error() {
let llm = MockLlm::failing(FierrosError::Provider("downstream unavailable".into()));
let error = llm
.complete(CompletionRequest::from_user("question"))
.await
.unwrap_err();
assert_eq!(
error,
FierrosError::Provider("downstream unavailable".into())
);
}
#[test]
fn completion_request_from_user_has_defaults() {
let request = CompletionRequest::from_user("question");
assert_eq!(request.messages.len(), 1);
assert_eq!(request.messages[0].role, MessageRole::User);
assert_eq!(request.temperature, 0.0);
assert_eq!(request.max_tokens, None);
}
}