#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub enum Prompt {
Single(String),
Multi(std::collections::VecDeque<String>),
}
impl std::fmt::Display for Prompt {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Prompt::Single(s) => write!(f, "{}", s),
Prompt::Multi(m) => write!(f, "{:?}", m),
}
}
}
#[cfg(feature = "openai")]
impl std::str::FromStr for Prompt {
type Err = serde_json::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if let Ok(prompts) = serde_json::from_str::<Vec<String>>(s) {
Ok(Prompt::Multi(
prompts
.into_iter()
.collect::<std::collections::VecDeque<String>>(),
))
} else {
Ok(Prompt::Single(s.to_string()))
}
}
}
impl Prompt {
pub fn new_single(prompt: &str) -> Self {
Prompt::Single(prompt.into())
}
pub fn new_multiple(prompt: std::collections::VecDeque<String>) -> Self {
Prompt::Multi(prompt)
}
pub fn next(&mut self) -> Option<String> {
match self {
Prompt::Single(prompt) => {
if prompt.is_empty() {
None
} else {
Some(prompt.drain(..).collect())
}
}
Prompt::Multi(prompt) => prompt.pop_front(),
}
}
}
impl Default for Prompt {
fn default() -> Self {
Prompt::Single(Default::default())
}
}
#[derive(Debug, Default, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct GPTConfigs {
pub prompt: Prompt,
pub model: String,
pub max_tokens: u16,
pub temperature: Option<f32>,
pub user: Option<String>,
pub top_p: Option<f32>,
pub prompt_url_map:
Option<hashbrown::HashMap<case_insensitive_string::CaseInsensitiveString, Self>>,
#[cfg_attr(feature = "serde", serde(default))]
pub extra_ai_data: bool,
#[cfg_attr(feature = "serde", serde(default))]
pub paths_map: bool,
#[cfg_attr(feature = "serde", serde(default))]
pub screenshot: bool,
#[cfg_attr(feature = "serde", serde(default))]
pub api_key: Option<String>,
#[cfg_attr(
feature = "serde",
serde(default),
serde(skip_serializing, skip_deserializing)
)]
pub cache: Option<AICache>,
}
#[derive(Debug, Default, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct OpenAIUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
pub cached: bool,
}
pub type OpenAIReturn = (String, OpenAIUsage);
#[cfg(feature = "cache_openai")]
pub type AICache = moka::future::Cache<u64, OpenAIReturn>;
#[cfg(not(feature = "cache_openai"))]
pub type AICache = String;
impl GPTConfigs {
pub fn new(model: &str, prompt: &str, max_tokens: u16) -> GPTConfigs {
Self {
model: model.into(),
prompt: Prompt::Single(prompt.into()),
max_tokens,
..Default::default()
}
}
pub fn new_cache(
model: &str,
prompt: &str,
max_tokens: u16,
cache: Option<AICache>,
) -> GPTConfigs {
Self {
model: model.into(),
prompt: Prompt::Single(prompt.into()),
max_tokens,
cache,
..Default::default()
}
}
pub fn new_multi<I, S>(model: &str, prompt: I, max_tokens: u16) -> GPTConfigs
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
Self {
model: model.into(),
prompt: Prompt::Multi(prompt.into_iter().map(|s| s.as_ref().to_string()).collect()),
max_tokens,
..Default::default()
}
}
pub fn new_multi_cache<I, S>(
model: &str,
prompt: I,
max_tokens: u16,
cache: Option<AICache>,
) -> GPTConfigs
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
Self {
model: model.into(),
prompt: Prompt::Multi(prompt.into_iter().map(|s| s.as_ref().to_string()).collect()),
max_tokens,
cache,
..Default::default()
}
}
pub fn set_extra(&mut self, extra_ai_data: bool) -> &mut Self {
self.extra_ai_data = extra_ai_data;
self
}
}
#[cfg(feature = "serde")]
mod prompt_deserializer {
use super::Prompt;
use serde::{
de::{self, SeqAccess, Visitor},
Deserialize, Deserializer,
};
use std::collections::VecDeque;
use std::fmt;
struct PromptVisitor;
impl<'de> Visitor<'de> for PromptVisitor {
type Value = Prompt;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string or an array of strings")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(Prompt::Single(value.to_owned()))
}
fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
where
S: SeqAccess<'de>,
{
let mut strings = VecDeque::new();
while let Some(value) = seq.next_element()? {
strings.push_back(value);
}
Ok(Prompt::Multi(strings))
}
}
impl<'de> Deserialize<'de> for Prompt {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(PromptVisitor)
}
}
}
#[test]
#[cfg(feature = "openai")]
fn deserialize_gpt_configs() {
let gpt_configs_json = "{\"prompt\":\"change background blue\",\"model\":\"gpt-3.5-turbo-16k\",\"max_tokens\":256,\"temperature\":0.54,\"top_p\":0.17}";
let configs = match serde_json::from_str::<GPTConfigs>(&gpt_configs_json) {
Ok(e) => Some(e),
_ => None,
};
assert!(configs.is_some())
}