use crate::error::LlmError;
use crate::params::gemini::GeminiParams;
use crate::request_factory::{RequestBuilder, RequestBuilderConfig};
use crate::types::{ChatMessage, ChatRequest, CommonParams, ProviderParams, Tool};
pub trait GeminiParameterMapper {
fn map_common_to_gemini(&self, params: &CommonParams) -> serde_json::Value;
fn merge_gemini_params(
&self,
base: serde_json::Value,
gemini_params: &GeminiParams,
) -> serde_json::Value;
fn validate_gemini_params(&self, params: &serde_json::Value) -> Result<(), LlmError>;
}
#[derive(Clone)]
pub struct GeminiRequestBuilder {
common_params: CommonParams,
gemini_params: GeminiParams,
}
impl GeminiParameterMapper for GeminiRequestBuilder {
fn map_common_to_gemini(&self, params: &CommonParams) -> serde_json::Value {
let mut json = serde_json::json!({
"model": params.model
});
if let Some(temp) = params.temperature {
json["temperature"] = temp.into();
}
if let Some(max_tokens) = params.max_tokens {
json["maxOutputTokens"] = max_tokens.into();
}
if let Some(top_p) = params.top_p {
json["topP"] = top_p.into();
}
if let Some(stop) = ¶ms.stop_sequences {
json["stopSequences"] = stop.clone().into();
}
json
}
fn merge_gemini_params(
&self,
mut base: serde_json::Value,
gemini_params: &GeminiParams,
) -> serde_json::Value {
if let Ok(gemini_json) = serde_json::to_value(gemini_params)
&& let Some(gemini_obj) = gemini_json.as_object()
&& let Some(base_obj) = base.as_object_mut()
{
for (key, value) in gemini_obj {
if !value.is_null() {
base_obj.insert(key.clone(), value.clone());
}
}
}
base
}
fn validate_gemini_params(&self, params: &serde_json::Value) -> Result<(), LlmError> {
self.validate_gemini_params_with_config(params, &RequestBuilderConfig::default())
}
}
impl GeminiRequestBuilder {
fn validate_gemini_params_with_config(
&self,
params: &serde_json::Value,
config: &RequestBuilderConfig,
) -> Result<(), LlmError> {
if !config.provider_validation {
return Ok(());
}
if let Some(temp) = params.get("temperature").and_then(|v| v.as_f64())
&& !(0.0..=2.0).contains(&temp)
{
return Err(LlmError::InvalidParameter(
"Gemini temperature must be between 0.0 and 2.0 per official API spec (validation can be disabled)".to_string(),
));
}
if let Some(top_p) = params.get("topP").and_then(|v| v.as_f64())
&& !(0.0..=1.0).contains(&top_p)
{
return Err(LlmError::InvalidParameter(
"Gemini topP must be between 0.0 and 1.0 per official API spec (validation can be disabled)".to_string(),
));
}
if let Some(max_tokens) = params.get("maxOutputTokens").and_then(|v| v.as_i64()) {
if max_tokens <= 0 {
return Err(LlmError::InvalidParameter(
"Gemini maxOutputTokens must be positive per official API spec (validation can be disabled)".to_string(),
));
}
if max_tokens > 8192 {
return Err(LlmError::InvalidParameter(
"Gemini maxOutputTokens cannot exceed 8192 per official API spec (validation can be disabled)".to_string(),
));
}
}
if let Some(thinking_budget) = params.get("thinking_budget").and_then(|v| v.as_i64()) {
if thinking_budget < 1024 {
return Err(LlmError::InvalidParameter(
"Gemini thinking budget must be at least 1024 tokens per official API spec (validation can be disabled)".to_string(),
));
}
if thinking_budget > 32768 {
return Err(LlmError::InvalidParameter(
"Gemini thinking budget cannot exceed 32768 tokens per official API spec (validation can be disabled)".to_string(),
));
}
}
Ok(())
}
pub fn new(common_params: CommonParams, gemini_params: GeminiParams) -> Self {
Self {
common_params,
gemini_params,
}
}
fn create_provider_params(&self) -> ProviderParams {
ProviderParams::from_gemini(self.gemini_params.clone())
}
fn validate_gemini_request(&self, request: &ChatRequest) -> Result<(), LlmError> {
let model = &request.common_params.model;
if model.is_empty() {
return Err(LlmError::InvalidParameter(
"Model name is required for Gemini".to_string(),
));
}
if !model.starts_with("gemini-") {
return Err(LlmError::InvalidParameter(
"Gemini model names should start with 'gemini-'".to_string(),
));
}
Ok(())
}
}
impl RequestBuilder for GeminiRequestBuilder {
fn build_chat_request(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<Tool>>,
stream: bool,
) -> Result<ChatRequest, LlmError> {
self.build_chat_request_with_config(
messages,
tools,
stream,
&RequestBuilderConfig::default(),
)
}
fn build_chat_request_with_config(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<Tool>>,
stream: bool,
config: &RequestBuilderConfig,
) -> Result<ChatRequest, LlmError> {
let mut params_json = self.map_common_to_gemini(&self.common_params);
params_json = self.merge_gemini_params(params_json, &self.gemini_params);
self.validate_gemini_params_with_config(¶ms_json, config)?;
let request = ChatRequest {
messages,
tools,
common_params: self.common_params.clone(),
provider_params: Some(self.create_provider_params()),
http_config: None,
web_search: None,
stream,
};
self.validate_request(&request)?;
self.validate_gemini_request(&request)?;
Ok(request)
}
fn get_common_params(&self) -> &CommonParams {
&self.common_params
}
fn get_provider_params(&self) -> Option<ProviderParams> {
Some(self.create_provider_params())
}
fn validate_request(&self, request: &ChatRequest) -> Result<(), LlmError> {
if request.messages.is_empty() {
return Err(LlmError::InvalidParameter(
"Messages cannot be empty".to_string(),
));
}
if request.common_params.model.is_empty() {
return Err(LlmError::InvalidParameter(
"Model must be specified".to_string(),
));
}
self.validate_gemini_request(request)?;
Ok(())
}
}
pub fn create_gemini_request_builder(
common_params: CommonParams,
gemini_params: GeminiParams,
) -> GeminiRequestBuilder {
GeminiRequestBuilder::new(common_params, gemini_params)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{MessageContent, MessageRole};
#[test]
fn test_gemini_parameter_mapping() {
let common_params = CommonParams {
model: "gemini-1.5-pro".to_string(),
temperature: Some(0.7),
max_tokens: Some(1000),
top_p: Some(0.9),
stop_sequences: Some(vec!["STOP".to_string()]),
seed: Some(42), };
let gemini_params = GeminiParams::default();
let builder = GeminiRequestBuilder::new(common_params.clone(), gemini_params);
let mapped = builder.map_common_to_gemini(&common_params);
assert_eq!(mapped["model"], "gemini-1.5-pro");
assert!((mapped["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
assert_eq!(mapped["maxOutputTokens"], 1000); assert!((mapped["topP"].as_f64().unwrap() - 0.9).abs() < 0.001); assert_eq!(mapped["stopSequences"], serde_json::json!(["STOP"]));
assert!(mapped.get("seed").is_none());
}
#[test]
fn test_gemini_validation() {
let common_params = CommonParams {
model: "gemini-1.5-pro".to_string(),
..Default::default()
};
let gemini_params = GeminiParams::default();
let _builder = GeminiRequestBuilder::new(common_params, gemini_params);
let invalid_model_params = CommonParams {
model: "gpt-4".to_string(), ..Default::default()
};
let messages = vec![crate::types::ChatMessage {
role: MessageRole::User,
content: MessageContent::Text("Hello".to_string()),
metadata: Default::default(),
tool_calls: None,
tool_call_id: None,
}];
let invalid_builder =
GeminiRequestBuilder::new(invalid_model_params, GeminiParams::default());
let result = invalid_builder.build_chat_request(messages, None, false);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("should start with 'gemini-'")
);
}
}