1use std::path::PathBuf;
18
19use url::Url;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum ParseErrorKind {
27 Empty,
29 InvalidUri,
33 UnsupportedScheme,
35 LimitExceeded,
39}
40
41impl ParseErrorKind {
42 pub fn as_str(self) -> &'static str {
43 match self {
44 ParseErrorKind::Empty => "EMPTY",
45 ParseErrorKind::InvalidUri => "INVALID_URI",
46 ParseErrorKind::UnsupportedScheme => "UNSUPPORTED_SCHEME",
47 ParseErrorKind::LimitExceeded => "LIMIT_EXCEEDED",
48 }
49 }
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct ParseError {
55 pub kind: ParseErrorKind,
56 pub message: String,
57}
58
59impl ParseError {
60 pub fn new(kind: ParseErrorKind, message: impl Into<String>) -> Self {
61 Self {
62 kind,
63 message: message.into(),
64 }
65 }
66}
67
68impl std::fmt::Display for ParseError {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 write!(f, "{}: {}", self.kind.as_str(), self.message)
71 }
72}
73
74impl std::error::Error for ParseError {}
75
76pub const DEFAULT_PORT_RED: u16 = 5050;
79pub const DEFAULT_PORT_GRPC: u16 = 5055;
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub struct ConnStringLimits {
89 pub max_uri_bytes: usize,
91 pub max_query_params: usize,
93 pub max_cluster_hosts: usize,
96}
97
98impl Default for ConnStringLimits {
99 fn default() -> Self {
100 Self {
101 max_uri_bytes: 8 * 1024,
102 max_query_params: 32,
103 max_cluster_hosts: 64,
104 }
105 }
106}
107
108#[derive(Debug, Clone, PartialEq, Eq)]
114pub enum ConnectionTarget {
115 Memory,
117 File { path: PathBuf },
119 Grpc { endpoint: String },
123 GrpcCluster {
127 primary: String,
128 replicas: Vec<String>,
129 force_primary: bool,
130 },
131 Http { base_url: String },
133 RedWire { host: String, port: u16, tls: bool },
138}
139
140pub fn parse(uri: &str) -> Result<ConnectionTarget, ParseError> {
151 parse_with_limits(uri, ConnStringLimits::default())
152}
153
154pub fn parse_with_limits(
159 uri: &str,
160 limits: ConnStringLimits,
161) -> Result<ConnectionTarget, ParseError> {
162 if uri.is_empty() {
163 return Err(ParseError::new(
164 ParseErrorKind::Empty,
165 "empty connection string",
166 ));
167 }
168
169 if uri.len() > limits.max_uri_bytes {
170 return Err(ParseError::new(
171 ParseErrorKind::LimitExceeded,
172 format!(
173 "max_uri_bytes exceeded: limit={} actual={}",
174 limits.max_uri_bytes,
175 uri.len(),
176 ),
177 ));
178 }
179
180 let normalised = normalise_scheme(uri);
185 let uri = normalised.as_str();
186
187 if uri == "memory://" || uri == "memory:" {
188 return Ok(ConnectionTarget::Memory);
189 }
190
191 if let Some(rest) = uri.strip_prefix("file://") {
192 if rest.is_empty() {
193 return Err(ParseError::new(
194 ParseErrorKind::InvalidUri,
195 "file:// URI is missing a path",
196 ));
197 }
198 return Ok(ConnectionTarget::File {
199 path: PathBuf::from(rest),
200 });
201 }
202
203 if let Some(cluster) = try_parse_grpc_cluster(uri, &limits)? {
204 return Ok(cluster);
205 }
206
207 let parsed = Url::parse(uri)
208 .map_err(|e| ParseError::new(ParseErrorKind::InvalidUri, format!("{e}: {uri}")))?;
209
210 enforce_query_param_limit(&parsed, &limits)?;
211
212 match parsed.scheme() {
213 "red" | "reds" => {
214 let host = parsed.host_str().ok_or_else(|| {
215 ParseError::new(ParseErrorKind::InvalidUri, "red:// URI is missing a host")
216 })?;
217 let port = parsed.port().unwrap_or(DEFAULT_PORT_RED);
218 Ok(ConnectionTarget::RedWire {
219 host: host.to_string(),
220 port,
221 tls: parsed.scheme() == "reds",
222 })
223 }
224 "grpc" | "grpcs" => {
225 let host = parsed.host_str().ok_or_else(|| {
226 ParseError::new(ParseErrorKind::InvalidUri, "grpc:// URI is missing a host")
227 })?;
228 let port = parsed.port().unwrap_or(DEFAULT_PORT_GRPC);
229 Ok(ConnectionTarget::Grpc {
230 endpoint: format!("http://{host}:{port}"),
231 })
232 }
233 "http" | "https" => {
234 let host = parsed.host_str().ok_or_else(|| {
235 ParseError::new(
236 ParseErrorKind::InvalidUri,
237 "http(s):// URI is missing a host",
238 )
239 })?;
240 let scheme = parsed.scheme();
241 let port = parsed
242 .port()
243 .unwrap_or(if scheme == "https" { 443 } else { 80 });
244 Ok(ConnectionTarget::Http {
245 base_url: format!("{scheme}://{host}:{port}"),
246 })
247 }
248 other => Err(ParseError::new(
249 ParseErrorKind::UnsupportedScheme,
250 format!("unsupported scheme: {other}"),
251 )),
252 }
253}
254
255fn normalise_scheme(uri: &str) -> String {
261 match uri.find(':') {
262 Some(i) => {
263 let scheme = &uri[..i];
264 if scheme.is_empty()
269 || !scheme
270 .bytes()
271 .all(|b| b.is_ascii_alphanumeric() || b == b'+' || b == b'.' || b == b'-')
272 {
273 return uri.to_string();
274 }
275 let mut out = String::with_capacity(uri.len());
276 out.push_str(&scheme.to_ascii_lowercase());
277 out.push_str(&uri[i..]);
278 out
279 }
280 None => uri.to_string(),
281 }
282}
283
284fn enforce_query_param_limit(url: &Url, limits: &ConnStringLimits) -> Result<(), ParseError> {
285 let Some(q) = url.query() else {
286 return Ok(());
287 };
288 if q.is_empty() {
289 return Ok(());
290 }
291 let count = q.split('&').count();
292 if count > limits.max_query_params {
293 return Err(ParseError::new(
294 ParseErrorKind::LimitExceeded,
295 format!(
296 "max_query_params exceeded: limit={} actual={}",
297 limits.max_query_params, count,
298 ),
299 ));
300 }
301 Ok(())
302}
303
304fn try_parse_grpc_cluster(
307 uri: &str,
308 limits: &ConnStringLimits,
309) -> Result<Option<ConnectionTarget>, ParseError> {
310 let (rest, default_port) = if let Some(r) = uri.strip_prefix("grpc://") {
311 (r, DEFAULT_PORT_GRPC)
312 } else if let Some(r) = uri.strip_prefix("grpcs://") {
313 (r, DEFAULT_PORT_GRPC)
314 } else if let Some(r) = uri
315 .strip_prefix("red://")
316 .or_else(|| uri.strip_prefix("reds://"))
317 {
318 (r, DEFAULT_PORT_RED)
319 } else {
320 return Ok(None);
321 };
322
323 let (host_part, query_part) = match rest.find('?') {
324 Some(i) => (&rest[..i], Some(&rest[i + 1..])),
325 None => (rest, None),
326 };
327
328 if !host_part.contains(',') {
329 return Ok(None);
330 }
331
332 let raw_count = host_part.split(',').count();
333 if raw_count > limits.max_cluster_hosts {
334 return Err(ParseError::new(
335 ParseErrorKind::LimitExceeded,
336 format!(
337 "max_cluster_hosts exceeded: limit={} actual={}",
338 limits.max_cluster_hosts, raw_count,
339 ),
340 ));
341 }
342
343 let mut endpoints: Vec<String> = Vec::with_capacity(raw_count);
344 for raw in host_part.split(',') {
345 let raw = raw.trim();
346 if raw.is_empty() {
347 return Err(ParseError::new(
348 ParseErrorKind::InvalidUri,
349 "grpc cluster URI has an empty host entry",
350 ));
351 }
352 let (host, port) = if let Some(after_bracket) = raw.strip_prefix('[') {
354 let end = after_bracket.find(']').ok_or_else(|| {
355 ParseError::new(
356 ParseErrorKind::InvalidUri,
357 format!("unterminated IPv6 bracket in cluster URI: {raw}"),
358 )
359 })?;
360 let host = &after_bracket[..end];
361 let tail = &after_bracket[end + 1..];
362 let port = if tail.is_empty() {
363 default_port
364 } else if let Some(p) = tail.strip_prefix(':') {
365 p.parse::<u16>().map_err(|_| {
366 ParseError::new(
367 ParseErrorKind::InvalidUri,
368 format!("invalid port in cluster URI: {raw}"),
369 )
370 })?
371 } else {
372 return Err(ParseError::new(
373 ParseErrorKind::InvalidUri,
374 format!("trailing junk after IPv6 bracket in cluster URI: {raw}"),
375 ));
376 };
377 (format!("[{host}]"), port)
378 } else {
379 match raw.rsplit_once(':') {
380 Some((h, p)) => {
381 let port: u16 = p.parse().map_err(|_| {
382 ParseError::new(
383 ParseErrorKind::InvalidUri,
384 format!("invalid port in cluster URI: {raw}"),
385 )
386 })?;
387 (h.to_string(), port)
388 }
389 None => (raw.to_string(), default_port),
390 }
391 };
392 if host.is_empty() || host == "[]" {
393 return Err(ParseError::new(
394 ParseErrorKind::InvalidUri,
395 "grpc cluster URI has an empty host entry",
396 ));
397 }
398 endpoints.push(format!("http://{host}:{port}"));
399 }
400
401 if let Some(q) = query_part {
402 let qcount = if q.is_empty() {
403 0
404 } else {
405 q.split('&').count()
406 };
407 if qcount > limits.max_query_params {
408 return Err(ParseError::new(
409 ParseErrorKind::LimitExceeded,
410 format!(
411 "max_query_params exceeded: limit={} actual={}",
412 limits.max_query_params, qcount,
413 ),
414 ));
415 }
416 }
417
418 let force_primary = query_part
419 .map(|q| {
420 q.split('&').any(|kv| {
421 let mut parts = kv.splitn(2, '=');
422 let k = parts.next().unwrap_or("");
423 let v = parts.next().unwrap_or("");
424 k.eq_ignore_ascii_case("route") && v.eq_ignore_ascii_case("primary")
425 })
426 })
427 .unwrap_or(false);
428
429 let mut iter = endpoints.into_iter();
430 let primary = iter.next().expect("split on ',' yields at least one entry");
431 let replicas: Vec<String> = iter.collect();
432
433 Ok(Some(ConnectionTarget::GrpcCluster {
434 primary,
435 replicas,
436 force_primary,
437 }))
438}