use crate::features::openai_common::{Prompt, ResponseFormatJsonSchema};
pub const DEFAULT_GEMINI_MODEL: &str = "gemini-flash-latest";
#[derive(Debug, Default, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
all(
not(feature = "regex"),
not(feature = "openai"),
not(feature = "cache_openai"),
not(feature = "gemini"),
not(feature = "cache_gemini")
),
derive(PartialEq)
)]
pub struct GeminiConfigs {
pub prompt: Prompt,
#[cfg_attr(feature = "serde", serde(default))]
pub model: String,
#[cfg_attr(feature = "serde", serde(default))]
pub max_tokens: u16,
#[cfg_attr(feature = "serde", serde(default))]
pub temperature: Option<f32>,
#[cfg_attr(feature = "serde", serde(default))]
pub top_p: Option<f32>,
#[cfg_attr(feature = "serde", serde(default))]
pub top_k: Option<i32>,
pub prompt_url_map:
Option<Box<hashbrown::HashMap<case_insensitive_string::CaseInsensitiveString, Box<Self>>>>,
#[cfg_attr(feature = "serde", serde(default))]
pub extra_ai_data: bool,
#[cfg_attr(feature = "serde", serde(default))]
pub paths_map: bool,
#[cfg_attr(feature = "serde", serde(default))]
pub screenshot: bool,
#[cfg_attr(feature = "serde", serde(default))]
pub api_key: Option<String>,
#[cfg_attr(
feature = "serde",
serde(default),
serde(skip_serializing, skip_deserializing)
)]
pub cache: Option<GeminiCache>,
#[cfg_attr(feature = "serde", serde(default))]
pub json_schema: Option<ResponseFormatJsonSchema>,
}
#[derive(Debug, Default, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct GeminiUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
pub cached: bool,
}
#[derive(Debug, Default, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct GeminiReturn {
pub response: String,
pub usage: GeminiUsage,
pub error: Option<String>,
}
#[cfg(feature = "cache_gemini")]
pub type GeminiCache = moka::future::Cache<u64, GeminiReturn>;
#[cfg(not(feature = "cache_gemini"))]
pub type GeminiCache = String;
impl GeminiConfigs {
pub fn new(model: &str, prompt: &str, max_tokens: u16) -> GeminiConfigs {
Self {
model: model.into(),
prompt: Prompt::Single(prompt.into()),
max_tokens,
..Default::default()
}
}
pub fn new_default(prompt: &str, max_tokens: u16) -> GeminiConfigs {
Self::new(DEFAULT_GEMINI_MODEL, prompt, max_tokens)
}
pub fn new_cache(
model: &str,
prompt: &str,
max_tokens: u16,
cache: Option<GeminiCache>,
) -> GeminiConfigs {
Self {
model: model.into(),
prompt: Prompt::Single(prompt.into()),
max_tokens,
cache,
..Default::default()
}
}
pub fn new_multi<I, S>(model: &str, prompt: I, max_tokens: u16) -> GeminiConfigs
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
Self {
model: model.into(),
prompt: Prompt::Multi(prompt.into_iter().map(|s| s.as_ref().to_string()).collect()),
max_tokens,
..Default::default()
}
}
pub fn new_multi_cache<I, S>(
model: &str,
prompt: I,
max_tokens: u16,
cache: Option<GeminiCache>,
) -> GeminiConfigs
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
Self {
model: model.into(),
prompt: Prompt::Multi(prompt.into_iter().map(|s| s.as_ref().to_string()).collect()),
max_tokens,
cache,
..Default::default()
}
}
pub fn set_extra(&mut self, extra_ai_data: bool) -> &mut Self {
self.extra_ai_data = extra_ai_data;
self
}
pub fn set_top_k(&mut self, top_k: Option<i32>) -> &mut Self {
self.top_k = top_k;
self
}
}
#[test]
#[cfg(feature = "gemini")]
fn deserialize_gemini_configs() {
let gemini_configs_json = r#"{"prompt":"change background blue","model":"gemini-flash-latest","max_tokens":256,"temperature":0.54,"top_p":0.17}"#;
let configs = serde_json::from_str::<GeminiConfigs>(gemini_configs_json).ok();
assert!(configs.is_some())
}