use std::path::PathBuf;
use url::Url;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ParseErrorKind {
Empty,
InvalidUri,
UnsupportedScheme,
LimitExceeded,
}
impl ParseErrorKind {
pub fn as_str(self) -> &'static str {
match self {
ParseErrorKind::Empty => "EMPTY",
ParseErrorKind::InvalidUri => "INVALID_URI",
ParseErrorKind::UnsupportedScheme => "UNSUPPORTED_SCHEME",
ParseErrorKind::LimitExceeded => "LIMIT_EXCEEDED",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParseError {
pub kind: ParseErrorKind,
pub message: String,
}
impl ParseError {
pub fn new(kind: ParseErrorKind, message: impl Into<String>) -> Self {
Self {
kind,
message: message.into(),
}
}
}
impl std::fmt::Display for ParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: {}", self.kind.as_str(), self.message)
}
}
impl std::error::Error for ParseError {}
pub const DEFAULT_PORT_RED: u16 = 5050;
pub const DEFAULT_PORT_GRPC: u16 = 5055;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ConnStringLimits {
pub max_uri_bytes: usize,
pub max_query_params: usize,
pub max_cluster_hosts: usize,
}
impl Default for ConnStringLimits {
fn default() -> Self {
Self {
max_uri_bytes: 8 * 1024,
max_query_params: 32,
max_cluster_hosts: 64,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ConnectionTarget {
Memory,
File { path: PathBuf },
Grpc { endpoint: String },
GrpcCluster {
primary: String,
replicas: Vec<String>,
force_primary: bool,
},
Http { base_url: String },
RedWire { host: String, port: u16, tls: bool },
}
pub fn parse(uri: &str) -> Result<ConnectionTarget, ParseError> {
parse_with_limits(uri, ConnStringLimits::default())
}
pub fn parse_with_limits(
uri: &str,
limits: ConnStringLimits,
) -> Result<ConnectionTarget, ParseError> {
if uri.is_empty() {
return Err(ParseError::new(
ParseErrorKind::Empty,
"empty connection string",
));
}
if uri.len() > limits.max_uri_bytes {
return Err(ParseError::new(
ParseErrorKind::LimitExceeded,
format!(
"max_uri_bytes exceeded: limit={} actual={}",
limits.max_uri_bytes,
uri.len(),
),
));
}
let normalised = normalise_scheme(uri);
let uri = normalised.as_str();
if uri == "memory://" || uri == "memory:" {
return Ok(ConnectionTarget::Memory);
}
if let Some(rest) = uri.strip_prefix("file://") {
if rest.is_empty() {
return Err(ParseError::new(
ParseErrorKind::InvalidUri,
"file:// URI is missing a path",
));
}
return Ok(ConnectionTarget::File {
path: PathBuf::from(rest),
});
}
if let Some(cluster) = try_parse_grpc_cluster(uri, &limits)? {
return Ok(cluster);
}
let parsed = Url::parse(uri)
.map_err(|e| ParseError::new(ParseErrorKind::InvalidUri, format!("{e}: {uri}")))?;
enforce_query_param_limit(&parsed, &limits)?;
match parsed.scheme() {
"red" | "reds" => {
let host = parsed.host_str().ok_or_else(|| {
ParseError::new(ParseErrorKind::InvalidUri, "red:// URI is missing a host")
})?;
let port = parsed.port().unwrap_or(DEFAULT_PORT_RED);
Ok(ConnectionTarget::RedWire {
host: host.to_string(),
port,
tls: parsed.scheme() == "reds",
})
}
"grpc" | "grpcs" => {
let host = parsed.host_str().ok_or_else(|| {
ParseError::new(ParseErrorKind::InvalidUri, "grpc:// URI is missing a host")
})?;
let port = parsed.port().unwrap_or(DEFAULT_PORT_GRPC);
Ok(ConnectionTarget::Grpc {
endpoint: format!("http://{host}:{port}"),
})
}
"http" | "https" => {
let host = parsed.host_str().ok_or_else(|| {
ParseError::new(
ParseErrorKind::InvalidUri,
"http(s):// URI is missing a host",
)
})?;
let scheme = parsed.scheme();
let port = parsed
.port()
.unwrap_or(if scheme == "https" { 443 } else { 80 });
Ok(ConnectionTarget::Http {
base_url: format!("{scheme}://{host}:{port}"),
})
}
other => Err(ParseError::new(
ParseErrorKind::UnsupportedScheme,
format!("unsupported scheme: {other}"),
)),
}
}
fn normalise_scheme(uri: &str) -> String {
match uri.find(':') {
Some(i) => {
let scheme = &uri[..i];
if scheme.is_empty()
|| !scheme
.bytes()
.all(|b| b.is_ascii_alphanumeric() || b == b'+' || b == b'.' || b == b'-')
{
return uri.to_string();
}
let mut out = String::with_capacity(uri.len());
out.push_str(&scheme.to_ascii_lowercase());
out.push_str(&uri[i..]);
out
}
None => uri.to_string(),
}
}
fn enforce_query_param_limit(url: &Url, limits: &ConnStringLimits) -> Result<(), ParseError> {
let Some(q) = url.query() else {
return Ok(());
};
if q.is_empty() {
return Ok(());
}
let count = q.split('&').count();
if count > limits.max_query_params {
return Err(ParseError::new(
ParseErrorKind::LimitExceeded,
format!(
"max_query_params exceeded: limit={} actual={}",
limits.max_query_params, count,
),
));
}
Ok(())
}
fn try_parse_grpc_cluster(
uri: &str,
limits: &ConnStringLimits,
) -> Result<Option<ConnectionTarget>, ParseError> {
let (rest, default_port) = if let Some(r) = uri.strip_prefix("grpc://") {
(r, DEFAULT_PORT_GRPC)
} else if let Some(r) = uri.strip_prefix("grpcs://") {
(r, DEFAULT_PORT_GRPC)
} else if let Some(r) = uri
.strip_prefix("red://")
.or_else(|| uri.strip_prefix("reds://"))
{
(r, DEFAULT_PORT_RED)
} else {
return Ok(None);
};
let (host_part, query_part) = match rest.find('?') {
Some(i) => (&rest[..i], Some(&rest[i + 1..])),
None => (rest, None),
};
if !host_part.contains(',') {
return Ok(None);
}
let raw_count = host_part.split(',').count();
if raw_count > limits.max_cluster_hosts {
return Err(ParseError::new(
ParseErrorKind::LimitExceeded,
format!(
"max_cluster_hosts exceeded: limit={} actual={}",
limits.max_cluster_hosts, raw_count,
),
));
}
let mut endpoints: Vec<String> = Vec::with_capacity(raw_count);
for raw in host_part.split(',') {
let raw = raw.trim();
if raw.is_empty() {
return Err(ParseError::new(
ParseErrorKind::InvalidUri,
"grpc cluster URI has an empty host entry",
));
}
let (host, port) = if let Some(after_bracket) = raw.strip_prefix('[') {
let end = after_bracket.find(']').ok_or_else(|| {
ParseError::new(
ParseErrorKind::InvalidUri,
format!("unterminated IPv6 bracket in cluster URI: {raw}"),
)
})?;
let host = &after_bracket[..end];
let tail = &after_bracket[end + 1..];
let port = if tail.is_empty() {
default_port
} else if let Some(p) = tail.strip_prefix(':') {
p.parse::<u16>().map_err(|_| {
ParseError::new(
ParseErrorKind::InvalidUri,
format!("invalid port in cluster URI: {raw}"),
)
})?
} else {
return Err(ParseError::new(
ParseErrorKind::InvalidUri,
format!("trailing junk after IPv6 bracket in cluster URI: {raw}"),
));
};
(format!("[{host}]"), port)
} else {
match raw.rsplit_once(':') {
Some((h, p)) => {
let port: u16 = p.parse().map_err(|_| {
ParseError::new(
ParseErrorKind::InvalidUri,
format!("invalid port in cluster URI: {raw}"),
)
})?;
(h.to_string(), port)
}
None => (raw.to_string(), default_port),
}
};
if host.is_empty() || host == "[]" {
return Err(ParseError::new(
ParseErrorKind::InvalidUri,
"grpc cluster URI has an empty host entry",
));
}
endpoints.push(format!("http://{host}:{port}"));
}
if let Some(q) = query_part {
let qcount = if q.is_empty() {
0
} else {
q.split('&').count()
};
if qcount > limits.max_query_params {
return Err(ParseError::new(
ParseErrorKind::LimitExceeded,
format!(
"max_query_params exceeded: limit={} actual={}",
limits.max_query_params, qcount,
),
));
}
}
let force_primary = query_part
.map(|q| {
q.split('&').any(|kv| {
let mut parts = kv.splitn(2, '=');
let k = parts.next().unwrap_or("");
let v = parts.next().unwrap_or("");
k.eq_ignore_ascii_case("route") && v.eq_ignore_ascii_case("primary")
})
})
.unwrap_or(false);
let mut iter = endpoints.into_iter();
let primary = iter.next().expect("split on ',' yields at least one entry");
let replicas: Vec<String> = iter.collect();
Ok(Some(ConnectionTarget::GrpcCluster {
primary,
replicas,
force_primary,
}))
}