#![allow(clippy::manual_async_fn)]
use elicitation_macros::elicit_trait_tools_router;
use rmcp::handler::server::wrapper::{Json, Parameters};
use rmcp::model::ServerInfo;
use rmcp::{ServerHandler, tool_router};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct EchoParams {
pub message: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct EchoResult {
pub echoed: String,
}
pub trait EchoTrait: Send + Sync {
fn echo(
&self,
params: Parameters<EchoParams>,
) -> impl std::future::Future<Output = Result<Json<EchoResult>, rmcp::ErrorData>> + Send;
}
pub struct EchoHandler;
impl EchoTrait for EchoHandler {
fn echo(
&self,
params: Parameters<EchoParams>,
) -> impl std::future::Future<Output = Result<Json<EchoResult>, rmcp::ErrorData>> + Send {
async move {
Ok(Json(EchoResult {
echoed: params.0.message,
}))
}
}
}
pub struct EchoServer<H: EchoTrait + 'static> {
handler: H,
}
#[elicit_trait_tools_router(EchoTrait, handler, [echo])]
#[tool_router(router = echo_tools)]
impl<H: EchoTrait + 'static> EchoServer<H> {}
impl<H: EchoTrait + 'static> ServerHandler for EchoServer<H> {
fn get_info(&self) -> ServerInfo {
ServerInfo::default()
}
}
#[test]
fn test_simple_trait_tool_generation() {
let handler = EchoHandler;
let _server = EchoServer { handler };
let _router = EchoServer::<EchoHandler>::echo_tools();
}
#[test]
fn test_simple_trait_has_echo_method() {
let handler = EchoHandler;
let server = EchoServer { handler };
let _ = &server; }
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct AddParams {
pub a: i32,
pub b: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct AddResult {
pub result: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct MultiplyParams {
pub x: i32,
pub y: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct MultiplyResult {
pub result: i32,
}
pub trait MathOps: Send + Sync {
fn add(
&self,
params: Parameters<AddParams>,
) -> impl std::future::Future<Output = Result<Json<AddResult>, rmcp::ErrorData>> + Send;
fn multiply(
&self,
params: Parameters<MultiplyParams>,
) -> impl std::future::Future<Output = Result<Json<MultiplyResult>, rmcp::ErrorData>> + Send;
}
pub struct Calculator;
impl MathOps for Calculator {
fn add(
&self,
params: Parameters<AddParams>,
) -> impl std::future::Future<Output = Result<Json<AddResult>, rmcp::ErrorData>> + Send {
async move {
Ok(Json(AddResult {
result: params.0.a + params.0.b,
}))
}
}
fn multiply(
&self,
params: Parameters<MultiplyParams>,
) -> impl std::future::Future<Output = Result<Json<MultiplyResult>, rmcp::ErrorData>> + Send
{
async move {
Ok(Json(MultiplyResult {
result: params.0.x * params.0.y,
}))
}
}
}
pub struct MathServer<C: MathOps + 'static> {
calculator: C,
}
#[elicit_trait_tools_router(MathOps, calculator, [add, multiply])]
#[tool_router(router = math_tools)]
impl<C: MathOps + 'static> MathServer<C> {}
impl<C: MathOps + 'static> ServerHandler for MathServer<C> {
fn get_info(&self) -> ServerInfo {
ServerInfo::default()
}
}
#[test]
fn test_multiple_methods_compile() {
let calc = Calculator;
let _server = MathServer { calculator: calc };
let _router = MathServer::<Calculator>::math_tools();
}
#[test]
fn test_multiple_methods_exist() {
let calc = Calculator;
let server = MathServer { calculator: calc };
let _ = &server; }
#[test]
fn test_tool_router_discovers_methods() {
let calc = Calculator;
let _server = MathServer { calculator: calc };
let router = MathServer::<Calculator>::math_tools();
let tools = router.list_all();
assert_eq!(tools.len(), 2);
let tool_names: Vec<String> = tools.iter().map(|t| t.name.to_string()).collect();
assert!(tool_names.contains(&"add".to_string()));
assert!(tool_names.contains(&"multiply".to_string()));
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct GreetParams {
pub name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct GreetResult {
pub greeting: String,
}
#[async_trait::async_trait]
pub trait Greeter: Send + Sync {
async fn greet(
&self,
params: Parameters<GreetParams>,
) -> Result<Json<GreetResult>, rmcp::ErrorData>;
}
pub struct SimpleGreeter;
#[async_trait::async_trait]
impl Greeter for SimpleGreeter {
async fn greet(
&self,
params: Parameters<GreetParams>,
) -> Result<Json<GreetResult>, rmcp::ErrorData> {
Ok(Json(GreetResult {
greeting: format!("Hello, {}!", params.0.name),
}))
}
}
pub struct GreeterServer<G: Greeter + 'static> {
greeter: G,
}
#[elicit_trait_tools_router(Greeter, greeter, [greet])]
#[tool_router(router = greeter_tools)]
impl<G: Greeter + 'static> GreeterServer<G> {}
impl<G: Greeter + 'static> ServerHandler for GreeterServer<G> {
fn get_info(&self) -> ServerInfo {
ServerInfo::default()
}
}
#[test]
fn test_async_trait_tool_generation() {
let greeter = SimpleGreeter;
let _server = GreeterServer { greeter };
let _router = GreeterServer::<SimpleGreeter>::greeter_tools();
}
#[test]
fn test_async_trait_tool_router_integration() {
let greeter = SimpleGreeter;
let _server = GreeterServer { greeter };
let router = GreeterServer::<SimpleGreeter>::greeter_tools();
let tools = router.list_all();
assert_eq!(tools.len(), 1);
let tool_names: Vec<String> = tools.iter().map(|t| t.name.to_string()).collect();
assert!(tool_names.contains(&"greet".to_string()));
}