1use std::path::PathBuf;
18
19use url::Url;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum ParseErrorKind {
27 Empty,
29 InvalidUri,
33 UnsupportedScheme,
35 LimitExceeded,
39}
40
41impl ParseErrorKind {
42 pub fn as_str(self) -> &'static str {
43 match self {
44 ParseErrorKind::Empty => "EMPTY",
45 ParseErrorKind::InvalidUri => "INVALID_URI",
46 ParseErrorKind::UnsupportedScheme => "UNSUPPORTED_SCHEME",
47 ParseErrorKind::LimitExceeded => "LIMIT_EXCEEDED",
48 }
49 }
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct ParseError {
55 pub kind: ParseErrorKind,
56 pub message: String,
57}
58
59impl ParseError {
60 pub fn new(kind: ParseErrorKind, message: impl Into<String>) -> Self {
61 Self {
62 kind,
63 message: message.into(),
64 }
65 }
66}
67
68impl std::fmt::Display for ParseError {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 write!(f, "{}: {}", self.kind.as_str(), self.message)
71 }
72}
73
74impl std::error::Error for ParseError {}
75
76pub const DEFAULT_PORT_RED: u16 = 5050;
79pub const DEFAULT_PORT_GRPC: u16 = 5055;
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub struct ConnStringLimits {
89 pub max_uri_bytes: usize,
91 pub max_query_params: usize,
93 pub max_cluster_hosts: usize,
96}
97
98impl Default for ConnStringLimits {
99 fn default() -> Self {
100 Self {
101 max_uri_bytes: 8 * 1024,
102 max_query_params: 32,
103 max_cluster_hosts: 64,
104 }
105 }
106}
107
108#[derive(Debug, Clone, PartialEq, Eq)]
114pub enum ConnectionTarget {
115 Memory,
117 File { path: PathBuf },
119 Grpc { endpoint: String },
123 GrpcCluster {
127 primary: String,
128 replicas: Vec<String>,
129 force_primary: bool,
130 },
131 Http { base_url: String },
133 RedWire { host: String, port: u16, tls: bool },
138}
139
140pub fn parse(uri: &str) -> Result<ConnectionTarget, ParseError> {
151 parse_with_limits(uri, ConnStringLimits::default())
152}
153
154pub fn is_embedded_connection_uri(uri: &str) -> bool {
161 let trimmed = uri.trim();
162 matches!(
163 trimmed,
164 "red://" | "red:" | "red:///" | "red://:memory" | "red://:memory:"
165 ) || trimmed.starts_with("red:///")
166}
167
168pub fn parse_with_limits(
173 uri: &str,
174 limits: ConnStringLimits,
175) -> Result<ConnectionTarget, ParseError> {
176 if uri.is_empty() {
177 return Err(ParseError::new(
178 ParseErrorKind::Empty,
179 "empty connection string",
180 ));
181 }
182
183 if uri.len() > limits.max_uri_bytes {
184 return Err(ParseError::new(
185 ParseErrorKind::LimitExceeded,
186 format!(
187 "max_uri_bytes exceeded: limit={} actual={}",
188 limits.max_uri_bytes,
189 uri.len(),
190 ),
191 ));
192 }
193
194 let normalised = normalise_scheme(uri);
199 let uri = normalised.as_str();
200
201 if uri == "memory://" || uri == "memory:" {
202 return Ok(ConnectionTarget::Memory);
203 }
204
205 if let Some(rest) = uri.strip_prefix("file://") {
206 if rest.is_empty() {
207 return Err(ParseError::new(
208 ParseErrorKind::InvalidUri,
209 "file:// URI is missing a path",
210 ));
211 }
212 return Ok(ConnectionTarget::File {
213 path: PathBuf::from(rest),
214 });
215 }
216
217 if let Some(cluster) = try_parse_grpc_cluster(uri, &limits)? {
218 return Ok(cluster);
219 }
220
221 let parsed = Url::parse(uri)
222 .map_err(|e| ParseError::new(ParseErrorKind::InvalidUri, format!("{e}: {uri}")))?;
223
224 enforce_query_param_limit(&parsed, &limits)?;
225
226 match parsed.scheme() {
227 "red" | "reds" => {
228 let host = parsed.host_str().ok_or_else(|| {
229 ParseError::new(ParseErrorKind::InvalidUri, "red:// URI is missing a host")
230 })?;
231 let port = parsed.port().unwrap_or(DEFAULT_PORT_RED);
232 Ok(ConnectionTarget::RedWire {
233 host: host.to_string(),
234 port,
235 tls: parsed.scheme() == "reds",
236 })
237 }
238 "grpc" | "grpcs" => {
239 let host = parsed.host_str().ok_or_else(|| {
240 ParseError::new(ParseErrorKind::InvalidUri, "grpc:// URI is missing a host")
241 })?;
242 let port = parsed.port().unwrap_or(DEFAULT_PORT_GRPC);
243 Ok(ConnectionTarget::Grpc {
244 endpoint: format!("http://{host}:{port}"),
245 })
246 }
247 "http" | "https" => {
248 let host = parsed.host_str().ok_or_else(|| {
249 ParseError::new(
250 ParseErrorKind::InvalidUri,
251 "http(s):// URI is missing a host",
252 )
253 })?;
254 let scheme = parsed.scheme();
255 let port = parsed
256 .port()
257 .unwrap_or(if scheme == "https" { 443 } else { 80 });
258 Ok(ConnectionTarget::Http {
259 base_url: format!("{scheme}://{host}:{port}"),
260 })
261 }
262 other => Err(ParseError::new(
263 ParseErrorKind::UnsupportedScheme,
264 format!("unsupported scheme: {other}"),
265 )),
266 }
267}
268
269fn normalise_scheme(uri: &str) -> String {
275 match uri.find(':') {
276 Some(i) => {
277 let scheme = &uri[..i];
278 if scheme.is_empty()
283 || !scheme
284 .bytes()
285 .all(|b| b.is_ascii_alphanumeric() || b == b'+' || b == b'.' || b == b'-')
286 {
287 return uri.to_string();
288 }
289 let mut out = String::with_capacity(uri.len());
290 out.push_str(&scheme.to_ascii_lowercase());
291 out.push_str(&uri[i..]);
292 out
293 }
294 None => uri.to_string(),
295 }
296}
297
298fn enforce_query_param_limit(url: &Url, limits: &ConnStringLimits) -> Result<(), ParseError> {
299 let Some(q) = url.query() else {
300 return Ok(());
301 };
302 if q.is_empty() {
303 return Ok(());
304 }
305 let count = q.split('&').count();
306 if count > limits.max_query_params {
307 return Err(ParseError::new(
308 ParseErrorKind::LimitExceeded,
309 format!(
310 "max_query_params exceeded: limit={} actual={}",
311 limits.max_query_params, count,
312 ),
313 ));
314 }
315 Ok(())
316}
317
318fn try_parse_grpc_cluster(
321 uri: &str,
322 limits: &ConnStringLimits,
323) -> Result<Option<ConnectionTarget>, ParseError> {
324 let (rest, default_port) = if let Some(r) = uri.strip_prefix("grpc://") {
325 (r, DEFAULT_PORT_GRPC)
326 } else if let Some(r) = uri.strip_prefix("grpcs://") {
327 (r, DEFAULT_PORT_GRPC)
328 } else if let Some(r) = uri
329 .strip_prefix("red://")
330 .or_else(|| uri.strip_prefix("reds://"))
331 {
332 (r, DEFAULT_PORT_RED)
333 } else {
334 return Ok(None);
335 };
336
337 let (host_part, query_part) = match rest.find('?') {
338 Some(i) => (&rest[..i], Some(&rest[i + 1..])),
339 None => (rest, None),
340 };
341
342 if !host_part.contains(',') {
343 return Ok(None);
344 }
345
346 let raw_count = host_part.split(',').count();
347 if raw_count > limits.max_cluster_hosts {
348 return Err(ParseError::new(
349 ParseErrorKind::LimitExceeded,
350 format!(
351 "max_cluster_hosts exceeded: limit={} actual={}",
352 limits.max_cluster_hosts, raw_count,
353 ),
354 ));
355 }
356
357 let mut endpoints: Vec<String> = Vec::with_capacity(raw_count);
358 for raw in host_part.split(',') {
359 let raw = raw.trim();
360 if raw.is_empty() {
361 return Err(ParseError::new(
362 ParseErrorKind::InvalidUri,
363 "grpc cluster URI has an empty host entry",
364 ));
365 }
366 let (host, port) = if let Some(after_bracket) = raw.strip_prefix('[') {
368 let end = after_bracket.find(']').ok_or_else(|| {
369 ParseError::new(
370 ParseErrorKind::InvalidUri,
371 format!("unterminated IPv6 bracket in cluster URI: {raw}"),
372 )
373 })?;
374 let host = &after_bracket[..end];
375 let tail = &after_bracket[end + 1..];
376 let port = if tail.is_empty() {
377 default_port
378 } else if let Some(p) = tail.strip_prefix(':') {
379 p.parse::<u16>().map_err(|_| {
380 ParseError::new(
381 ParseErrorKind::InvalidUri,
382 format!("invalid port in cluster URI: {raw}"),
383 )
384 })?
385 } else {
386 return Err(ParseError::new(
387 ParseErrorKind::InvalidUri,
388 format!("trailing junk after IPv6 bracket in cluster URI: {raw}"),
389 ));
390 };
391 (format!("[{host}]"), port)
392 } else {
393 match raw.rsplit_once(':') {
394 Some((h, p)) => {
395 let port: u16 = p.parse().map_err(|_| {
396 ParseError::new(
397 ParseErrorKind::InvalidUri,
398 format!("invalid port in cluster URI: {raw}"),
399 )
400 })?;
401 (h.to_string(), port)
402 }
403 None => (raw.to_string(), default_port),
404 }
405 };
406 if host.is_empty() || host == "[]" {
407 return Err(ParseError::new(
408 ParseErrorKind::InvalidUri,
409 "grpc cluster URI has an empty host entry",
410 ));
411 }
412 endpoints.push(format!("http://{host}:{port}"));
413 }
414
415 if let Some(q) = query_part {
416 let qcount = if q.is_empty() {
417 0
418 } else {
419 q.split('&').count()
420 };
421 if qcount > limits.max_query_params {
422 return Err(ParseError::new(
423 ParseErrorKind::LimitExceeded,
424 format!(
425 "max_query_params exceeded: limit={} actual={}",
426 limits.max_query_params, qcount,
427 ),
428 ));
429 }
430 }
431
432 let force_primary = query_part
433 .map(|q| {
434 q.split('&').any(|kv| {
435 let mut parts = kv.splitn(2, '=');
436 let k = parts.next().unwrap_or("");
437 let v = parts.next().unwrap_or("");
438 k.eq_ignore_ascii_case("route") && v.eq_ignore_ascii_case("primary")
439 })
440 })
441 .unwrap_or(false);
442
443 let mut iter = endpoints.into_iter();
444 let primary = iter.next().expect("split on ',' yields at least one entry");
445 let replicas: Vec<String> = iter.collect();
446
447 Ok(Some(ConnectionTarget::GrpcCluster {
448 primary,
449 replicas,
450 force_primary,
451 }))
452}