use base64::{Engine, engine::general_purpose::STANDARD};
use core::fmt;
use derive_new::new;
use getset::Getters;
use mime::{FromStrError, Mime};
#[cfg(feature = "reqwest")]
use reqwest::header::{HeaderMap, ToStrError};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_json::Value;
use std::str::FromStr;
mod chat;
pub use chat::Chat;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub enum Role {
User,
Model,
Function,
}
fn deserialize_mime<'de, D>(deserializer: D) -> Result<Mime, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Mime::from_str(&s).map_err(serde::de::Error::custom)
}
fn serialize_mime<S>(mime: &Mime, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(mime.as_ref())
}
#[derive(Serialize, Deserialize, Clone, Getters, Debug)]
pub struct InlineData {
#[get = "pub"]
#[serde(
deserialize_with = "deserialize_mime",
serialize_with = "serialize_mime"
)]
mime_type: Mime,
#[get = "pub"]
data: String,
}
#[derive(thiserror::Error, Debug)]
pub enum InlineDataError {
#[error(transparent)]
#[cfg(feature = "reqwest")]
RequestFailed(reqwest::Error),
#[error("Checker function returned false")]
CheckerFalse,
#[error("Content-Type was missing in response headers")]
ContentTypeMissing,
#[error(transparent)]
#[cfg(feature = "reqwest")]
ContentTypeParseFailed(ToStrError),
#[error("Failed to parse mime type: {0}")]
InvalidMimeType(FromStrError),
}
impl InlineData {
pub fn new(mime_type: Mime, data: String) -> Self {
Self { mime_type, data }
}
#[cfg(feature = "reqwest")]
pub async fn from_url_with_check<F: FnOnce(&HeaderMap) -> bool>(
url: &str,
checker: F,
) -> Result<Self, InlineDataError> {
let response = reqwest::get(url)
.await
.map_err(|e| InlineDataError::RequestFailed(e))?;
if !checker(response.headers()) {
return Err(InlineDataError::CheckerFalse);
}
let mime_type = response
.headers()
.get("Content-Type")
.ok_or(InlineDataError::ContentTypeMissing)?
.to_str()
.map_err(|e| InlineDataError::ContentTypeParseFailed(e))?;
let mime_type =
Mime::from_str(mime_type).map_err(|e| InlineDataError::InvalidMimeType(e))?;
let body = response
.bytes()
.await
.map_err(|e| InlineDataError::RequestFailed(e))?;
Ok(InlineData::new(mime_type, STANDARD.encode(body)))
}
#[cfg(feature = "reqwest")]
pub async fn from_url(url: &str) -> Result<Self, InlineDataError> {
Self::from_url_with_check(url, |_| true).await
}
#[cfg(all(feature = "tokio", not(target_arch = "wasm32")))]
pub async fn from_path(file_path: &str, mime_type: Mime) -> Result<Self, std::io::Error> {
let data = tokio::fs::read(file_path).await?;
Ok(InlineData::new(mime_type, STANDARD.encode(data)))
}
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Language {
LanguageUnspecified,
Python,
}
#[derive(Serialize, Deserialize, Clone, new, Getters, Debug)]
pub struct ExecutableCode {
#[get = "pub"]
language: Language,
#[get = "pub"]
code: String,
}
#[derive(Serialize, Deserialize, Clone, new, Getters, Debug)]
pub struct FunctionCall {
#[get = "pub"]
name: String,
#[serde(skip_serializing_if = "Option::is_none")]
#[get = "pub"]
args: Option<Value>,
}
#[derive(Serialize, Deserialize, Clone, new, Getters, Debug)]
pub struct FunctionResponse {
#[get = "pub"]
name: String,
#[get = "pub"]
response: Value,
}
#[derive(Serialize, Deserialize, Clone, new, Getters, Debug)]
pub struct FileData {
#[serde(skip_serializing_if = "Option::is_none", alias = "mimeType")]
#[get = "pub"]
mime_type: Option<String>,
#[serde(alias = "fileUri")]
#[get = "pub"]
file_uri: String,
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Outcome {
OutcomeUnspecified,
OutcomeOk,
OutcomeFailed,
OutcomeDeadlineExceeded,
}
#[derive(Serialize, Deserialize, Clone, new, Getters, Debug)]
pub struct CodeExecutionResult {
#[get = "pub"]
outcome: Outcome,
#[serde(skip_serializing_if = "Option::is_none")]
#[get = "pub"]
output: Option<String>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "camelCase")]
pub enum PartType {
Text(String),
InlineData(InlineData),
ExecutableCode(ExecutableCode),
CodeExecutionResult(CodeExecutionResult),
FunctionCall(FunctionCall),
FunctionResponse(FunctionResponse),
FileData(FileData),
}
#[derive(Serialize, Deserialize, Clone, Getters)]
#[serde(rename_all = "camelCase")]
pub struct Part {
#[get = "pub"]
#[serde(flatten)]
data: PartType,
#[get = "pub"]
#[serde(skip_serializing_if = "Option::is_none")]
thought: Option<bool>,
#[get = "pub"]
#[serde(skip_serializing_if = "Option::is_none")]
thought_signature: Option<String>,
}
impl fmt::Debug for Part {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Part")
.field("data", &self.data)
.field("thought", &self.thought)
.field(
"thought_signature",
&self
.thought_signature
.as_ref()
.map(|s| format!("{}..truncated", &s[..3])),
)
.finish()
}
}
impl Part {
pub fn is_thought(&self) -> bool {
self.thought == Some(true)
}
pub fn new(data: PartType) -> Self {
Self {
data,
thought: None,
thought_signature: None,
}
}
pub fn data_mut(&mut self) -> &mut PartType {
&mut self.data
}
}
impl From<PartType> for Part {
fn from(value: PartType) -> Self {
Self::new(value)
}
}
impl From<String> for Part {
fn from(value: String) -> Self {
Self::new(PartType::Text(value))
}
}
impl From<&str> for Part {
fn from(value: &str) -> Self {
Self::new(PartType::Text(value.into()))
}
}
impl From<InlineData> for Part {
fn from(value: InlineData) -> Self {
Self::new(PartType::InlineData(value))
}
}
impl From<ExecutableCode> for Part {
fn from(value: ExecutableCode) -> Self {
Self::new(PartType::ExecutableCode(value))
}
}
impl From<CodeExecutionResult> for Part {
fn from(value: CodeExecutionResult) -> Self {
Self::new(PartType::CodeExecutionResult(value))
}
}
impl From<FunctionCall> for Part {
fn from(value: FunctionCall) -> Self {
Self::new(PartType::FunctionCall(value))
}
}
impl From<FunctionResponse> for Part {
fn from(value: FunctionResponse) -> Self {
Self::new(PartType::FunctionResponse(value))
}
}
impl From<FileData> for Part {
fn from(value: FileData) -> Self {
Self::new(PartType::FileData(value))
}
}
#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq, Eq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum ThinkingLevel {
#[default]
ThinkingLevelUnspecified,
Minimal,
Low,
Medium,
High,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "camelCase")]
pub enum ThinkingControl {
ThinkingLevel(ThinkingLevel),
ThinkingBudget(i32),
}
impl From<ThinkingLevel> for ThinkingControl {
fn from(value: ThinkingLevel) -> Self {
Self::ThinkingLevel(value)
}
}
impl From<u32> for ThinkingControl {
fn from(value: u32) -> Self {
Self::ThinkingBudget(value as i32)
}
}
#[derive(Serialize, Deserialize, Clone, Getters, Debug)]
#[serde(rename_all = "camelCase")]
pub struct ThinkingConfig {
#[get = "pub"]
include_thoughts: bool,
#[serde(flatten, skip_serializing_if = "Option::is_none")]
control: Option<ThinkingControl>,
}
impl ThinkingConfig {
pub fn control(&self) -> Option<&ThinkingControl> {
self.control.as_ref()
}
pub fn new(include_thoughts: bool, control: impl Into<ThinkingControl>) -> Self {
Self {
include_thoughts,
control: Some(control.into()),
}
}
pub fn new_disable_thinking() -> Self {
Self::new(false, 0)
}
pub fn new_dynamic_thinking(include_thoughts: bool) -> Self {
Self {
include_thoughts,
control: Some(ThinkingControl::ThinkingBudget(-1)),
}
}
}
impl Default for ThinkingConfig {
fn default() -> Self {
Self {
include_thoughts: true,
control: None,
}
}
}
#[derive(Serialize, Deserialize, Getters, new, Debug, Clone)]
pub struct SystemInstruction {
#[get = "pub"]
parts: Vec<Part>,
}
impl From<String> for SystemInstruction {
fn from(prompt: String) -> Self {
Self {
parts: vec![prompt.into()],
}
}
}
impl<'a> From<&'a str> for SystemInstruction {
fn from(prompt: &'a str) -> Self {
Self {
parts: vec![prompt.into()],
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum HarmCategory {
HarmCategoryHarassment,
HarmCategoryHateSpeech,
HarmCategorySexuallyExplicit,
HarmCategoryDangerousContent,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum BlockThreshold {
BlockNone,
BlockOnlyHigh,
BlockMediumAndAbove,
BlockLowAndAbove,
}
#[derive(Serialize, Deserialize, new, Getters, Debug, Clone)]
pub struct SafetySetting {
#[get = "pub"]
category: HarmCategory,
#[get = "pub"]
threshold: BlockThreshold,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
pub struct ToolConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub function_calling_config: Option<FunctionCallingConfig>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
pub struct FunctionCallingConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub mode: Option<FunctionCallingMode>,
#[serde(skip_serializing_if = "Option::is_none")]
pub allowed_function_names: Option<Vec<String>>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum FunctionCallingMode {
Auto,
Any,
None,
}
#[derive(Serialize, new)]
#[serde(rename_all = "camelCase")]
pub struct GeminiRequestBody<'a> {
system_instruction: Option<&'a SystemInstruction>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<&'a [Tool]>,
contents: &'a [&'a Chat],
#[serde(skip_serializing_if = "Option::is_none")]
generation_config: Option<&'a Value>,
#[serde(skip_serializing_if = "Option::is_none")]
safety_settings: Option<&'a [SafetySetting]>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_config: Option<&'a ToolConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
cached_content: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
pub enum Tool {
GoogleSearch(Value),
FunctionDeclarations(Vec<Value>),
CodeExecution(Value),
UrlContext(Value),
}
pub fn concatenate_parts(updating: &mut Vec<Part>, updator: &[Part]) {
for updator_part in updator {
if let Some(updating_last) = updating.last_mut() {
match &updator_part.data {
PartType::Text(updator_text) => {
if updating_last.is_thought() == updator_part.is_thought() {
if let PartType::Text(ref mut updating_text) = updating_last.data {
updating_text.push_str(&updator_text);
continue;
}
}
}
_ => {}
}
}
updating.push(updator_part.clone());
}
}