use super::error::GeminiResponseError;
use super::types::caching::{CachedContent, CachedContentList, CachedContentUpdate};
use super::types::request::*;
use super::types::response::*;
use super::types::sessions::Session;
use reqwest::Client;
use serde_json::{Value, json};
use std::time::Duration;
const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
#[derive(Clone, Default, Debug)]
pub struct Gemini {
client: Client,
api_key: String,
model: String,
sys_prompt: Option<SystemInstruction>,
generation_config: Option<Value>,
safety_settings: Option<Vec<SafetySetting>>,
tools: Option<Vec<Tool>>,
tool_config: Option<ToolConfig>,
cached_content: Option<String>,
}
impl Gemini {
pub fn new(
api_key: impl Into<String>,
model: impl Into<String>,
sys_prompt: Option<SystemInstruction>,
) -> Self {
Self {
client: Client::builder()
.timeout(Duration::from_secs(60))
.build()
.unwrap(),
api_key: api_key.into(),
model: model.into(),
sys_prompt,
generation_config: None,
safety_settings: None,
tools: None,
tool_config: None,
cached_content: None,
}
}
#[deprecated]
pub fn new_with_timeout(
api_key: impl Into<String>,
model: impl Into<String>,
sys_prompt: Option<SystemInstruction>,
api_timeout: Duration,
) -> Self {
Self {
client: Client::builder().timeout(api_timeout).build().unwrap(),
api_key: api_key.into(),
model: model.into(),
sys_prompt,
generation_config: None,
safety_settings: None,
tools: None,
tool_config: None,
cached_content: None,
}
}
pub fn new_with_client(
api_key: impl Into<String>,
model: impl Into<String>,
sys_prompt: Option<SystemInstruction>,
client: Client,
) -> Self {
Self {
client,
api_key: api_key.into(),
model: model.into(),
sys_prompt,
generation_config: None,
safety_settings: None,
tools: None,
tool_config: None,
cached_content: None,
}
}
pub fn set_generation_config(&mut self) -> &mut Value {
if let None = self.generation_config {
self.generation_config = Some(json!({}));
}
self.generation_config.as_mut().unwrap()
}
pub fn set_tool_config(mut self, config: ToolConfig) -> Self {
self.tool_config = Some(config);
self
}
pub fn set_thinking_config(mut self, config: ThinkingConfig) -> Self {
if let Value::Object(map) = self.set_generation_config() {
if let Ok(thinking_value) = serde_json::to_value(config) {
map.insert("thinking_config".to_string(), thinking_value);
}
}
self
}
pub fn set_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn set_sys_prompt(mut self, sys_prompt: Option<SystemInstruction>) -> Self {
self.sys_prompt = sys_prompt;
self
}
pub fn set_safety_settings(mut self, settings: Option<Vec<SafetySetting>>) -> Self {
self.safety_settings = settings;
self
}
pub fn set_api_key(mut self, api_key: impl Into<String>) -> Self {
self.api_key = api_key.into();
self
}
pub fn set_json_mode(mut self, schema: Value) -> Self {
let config = self.set_generation_config();
config["response_mime_type"] = "application/json".into();
config["response_schema"] = schema.into();
self
}
pub fn remove_json_mode(mut self) -> Self {
if let Some(ref mut generation_config) = self.generation_config {
generation_config["response_schema"] = None::<Value>.into();
generation_config["response_mime_type"] = None::<Value>.into();
}
self
}
pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
self.tools = Some(tools);
self
}
pub fn remove_tools(mut self) -> Self {
self.tools = None;
self
}
pub fn set_cached_content(mut self, name: impl Into<String>) -> Self {
self.cached_content = Some(name.into());
self
}
pub fn remove_cached_content(mut self) -> Self {
self.cached_content = None;
self
}
pub async fn create_cache(
&self,
cached_content: &CachedContent,
) -> Result<CachedContent, GeminiResponseError> {
let req_url = format!(
"https://generativelanguage.googleapis.com/v1beta/cachedContents?key={}",
self.api_key
);
let response = self
.client
.post(req_url)
.json(cached_content)
.send()
.await
.map_err(|e| GeminiResponseError::ReqwestError(e))?;
if !response.status().is_success() {
let error = response
.json()
.await
.map_err(|e| GeminiResponseError::ReqwestError(e))?;
return Err(GeminiResponseError::StatusNotOk(error));
}
let cached_content: CachedContent = response
.json()
.await
.map_err(|e| GeminiResponseError::ReqwestError(e))?;
Ok(cached_content)
}
pub async fn list_caches(&self) -> Result<CachedContentList, GeminiResponseError> {
let req_url = format!(
"https://generativelanguage.googleapis.com/v1beta/cachedContents?key={}",
self.api_key
);
let response = self
.client
.get(req_url)
.send()
.await
.map_err(|e| GeminiResponseError::ReqwestError(e))?;
if !response.status().is_success() {
let error = response
.json()
.await
.map_err(|e| GeminiResponseError::ReqwestError(e))?;
return Err(GeminiResponseError::StatusNotOk(error));
}
let list: CachedContentList = response
.json()
.await
.map_err(|e| GeminiResponseError::ReqwestError(e))?;
Ok(list)
}
pub async fn get_cache(&self, name: &str) -> Result<CachedContent, GeminiResponseError> {
let req_url = format!(
"https://generativelanguage.googleapis.com/v1beta/{}?key={}",
name, self.api_key
);
let response = self
.client
.get(req_url)
.send()
.await
.map_err(|e| GeminiResponseError::ReqwestError(e))?;
if !response.status().is_success() {
let error = response
.json()
.await
.map_err(|e| GeminiResponseError::ReqwestError(e))?;
return Err(GeminiResponseError::StatusNotOk(error));
}
let cached_content: CachedContent = response
.json()
.await
.map_err(|e| GeminiResponseError::ReqwestError(e))?;
Ok(cached_content)
}
pub async fn update_cache(
&self,
name: &str,
update: &CachedContentUpdate,
) -> Result<CachedContent, GeminiResponseError> {
let req_url = format!(
"https://generativelanguage.googleapis.com/v1beta/{}?key={}",
name, self.api_key
);
let response = self
.client
.patch(req_url)
.json(update)
.send()
.await
.map_err(|e| GeminiResponseError::ReqwestError(e))?;
if !response.status().is_success() {
let error = response
.json()
.await
.map_err(|e| GeminiResponseError::ReqwestError(e))?;
return Err(GeminiResponseError::StatusNotOk(error));
}
let cached_content: CachedContent = response
.json()
.await
.map_err(|e| GeminiResponseError::ReqwestError(e))?;
Ok(cached_content)
}
pub async fn delete_cache(&self, name: &str) -> Result<(), GeminiResponseError> {
let req_url = format!(
"https://generativelanguage.googleapis.com/v1beta/{}?key={}",
name, self.api_key
);
let response = self
.client
.delete(req_url)
.send()
.await
.map_err(|e| GeminiResponseError::ReqwestError(e))?;
if !response.status().is_success() {
let error = response
.json()
.await
.map_err(|e| GeminiResponseError::ReqwestError(e))?;
return Err(GeminiResponseError::StatusNotOk(error));
}
Ok(())
}
pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
if session
.get_last_chat()
.is_some_and(|chat| *chat.role() == Role::Model)
{
return Err(GeminiResponseError::NothingToRespond);
}
let req_url = format!(
"{BASE_URL}/{}:generateContent?key={}",
self.model, self.api_key
);
let response = self
.client
.post(req_url)
.json(&GeminiRequestBody::new(
self.sys_prompt.as_ref(),
self.tools.as_deref(),
&session.get_history().as_slice(),
self.generation_config.as_ref(),
self.safety_settings.as_deref(),
self.tool_config.as_ref(),
self.cached_content.clone(),
))
.send()
.await
.map_err(|e| GeminiResponseError::ReqwestError(e))?;
if !response.status().is_success() {
let error = response
.json()
.await
.map_err(|e| GeminiResponseError::ReqwestError(e))?;
return Err(GeminiResponseError::StatusNotOk(error));
}
let reply = GeminiResponse::new(response)
.await
.map_err(|e| GeminiResponseError::ReqwestError(e))?;
session.update(&reply);
Ok(reply)
}
pub async fn ask_as_stream_with_extractor<F, StreamType>(
&self,
session: Session,
data_extractor: F,
) -> Result<ResponseStream<F, StreamType>, (Session, GeminiResponseError)>
where
F: FnMut(&Session, GeminiResponse) -> StreamType,
{
if session
.get_last_chat()
.is_some_and(|chat| *chat.role() == Role::Model)
{
return Err((session, GeminiResponseError::NothingToRespond));
}
let req_url = format!(
"{BASE_URL}/{}:streamGenerateContent?alt=sse&key={}",
self.model, self.api_key
);
let request = self
.client
.post(req_url)
.json(&GeminiRequestBody::new(
self.sys_prompt.as_ref(),
self.tools.as_deref(),
session.get_history().as_slice(),
self.generation_config.as_ref(),
self.safety_settings.as_deref(),
self.tool_config.as_ref(),
self.cached_content.clone(),
))
.send()
.await;
let response = match request {
Ok(response) => response,
Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
};
if !response.status().is_success() {
let error = match response.json().await {
Ok(response) => response,
Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
};
return Err((session, GeminiResponseError::StatusNotOk(error)));
}
Ok(ResponseStream::new(
Box::new(response.bytes_stream()),
session,
data_extractor,
))
}
pub async fn ask_as_stream(
&self,
session: Session,
) -> Result<GeminiResponseStream, (Session, GeminiResponseError)> {
self.ask_as_stream_with_extractor(
session,
(|_, gemini_response| gemini_response)
as fn(&Session, GeminiResponse) -> GeminiResponse,
)
.await
}
}