use std::fmt;
use std::str::FromStr;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
const ALLOWED_SCHEMES: &[&str] = &["http", "https"];
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Url(url::Url);
impl Url {
pub fn parse(input: &str) -> Result<Self, UrlError> {
let parsed = url::Url::parse(input).map_err(|e| UrlError::Invalid(e.to_string()))?;
let scheme = parsed.scheme();
if !ALLOWED_SCHEMES.contains(&scheme) {
return Err(UrlError::DisallowedScheme(scheme.to_string()));
}
Ok(Self(parsed))
}
pub fn as_str(&self) -> &str {
self.0.as_str()
}
pub fn host_str(&self) -> Option<&str> {
self.0.host_str()
}
pub fn inner(&self) -> &url::Url {
&self.0
}
}
impl fmt::Display for Url {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self.0, f)
}
}
impl FromStr for Url {
type Err = UrlError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::parse(s)
}
}
impl AsRef<str> for Url {
fn as_ref(&self) -> &str {
self.0.as_str()
}
}
#[derive(Debug, thiserror::Error)]
pub enum UrlError {
#[error("invalid URL: {0}")]
Invalid(String),
#[error("disallowed URL scheme: {0:?} (only http/https are accepted by the basemind crawler)")]
DisallowedScheme(String),
}
impl Serialize for Url {
fn serialize<S: Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
ser.serialize_str(self.0.as_str())
}
}
impl<'de> Deserialize<'de> for Url {
fn deserialize<D: Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
let s = String::deserialize(de)?;
Url::parse(&s).map_err(serde::de::Error::custom)
}
}
impl rmcp::schemars::JsonSchema for Url {
fn schema_name() -> std::borrow::Cow<'static, str> {
"Url".into()
}
fn json_schema(_: &mut rmcp::schemars::SchemaGenerator) -> rmcp::schemars::Schema {
rmcp::schemars::json_schema!({
"type": "string",
"format": "uri",
"description": "An absolute http or https URL. Other schemes are rejected at parse time."
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_http_and_https() {
assert!(Url::parse("http://example.com/").is_ok());
assert!(Url::parse("https://example.com/path?q=1").is_ok());
}
#[test]
fn rejects_file_scheme() {
let err = Url::parse("file:///etc/passwd").expect_err("file:// must be rejected");
match err {
UrlError::DisallowedScheme(s) => assert_eq!(s, "file"),
other => panic!("expected DisallowedScheme, got {other:?}"),
}
}
#[test]
fn rejects_javascript_scheme() {
assert!(matches!(
Url::parse("javascript:alert(1)"),
Err(UrlError::DisallowedScheme(_))
));
}
#[test]
fn rejects_data_scheme() {
assert!(matches!(
Url::parse("data:text/plain,hello"),
Err(UrlError::DisallowedScheme(_))
));
}
#[test]
fn rejects_relative_path() {
assert!(matches!(
Url::parse("/just/a/path"),
Err(UrlError::Invalid(_))
));
}
#[test]
fn serde_roundtrips_via_json_string() {
let u = Url::parse("https://example.com/x").unwrap();
let json = serde_json::to_string(&u).unwrap();
assert_eq!(json, "\"https://example.com/x\"");
let back: Url = serde_json::from_str(&json).unwrap();
assert_eq!(u, back);
}
#[test]
fn deserialize_rejects_disallowed_scheme() {
let res: Result<Url, _> = serde_json::from_str("\"file:///etc/passwd\"");
assert!(res.is_err());
}
#[test]
fn host_str_reports_authority() {
let u = Url::parse("https://docs.rs/rmcp/").unwrap();
assert_eq!(u.host_str(), Some("docs.rs"));
}
}