use std::collections::HashMap;
use std::convert::Infallible;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use pin_project_lite::pin_project;
use tokio::sync::Mutex;
use tower::util::BoxCloneService;
use tower::{Layer, ServiceExt};
use tower_service::Service;
use crate::context::RequestContext;
use crate::error::{Error, Result};
use crate::protocol::{
Content, GetPromptResult, PromptArgument, PromptDefinition, PromptMessage, PromptRole,
RequestId, ToolIcon,
};
pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
#[derive(Debug, Clone)]
pub struct PromptRequest {
pub context: RequestContext,
pub arguments: HashMap<String, String>,
}
impl PromptRequest {
pub fn new(context: RequestContext, arguments: HashMap<String, String>) -> Self {
Self { context, arguments }
}
pub fn with_arguments(arguments: HashMap<String, String>) -> Self {
Self {
context: RequestContext::new(RequestId::Number(0)),
arguments,
}
}
}
pub type BoxPromptService = BoxCloneService<PromptRequest, GetPromptResult, Infallible>;
#[doc(hidden)]
pub struct PromptCatchError<S> {
inner: S,
}
impl<S> PromptCatchError<S> {
pub fn new(inner: S) -> Self {
Self { inner }
}
}
impl<S: Clone> Clone for PromptCatchError<S> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl<S: fmt::Debug> fmt::Debug for PromptCatchError<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PromptCatchError")
.field("inner", &self.inner)
.finish()
}
}
pin_project! {
#[doc(hidden)]
pub struct PromptCatchErrorFuture<F> {
#[pin]
inner: F,
}
}
impl<F, E> Future for PromptCatchErrorFuture<F>
where
F: Future<Output = std::result::Result<GetPromptResult, E>>,
E: fmt::Display,
{
type Output = std::result::Result<GetPromptResult, Infallible>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().inner.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(response)) => Poll::Ready(Ok(response)),
Poll::Ready(Err(err)) => Poll::Ready(Ok(GetPromptResult {
description: Some(format!("Prompt error: {}", err)),
messages: vec![PromptMessage {
role: PromptRole::Assistant,
content: Content::Text {
text: format!("Error generating prompt: {}", err),
annotations: None,
meta: None,
},
meta: None,
}],
meta: None,
})),
}
}
}
impl<S> Service<PromptRequest> for PromptCatchError<S>
where
S: Service<PromptRequest, Response = GetPromptResult> + Clone + Send + 'static,
S::Error: fmt::Display + Send,
S::Future: Send,
{
type Response = GetPromptResult;
type Error = Infallible;
type Future = PromptCatchErrorFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(|_| unreachable!())
}
fn call(&mut self, req: PromptRequest) -> Self::Future {
PromptCatchErrorFuture {
inner: self.inner.call(req),
}
}
}
#[doc(hidden)]
pub struct PromptHandlerService<F> {
handler: F,
}
impl<F> Clone for PromptHandlerService<F>
where
F: Clone,
{
fn clone(&self) -> Self {
Self {
handler: self.handler.clone(),
}
}
}
impl<F, Fut> Service<PromptRequest> for PromptHandlerService<F>
where
F: Fn(HashMap<String, String>) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
{
type Response = GetPromptResult;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = std::result::Result<GetPromptResult, Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: PromptRequest) -> Self::Future {
let handler = self.handler.clone();
Box::pin(async move { handler(req.arguments).await })
}
}
#[doc(hidden)]
pub struct PromptContextHandlerService<F> {
handler: F,
}
impl<F> Clone for PromptContextHandlerService<F>
where
F: Clone,
{
fn clone(&self) -> Self {
Self {
handler: self.handler.clone(),
}
}
}
impl<F, Fut> Service<PromptRequest> for PromptContextHandlerService<F>
where
F: Fn(RequestContext, HashMap<String, String>) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
{
type Response = GetPromptResult;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = std::result::Result<GetPromptResult, Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: PromptRequest) -> Self::Future {
let handler = self.handler.clone();
Box::pin(async move { handler(req.context, req.arguments).await })
}
}
pub trait PromptHandler: Send + Sync {
fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>>;
fn get_with_context(
&self,
_ctx: RequestContext,
arguments: HashMap<String, String>,
) -> BoxFuture<'_, Result<GetPromptResult>> {
self.get(arguments)
}
fn uses_context(&self) -> bool {
false
}
}
pub struct Prompt {
pub name: String,
pub title: Option<String>,
pub description: Option<String>,
pub icons: Option<Vec<ToolIcon>>,
pub arguments: Vec<PromptArgument>,
handler: Arc<dyn PromptHandler>,
}
impl Clone for Prompt {
fn clone(&self) -> Self {
Self {
name: self.name.clone(),
title: self.title.clone(),
description: self.description.clone(),
icons: self.icons.clone(),
arguments: self.arguments.clone(),
handler: self.handler.clone(),
}
}
}
impl std::fmt::Debug for Prompt {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Prompt")
.field("name", &self.name)
.field("title", &self.title)
.field("description", &self.description)
.field("icons", &self.icons)
.field("arguments", &self.arguments)
.finish_non_exhaustive()
}
}
impl Prompt {
pub fn builder(name: impl Into<String>) -> PromptBuilder {
PromptBuilder::new(name)
}
pub fn definition(&self) -> PromptDefinition {
PromptDefinition {
name: self.name.clone(),
title: self.title.clone(),
description: self.description.clone(),
icons: self.icons.clone(),
arguments: self.arguments.clone(),
meta: None,
}
}
pub fn get(
&self,
arguments: HashMap<String, String>,
) -> BoxFuture<'_, Result<GetPromptResult>> {
self.handler.get(arguments)
}
pub fn get_with_context(
&self,
ctx: RequestContext,
arguments: HashMap<String, String>,
) -> BoxFuture<'_, Result<GetPromptResult>> {
self.handler.get_with_context(ctx, arguments)
}
pub fn uses_context(&self) -> bool {
self.handler.uses_context()
}
}
pub struct PromptBuilder {
name: String,
title: Option<String>,
description: Option<String>,
icons: Option<Vec<ToolIcon>>,
arguments: Vec<PromptArgument>,
}
impl PromptBuilder {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
title: None,
description: None,
icons: None,
arguments: Vec::new(),
}
}
pub fn title(mut self, title: impl Into<String>) -> Self {
self.title = Some(title.into());
self
}
pub fn description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn icon(mut self, src: impl Into<String>) -> Self {
self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
src: src.into(),
mime_type: None,
sizes: None,
theme: None,
});
self
}
pub fn icon_with_meta(
mut self,
src: impl Into<String>,
mime_type: Option<String>,
sizes: Option<Vec<String>>,
) -> Self {
self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
src: src.into(),
mime_type,
sizes,
theme: None,
});
self
}
pub fn required_arg(mut self, name: impl Into<String>, description: impl Into<String>) -> Self {
self.arguments.push(PromptArgument {
name: name.into(),
description: Some(description.into()),
required: true,
});
self
}
pub fn optional_arg(mut self, name: impl Into<String>, description: impl Into<String>) -> Self {
self.arguments.push(PromptArgument {
name: name.into(),
description: Some(description.into()),
required: false,
});
self
}
pub fn argument(mut self, arg: PromptArgument) -> Self {
self.arguments.push(arg);
self
}
pub fn handler<F, Fut>(self, handler: F) -> PromptBuilderWithHandler<F>
where
F: Fn(HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
{
PromptBuilderWithHandler {
name: self.name,
title: self.title,
description: self.description,
icons: self.icons,
arguments: self.arguments,
handler,
}
}
pub fn handler_with_context<F, Fut>(self, handler: F) -> PromptBuilderWithContextHandler<F>
where
F: Fn(RequestContext, HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
{
PromptBuilderWithContextHandler {
name: self.name,
title: self.title,
description: self.description,
icons: self.icons,
arguments: self.arguments,
handler,
}
}
pub fn static_prompt(self, messages: Vec<PromptMessage>) -> Prompt {
let description = self.description.clone();
self.handler(move |_| {
let messages = messages.clone();
let description = description.clone();
async move {
Ok(GetPromptResult {
description,
messages,
meta: None,
})
}
})
.build()
}
pub fn user_message(self, text: impl Into<String>) -> Prompt {
let text = text.into();
self.static_prompt(vec![PromptMessage {
role: PromptRole::User,
content: Content::Text {
text,
annotations: None,
meta: None,
},
meta: None,
}])
}
pub fn build<F, Fut>(self, handler: F) -> Prompt
where
F: Fn(HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
{
self.handler(handler).build()
}
}
#[doc(hidden)]
pub struct PromptBuilderWithHandler<F> {
name: String,
title: Option<String>,
description: Option<String>,
icons: Option<Vec<ToolIcon>>,
arguments: Vec<PromptArgument>,
handler: F,
}
impl<F, Fut> PromptBuilderWithHandler<F>
where
F: Fn(HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
{
pub fn build(self) -> Prompt {
Prompt {
name: self.name,
title: self.title,
description: self.description,
icons: self.icons,
arguments: self.arguments,
handler: Arc::new(FnHandler {
handler: self.handler,
}),
}
}
pub fn layer<L>(self, layer: L) -> Prompt
where
L: Layer<PromptHandlerService<F>> + Send + Sync + 'static,
L::Service: Service<PromptRequest, Response = GetPromptResult> + Clone + Send + 'static,
<L::Service as Service<PromptRequest>>::Error: fmt::Display + Send,
<L::Service as Service<PromptRequest>>::Future: Send,
{
let service = PromptHandlerService {
handler: self.handler,
};
let wrapped = layer.layer(service);
let boxed = BoxCloneService::new(PromptCatchError::new(wrapped));
Prompt {
name: self.name,
title: self.title,
description: self.description,
icons: self.icons,
arguments: self.arguments,
handler: Arc::new(ServiceHandler {
service: Mutex::new(boxed),
}),
}
}
}
#[doc(hidden)]
pub struct PromptBuilderWithContextHandler<F> {
name: String,
title: Option<String>,
description: Option<String>,
icons: Option<Vec<ToolIcon>>,
arguments: Vec<PromptArgument>,
handler: F,
}
impl<F, Fut> PromptBuilderWithContextHandler<F>
where
F: Fn(RequestContext, HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
{
pub fn build(self) -> Prompt {
Prompt {
name: self.name,
title: self.title,
description: self.description,
icons: self.icons,
arguments: self.arguments,
handler: Arc::new(ContextAwareHandler {
handler: self.handler,
}),
}
}
pub fn layer<L>(self, layer: L) -> Prompt
where
L: Layer<PromptContextHandlerService<F>> + Send + Sync + 'static,
L::Service: Service<PromptRequest, Response = GetPromptResult> + Clone + Send + 'static,
<L::Service as Service<PromptRequest>>::Error: fmt::Display + Send,
<L::Service as Service<PromptRequest>>::Future: Send,
{
let service = PromptContextHandlerService {
handler: self.handler,
};
let wrapped = layer.layer(service);
let boxed = BoxCloneService::new(PromptCatchError::new(wrapped));
Prompt {
name: self.name,
title: self.title,
description: self.description,
icons: self.icons,
arguments: self.arguments,
handler: Arc::new(ServiceContextHandler {
service: Mutex::new(boxed),
}),
}
}
}
struct FnHandler<F> {
handler: F,
}
impl<F, Fut> PromptHandler for FnHandler<F>
where
F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
{
fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
Box::pin((self.handler)(arguments))
}
}
struct ContextAwareHandler<F> {
handler: F,
}
impl<F, Fut> PromptHandler for ContextAwareHandler<F>
where
F: Fn(RequestContext, HashMap<String, String>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
{
fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
let ctx = RequestContext::new(RequestId::Number(0));
self.get_with_context(ctx, arguments)
}
fn get_with_context(
&self,
ctx: RequestContext,
arguments: HashMap<String, String>,
) -> BoxFuture<'_, Result<GetPromptResult>> {
Box::pin((self.handler)(ctx, arguments))
}
fn uses_context(&self) -> bool {
true
}
}
struct ServiceHandler {
service: Mutex<BoxPromptService>,
}
impl PromptHandler for ServiceHandler {
fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
Box::pin(async move {
let req = PromptRequest::with_arguments(arguments);
let mut service = self.service.lock().await.clone();
match service.ready().await {
Ok(svc) => svc.call(req).await.map_err(|e| match e {}),
Err(e) => match e {},
}
})
}
fn get_with_context(
&self,
ctx: RequestContext,
arguments: HashMap<String, String>,
) -> BoxFuture<'_, Result<GetPromptResult>> {
Box::pin(async move {
let req = PromptRequest::new(ctx, arguments);
let mut service = self.service.lock().await.clone();
match service.ready().await {
Ok(svc) => svc.call(req).await.map_err(|e| match e {}),
Err(e) => match e {},
}
})
}
}
struct ServiceContextHandler {
service: Mutex<BoxPromptService>,
}
impl PromptHandler for ServiceContextHandler {
fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
let ctx = RequestContext::new(RequestId::Number(0));
self.get_with_context(ctx, arguments)
}
fn get_with_context(
&self,
ctx: RequestContext,
arguments: HashMap<String, String>,
) -> BoxFuture<'_, Result<GetPromptResult>> {
Box::pin(async move {
let req = PromptRequest::new(ctx, arguments);
let mut service = self.service.lock().await.clone();
match service.ready().await {
Ok(svc) => svc.call(req).await.map_err(|e| match e {}),
Err(e) => match e {},
}
})
}
fn uses_context(&self) -> bool {
true
}
}
pub trait McpPrompt: Send + Sync + 'static {
const NAME: &'static str;
const DESCRIPTION: &'static str;
fn arguments(&self) -> Vec<PromptArgument> {
Vec::new()
}
fn get(
&self,
arguments: HashMap<String, String>,
) -> impl Future<Output = Result<GetPromptResult>> + Send;
fn into_prompt(self) -> Prompt
where
Self: Sized,
{
let arguments = self.arguments();
let prompt = Arc::new(self);
Prompt {
name: Self::NAME.to_string(),
title: None,
description: Some(Self::DESCRIPTION.to_string()),
icons: None,
arguments,
handler: Arc::new(McpPromptHandler { prompt }),
}
}
}
struct McpPromptHandler<T: McpPrompt> {
prompt: Arc<T>,
}
impl<T: McpPrompt> PromptHandler for McpPromptHandler<T> {
fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
let prompt = self.prompt.clone();
Box::pin(async move { prompt.get(arguments).await })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_builder_prompt() {
let prompt = PromptBuilder::new("greet")
.description("A greeting prompt")
.required_arg("name", "Name to greet")
.handler(|args| async move {
let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
Ok(GetPromptResult {
description: Some("Greeting".to_string()),
messages: vec![PromptMessage {
role: PromptRole::User,
content: Content::Text {
text: format!("Hello, {}!", name),
annotations: None,
meta: None,
},
meta: None,
}],
meta: None,
})
})
.build();
assert_eq!(prompt.name, "greet");
assert_eq!(prompt.description.as_deref(), Some("A greeting prompt"));
assert_eq!(prompt.arguments.len(), 1);
assert!(prompt.arguments[0].required);
let mut args = HashMap::new();
args.insert("name".to_string(), "Alice".to_string());
let result = prompt.get(args).await.unwrap();
assert_eq!(result.messages.len(), 1);
match &result.messages[0].content {
Content::Text { text, .. } => assert_eq!(text, "Hello, Alice!"),
_ => panic!("Expected text content"),
}
}
#[tokio::test]
async fn test_static_prompt() {
let prompt = PromptBuilder::new("help")
.description("Help prompt")
.user_message("How can I help you today?");
let result = prompt.get(HashMap::new()).await.unwrap();
assert_eq!(result.messages.len(), 1);
match &result.messages[0].content {
Content::Text { text, .. } => assert_eq!(text, "How can I help you today?"),
_ => panic!("Expected text content"),
}
}
#[tokio::test]
async fn test_trait_prompt() {
struct TestPrompt;
impl McpPrompt for TestPrompt {
const NAME: &'static str = "test";
const DESCRIPTION: &'static str = "A test prompt";
fn arguments(&self) -> Vec<PromptArgument> {
vec![PromptArgument {
name: "input".to_string(),
description: Some("Test input".to_string()),
required: true,
}]
}
async fn get(&self, args: HashMap<String, String>) -> Result<GetPromptResult> {
let input = args.get("input").map(|s| s.as_str()).unwrap_or("default");
Ok(GetPromptResult {
description: Some("Test".to_string()),
messages: vec![PromptMessage {
role: PromptRole::User,
content: Content::Text {
text: format!("Input: {}", input),
annotations: None,
meta: None,
},
meta: None,
}],
meta: None,
})
}
}
let prompt = TestPrompt.into_prompt();
assert_eq!(prompt.name, "test");
assert_eq!(prompt.arguments.len(), 1);
let mut args = HashMap::new();
args.insert("input".to_string(), "hello".to_string());
let result = prompt.get(args).await.unwrap();
match &result.messages[0].content {
Content::Text { text, .. } => assert_eq!(text, "Input: hello"),
_ => panic!("Expected text content"),
}
}
#[test]
fn test_prompt_definition() {
let prompt = PromptBuilder::new("test")
.description("Test description")
.required_arg("arg1", "First arg")
.optional_arg("arg2", "Second arg")
.user_message("Test");
let def = prompt.definition();
assert_eq!(def.name, "test");
assert_eq!(def.description.as_deref(), Some("Test description"));
assert_eq!(def.arguments.len(), 2);
assert!(def.arguments[0].required);
assert!(!def.arguments[1].required);
}
#[tokio::test]
async fn test_handler_with_context() {
let prompt = PromptBuilder::new("context_prompt")
.description("A prompt with context")
.handler_with_context(|ctx: RequestContext, args| async move {
let _ = ctx.is_cancelled();
let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
Ok(GetPromptResult {
description: Some("Context prompt".to_string()),
messages: vec![PromptMessage {
role: PromptRole::User,
content: Content::Text {
text: format!("Hello, {}!", name),
annotations: None,
meta: None,
},
meta: None,
}],
meta: None,
})
})
.build();
assert_eq!(prompt.name, "context_prompt");
assert!(prompt.uses_context());
let ctx = RequestContext::new(RequestId::Number(1));
let mut args = HashMap::new();
args.insert("name".to_string(), "Alice".to_string());
let result = prompt.get_with_context(ctx, args).await.unwrap();
match &result.messages[0].content {
Content::Text { text, .. } => assert_eq!(text, "Hello, Alice!"),
_ => panic!("Expected text content"),
}
}
#[tokio::test]
async fn test_prompt_with_timeout_layer() {
use std::time::Duration;
use tower::timeout::TimeoutLayer;
let prompt = PromptBuilder::new("timeout_prompt")
.description("A prompt with timeout")
.handler(|args: HashMap<String, String>| async move {
let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
Ok(GetPromptResult {
description: Some("Timeout prompt".to_string()),
messages: vec![PromptMessage {
role: PromptRole::User,
content: Content::Text {
text: format!("Hello, {}!", name),
annotations: None,
meta: None,
},
meta: None,
}],
meta: None,
})
})
.layer(TimeoutLayer::new(Duration::from_secs(5)));
assert_eq!(prompt.name, "timeout_prompt");
let mut args = HashMap::new();
args.insert("name".to_string(), "Alice".to_string());
let result = prompt.get(args).await.unwrap();
match &result.messages[0].content {
Content::Text { text, .. } => assert_eq!(text, "Hello, Alice!"),
_ => panic!("Expected text content"),
}
}
#[tokio::test]
async fn test_prompt_timeout_expires() {
use std::time::Duration;
use tower::timeout::TimeoutLayer;
let prompt = PromptBuilder::new("slow_prompt")
.description("A slow prompt")
.handler(|_args: HashMap<String, String>| async move {
tokio::time::sleep(Duration::from_secs(1)).await;
Ok(GetPromptResult {
description: Some("Slow prompt".to_string()),
messages: vec![PromptMessage {
role: PromptRole::User,
content: Content::Text {
text: "This should not appear".to_string(),
annotations: None,
meta: None,
},
meta: None,
}],
meta: None,
})
})
.layer(TimeoutLayer::new(Duration::from_millis(50)));
let result = prompt.get(HashMap::new()).await.unwrap();
assert!(result.description.as_ref().unwrap().contains("error"));
match &result.messages[0].content {
Content::Text { text, .. } => {
assert!(text.contains("Error generating prompt"));
}
_ => panic!("Expected text content"),
}
}
#[tokio::test]
async fn test_context_handler_with_layer() {
use std::time::Duration;
use tower::timeout::TimeoutLayer;
let prompt = PromptBuilder::new("context_timeout")
.description("Context prompt with timeout")
.handler_with_context(
|_ctx: RequestContext, args: HashMap<String, String>| async move {
let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
Ok(GetPromptResult {
description: Some("Context timeout".to_string()),
messages: vec![PromptMessage {
role: PromptRole::User,
content: Content::Text {
text: format!("Hello, {}!", name),
annotations: None,
meta: None,
},
meta: None,
}],
meta: None,
})
},
)
.layer(TimeoutLayer::new(Duration::from_secs(5)));
assert_eq!(prompt.name, "context_timeout");
assert!(prompt.uses_context());
let ctx = RequestContext::new(RequestId::Number(1));
let mut args = HashMap::new();
args.insert("name".to_string(), "Bob".to_string());
let result = prompt.get_with_context(ctx, args).await.unwrap();
match &result.messages[0].content {
Content::Text { text, .. } => assert_eq!(text, "Hello, Bob!"),
_ => panic!("Expected text content"),
}
}
#[test]
fn test_prompt_request_construction() {
let args: HashMap<String, String> = [("key".to_string(), "value".to_string())]
.into_iter()
.collect();
let req = PromptRequest::with_arguments(args.clone());
assert_eq!(req.arguments.get("key"), Some(&"value".to_string()));
let ctx = RequestContext::new(RequestId::Number(42));
let req2 = PromptRequest::new(ctx, args);
assert_eq!(req2.arguments.get("key"), Some(&"value".to_string()));
}
#[test]
fn test_prompt_catch_error_clone() {
let handler = PromptHandlerService {
handler: |_args: HashMap<String, String>| async {
Ok::<GetPromptResult, Error>(GetPromptResult {
description: None,
messages: vec![],
meta: None,
})
},
};
let catch_error = PromptCatchError::new(handler);
let _clone = catch_error.clone();
}
#[tokio::test]
async fn test_prompt_handler_with_arguments() {
let prompt = PromptBuilder::new("greet")
.description("Greeting prompt")
.required_arg("name", "Person to greet")
.optional_arg("style", "Greeting style")
.handler(|args: HashMap<String, String>| async move {
let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
let style = args.get("style").map(|s| s.as_str()).unwrap_or("casual");
let text = match style {
"formal" => format!("Good day, {name}."),
_ => format!("Hey {name}!"),
};
Ok(GetPromptResult::user_message(text))
})
.build();
let mut args = HashMap::new();
args.insert("name".to_string(), "Alice".to_string());
args.insert("style".to_string(), "formal".to_string());
let result = prompt.get(args).await.unwrap();
assert_eq!(result.messages.len(), 1);
let mut args = HashMap::new();
args.insert("name".to_string(), "Bob".to_string());
let result = prompt.get(args).await.unwrap();
assert_eq!(result.messages.len(), 1);
}
#[tokio::test]
async fn test_prompt_definition_fields() {
let prompt = PromptBuilder::new("test_prompt")
.title("Test Prompt")
.description("A test prompt")
.required_arg("input", "The input")
.optional_arg("format", "Output format")
.handler(|_args: HashMap<String, String>| async move {
Ok(GetPromptResult::user_message("test"))
})
.build();
let def = prompt.definition();
assert_eq!(def.name, "test_prompt");
assert_eq!(def.title.as_deref(), Some("Test Prompt"));
assert_eq!(def.description.as_deref(), Some("A test prompt"));
assert_eq!(def.arguments.len(), 2);
assert!(def.arguments[0].required);
assert!(!def.arguments[1].required);
}
#[tokio::test]
async fn test_prompt_with_context_handler() {
let prompt = PromptBuilder::new("ctx_prompt")
.description("Context-aware prompt")
.handler_with_context(
|ctx: RequestContext, args: HashMap<String, String>| async move {
let _ = ctx;
let name = args.get("name").map(|s| s.as_str()).unwrap_or("default");
Ok(GetPromptResult::user_message(format!("ctx: {name}")))
},
)
.build();
assert!(prompt.uses_context());
let mut args = HashMap::new();
args.insert("name".to_string(), "test".to_string());
let ctx = RequestContext::new(RequestId::Number(1));
let result: std::result::Result<GetPromptResult, Error> =
prompt.get_with_context(ctx, args).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().messages.len(), 1);
}
#[tokio::test]
async fn test_prompt_with_layer_catches_timeout() {
use std::time::Duration;
use tower::timeout::TimeoutLayer;
let prompt = PromptBuilder::new("slow_prompt")
.description("Will timeout")
.handler(|_args: HashMap<String, String>| async move {
tokio::time::sleep(Duration::from_secs(10)).await;
Ok(GetPromptResult::user_message("too late"))
})
.layer(TimeoutLayer::new(Duration::from_millis(10)));
let result = prompt.get(HashMap::new()).await;
match result {
Ok(r) => {
assert!(
!r.messages.is_empty(),
"Expected error message in prompt result"
);
}
Err(_) => {
}
}
}
#[tokio::test]
async fn test_prompt_clone() {
let prompt = PromptBuilder::new("cloneable")
.description("Can be cloned")
.handler(|_args: HashMap<String, String>| async move {
Ok(GetPromptResult::user_message("original"))
})
.build();
let cloned = prompt.clone();
assert_eq!(cloned.name, "cloneable");
let result = cloned.get(HashMap::new()).await.unwrap();
assert_eq!(result.messages.len(), 1);
}
}