1use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6
7use crate::error::TraceError;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11#[serde(tag = "kind", rename_all = "snake_case")]
12pub enum PromptBody {
13 Text {
15 prompt: String,
17 },
18 Chat {
20 messages: Vec<ChatMessageTemplate>,
22 },
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27#[serde(tag = "type", rename_all = "lowercase")]
28pub enum ChatMessageTemplate {
29 #[serde(rename = "chatmessage")]
31 Message {
32 role: String,
34 content: String,
36 },
37 #[serde(rename = "placeholder")]
39 Placeholder {
40 name: String,
42 },
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct Prompt {
48 pub name: String,
50 pub version: u32,
52 pub body: PromptBody,
54 #[serde(default)]
56 pub config: serde_json::Value,
57 #[serde(default)]
59 pub labels: Vec<String>,
60}
61
62#[async_trait]
64pub trait PromptStore: Send + Sync {
65 async fn get(&self, name: &str) -> Result<Prompt, TraceError>;
68
69 async fn get_version(&self, name: &str, version: u32) -> Result<Prompt, TraceError>;
71
72 async fn get_label(&self, name: &str, label: &str) -> Result<Prompt, TraceError>;
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79
80 #[test]
81 fn prompt_body_text_round_trips() {
82 let p = PromptBody::Text {
83 prompt: "hi {name}".into(),
84 };
85 let s = serde_json::to_string(&p).unwrap();
86 let p2: PromptBody = serde_json::from_str(&s).unwrap();
87 match p2 {
88 PromptBody::Text { prompt } => assert_eq!(prompt, "hi {name}"),
89 _ => panic!("wrong variant"),
90 }
91 }
92
93 #[test]
94 fn prompt_body_chat_with_placeholder_round_trips() {
95 let p = PromptBody::Chat {
96 messages: vec![
97 ChatMessageTemplate::Message {
98 role: "system".into(),
99 content: "you are helpful".into(),
100 },
101 ChatMessageTemplate::Placeholder {
102 name: "history".into(),
103 },
104 ],
105 };
106 let s = serde_json::to_string(&p).unwrap();
107 let _: PromptBody = serde_json::from_str(&s).unwrap();
108 }
109}