Skip to main content

cognis_trace/
prompts.rs

1//! `PromptStore` trait — pull versioned prompts from an external store.
2//! Concrete impls live alongside their backend (`exporters/langfuse/prompts.rs`).
3
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6
7use crate::error::TraceError;
8
9/// Two prompt shapes Langfuse supports.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11#[serde(tag = "kind", rename_all = "snake_case")]
12pub enum PromptBody {
13    /// Single string template.
14    Text {
15        /// Templated string.
16        prompt: String,
17    },
18    /// Sequence of role/content messages, possibly with placeholders.
19    Chat {
20        /// Templated messages.
21        messages: Vec<ChatMessageTemplate>,
22    },
23}
24
25/// One message in a chat prompt template.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27#[serde(tag = "type", rename_all = "lowercase")]
28pub enum ChatMessageTemplate {
29    /// Concrete role/content.
30    #[serde(rename = "chatmessage")]
31    Message {
32        /// Role string ("system", "user", "assistant", ...).
33        role: String,
34        /// Templated content.
35        content: String,
36    },
37    /// Placeholder for runtime-injected messages (Langfuse's `placeholder`).
38    #[serde(rename = "placeholder")]
39    Placeholder {
40        /// Placeholder name.
41        name: String,
42    },
43}
44
45/// A versioned prompt fetched from a `PromptStore`.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct Prompt {
48    /// Stable name.
49    pub name: String,
50    /// Monotonic version number.
51    pub version: u32,
52    /// The template body.
53    pub body: PromptBody,
54    /// Free-form config (e.g. `{"model": "gpt-4o", "temperature": 0.7}`).
55    #[serde(default)]
56    pub config: serde_json::Value,
57    /// Deployment labels (e.g. ["production", "experimental"]).
58    #[serde(default)]
59    pub labels: Vec<String>,
60}
61
62/// Fetcher for versioned prompts.
63#[async_trait]
64pub trait PromptStore: Send + Sync {
65    /// Fetch by name; the implementation chooses what "current" means
66    /// (latest version, "production" label, etc.).
67    async fn get(&self, name: &str) -> Result<Prompt, TraceError>;
68
69    /// Fetch a specific version.
70    async fn get_version(&self, name: &str, version: u32) -> Result<Prompt, TraceError>;
71
72    /// Fetch by label (e.g. "production").
73    async fn get_label(&self, name: &str, label: &str) -> Result<Prompt, TraceError>;
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79
80    #[test]
81    fn prompt_body_text_round_trips() {
82        let p = PromptBody::Text {
83            prompt: "hi {name}".into(),
84        };
85        let s = serde_json::to_string(&p).unwrap();
86        let p2: PromptBody = serde_json::from_str(&s).unwrap();
87        match p2 {
88            PromptBody::Text { prompt } => assert_eq!(prompt, "hi {name}"),
89            _ => panic!("wrong variant"),
90        }
91    }
92
93    #[test]
94    fn prompt_body_chat_with_placeholder_round_trips() {
95        let p = PromptBody::Chat {
96            messages: vec![
97                ChatMessageTemplate::Message {
98                    role: "system".into(),
99                    content: "you are helpful".into(),
100                },
101                ChatMessageTemplate::Placeholder {
102                    name: "history".into(),
103                },
104            ],
105        };
106        let s = serde_json::to_string(&p).unwrap();
107        let _: PromptBody = serde_json::from_str(&s).unwrap();
108    }
109}