use crate::base::entity::node::Node;
use chrono::Utc;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::hash::{Hash, Hasher};
use thiserror::Error;
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct PromptItem {
pub node: Node<PromptData>,
}
impl Hash for PromptItem {
fn hash<H: Hasher>(&self, state: &mut H) {
self.node.hash(state);
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct PromptData {
pub prompt_type: PromptType,
pub content_attachments: Vec<Uuid>, pub parameters: PromptParameters,
pub context: Option<String>,
pub expected_output: Option<String>,
pub tags: Option<Vec<String>>,
pub category: Option<String>,
pub author: Option<String>,
pub metadata: BTreeMap<String, String>,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub enum PromptType {
Text(TextPrompt),
System(SystemPrompt),
User(UserPrompt),
Assistant(AssistantPrompt),
Function(FunctionPrompt),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptParameters {
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub frequency_penalty: Option<f32>,
pub presence_penalty: Option<f32>,
pub stop_sequences: Option<Vec<String>>,
}
impl PartialEq for PromptParameters {
fn eq(&self, other: &Self) -> bool {
self.max_tokens == other.max_tokens
&& self.temperature == other.temperature
&& self.top_p == other.top_p
&& self.frequency_penalty == other.frequency_penalty
&& self.presence_penalty == other.presence_penalty
&& self.stop_sequences == other.stop_sequences
}
}
impl Eq for PromptParameters {}
impl Ord for PromptParameters {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match self.max_tokens.cmp(&other.max_tokens) {
std::cmp::Ordering::Equal => {}
ord => return ord,
}
match (self.temperature, other.temperature) {
(Some(a), Some(b)) => match a.to_bits().cmp(&b.to_bits()) {
std::cmp::Ordering::Equal => {}
ord => return ord,
},
(None, None) => {}
(Some(_), None) => return std::cmp::Ordering::Greater,
(None, Some(_)) => return std::cmp::Ordering::Less,
}
match (self.top_p, other.top_p) {
(Some(a), Some(b)) => match a.to_bits().cmp(&b.to_bits()) {
std::cmp::Ordering::Equal => {}
ord => return ord,
},
(None, None) => {}
(Some(_), None) => return std::cmp::Ordering::Greater,
(None, Some(_)) => return std::cmp::Ordering::Less,
}
match (self.frequency_penalty, other.frequency_penalty) {
(Some(a), Some(b)) => match a.to_bits().cmp(&b.to_bits()) {
std::cmp::Ordering::Equal => {}
ord => return ord,
},
(None, None) => {}
(Some(_), None) => return std::cmp::Ordering::Greater,
(None, Some(_)) => return std::cmp::Ordering::Less,
}
match (self.presence_penalty, other.presence_penalty) {
(Some(a), Some(b)) => match a.to_bits().cmp(&b.to_bits()) {
std::cmp::Ordering::Equal => {}
ord => return ord,
},
(None, None) => {}
(Some(_), None) => return std::cmp::Ordering::Greater,
(None, Some(_)) => return std::cmp::Ordering::Less,
}
self.stop_sequences.cmp(&other.stop_sequences)
}
}
impl PartialOrd for PromptParameters {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Hash for PromptParameters {
fn hash<H: Hasher>(&self, state: &mut H) {
self.max_tokens.hash(state);
if let Some(temp) = self.temperature {
temp.to_bits().hash(state);
}
if let Some(top_p) = self.top_p {
top_p.to_bits().hash(state);
}
if let Some(freq) = self.frequency_penalty {
freq.to_bits().hash(state);
}
if let Some(pres) = self.presence_penalty {
pres.to_bits().hash(state);
}
self.stop_sequences.hash(state);
}
}
impl Default for PromptParameters {
fn default() -> Self {
Self {
max_tokens: Some(1000),
temperature: Some(0.7),
top_p: Some(1.0),
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
stop_sequences: None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct TextPrompt {
pub content: String,
pub role: PromptRole,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub enum PromptRole {
System,
User,
Assistant,
Function,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct SystemPrompt {
pub instructions: String,
pub constraints: Option<Vec<String>>,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct UserPrompt {
pub query: String,
pub context: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct AssistantPrompt {
pub response: String,
pub reasoning: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct FunctionPrompt {
pub function_name: String,
pub arguments: BTreeMap<String, String>,
pub description: Option<String>,
}
impl PromptItem {
pub fn new(prompt_type: PromptType) -> Result<Self, PromptItemError> {
let prompt_data = PromptData {
prompt_type,
content_attachments: Vec::new(),
parameters: PromptParameters::default(),
context: None,
expected_output: None,
tags: None,
category: None,
author: None,
metadata: BTreeMap::new(),
};
let node = Node::new(prompt_data, None);
Ok(PromptItem { node })
}
pub fn new_with_title(prompt_type: PromptType, title: String) -> Result<Self, PromptItemError> {
let prompt_data = PromptData {
prompt_type,
content_attachments: Vec::new(),
parameters: PromptParameters::default(),
context: None,
expected_output: None,
tags: None,
category: None,
author: None,
metadata: BTreeMap::new(),
};
let node = Node::new(prompt_data, Some(title));
Ok(PromptItem { node })
}
pub fn uuid(&self) -> Uuid {
self.node.uuid
}
pub fn title(&self) -> Option<&String> {
self.node.name.as_ref()
}
pub fn prompt_type(&self) -> &PromptType {
&self.node.node.prompt_type
}
pub fn parameters(&self) -> &PromptParameters {
&self.node.node.parameters
}
pub fn content_attachments(&self) -> &[Uuid] {
&self.node.node.content_attachments
}
pub fn add_content_attachment(&mut self, content_id: Uuid) {
self.node.node.content_attachments.push(content_id);
self.node.modified = Utc::now();
}
pub fn set_parameters(&mut self, parameters: PromptParameters) {
self.node.node.parameters = parameters;
self.node.modified = Utc::now();
}
pub fn set_context(&mut self, context: Option<String>) {
self.node.node.context = context;
self.node.modified = Utc::now();
}
}
#[derive(Debug, Clone, Error)]
pub enum PromptItemError {
#[error("Invalid prompt configuration: {0}")]
InvalidConfiguration(String),
#[error("Missing required field: {0}")]
MissingField(String),
#[error("Invalid parameter value: {0}")]
InvalidParameter(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prompt_parameters_equality() {
let params1 = PromptParameters {
max_tokens: Some(100),
temperature: Some(0.5),
top_p: Some(0.9),
frequency_penalty: Some(0.1),
presence_penalty: Some(0.2),
stop_sequences: Some(vec!["END".to_string()]),
};
let params2 = PromptParameters {
max_tokens: Some(100),
temperature: Some(0.5),
top_p: Some(0.9),
frequency_penalty: Some(0.1),
presence_penalty: Some(0.2),
stop_sequences: Some(vec!["END".to_string()]),
};
assert_eq!(params1, params2);
}
#[test]
fn test_prompt_parameters_ordering() {
let params1 = PromptParameters {
max_tokens: Some(100),
..Default::default()
};
let params2 = PromptParameters {
max_tokens: Some(200),
..Default::default()
};
assert!(params1 < params2);
}
#[test]
fn test_text_prompt_creation() {
let text_prompt = TextPrompt {
content: "Hello, world!".to_string(),
role: PromptRole::User,
};
let prompt_item =
PromptItem::new_with_title(PromptType::Text(text_prompt), "Test Prompt".to_string());
assert!(prompt_item.is_ok());
let item = prompt_item.unwrap();
assert_eq!(item.title(), Some(&"Test Prompt".to_string()));
match item.prompt_type() {
PromptType::Text(text) => {
assert_eq!(text.content, "Hello, world!");
assert_eq!(text.role, PromptRole::User);
}
_ => panic!("Expected text prompt"),
}
}
#[test]
fn test_prompt_item_modifications() {
let text_prompt = TextPrompt {
content: "Test content".to_string(),
role: PromptRole::System,
};
let mut prompt_item = PromptItem::new_with_title(
PromptType::Text(text_prompt),
"Modifiable Prompt".to_string(),
)
.unwrap();
let content_id = Uuid::new_v4();
prompt_item.add_content_attachment(content_id);
assert_eq!(prompt_item.content_attachments().len(), 1);
assert_eq!(prompt_item.content_attachments()[0], content_id);
let new_params = PromptParameters {
max_tokens: Some(500),
temperature: Some(0.8),
..Default::default()
};
prompt_item.set_parameters(new_params.clone());
assert_eq!(prompt_item.parameters(), &new_params);
prompt_item.set_context(Some("Test context".to_string()));
assert_eq!(
prompt_item.node.node.context,
Some("Test context".to_string())
);
}
}