fastmcp-rs 0.2.0

Rust prototype for the FastMCP server
Documentation
use std::collections::HashSet;
use std::sync::Arc;

use indexmap::IndexMap;
use parking_lot::RwLock;
use regex::Regex;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use tracing::{trace, warn};

use crate::error::{FastMcpError, Result, expect_object};
use crate::tool::DuplicateBehavior;

fn annotations_is_empty(map: &Map<String, Value>) -> bool {
    map.is_empty()
}

fn params_is_none(value: &Option<Value>) -> bool {
    value.is_none()
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum PromptMessageContent {
    Text { text: String },
    Json { value: Value },
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PromptMessage {
    pub role: String,
    pub content: PromptMessageContent,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PromptDefinitionMetadata {
    pub name: String,
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub description: Option<String>,
    #[serde(default, skip_serializing_if = "params_is_none")]
    pub parameters: Option<Value>,
    #[serde(default, skip_serializing_if = "annotations_is_empty")]
    pub annotations: Map<String, Value>,
}

#[derive(Clone)]
pub struct PromptTemplate {
    pub name: String,
    pub description: Option<String>,
    pub parameters: Option<Value>,
    pub annotations: Map<String, Value>,
    pub messages: Vec<PromptMessage>,
    placeholder_pattern: Arc<Regex>,
}

impl PromptTemplate {
    pub fn new(name: impl Into<String>, messages: Vec<PromptMessage>) -> Self {
        let pattern = Regex::new(r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}").unwrap();
        Self {
            name: name.into(),
            description: None,
            parameters: None,
            annotations: Map::new(),
            messages,
            placeholder_pattern: Arc::new(pattern),
        }
    }

    pub fn with_description(mut self, description: impl Into<String>) -> Self {
        self.description = Some(description.into());
        self
    }

    pub fn with_parameters(mut self, schema: Value) -> Self {
        self.parameters = Some(schema);
        self
    }

    pub fn with_annotations(mut self, annotations: Map<String, Value>) -> Self {
        self.annotations = annotations;
        self
    }

    pub fn metadata(&self) -> PromptDefinitionMetadata {
        PromptDefinitionMetadata {
            name: self.name.clone(),
            description: self.description.clone(),
            parameters: self.parameters.clone(),
            annotations: self.annotations.clone(),
        }
    }

    pub fn instantiate(&self, arguments: Option<&Value>) -> Result<Vec<PromptMessage>> {
        let context = match (arguments, &self.parameters) {
            (Some(value), _) => expect_object(value, "prompt arguments")?.clone(),
            (None, Some(schema)) => {
                // Check if schema declares required fields.
                let required = schema
                    .get("required")
                    .and_then(|value| value.as_array())
                    .map(|arr| {
                        arr.iter()
                            .filter_map(|value| value.as_str().map(str::to_string))
                            .collect::<HashSet<_>>()
                    })
                    .unwrap_or_default();
                if !required.is_empty() {
                    return Err(FastMcpError::InvalidInvocation(format!(
                        "prompt '{}' expects parameters: {}",
                        self.name,
                        required.into_iter().collect::<Vec<_>>().join(", ")
                    )));
                }
                Map::new()
            }
            (None, None) => Map::new(),
        };

        let mut instantiated = Vec::with_capacity(self.messages.len());
        for message in &self.messages {
            instantiated.push(PromptMessage {
                role: message.role.clone(),
                content: match &message.content {
                    PromptMessageContent::Text { text } => PromptMessageContent::Text {
                        text: self.interpolate(text, &context)?,
                    },
                    PromptMessageContent::Json { value } => PromptMessageContent::Json {
                        value: value.clone(),
                    },
                },
            });
        }
        Ok(instantiated)
    }

    fn interpolate(&self, template: &str, ctx: &Map<String, Value>) -> Result<String> {
        let mut output = String::with_capacity(template.len());
        let mut last_match = 0;
        for capture in self.placeholder_pattern.captures_iter(template) {
            if let Some(m) = capture.get(0) {
                output.push_str(&template[last_match..m.start()]);
                let key = capture.get(1).expect("capture group missing").as_str();
                let replacement = ctx
                    .get(key)
                    .and_then(|value| {
                        if value.is_string() {
                            value.as_str().map(str::to_string)
                        } else if value.is_number() || value.is_boolean() {
                            Some(value.to_string())
                        } else {
                            Some(value.to_string())
                        }
                    })
                    .ok_or_else(|| {
                        FastMcpError::InvalidInvocation(format!(
                            "missing prompt argument '{key}' for prompt '{}'",
                            self.name
                        ))
                    })?;
                output.push_str(&replacement);
                last_match = m.end();
            }
        }
        output.push_str(&template[last_match..]);
        Ok(output)
    }
}

pub struct PromptManager {
    duplicate_behavior: DuplicateBehavior,
    prompts: RwLock<IndexMap<String, Arc<PromptTemplate>>>,
}

impl PromptManager {
    pub fn new(duplicate_behavior: DuplicateBehavior) -> Self {
        Self {
            duplicate_behavior,
            prompts: RwLock::new(IndexMap::new()),
        }
    }

    pub fn register(&self, prompt: PromptTemplate) -> Result<()> {
        let mut guard = self.prompts.write();
        match guard.get_mut(&prompt.name) {
            Some(existing) => match self.duplicate_behavior {
                DuplicateBehavior::Error => {
                    return Err(FastMcpError::DuplicatePrompt(prompt.name));
                }
                DuplicateBehavior::Ignore => {
                    trace!("Ignoring duplicate prompt {}", prompt.name);
                }
                DuplicateBehavior::Replace => {
                    trace!("Replacing prompt {}", prompt.name);
                    *existing = Arc::new(prompt);
                }
                DuplicateBehavior::Warn => {
                    warn!("Replacing duplicate prompt {}", prompt.name);
                    *existing = Arc::new(prompt);
                }
            },
            None => {
                guard.insert(prompt.name.clone(), Arc::new(prompt));
            }
        }
        Ok(())
    }

    pub fn list(&self) -> Vec<PromptDefinitionMetadata> {
        self.prompts
            .read()
            .values()
            .map(|prompt| prompt.metadata())
            .collect()
    }

    pub fn get(&self, name: &str) -> Result<Arc<PromptTemplate>> {
        self.prompts
            .read()
            .get(name)
            .cloned()
            .ok_or_else(|| FastMcpError::PromptNotFound(name.to_string()))
    }
}

#[cfg(test)]
mod tests {
    use serde_json::json;

    use super::*;

    #[test]
    fn instantiates_prompt_with_arguments() {
        let prompt = PromptTemplate::new(
            "welcome",
            vec![PromptMessage {
                role: "system".into(),
                content: PromptMessageContent::Text {
                    text: "Hello {{ user }}!".into(),
                },
            }],
        );

        let messages = prompt.instantiate(Some(&json!({ "user": "Dev" }))).unwrap();

        assert_eq!(messages.len(), 1);
        match &messages[0].content {
            PromptMessageContent::Text { text } => {
                assert_eq!(text, "Hello Dev!");
            }
            _ => panic!("expected text content"),
        }
    }
}