use crate::error::LlmError;
use crate::types::{ChatMessage, ChatRequest, CommonParams, ProviderParams, Tool};
#[derive(Debug, Clone, Default)]
pub struct RequestBuilderConfig {
pub strict_validation: bool,
pub provider_validation: bool,
}
pub trait RequestBuilder: Send + Sync {
fn build_chat_request(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<Tool>>,
stream: bool,
) -> Result<ChatRequest, LlmError>;
fn build_chat_request_with_config(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<Tool>>,
stream: bool,
config: &RequestBuilderConfig,
) -> Result<ChatRequest, LlmError> {
let _ = config; self.build_chat_request(messages, tools, stream)
}
fn get_common_params(&self) -> &CommonParams;
fn get_provider_params(&self) -> Option<ProviderParams>;
fn validate_request(&self, request: &ChatRequest) -> Result<(), LlmError> {
if request.messages.is_empty() {
return Err(LlmError::InvalidParameter(
"Messages cannot be empty".to_string(),
));
}
if request.common_params.model.is_empty() {
return Err(LlmError::InvalidParameter(
"Model must be specified".to_string(),
));
}
Ok(())
}
fn validate_configuration(&self) -> Result<(), LlmError> {
let common_params = self.get_common_params();
if common_params.model.is_empty() {
return Err(LlmError::ConfigurationError(
"Model must be specified".to_string(),
));
}
if let Some(temp) = common_params.temperature
&& !(0.0..=2.0).contains(&temp)
{
return Err(LlmError::ConfigurationError(format!(
"Temperature must be between 0.0 and 2.0, got {}",
temp
)));
}
if let Some(top_p) = common_params.top_p
&& !(0.0..=1.0).contains(&top_p)
{
return Err(LlmError::ConfigurationError(format!(
"top_p must be between 0.0 and 1.0, got {}",
top_p
)));
}
if let Some(max_tokens) = common_params.max_tokens
&& max_tokens == 0
{
return Err(LlmError::ConfigurationError(
"max_tokens must be greater than 0".to_string(),
));
}
Ok(())
}
fn get_validated_common_params(&self) -> Result<&CommonParams, LlmError> {
self.validate_configuration()?;
Ok(self.get_common_params())
}
fn get_validated_provider_params(&self) -> Result<Option<ProviderParams>, LlmError> {
self.validate_configuration()?;
Ok(self.get_provider_params())
}
}
#[derive(Clone)]
pub struct StandardRequestBuilder {
common_params: CommonParams,
provider_params: Option<ProviderParams>,
}
impl StandardRequestBuilder {
pub fn new(common_params: CommonParams, provider_params: Option<ProviderParams>) -> Self {
Self {
common_params,
provider_params,
}
}
pub fn build_standard_request(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<Tool>>,
stream: bool,
) -> Result<ChatRequest, LlmError> {
let request = ChatRequest {
messages,
tools,
common_params: self.common_params.clone(),
provider_params: self.provider_params.clone(),
http_config: None,
web_search: None,
stream,
};
self.validate_standard_request(&request)?;
Ok(request)
}
fn validate_standard_request(&self, request: &ChatRequest) -> Result<(), LlmError> {
if request.messages.is_empty() {
return Err(LlmError::InvalidParameter(
"Messages cannot be empty".to_string(),
));
}
if request.common_params.model.is_empty() {
return Err(LlmError::InvalidParameter(
"Model must be specified".to_string(),
));
}
Ok(())
}
}
impl RequestBuilder for StandardRequestBuilder {
fn build_chat_request(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<Tool>>,
stream: bool,
) -> Result<ChatRequest, LlmError> {
self.build_standard_request(messages, tools, stream)
}
fn get_common_params(&self) -> &CommonParams {
&self.common_params
}
fn get_provider_params(&self) -> Option<ProviderParams> {
self.provider_params.clone()
}
fn validate_request(&self, request: &ChatRequest) -> Result<(), LlmError> {
self.validate_standard_request(request)
}
}
#[derive(Clone)]
pub struct RequestBuilderFactory;
impl RequestBuilderFactory {
pub fn create_builder(
provider_type: &crate::types::ProviderType,
common_params: CommonParams,
provider_params: Option<ProviderParams>,
) -> Box<dyn RequestBuilder> {
match provider_type {
crate::types::ProviderType::OpenAi => {
Box::new(StandardRequestBuilder::new(common_params, provider_params))
}
crate::types::ProviderType::Anthropic => {
Box::new(StandardRequestBuilder::new(common_params, provider_params))
}
crate::types::ProviderType::Gemini => {
Box::new(StandardRequestBuilder::new(common_params, provider_params))
}
crate::types::ProviderType::Ollama => {
Box::new(StandardRequestBuilder::new(common_params, provider_params))
}
crate::types::ProviderType::XAI => {
Box::new(StandardRequestBuilder::new(common_params, provider_params))
}
crate::types::ProviderType::Groq => {
Box::new(StandardRequestBuilder::new(common_params, provider_params))
}
crate::types::ProviderType::Custom(name) => {
match name.as_str() {
"deepseek" | "openrouter" | "groq" | "xai" => {
Box::new(StandardRequestBuilder::new(common_params, provider_params))
}
_ => {
Box::new(StandardRequestBuilder::new(common_params, provider_params))
}
}
}
}
}
pub fn create_and_validate_builder(
provider_type: &crate::types::ProviderType,
common_params: CommonParams,
provider_params: Option<ProviderParams>,
) -> Result<Box<dyn RequestBuilder>, LlmError> {
let builder = Self::create_builder(provider_type, common_params, provider_params);
builder.validate_configuration()?;
Ok(builder)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{MessageContent, MessageRole};
fn assert_send_sync<T: Send + Sync>() {}
#[test]
fn test_request_builder_send_sync() {
assert_send_sync::<StandardRequestBuilder>();
assert_send_sync::<Box<dyn RequestBuilder>>();
}
#[test]
fn test_provider_specific_request_builders_send_sync() {
use crate::providers::anthropic::request::AnthropicRequestBuilder;
use crate::providers::gemini::request::GeminiRequestBuilder;
use crate::providers::openai::request::OpenAiRequestBuilder;
assert_send_sync::<OpenAiRequestBuilder>();
assert_send_sync::<AnthropicRequestBuilder>();
assert_send_sync::<GeminiRequestBuilder>();
}
#[test]
fn test_standard_request_builder() {
let common_params = CommonParams {
model: "test-model".to_string(),
temperature: Some(0.7),
..Default::default()
};
let builder = StandardRequestBuilder::new(common_params, None);
let messages = vec![crate::types::ChatMessage {
role: MessageRole::User,
content: MessageContent::Text("Hello".to_string()),
metadata: Default::default(),
tool_calls: None,
tool_call_id: None,
}];
let request = builder
.build_chat_request(messages, None, false)
.expect("Should build request successfully");
assert_eq!(request.common_params.model, "test-model");
assert!(!request.stream);
assert!(request.provider_params.is_none());
}
#[test]
fn test_request_validation() {
let common_params = CommonParams {
model: "".to_string(), ..Default::default()
};
let builder = StandardRequestBuilder::new(common_params, None);
let messages = vec![crate::types::ChatMessage {
role: MessageRole::User,
content: MessageContent::Text("Hello".to_string()),
metadata: Default::default(),
tool_calls: None,
tool_call_id: None,
}];
let result = builder.build_chat_request(messages, None, false);
assert!(result.is_err());
}
#[test]
fn test_configuration_validation() {
let valid_params = CommonParams {
model: "test-model".to_string(),
temperature: Some(0.7),
top_p: Some(0.9),
max_tokens: Some(1000),
..Default::default()
};
let builder = StandardRequestBuilder::new(valid_params, None);
assert!(builder.validate_configuration().is_ok());
let invalid_temp_params = CommonParams {
model: "test-model".to_string(),
temperature: Some(3.0), ..Default::default()
};
let invalid_builder = StandardRequestBuilder::new(invalid_temp_params, None);
assert!(invalid_builder.validate_configuration().is_err());
let invalid_top_p_params = CommonParams {
model: "test-model".to_string(),
top_p: Some(1.5), ..Default::default()
};
let invalid_top_p_builder = StandardRequestBuilder::new(invalid_top_p_params, None);
assert!(invalid_top_p_builder.validate_configuration().is_err());
let empty_model_params = CommonParams {
model: "".to_string(), ..Default::default()
};
let empty_model_builder = StandardRequestBuilder::new(empty_model_params, None);
assert!(empty_model_builder.validate_configuration().is_err());
}
#[test]
fn test_factory_validation() {
use crate::types::ProviderType;
let valid_params = CommonParams {
model: "test-model".to_string(),
temperature: Some(0.7),
..Default::default()
};
let result = RequestBuilderFactory::create_and_validate_builder(
&ProviderType::OpenAi,
valid_params,
None,
);
assert!(result.is_ok());
let invalid_params = CommonParams {
model: "".to_string(), ..Default::default()
};
let invalid_result = RequestBuilderFactory::create_and_validate_builder(
&ProviderType::OpenAi,
invalid_params,
None,
);
assert!(invalid_result.is_err());
}
}