use crate::client::AiClient;
use crate::error::{AiError, AiResult};
use crate::model::Model;
use crate::types::{ChatCompletionResponse, Message};
pub type VerifyFn = Box<dyn Fn(&str) -> Result<(), String> + Send + Sync>;
pub struct PromptBuilder<'a> {
client: &'a AiClient,
model: Model,
system_message: Option<String>,
user_messages: Vec<String>,
history: Vec<Message>,
temperature: Option<f32>,
max_tokens: Option<u32>,
verify_fn: Option<VerifyFn>,
max_retries: usize,
}
impl<'a> PromptBuilder<'a> {
pub fn new(client: &'a AiClient) -> Self {
Self {
client,
model: Model::default_general(),
system_message: None,
user_messages: Vec::new(),
history: Vec::new(),
temperature: None,
max_tokens: None,
verify_fn: None,
max_retries: 3,
}
}
pub fn model(mut self, model: Model) -> Self {
self.model = model;
self
}
pub fn system(mut self, message: impl Into<String>) -> Self {
self.system_message = Some(message.into());
self
}
pub fn user(mut self, message: impl Into<String>) -> Self {
self.user_messages.push(message.into());
self
}
pub fn prompt(self, message: impl Into<String>) -> Self {
self.user(message)
}
pub fn with_history(mut self, history: Vec<Message>) -> Self {
self.history = history;
self
}
pub fn temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn verify<F>(mut self, verify_fn: F) -> Self
where
F: Fn(&str) -> Result<(), String> + Send + Sync + 'static,
{
self.verify_fn = Some(Box::new(verify_fn));
self
}
pub fn max_retries(mut self, max_retries: usize) -> Self {
self.max_retries = max_retries;
self
}
fn build_messages(&self) -> Vec<Message> {
let mut messages = Vec::new();
if let Some(system) = &self.system_message {
messages.push(Message::system(system));
}
messages.extend(self.history.clone());
for user_msg in &self.user_messages {
messages.push(Message::user(user_msg));
}
messages
}
pub fn execute(&self) -> AiResult<ChatCompletionResponse> {
let messages = self.build_messages();
if messages.is_empty() {
return Err(AiError::InvalidRequest("No messages provided".to_string()));
}
self.client
.chat_with_options(self.model, messages, self.temperature, self.max_tokens)
}
pub fn execute_content(&self) -> AiResult<String> {
let response = self.execute()?;
response
.content()
.map(|s| s.to_string())
.ok_or_else(|| AiError::ParseError("No content in response".to_string()))
}
pub fn execute_verified(&self) -> AiResult<String> {
let Some(verify_fn) = &self.verify_fn else {
return self.execute_content();
};
let mut messages = self.build_messages();
let mut last_error = String::new();
for attempt in 0..=self.max_retries {
let response = self.client.chat_with_options(
self.model,
messages.clone(),
self.temperature,
self.max_tokens,
)?;
let content = response
.content()
.ok_or_else(|| AiError::ParseError("No content in response".to_string()))?;
match verify_fn(content) {
Ok(()) => return Ok(content.to_string()),
Err(feedback) => {
last_error = feedback.clone();
if attempt < self.max_retries {
messages.push(Message::assistant(content));
messages.push(Message::user(format!(
"The previous response was not acceptable. Please fix the following issue and try again:\n\n{}",
feedback
)));
}
}
}
}
Err(AiError::VerificationFailed {
retries: self.max_retries,
message: last_error,
})
}
}
pub trait PromptBuilderExt {
fn prompt(&self) -> PromptBuilder<'_>;
}
impl PromptBuilderExt for AiClient {
fn prompt(&self) -> PromptBuilder<'_> {
PromptBuilder::new(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prompt_builder_messages() {
let client = AiClient::new();
let builder = client
.prompt()
.system("You are a helpful assistant")
.user("Hello")
.user("How are you?");
let messages = builder.build_messages();
assert_eq!(messages.len(), 3);
assert_eq!(messages[0].content, "You are a helpful assistant");
assert_eq!(messages[1].content, "Hello");
assert_eq!(messages[2].content, "How are you?");
}
#[test]
fn test_prompt_builder_with_history() {
let client = AiClient::new();
let history = vec![
Message::user("Previous question"),
Message::assistant("Previous answer"),
];
let builder = client
.prompt()
.system("System")
.with_history(history)
.user("New question");
let messages = builder.build_messages();
assert_eq!(messages.len(), 4);
assert_eq!(messages[1].content, "Previous question");
assert_eq!(messages[2].content, "Previous answer");
assert_eq!(messages[3].content, "New question");
}
#[test]
fn test_empty_messages_error() {
let client = AiClient::new();
let builder = client.prompt();
let result = builder.execute();
assert!(matches!(result, Err(AiError::InvalidRequest(_))));
}
}