use mcprotocol_rs::{
transport::{ClientTransportFactory, ServerTransportFactory, TransportConfig, TransportType},
Message, Method, Notification, Request, RequestId, Result,
};
use std::{collections::HashSet, time::Duration};
use tokio::{self, time::sleep, time::timeout};
const PING_INTERVAL: Duration = Duration::from_secs(60); const PING_TIMEOUT: Duration = Duration::from_secs(2);
const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
const SERVER_TIMEOUT: Duration = Duration::from_secs(300); const SERVER_PORT: u16 = 3000;
const SERVER_URL: &str = "127.0.0.1:3000";
const AUTH_TOKEN: &str = "test-auth-token";
#[tokio::main]
async fn main() -> Result<()> {
let server_handle = tokio::spawn(run_server());
sleep(Duration::from_millis(100)).await;
let client_handle = tokio::spawn(run_client());
match tokio::try_join!(server_handle, client_handle) {
Ok((server_result, client_result)) => {
server_result?;
client_result?;
Ok(())
}
Err(e) => {
eprintln!("Error in task execution: {}", e);
Err(mcprotocol_rs::Error::Transport(e.to_string()))
}
}
}
async fn run_server() -> Result<()> {
let config = TransportConfig {
transport_type: TransportType::Http {
base_url: SERVER_URL.to_string(),
auth_token: Some(AUTH_TOKEN.to_string()),
},
parameters: None,
};
let factory = ServerTransportFactory;
let mut server = factory.create(config)?;
server.initialize().await?;
eprintln!(
"Server started and waiting for ping requests on port {}",
SERVER_PORT
);
let (tx, mut rx) = tokio::sync::oneshot::channel::<()>();
let exit_signal = async move {
rx.await.ok();
};
tokio::select! {
_ = exit_signal => {
eprintln!("Server received exit signal");
}
_ = tokio::time::sleep(SERVER_TIMEOUT) => {
eprintln!("Server timeout after {} seconds", SERVER_TIMEOUT.as_secs());
}
}
server.close().await?;
eprintln!("Server stopped");
Ok(())
}
async fn run_client() -> Result<()> {
let mut session_ids = HashSet::new();
let mut ping_count = 0;
let total_pings = 3;
let config = TransportConfig {
transport_type: TransportType::Http {
base_url: format!("http://{}", SERVER_URL),
auth_token: Some(AUTH_TOKEN.to_string()),
},
parameters: None,
};
let factory = ClientTransportFactory;
let mut client = factory.create(config)?;
match timeout(CONNECTION_TIMEOUT, client.initialize()).await {
Ok(result) => result?,
Err(_) => {
return Err(mcprotocol_rs::Error::Transport(
"Client initialization timeout".into(),
))
}
}
eprintln!("Client started");
let start_time = std::time::Instant::now();
while ping_count < total_pings {
if start_time.elapsed() > SERVER_TIMEOUT - Duration::from_secs(30) {
eprintln!("Approaching server timeout, ending session");
break;
}
let request_id = RequestId::String(format!("ping-{}", ping_count + 1));
let ping_request = Request::new(Method::Ping, None, request_id.clone());
if !ping_request.validate_id_uniqueness(&mut session_ids) {
eprintln!("Request ID has already been used in this session");
break;
}
eprintln!("Sending ping request #{}", ping_count + 1);
client.send(Message::Request(ping_request.clone())).await?;
match timeout(PING_TIMEOUT, client.receive()).await {
Ok(Ok(Message::Response(response))) => {
if !request_id_matches(&request_id, &response.id) {
eprintln!(
"Received response with mismatched ID: expected {}, got {}",
request_id_to_string(&request_id),
request_id_to_string(&response.id)
);
continue;
}
if response.error.is_some() {
eprintln!("Received error response: {:?}", response.error);
break;
}
eprintln!("Received pong response #{}", ping_count + 1);
}
Ok(Ok(message)) => {
eprintln!("Unexpected message type: {:?}", message);
continue;
}
Ok(Err(e)) => {
eprintln!("Error receiving response: {}", e);
break;
}
Err(_) => {
eprintln!("Ping timeout for request #{}", ping_count + 1);
break;
}
}
ping_count += 1;
if ping_count < total_pings {
sleep(PING_INTERVAL.min(Duration::from_secs(30))).await;
}
}
if ping_count == total_pings {
let shutdown_request = Request::new(
Method::Shutdown,
None,
RequestId::String("shutdown".to_string()),
);
if shutdown_request.validate_id_uniqueness(&mut session_ids) {
client.send(Message::Request(shutdown_request)).await?;
match timeout(PING_TIMEOUT, client.receive()).await {
Ok(Ok(Message::Response(response))) => {
if response.error.is_some() {
eprintln!("Shutdown failed: {:?}", response.error);
} else {
let exit_notification = Notification::new(Method::Exit, None);
client
.send(Message::Notification(exit_notification))
.await?;
}
}
Ok(Ok(_)) => eprintln!("Unexpected response type"),
Ok(Err(e)) => eprintln!("Error receiving shutdown response: {}", e),
Err(_) => eprintln!("Shutdown response timeout"),
}
} else {
eprintln!("Shutdown request ID has already been used in this session");
}
}
client.close().await?;
eprintln!("Client stopped");
Ok(())
}
fn request_id_matches(request_id: &RequestId, response_id: &RequestId) -> bool {
match (request_id, response_id) {
(RequestId::String(req), RequestId::String(res)) => req == res,
(RequestId::Number(req), RequestId::Number(res)) => req == res,
_ => false,
}
}
fn request_id_to_string(id: &RequestId) -> String {
match id {
RequestId::String(s) => s.clone(),
RequestId::Number(n) => n.to_string(),
}
}