use std::sync::Arc;
use rig::{
client::{CompletionClient, ProviderClient},
completion::Prompt,
providers::openai,
tool::{rmcp::McpClientHandler, server::ToolServer},
};
use rmcp::{
RoleServer, ServerHandler,
handler::server::{router::tool::ToolRouter, wrapper::Parameters},
model::*,
schemars,
service::RequestContext,
tool, tool_handler, tool_router,
};
use serde_json::json;
use tokio::sync::Mutex;
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server::conn::auto::Builder,
service::TowerToHyperService,
};
use rmcp::transport::streamable_http_server::{
StreamableHttpService, session::local::LocalSessionManager,
};
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
pub struct StructRequest {
pub a: i32,
pub b: i32,
}
#[derive(Clone)]
pub struct Counter {
pub counter: Arc<Mutex<i32>>,
tool_router: ToolRouter<Counter>,
}
impl Default for Counter {
fn default() -> Self {
Self::new()
}
}
#[tool_router]
impl Counter {
#[allow(dead_code)]
pub fn new() -> Self {
Self {
counter: Arc::new(Mutex::new(0)),
tool_router: Self::tool_router(),
}
}
fn _create_resource_text(&self, uri: &str, name: &str) -> Resource {
RawResource::new(uri, name.to_string()).no_annotation()
}
#[tool(description = "Calculate the sum of two numbers")]
fn sum(
&self,
Parameters(StructRequest { a, b }): Parameters<StructRequest>,
) -> Result<CallToolResult, ErrorData> {
Ok(CallToolResult::success(vec![Content::text(
(a + b).to_string(),
)]))
}
}
#[tool_handler]
impl ServerHandler for Counter {
fn get_info(&self) -> ServerInfo {
ServerInfo::new(
ServerCapabilities::builder()
.enable_resources()
.enable_tools()
.build(),
)
.with_protocol_version(ProtocolVersion::LATEST)
.with_server_info(Implementation::from_build_env())
.with_instructions("This server provides a counter tool that can increment and decrement values. The counter starts at 0 and can be modified using the 'increment' and 'decrement' tools. Use 'get_value' to check the current count.")
}
async fn list_resources(
&self,
_request: Option<PaginatedRequestParams>,
_: RequestContext<RoleServer>,
) -> Result<ListResourcesResult, ErrorData> {
Ok(ListResourcesResult {
resources: vec![
self._create_resource_text("str:////Users/to/some/path/", "cwd"),
self._create_resource_text("memo://insights", "memo-name"),
],
next_cursor: None,
meta: None,
})
}
async fn read_resource(
&self,
ReadResourceRequestParams { uri, .. }: ReadResourceRequestParams,
_: RequestContext<RoleServer>,
) -> Result<ReadResourceResult, ErrorData> {
match uri.as_str() {
"str:////Users/to/some/path/" => {
let cwd = "/Users/to/some/path/";
Ok(ReadResourceResult::new(vec![ResourceContents::text(
cwd, uri,
)]))
}
"memo://insights" => {
let memo = "Business Intelligence Memo\n\nAnalysis has revealed 5 key insights ...";
Ok(ReadResourceResult::new(vec![ResourceContents::text(
memo, uri,
)]))
}
_ => Err(ErrorData::resource_not_found(
"resource_not_found",
Some(json!({
"uri": uri
})),
)),
}
}
async fn list_resource_templates(
&self,
_request: Option<PaginatedRequestParams>,
_: RequestContext<RoleServer>,
) -> Result<ListResourceTemplatesResult, ErrorData> {
Ok(ListResourceTemplatesResult {
next_cursor: None,
resource_templates: Vec::new(),
meta: None,
})
}
async fn initialize(
&self,
_request: InitializeRequestParams,
context: RequestContext<RoleServer>,
) -> Result<InitializeResult, ErrorData> {
if let Some(http_request_part) = context.extensions.get::<axum::http::request::Parts>() {
let initialize_headers = &http_request_part.headers;
let initialize_uri = &http_request_part.uri;
tracing::info!(?initialize_headers, %initialize_uri, "initialize from http server");
}
Ok(self.get_info())
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt::init();
let service = TowerToHyperService::new(StreamableHttpService::new(
|| Ok(Counter::new()),
LocalSessionManager::default().into(),
Default::default(),
));
let listener = tokio::net::TcpListener::bind("localhost:8080").await?;
tokio::spawn({
let service = service.clone();
async move {
loop {
tokio::select! {
_ = tokio::signal::ctrl_c() => {
println!("Received Ctrl+C, shutting down");
break;
}
accept = listener.accept() => {
match accept {
Ok((stream, _addr)) => {
let io = TokioIo::new(stream);
let service = service.clone();
tokio::spawn(async move {
if let Err(e) = Builder::new(TokioExecutor::default())
.serve_connection(io, service)
.await
{
eprintln!("Connection error: {e:?}");
}
});
}
Err(e) => {
eprintln!("Accept error: {e:?}");
}
}
}
}
}
}
});
let client_info = ClientInfo::new(
ClientCapabilities::default(),
Implementation::new("rig-core", env!("CARGO_PKG_VERSION")),
);
let tool_server_handle = ToolServer::new().run();
let handler = McpClientHandler::new(client_info, tool_server_handle.clone());
let transport =
rmcp::transport::StreamableHttpClientTransport::from_uri("http://localhost:8080");
let mcp_service = handler.connect(transport).await.inspect_err(|e| {
tracing::error!("MCP client error: {:?}", e);
})?;
let server_info = mcp_service.peer_info();
tracing::info!("Connected to server: {server_info:#?}");
let openai_client = openai::Client::from_env();
let agent = openai_client
.agent(openai::GPT_4O)
.preamble("You are a helpful assistant who has access to a number of tools from an MCP server designed to be used for incrementing and decrementing a counter.")
.tool_server_handle(tool_server_handle)
.build();
let res = agent.prompt("What is 2+5?").max_turns(2).await.unwrap();
println!("GPT-4o: {res}");
Ok(())
}