use std::path::Path;
use base64::prelude::*;
use enum_iterator::all;
use file_format::FileFormat;
use serde_json::from_str;
use thiserror::Error;
use crate::google::{
GoogleModel,
common::{Blob, Content, HarmCategory, Modality, Part, Role},
request::{GenerateContentRequest, GenerationConfig, HarmBlockThreshold, SafetySettings},
response::ContentResponse,
};
const URL_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models";
const URL_EXTENSION: &str = ":streamGenerateContent";
#[derive(Error, Debug)]
pub enum Error {
#[error(transparent)]
SerdeJson(#[from] serde_json::Error),
#[error(transparent)]
Reqwest(#[from] reqwest::Error),
#[error("{0}")]
Request(String),
#[error(transparent)]
Io(#[from] std::io::Error),
}
#[derive(Debug, Clone)]
pub struct Client {
client: reqwest::Client,
model: GoogleModel,
key: String,
request: GenerateContentRequest,
}
#[derive(Debug)]
pub struct Responses(Vec<ContentResponse>);
impl Responses {
pub fn extract_text(&self) -> Option<String> {
let mut text = String::new();
for content in &self.0 {
for candidate in &content.candidates {
for part in &candidate.content.parts {
if let Part::Text(txt) = part {
text += txt
}
}
}
}
if text.is_empty() { None } else { Some(text) }
}
pub fn extract_images(&self) -> Vec<(String, String)> {
let mut images = Vec::new();
for content in &self.0 {
for candidate in &content.candidates {
for part in &candidate.content.parts {
if let Part::InlineData(blob) = part {
images.push((blob.mime_type.clone(), blob.data.clone()));
}
}
}
}
images
}
}
impl Client {
pub fn new(model: &GoogleModel, key: &str) -> Self {
Client {
client: reqwest::Client::new(),
model: model.clone(),
key: key.to_string(),
request: GenerateContentRequest {
system_instruction: None,
contents: vec![],
tools: vec![],
tool_config: None,
safety_settings: vec![],
generation_config: None,
cached_content: None,
},
}
}
pub async fn with_defaults(&mut self) -> Self {
let safety_settings = all::<HarmCategory>()
.collect::<Vec<_>>()
.into_iter()
.map(|cat| SafetySettings {
category: cat,
threshold: HarmBlockThreshold::default(),
})
.collect();
let generation_config = match &self.model {
GoogleModel::Gemini20FlashExpImageGen(_) => GenerationConfig {
response_modalities: vec![Modality::Text, Modality::Image],
..Default::default()
},
GoogleModel::Gemini20Flash(_) => GenerationConfig {
response_modalities: vec![Modality::Text],
..Default::default()
},
};
self.request.safety_settings = safety_settings;
self.request.generation_config = Some(generation_config);
self.to_owned()
}
pub async fn with_safety(&mut self, safety_settings: &[SafetySettings]) -> Self {
self.request.safety_settings = safety_settings.to_vec();
self.to_owned()
}
pub fn with_instruction(&mut self, system_instruction: &str) -> &mut Self {
match self.model {
GoogleModel::Gemini20FlashExpImageGen(_) => {
let mut contents = vec![Content {
parts: vec![Part::Text(system_instruction.to_string())],
role: Role::User,
}];
contents.extend(self.request.contents.clone());
self.request.contents = contents;
}
GoogleModel::Gemini20Flash(_) => {
self.request.system_instruction = Some(Content {
role: Role::User,
parts: vec![Part::Text(system_instruction.to_string())],
});
}
}
self
}
fn merge_response(&mut self, responses: &[ContentResponse]) -> Result<Responses, Error> {
let mut success = Vec::new();
for response in responses {
if let Some(error) = &response.error {
return Err(Error::Request(serde_json::to_string(error)?));
} else {
for candidate in &response.candidates {
if !candidate.content.parts.is_empty() {
self.request.contents.push(candidate.content.clone());
}
}
success.push(response.clone());
}
}
Ok(Responses(success))
}
async fn post(&mut self) -> Result<Responses, Error> {
let responses = from_str::<Vec<ContentResponse>>(
&self
.client
.post(self.url())
.query(&[("key", &self.key)])
.json(&self.request)
.send()
.await?
.text()
.await?,
)?;
self.merge_response(&responses)
}
pub async fn send_text(&mut self, text: &str) -> Result<Responses, Error> {
self.request.contents.push(Content {
parts: vec![Part::Text(text.to_string())],
role: Role::User,
});
self.post().await
}
pub async fn send_image(
&mut self,
message: Option<String>,
img: &Path,
) -> Result<Responses, Error> {
let format = FileFormat::from_file(img)?;
let data = BASE64_URL_SAFE.encode(&tokio::fs::read(img).await?);
self.send_image_bytes(message, format.media_type(), &data)
.await
}
pub async fn send_image_bytes(
&mut self,
message: Option<String>,
mime_type: &str,
data: &str,
) -> Result<Responses, Error> {
let mut parts = Vec::new();
if let Some(message) = message {
parts.push(Part::Text(message.to_string()));
}
parts.push(Part::InlineData(Blob {
mime_type: mime_type.to_string(),
data: data.to_string(),
}));
self.request.contents.push(Content {
parts,
role: Role::User,
});
self.post().await
}
fn url(&self) -> String {
format!("{URL_BASE}/{}{URL_EXTENSION}", self.model.name())
}
pub fn history(&self) -> &[Content] {
&self.request.contents
}
}