use std::{future::Future, marker::PhantomData};
#[cfg(not(feature = "local"))]
use futures::future::BoxFuture;
use serde::de::DeserializeOwned;
use super::common::{AsRequestContext, FromContextPart};
pub use super::common::{Extension, RequestId};
use crate::{
RoleServer,
handler::server::wrapper::Parameters,
model::{GetPromptResult, PromptMessage},
service::{MaybeBoxFuture, MaybeSend, MaybeSendFuture, RequestContext},
};
#[non_exhaustive]
pub struct PromptContext<'a, S> {
pub server: &'a S,
pub name: String,
pub arguments: Option<serde_json::Map<String, serde_json::Value>>,
pub context: RequestContext<RoleServer>,
}
impl<'a, S> PromptContext<'a, S> {
pub fn new(
server: &'a S,
name: String,
arguments: Option<serde_json::Map<String, serde_json::Value>>,
context: RequestContext<RoleServer>,
) -> Self {
Self {
server,
name,
arguments,
context,
}
}
}
impl<S> AsRequestContext for PromptContext<'_, S> {
fn as_request_context(&self) -> &RequestContext<RoleServer> {
&self.context
}
fn as_request_context_mut(&mut self) -> &mut RequestContext<RoleServer> {
&mut self.context
}
}
pub trait GetPromptHandler<S, A> {
fn handle(
self,
context: PromptContext<'_, S>,
) -> MaybeBoxFuture<'_, Result<GetPromptResult, crate::ErrorData>>;
}
#[cfg(not(feature = "local"))]
pub type DynGetPromptHandler<S> = dyn for<'a> Fn(PromptContext<'a, S>) -> BoxFuture<'a, Result<GetPromptResult, crate::ErrorData>>
+ Send
+ Sync;
#[cfg(feature = "local")]
pub type DynGetPromptHandler<S> = dyn for<'a> Fn(
PromptContext<'a, S>,
) -> futures::future::LocalBoxFuture<
'a,
Result<GetPromptResult, crate::ErrorData>,
>;
pub struct AsyncMethodAdapter<T>(PhantomData<T>);
pub struct AsyncMethodWithArgsAdapter<T>(PhantomData<T>);
#[allow(clippy::type_complexity)]
pub struct AsyncPromptAdapter<P, Fut, R>(PhantomData<fn(P) -> fn(Fut) -> R>);
pub struct SyncPromptAdapter<P, R>(PhantomData<fn(P) -> R>);
pub struct AsyncPromptMethodAdapter<P, R>(PhantomData<fn(P) -> R>);
pub struct SyncPromptMethodAdapter<P, R>(PhantomData<fn(P) -> R>);
pub trait IntoGetPromptResult {
fn into_get_prompt_result(self) -> Result<GetPromptResult, crate::ErrorData>;
}
impl IntoGetPromptResult for GetPromptResult {
fn into_get_prompt_result(self) -> Result<GetPromptResult, crate::ErrorData> {
Ok(self)
}
}
impl IntoGetPromptResult for Vec<PromptMessage> {
fn into_get_prompt_result(self) -> Result<GetPromptResult, crate::ErrorData> {
Ok(GetPromptResult {
description: None,
messages: self,
})
}
}
impl<T: IntoGetPromptResult> IntoGetPromptResult for Result<T, crate::ErrorData> {
fn into_get_prompt_result(self) -> Result<GetPromptResult, crate::ErrorData> {
self.and_then(|v| v.into_get_prompt_result())
}
}
pin_project_lite::pin_project! {
#[project = IntoGetPromptResultFutProj]
#[non_exhaustive]
pub enum IntoGetPromptResultFut<F, R> {
Pending {
#[pin]
fut: F,
_marker: PhantomData<R>,
},
Ready {
#[pin]
result: futures::future::Ready<Result<GetPromptResult, crate::ErrorData>>,
}
}
}
impl<F, R> Future for IntoGetPromptResultFut<F, R>
where
F: Future<Output = R>,
R: IntoGetPromptResult,
{
type Output = Result<GetPromptResult, crate::ErrorData>;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
match self.project() {
IntoGetPromptResultFutProj::Pending { fut, _marker } => fut
.poll(cx)
.map(IntoGetPromptResult::into_get_prompt_result),
IntoGetPromptResultFutProj::Ready { result } => result.poll(cx),
}
}
}
#[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")]
pub struct PromptName(pub String);
impl<S> FromContextPart<PromptContext<'_, S>> for PromptName {
fn from_context_part(context: &mut PromptContext<S>) -> Result<Self, crate::ErrorData> {
Ok(Self(context.name.clone()))
}
}
impl<S, P> FromContextPart<PromptContext<'_, S>> for Parameters<P>
where
P: DeserializeOwned,
{
fn from_context_part(context: &mut PromptContext<S>) -> Result<Self, crate::ErrorData> {
let params = if let Some(args_map) = context.arguments.take() {
let args_value = serde_json::Value::Object(args_map);
serde_json::from_value::<P>(args_value).map_err(|e| {
crate::ErrorData::invalid_params(format!("Failed to parse parameters: {}", e), None)
})?
} else {
serde_json::from_value::<P>(serde_json::json!({})).map_err(|e| {
crate::ErrorData::invalid_params(
format!("Missing required parameters: {}", e),
None,
)
})?
};
Ok(Parameters(params))
}
}
macro_rules! impl_prompt_handler_for {
($($T: ident)*) => {
impl_prompt_handler_for!([] [$($T)*]);
};
([$($Tn: ident)*] []) => {
impl_prompt_handler_for!(@impl $($Tn)*);
};
([$($Tn: ident)*] [$Tn_1: ident $($Rest: ident)*]) => {
impl_prompt_handler_for!(@impl $($Tn)*);
impl_prompt_handler_for!([$($Tn)* $Tn_1] [$($Rest)*]);
};
(@impl $($Tn: ident)*) => {
impl<$($Tn,)* S, F, R> GetPromptHandler<S, ($($Tn,)*)> for F
where
$(
$Tn: for<'a> FromContextPart<PromptContext<'a, S>> + MaybeSendFuture,
)*
F: FnOnce(&S, $($Tn,)*) -> MaybeBoxFuture<'_, R> + MaybeSendFuture,
R: IntoGetPromptResult + MaybeSendFuture + 'static,
S: MaybeSend + 'static,
{
#[allow(unused_variables, non_snake_case, unused_mut)]
fn handle(
self,
mut context: PromptContext<'_, S>,
) -> MaybeBoxFuture<'_, Result<GetPromptResult, crate::ErrorData>>
{
$(
let result = $Tn::from_context_part(&mut context);
let $Tn = match result {
Ok(value) => value,
Err(e) => return Box::pin(std::future::ready(Err(e))),
};
)*
let service = context.server;
let fut = self(service, $($Tn,)*);
Box::pin(async move {
let result = fut.await;
result.into_get_prompt_result()
})
}
}
impl<$($Tn,)* S, F, R> GetPromptHandler<S, SyncPromptMethodAdapter<($($Tn,)*), R>> for F
where
$(
$Tn: for<'a> FromContextPart<PromptContext<'a, S>> + MaybeSendFuture,
)*
F: FnOnce(&S, $($Tn,)*) -> R + MaybeSendFuture,
R: IntoGetPromptResult + MaybeSendFuture,
S: MaybeSend,
{
#[allow(unused_variables, non_snake_case, unused_mut)]
fn handle(
self,
mut context: PromptContext<'_, S>,
) -> MaybeBoxFuture<'_, Result<GetPromptResult, crate::ErrorData>>
{
$(
let result = $Tn::from_context_part(&mut context);
let $Tn = match result {
Ok(value) => value,
Err(e) => return Box::pin(std::future::ready(Err(e))),
};
)*
let service = context.server;
let result = self(service, $($Tn,)*);
Box::pin(std::future::ready(result.into_get_prompt_result()))
}
}
impl<$($Tn,)* S, F, Fut, R> GetPromptHandler<S, AsyncPromptAdapter<($($Tn,)*), Fut, R>> for F
where
$(
$Tn: for<'a> FromContextPart<PromptContext<'a, S>> + MaybeSendFuture + 'static,
)*
F: FnOnce($($Tn,)*) -> Fut + MaybeSendFuture + 'static,
Fut: Future<Output = Result<R, crate::ErrorData>> + MaybeSendFuture + 'static,
R: IntoGetPromptResult + MaybeSendFuture + 'static,
S: MaybeSend + 'static,
{
#[allow(unused_variables, non_snake_case, unused_mut)]
fn handle(
self,
mut context: PromptContext<'_, S>,
) -> MaybeBoxFuture<'_, Result<GetPromptResult, crate::ErrorData>>
{
$(
let result = $Tn::from_context_part(&mut context);
let $Tn = match result {
Ok(value) => value,
Err(e) => return Box::pin(std::future::ready(Err(e))),
};
)*
Box::pin(async move {
let result = self($($Tn,)*).await?;
result.into_get_prompt_result()
})
}
}
impl<$($Tn,)* S, F, R> GetPromptHandler<S, SyncPromptAdapter<($($Tn,)*), R>> for F
where
$(
$Tn: for<'a> FromContextPart<PromptContext<'a, S>> + MaybeSendFuture + 'static,
)*
F: FnOnce($($Tn,)*) -> Result<R, crate::ErrorData> + MaybeSendFuture + 'static,
R: IntoGetPromptResult + MaybeSendFuture + 'static,
S: MaybeSend,
{
#[allow(unused_variables, non_snake_case, unused_mut)]
fn handle(
self,
mut context: PromptContext<'_, S>,
) -> MaybeBoxFuture<'_, Result<GetPromptResult, crate::ErrorData>>
{
$(
let result = $Tn::from_context_part(&mut context);
let $Tn = match result {
Ok(value) => value,
Err(e) => return Box::pin(std::future::ready(Err(e))),
};
)*
let result = self($($Tn,)*);
Box::pin(std::future::ready(result.and_then(|r| r.into_get_prompt_result())))
}
}
};
}
impl_prompt_handler_for!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15);
pub fn cached_arguments_from_schema<T: schemars::JsonSchema + std::any::Any>()
-> Option<Vec<crate::model::PromptArgument>> {
let schema = super::common::schema_for_type::<T>();
let schema_value = serde_json::Value::Object((*schema).clone());
let properties = schema_value.get("properties").and_then(|p| p.as_object());
if let Some(props) = properties {
let required = schema_value
.get("required")
.and_then(|r| r.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str())
.collect::<std::collections::HashSet<_>>()
})
.unwrap_or_default();
let mut arguments = Vec::new();
for (name, prop_schema) in props {
let description = prop_schema
.get("description")
.and_then(|d| d.as_str())
.map(|s| s.to_string());
arguments.push(crate::model::PromptArgument {
name: name.clone(),
title: None,
description,
required: Some(required.contains(name.as_str())),
});
}
if arguments.is_empty() {
None
} else {
Some(arguments)
}
} else {
None
}
}