use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum Role {
System,
#[default]
User,
Assistant,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: String,
}
impl Message {
pub fn user(content: impl Into<String>) -> Self {
Self {
role: Role::User,
content: content.into(),
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: Role::Assistant,
content: content.into(),
}
}
pub fn system(content: impl Into<String>) -> Self {
Self {
role: Role::System,
content: content.into(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct LlmRequest {
pub model: Option<String>,
pub prompt: Option<String>,
pub messages: Option<Vec<Message>>,
pub system: Option<String>,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_k: Option<u32>,
pub stop: Option<Vec<String>>,
pub frequency_penalty: Option<f32>,
pub presence_penalty: Option<f32>,
}
impl LlmRequest {
pub fn prompt(text: impl Into<String>) -> Self {
Self {
prompt: Some(text.into()),
..Default::default()
}
}
pub fn chat(messages: Vec<Message>) -> Self {
Self {
messages: Some(messages),
..Default::default()
}
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn with_system(mut self, system: impl Into<String>) -> Self {
self.system = Some(system.into());
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_stop(mut self, stop: Vec<String>) -> Self {
self.stop = Some(stop);
self
}
pub fn to_messages(&self) -> Vec<Message> {
let mut messages = Vec::new();
if let Some(ref system) = self.system {
messages.push(Message::system(system.clone()));
}
if let Some(ref msgs) = self.messages {
messages.extend(msgs.clone());
} else if let Some(ref prompt) = self.prompt {
messages.push(Message::user(prompt.clone()));
}
messages
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_constructors() {
let user = Message::user("Hello");
assert_eq!(user.role, Role::User);
assert_eq!(user.content, "Hello");
let assistant = Message::assistant("Hi there!");
assert_eq!(assistant.role, Role::Assistant);
let system = Message::system("Be helpful");
assert_eq!(system.role, Role::System);
}
#[test]
fn test_request_prompt() {
let request = LlmRequest::prompt("Tell me a joke");
assert_eq!(request.prompt, Some("Tell me a joke".to_string()));
assert!(request.messages.is_none());
}
#[test]
fn test_request_chat() {
let messages = vec![
Message::user("Hello"),
Message::assistant("Hi!"),
Message::user("How are you?"),
];
let request = LlmRequest::chat(messages.clone());
assert_eq!(request.messages.as_ref().unwrap().len(), 3);
}
#[test]
fn test_request_builder() {
let request = LlmRequest::prompt("Test")
.with_model("gpt-4o-mini")
.with_system("Be concise")
.with_max_tokens(100)
.with_temperature(0.5);
assert_eq!(request.model, Some("gpt-4o-mini".to_string()));
assert_eq!(request.system, Some("Be concise".to_string()));
assert_eq!(request.max_tokens, Some(100));
assert_eq!(request.temperature, Some(0.5));
}
#[test]
fn test_to_messages_with_system() {
let request = LlmRequest::prompt("Hello").with_system("Be helpful");
let messages = request.to_messages();
assert_eq!(messages.len(), 2);
assert_eq!(messages[0].role, Role::System);
assert_eq!(messages[1].role, Role::User);
}
}