use crate::types::{GeminiConfig, SystemInstructions};
use serde::{Deserialize, Serialize};
use std::fmt::Write;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatEntry {
pub role: String,
pub content: String,
}
#[derive(Debug, Clone)]
pub struct GeminiRequest {
pub url: String,
pub scheme: String,
pub authority: String,
pub path: String,
pub headers: Vec<(String, Vec<u8>)>,
pub body: Vec<u8>,
}
#[derive(Debug, Clone)]
pub struct GeminiDirectClient {
model: String,
system_instruction: Option<String>,
temperature: f64,
max_output_tokens: u32,
}
impl GeminiDirectClient {
pub fn new(config: &GeminiConfig) -> Self {
Self {
model: config.models.default.name.clone(),
system_instruction: None,
temperature: 0.7,
max_output_tokens: 2048,
}
}
pub fn with_model(model: impl Into<String>) -> Self {
Self {
model: model.into(),
system_instruction: None,
temperature: 0.7,
max_output_tokens: 2048,
}
}
pub fn with_system_instruction(mut self, instruction: String) -> Self {
self.system_instruction = Some(instruction);
self
}
pub fn with_system_instructions(mut self, instructions: &SystemInstructions) -> Self {
match instructions {
SystemInstructions::Custom(custom) => {
self.system_instruction = Some(custom.text.clone());
}
SystemInstructions::Appended(appended) => {
let mut text = String::new();
if let Some(ref identity) = appended.custom_identity {
text.push_str(identity);
text.push('\n');
}
for section in &appended.appended_sections {
let _ = write!(text, "## {}\n{}\n\n", section.title, section.content);
}
self.system_instruction = Some(text);
}
}
self
}
pub const fn with_temperature(mut self, temperature: f64) -> Self {
self.temperature = temperature;
self
}
pub const fn with_max_output_tokens(mut self, max_tokens: u32) -> Self {
self.max_output_tokens = max_tokens;
self
}
pub fn model(&self) -> &str {
&self.model
}
pub fn build_request(
&self,
api_key: &str,
message: &str,
history: &[ChatEntry],
) -> Result<GeminiRequest, anyhow::Error> {
let authority = "generativelanguage.googleapis.com";
let path = format!("/v1beta/models/{}:generateContent", self.model);
let url = format!("https://{authority}{path}");
let mut contents = Vec::new();
for entry in history {
contents.push(serde_json::json!({
"role": entry.role,
"parts": [{ "text": entry.content }]
}));
}
contents.push(serde_json::json!({
"role": "user",
"parts": [{ "text": message }]
}));
let mut request_body = serde_json::json!({
"contents": contents,
"generationConfig": {
"temperature": self.temperature,
"maxOutputTokens": self.max_output_tokens
}
});
if let Some(ref instruction) = self.system_instruction {
request_body["systemInstruction"] = serde_json::json!({
"parts": [{ "text": instruction }]
});
}
let body = serde_json::to_vec(&request_body)?;
let headers = vec![
("content-type".to_string(), b"application/json".to_vec()),
("x-goog-api-key".to_string(), api_key.as_bytes().to_vec()),
];
Ok(GeminiRequest {
url,
scheme: "https".to_string(),
authority: authority.to_string(),
path,
headers,
body,
})
}
pub fn parse_response(response_body: &[u8]) -> Result<String, anyhow::Error> {
let json: serde_json::Value = serde_json::from_slice(response_body)?;
let text = json["candidates"][0]["content"]["parts"][0]["text"]
.as_str()
.ok_or_else(|| {
anyhow::anyhow!(
"Unexpected Gemini response structure: {}",
String::from_utf8_lossy(&response_body[..response_body.len().min(500)])
)
})?
.to_string();
Ok(text)
}
pub fn parse_error(status: u16, response_body: &[u8]) -> String {
let body_text = String::from_utf8_lossy(response_body);
format!("Gemini API error ({}): {}", status, body_text)
}
pub fn entries_to_contents(history: &[ChatEntry], new_message: &str) -> serde_json::Value {
let mut contents = Vec::new();
for entry in history {
contents.push(serde_json::json!({
"role": entry.role,
"parts": [{ "text": entry.content }]
}));
}
contents.push(serde_json::json!({
"role": "user",
"parts": [{ "text": new_message }]
}));
serde_json::Value::Array(contents)
}
}