use crate::error::{McpSdkError, ProtocolErrorKind, SdkResult};
use crate::schema::{ClientMessages, ProtocolVersion, SdkError};
use std::cmp::Ordering;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use time::format_description::well_known::Iso8601;
use time::OffsetDateTime;
#[cfg(feature = "auth")]
use url::Url;
pub struct AbortTaskOnDrop {
pub handle: tokio::task::AbortHandle,
}
impl Drop for AbortTaskOnDrop {
fn drop(&mut self) {
self.handle.abort();
}
}
pub fn unix_timestamp_to_systemtime(timestamp: u64) -> SystemTime {
UNIX_EPOCH + Duration::from_secs(timestamp)
}
#[allow(unused)]
pub fn ensure_server_protocole_compatibility(
client_protocol_version: &str,
server_protocol_version: &str,
) -> SdkResult<()> {
match client_protocol_version.cmp(server_protocol_version) {
Ordering::Less | Ordering::Greater => Err(McpSdkError::Protocol {
kind: ProtocolErrorKind::IncompatibleVersion {
requested: client_protocol_version.to_string(),
current: server_protocol_version.to_string(),
},
}),
Ordering::Equal => Ok(()),
}
}
#[allow(unused)]
pub fn enforce_compatible_protocol_version(
client_protocol_version: &str,
server_protocol_version: &str,
) -> SdkResult<Option<String>> {
match client_protocol_version.cmp(server_protocol_version) {
Ordering::Greater => Err(McpSdkError::Protocol {
kind: ProtocolErrorKind::IncompatibleVersion {
requested: client_protocol_version.to_string(),
current: server_protocol_version.to_string(),
},
}),
Ordering::Equal => Ok(None),
Ordering::Less => {
Ok(Some(client_protocol_version.to_string()))
}
}
}
pub fn validate_mcp_protocol_version(mcp_protocol_version: &str) -> SdkResult<()> {
let _mcp_protocol_version =
ProtocolVersion::try_from(mcp_protocol_version).map_err(|err| McpSdkError::Protocol {
kind: ProtocolErrorKind::ParseError(err),
})?;
Ok(())
}
#[allow(unused)]
pub(crate) fn remove_query_and_hash(endpoint: &str) -> String {
let without_fragment = endpoint.split_once('#').map_or(endpoint, |(path, _)| path);
let without_query = without_fragment
.split_once('?')
.map_or(without_fragment, |(path, _)| path);
if without_query.is_empty() {
"/".to_string()
} else {
without_query.to_string()
}
}
pub fn valid_initialize_method(json_str: &str) -> SdkResult<()> {
let Ok(request) = serde_json::from_str::<ClientMessages>(json_str) else {
return Err(SdkError::bad_request()
.with_message("Bad Request: Session not found")
.into());
};
match request {
ClientMessages::Single(client_message) => {
if !client_message.is_initialize_request() {
return Err(SdkError::bad_request()
.with_message("Bad Request: Session not found")
.into());
}
}
ClientMessages::Batch(client_messages) => {
let count = client_messages
.iter()
.filter(|item| item.is_initialize_request())
.count();
if count > 1 {
return Err(SdkError::invalid_request()
.with_message("Bad Request: Only one initialization request is allowed")
.into());
}
}
};
Ok(())
}
pub fn current_utc_time(ms_offset: Option<i64>) -> OffsetDateTime {
let mut dt = OffsetDateTime::now_utc();
if let Some(ms) = ms_offset {
let duration = time::Duration::milliseconds(ms);
dt = match dt.checked_add(duration) {
Some(new_dt) => new_dt,
None => {
if ms > 0 {
dt.checked_add(time::Duration::milliseconds(180_000))
.unwrap_or(dt)
} else {
dt.checked_sub(time::Duration::milliseconds(180_000))
.unwrap_or(dt)
}
}
};
}
dt
}
pub fn iso8601_time(time_value: OffsetDateTime) -> String {
time_value.format(&Iso8601::DEFAULT).unwrap_or_default()
}
#[cfg(feature = "auth")]
pub fn join_url(base: &Url, segment: &str) -> Result<Url, url::ParseError> {
if base.cannot_be_a_base() {
return Err(url::ParseError::RelativeUrlWithoutBase);
}
let mut url = base.clone();
url.path_segments_mut()
.map_err(|_| url::ParseError::RelativeUrlWithoutBase)?
.pop_if_empty() .extend(
segment
.trim_start_matches('/')
.split('/')
.filter(|s| !s.is_empty()),
);
Ok(url)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tets_remove_query_and_hash() {
assert_eq!(remove_query_and_hash("/messages"), "/messages");
assert_eq!(
remove_query_and_hash("/messages?foo=bar&baz=qux"),
"/messages"
);
assert_eq!(remove_query_and_hash("/messages#section1"), "/messages");
assert_eq!(
remove_query_and_hash("/messages?key=value#section2"),
"/messages"
);
assert_eq!(remove_query_and_hash("/"), "/");
}
#[test]
fn test_join_url() {
let expect = "http://example.com/api/user/userinfo";
let result = join_url(
&Url::parse("http://example.com/api").unwrap(),
"/user/userinfo",
)
.unwrap();
assert_eq!(result.to_string(), expect);
let result = join_url(
&Url::parse("http://example.com/api").unwrap(),
"user/userinfo",
)
.unwrap();
assert_eq!(result.to_string(), expect);
let result = join_url(
&Url::parse("http://example.com/api/").unwrap(),
"/user/userinfo",
)
.unwrap();
assert_eq!(result.to_string(), expect);
let result = join_url(
&Url::parse("http://example.com/api/").unwrap(),
"user/userinfo",
)
.unwrap();
assert_eq!(result.to_string(), expect);
}
}