use std::collections::HashSet;
use std::sync::Arc;
use indexmap::IndexMap;
use parking_lot::RwLock;
use regex::Regex;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use tracing::{trace, warn};
use crate::error::{FastMcpError, Result, expect_object};
use crate::tool::DuplicateBehavior;
fn annotations_is_empty(map: &Map<String, Value>) -> bool {
map.is_empty()
}
fn params_is_none(value: &Option<Value>) -> bool {
value.is_none()
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum PromptMessageContent {
Text { text: String },
Json { value: Value },
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PromptMessage {
pub role: String,
pub content: PromptMessageContent,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PromptDefinitionMetadata {
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(default, skip_serializing_if = "params_is_none")]
pub parameters: Option<Value>,
#[serde(default, skip_serializing_if = "annotations_is_empty")]
pub annotations: Map<String, Value>,
}
#[derive(Clone)]
pub struct PromptTemplate {
pub name: String,
pub description: Option<String>,
pub parameters: Option<Value>,
pub annotations: Map<String, Value>,
pub messages: Vec<PromptMessage>,
placeholder_pattern: Arc<Regex>,
}
impl PromptTemplate {
pub fn new(name: impl Into<String>, messages: Vec<PromptMessage>) -> Self {
let pattern = Regex::new(r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}").unwrap();
Self {
name: name.into(),
description: None,
parameters: None,
annotations: Map::new(),
messages,
placeholder_pattern: Arc::new(pattern),
}
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn with_parameters(mut self, schema: Value) -> Self {
self.parameters = Some(schema);
self
}
pub fn with_annotations(mut self, annotations: Map<String, Value>) -> Self {
self.annotations = annotations;
self
}
pub fn metadata(&self) -> PromptDefinitionMetadata {
PromptDefinitionMetadata {
name: self.name.clone(),
description: self.description.clone(),
parameters: self.parameters.clone(),
annotations: self.annotations.clone(),
}
}
pub fn instantiate(&self, arguments: Option<&Value>) -> Result<Vec<PromptMessage>> {
let context = match (arguments, &self.parameters) {
(Some(value), _) => expect_object(value, "prompt arguments")?.clone(),
(None, Some(schema)) => {
let required = schema
.get("required")
.and_then(|value| value.as_array())
.map(|arr| {
arr.iter()
.filter_map(|value| value.as_str().map(str::to_string))
.collect::<HashSet<_>>()
})
.unwrap_or_default();
if !required.is_empty() {
return Err(FastMcpError::InvalidInvocation(format!(
"prompt '{}' expects parameters: {}",
self.name,
required.into_iter().collect::<Vec<_>>().join(", ")
)));
}
Map::new()
}
(None, None) => Map::new(),
};
let mut instantiated = Vec::with_capacity(self.messages.len());
for message in &self.messages {
instantiated.push(PromptMessage {
role: message.role.clone(),
content: match &message.content {
PromptMessageContent::Text { text } => PromptMessageContent::Text {
text: self.interpolate(text, &context)?,
},
PromptMessageContent::Json { value } => PromptMessageContent::Json {
value: value.clone(),
},
},
});
}
Ok(instantiated)
}
fn interpolate(&self, template: &str, ctx: &Map<String, Value>) -> Result<String> {
let mut output = String::with_capacity(template.len());
let mut last_match = 0;
for capture in self.placeholder_pattern.captures_iter(template) {
if let Some(m) = capture.get(0) {
output.push_str(&template[last_match..m.start()]);
let key = capture.get(1).expect("capture group missing").as_str();
let replacement = ctx
.get(key)
.and_then(|value| {
if value.is_string() {
value.as_str().map(str::to_string)
} else if value.is_number() || value.is_boolean() {
Some(value.to_string())
} else {
Some(value.to_string())
}
})
.ok_or_else(|| {
FastMcpError::InvalidInvocation(format!(
"missing prompt argument '{key}' for prompt '{}'",
self.name
))
})?;
output.push_str(&replacement);
last_match = m.end();
}
}
output.push_str(&template[last_match..]);
Ok(output)
}
}
pub struct PromptManager {
duplicate_behavior: DuplicateBehavior,
prompts: RwLock<IndexMap<String, Arc<PromptTemplate>>>,
}
impl PromptManager {
pub fn new(duplicate_behavior: DuplicateBehavior) -> Self {
Self {
duplicate_behavior,
prompts: RwLock::new(IndexMap::new()),
}
}
pub fn register(&self, prompt: PromptTemplate) -> Result<()> {
let mut guard = self.prompts.write();
match guard.get_mut(&prompt.name) {
Some(existing) => match self.duplicate_behavior {
DuplicateBehavior::Error => {
return Err(FastMcpError::DuplicatePrompt(prompt.name));
}
DuplicateBehavior::Ignore => {
trace!("Ignoring duplicate prompt {}", prompt.name);
}
DuplicateBehavior::Replace => {
trace!("Replacing prompt {}", prompt.name);
*existing = Arc::new(prompt);
}
DuplicateBehavior::Warn => {
warn!("Replacing duplicate prompt {}", prompt.name);
*existing = Arc::new(prompt);
}
},
None => {
guard.insert(prompt.name.clone(), Arc::new(prompt));
}
}
Ok(())
}
pub fn list(&self) -> Vec<PromptDefinitionMetadata> {
self.prompts
.read()
.values()
.map(|prompt| prompt.metadata())
.collect()
}
pub fn get(&self, name: &str) -> Result<Arc<PromptTemplate>> {
self.prompts
.read()
.get(name)
.cloned()
.ok_or_else(|| FastMcpError::PromptNotFound(name.to_string()))
}
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::*;
#[test]
fn instantiates_prompt_with_arguments() {
let prompt = PromptTemplate::new(
"welcome",
vec![PromptMessage {
role: "system".into(),
content: PromptMessageContent::Text {
text: "Hello {{ user }}!".into(),
},
}],
);
let messages = prompt.instantiate(Some(&json!({ "user": "Dev" }))).unwrap();
assert_eq!(messages.len(), 1);
match &messages[0].content {
PromptMessageContent::Text { text } => {
assert_eq!(text, "Hello Dev!");
}
_ => panic!("expected text content"),
}
}
}