use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use serde_json::Value;
use log::warn;
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PromptMessage {
pub role: String,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Prompt {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tags: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub annotations: Option<HashMap<String, Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, Value>>,
}
type PromptFunction = Arc<dyn Fn(Option<HashMap<String, Value>>) -> Result<Vec<PromptMessage>, String> + Send + Sync>;
#[derive(Clone)]
pub struct FunctionPrompt {
pub function: Option<PromptFunction>,
pub name: String,
pub description: String,
pub tags: Vec<String>,
pub arguments: Option<HashMap<String, String>>,
pub meta: Option<Value>,
}
impl FunctionPrompt {
#[allow(clippy::too_many_arguments)]
pub fn from_function<F>(
function: F,
name: String,
description: Option<String>,
tags: Option<Vec<String>>,
arguments: Option<HashMap<String, String>>,
meta: Option<Value>,
) -> Self
where
F: Fn(Option<HashMap<String, Value>>) -> Result<Vec<PromptMessage>, String> + Send + Sync + 'static,
{
Self {
function: Some(Arc::new(function)),
name,
description: description.unwrap_or_default(),
tags: tags.unwrap_or_default(),
arguments,
meta,
}
}
pub fn get(&self, arguments: Option<HashMap<String, Value>>) -> Result<Vec<PromptMessage>, String> {
if let Some(func) = &self.function {
func(arguments)
} else {
Err("Prompt function not available".to_string())
}
}
}
impl Serialize for FunctionPrompt {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let prompt = Prompt {
name: self.name.clone(),
description: if self.description.is_empty() { None } else { Some(self.description.clone()) },
tags: if self.tags.is_empty() { None } else { Some(self.tags.clone()) },
annotations: None, meta: match &self.meta {
Some(Value::Object(obj)) if !obj.is_empty() => {
let map: HashMap<String, Value> = obj.clone().into_iter().collect();
Some(map)
},
_ => None,
},
};
prompt.serialize(serializer)
}
}
impl std::fmt::Debug for FunctionPrompt {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FunctionPrompt")
.field("name", &self.name)
.field("description", &self.description)
.field("tags", &self.tags)
.field("arguments", &self.arguments)
.field("meta", &self.meta)
.finish()
}
}
#[derive(Debug, Clone)]
pub enum DuplicateBehavior {
Warn,
Error,
Replace,
Ignore,
}
#[derive(Debug, Clone)]
pub struct PromptManager {
prompts: HashMap<String, FunctionPrompt>,
duplicate_behavior: DuplicateBehavior,
}
impl PromptManager {
pub fn new() -> Self {
Self {
prompts: HashMap::new(),
duplicate_behavior: DuplicateBehavior::Warn,
}
}
pub fn with_behavior(duplicate_behavior: DuplicateBehavior) -> Self {
Self {
prompts: HashMap::new(),
duplicate_behavior,
}
}
}
impl Default for PromptManager {
fn default() -> Self {
Self::new()
}
}
impl PromptManager {
pub fn add_prompt(&mut self, prompt: FunctionPrompt) {
if self.prompts.contains_key(&prompt.name) {
match self.duplicate_behavior {
DuplicateBehavior::Warn => {
warn!("Prompt '{}' already exists, replacing", prompt.name);
self.prompts.insert(prompt.name.clone(), prompt);
}
DuplicateBehavior::Error => {
panic!("Prompt '{}' already exists", prompt.name);
}
DuplicateBehavior::Replace => {
self.prompts.insert(prompt.name.clone(), prompt);
}
DuplicateBehavior::Ignore => {
}
}
} else {
self.prompts.insert(prompt.name.clone(), prompt);
}
}
pub fn list_prompts(&self) -> Vec<Prompt> {
self.prompts.values().map(|p| {
Prompt {
name: p.name.clone(),
description: if p.description.is_empty() { None } else { Some(p.description.clone()) },
tags: if p.tags.is_empty() { None } else { Some(p.tags.clone()) },
annotations: None, meta: match &p.meta {
Some(Value::Object(obj)) if !obj.is_empty() => {
let map: HashMap<String, Value> = obj.clone().into_iter().collect();
Some(map)
},
_ => None,
},
}
}).collect()
}
#[allow(clippy::type_complexity)]
pub fn get_prompt_function(&self, name: &str) -> Option<PromptFunction> {
self.prompts.get(name).and_then(|prompt| prompt.function.clone())
}
pub fn get_prompt(&self, name: &str, arguments: Option<HashMap<String, Value>>) -> Result<Vec<PromptMessage>, String> {
if let Some(prompt) = self.prompts.get(name) {
prompt.get(arguments)
} else {
Err(format!("Prompt not found: {}", name))
}
}
}