use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::error::TraceError;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum PromptBody {
Text {
prompt: String,
},
Chat {
messages: Vec<ChatMessageTemplate>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ChatMessageTemplate {
#[serde(rename = "chatmessage")]
Message {
role: String,
content: String,
},
#[serde(rename = "placeholder")]
Placeholder {
name: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Prompt {
pub name: String,
pub version: u32,
pub body: PromptBody,
#[serde(default)]
pub config: serde_json::Value,
#[serde(default)]
pub labels: Vec<String>,
}
#[async_trait]
pub trait PromptStore: Send + Sync {
async fn get(&self, name: &str) -> Result<Prompt, TraceError>;
async fn get_version(&self, name: &str, version: u32) -> Result<Prompt, TraceError>;
async fn get_label(&self, name: &str, label: &str) -> Result<Prompt, TraceError>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn prompt_body_text_round_trips() {
let p = PromptBody::Text {
prompt: "hi {name}".into(),
};
let s = serde_json::to_string(&p).unwrap();
let p2: PromptBody = serde_json::from_str(&s).unwrap();
match p2 {
PromptBody::Text { prompt } => assert_eq!(prompt, "hi {name}"),
_ => panic!("wrong variant"),
}
}
#[test]
fn prompt_body_chat_with_placeholder_round_trips() {
let p = PromptBody::Chat {
messages: vec![
ChatMessageTemplate::Message {
role: "system".into(),
content: "you are helpful".into(),
},
ChatMessageTemplate::Placeholder {
name: "history".into(),
},
],
};
let s = serde_json::to_string(&p).unwrap();
let _: PromptBody = serde_json::from_str(&s).unwrap();
}
}