Skip to main content

sip_header/
via.rs

1//! SIP Via header parser (RFC 3261 ยง20.42).
2
3use std::fmt;
4
5/// Error parsing a SIP Via header.
6#[derive(Debug, Clone, PartialEq, Eq)]
7#[non_exhaustive]
8pub enum SipViaError {
9    /// The Via header value is empty.
10    Empty,
11    /// The Via header value has an invalid format.
12    InvalidFormat(String),
13}
14
15impl fmt::Display for SipViaError {
16    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17        match self {
18            Self::Empty => write!(f, "Via header is empty"),
19            Self::InvalidFormat(msg) => write!(f, "Invalid Via format: {}", msg),
20        }
21    }
22}
23
24impl std::error::Error for SipViaError {}
25
26/// A single Via entry.
27#[derive(Debug, Clone, PartialEq, Eq)]
28#[non_exhaustive]
29pub struct SipViaEntry {
30    protocol_name: String,
31    protocol_version: String,
32    transport: String,
33    host: String,
34    port: Option<u16>,
35    params: Vec<(String, Option<String>)>,
36    rport: Option<Option<u16>>,
37}
38
39impl SipViaEntry {
40    /// Returns the protocol name (e.g., "SIP").
41    pub fn protocol(&self) -> &str {
42        &self.protocol_name
43    }
44
45    /// Returns the protocol version (e.g., "2.0").
46    pub fn version(&self) -> &str {
47        &self.protocol_version
48    }
49
50    /// Returns the transport protocol (e.g., "UDP", "TCP", "TLS").
51    pub fn transport(&self) -> &str {
52        &self.transport
53    }
54
55    /// Returns the host.
56    pub fn host(&self) -> &str {
57        &self.host
58    }
59
60    /// Returns the port, if present.
61    pub fn port(&self) -> Option<u16> {
62        self.port
63    }
64
65    /// Returns all parameters.
66    pub fn params(&self) -> &[(String, Option<String>)] {
67        &self.params
68    }
69
70    /// Returns a specific parameter value by key (case-insensitive).
71    pub fn param(&self, key: &str) -> Option<Option<&str>> {
72        let key_lower = key.to_ascii_lowercase();
73        self.params
74            .iter()
75            .find(|(k, _)| k == &key_lower)
76            .map(|(_, v)| v.as_deref())
77    }
78
79    /// Returns the `branch` parameter value, if present.
80    pub fn branch(&self) -> Option<&str> {
81        self.param("branch")
82            .flatten()
83    }
84
85    /// Returns the `received` parameter value, if present.
86    pub fn received(&self) -> Option<&str> {
87        self.param("received")
88            .flatten()
89    }
90
91    /// Returns the `rport` parameter.
92    ///
93    /// - `None` if the parameter is absent
94    /// - `Some(None)` if present without a value
95    /// - `Some(Some(port))` if present with a value
96    ///
97    /// Invalid rport values are rejected at parse time, so this accessor
98    /// is infallible.
99    pub fn rport(&self) -> Option<Option<u16>> {
100        self.rport
101    }
102
103    fn parse(entry: &str) -> Result<Self, SipViaError> {
104        let trimmed = entry.trim();
105        if trimmed.is_empty() {
106            return Err(SipViaError::InvalidFormat("empty Via entry".to_string()));
107        }
108
109        // Split on first semicolon to separate sent-protocol/sent-by from params
110        let (main_part, params_part) = if let Some(semi_idx) = trimmed.find(';') {
111            (&trimmed[..semi_idx], Some(&trimmed[semi_idx + 1..]))
112        } else {
113            (trimmed, None)
114        };
115
116        // Parse sent-protocol and sent-by
117        let parts: Vec<&str> = main_part
118            .split_whitespace()
119            .collect();
120        if parts.len() != 2 {
121            return Err(SipViaError::InvalidFormat(format!(
122                "expected 'protocol/version/transport host[:port]', got '{}'",
123                main_part
124            )));
125        }
126
127        let sent_protocol = parts[0];
128        let sent_by = parts[1];
129
130        // Parse sent-protocol: protocol-name/version/transport
131        let protocol_parts: Vec<&str> = sent_protocol
132            .split('/')
133            .collect();
134        if protocol_parts.len() != 3 {
135            return Err(SipViaError::InvalidFormat(format!(
136                "expected 'protocol/version/transport', got '{}'",
137                sent_protocol
138            )));
139        }
140
141        let protocol_name = protocol_parts[0].to_string();
142        let protocol_version = protocol_parts[1].to_string();
143        let transport = protocol_parts[2].to_string();
144
145        // Parse sent-by: host[:port]
146        // Handle IPv6 bracket notation [::1]:port
147        let (host, port) = parse_host_port(sent_by)?;
148
149        // Parse params
150        let mut params = Vec::new();
151        if let Some(params_str) = params_part {
152            for param in params_str.split(';') {
153                let param = param.trim();
154                if param.is_empty() {
155                    continue;
156                }
157
158                if let Some(eq_idx) = param.find('=') {
159                    let key = param[..eq_idx]
160                        .trim()
161                        .to_ascii_lowercase();
162                    let value = param[eq_idx + 1..]
163                        .trim()
164                        .to_string();
165                    params.push((key, Some(value)));
166                } else {
167                    // Parameter without value (e.g., rport)
168                    params.push((param.to_ascii_lowercase(), None));
169                }
170            }
171        }
172
173        let rport = params
174            .iter()
175            .find(|(k, _)| k == "rport")
176            .map(|(_, v)| match v {
177                None => Ok(None),
178                Some(s) => s
179                    .parse::<u16>()
180                    .map(Some)
181                    .map_err(|_| SipViaError::InvalidFormat(format!("invalid rport value: {s}"))),
182            })
183            .transpose()?;
184
185        Ok(Self {
186            protocol_name,
187            protocol_version,
188            transport,
189            host,
190            port,
191            params,
192            rport,
193        })
194    }
195}
196
197impl fmt::Display for SipViaEntry {
198    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199        write!(
200            f,
201            "{}/{}/{}",
202            self.protocol_name, self.protocol_version, self.transport
203        )?;
204
205        // Handle IPv6 addresses with brackets
206        if self
207            .host
208            .contains(':')
209            && !self
210                .host
211                .starts_with('[')
212        {
213            write!(f, " [{}]", self.host)?;
214        } else {
215            write!(f, " {}", self.host)?;
216        }
217
218        if let Some(port) = self.port {
219            write!(f, ":{}", port)?;
220        }
221
222        for (key, value) in &self.params {
223            if let Some(val) = value {
224                write!(f, ";{}={}", key, val)?;
225            } else {
226                write!(f, ";{}", key)?;
227            }
228        }
229
230        Ok(())
231    }
232}
233
234/// Parsed SIP Via header.
235#[derive(Debug, Clone, PartialEq, Eq)]
236#[non_exhaustive]
237pub struct SipVia {
238    entries: Vec<SipViaEntry>,
239}
240
241impl SipVia {
242    /// Parses a Via header value.
243    pub fn parse(raw: &str) -> Result<Self, SipViaError> {
244        let raw = raw.trim();
245        if raw.is_empty() {
246            return Err(SipViaError::Empty);
247        }
248
249        let parts = crate::split_comma_entries(raw);
250        let mut entries = Vec::new();
251
252        for part in parts {
253            entries.push(SipViaEntry::parse(part)?);
254        }
255
256        if entries.is_empty() {
257            return Err(SipViaError::Empty);
258        }
259
260        Ok(Self { entries })
261    }
262
263    /// Returns the Via entries.
264    pub fn entries(&self) -> &[SipViaEntry] {
265        &self.entries
266    }
267
268    /// Consume self and return entries as a `Vec`.
269    pub fn into_entries(self) -> Vec<SipViaEntry> {
270        self.entries
271    }
272
273    /// Returns the number of Via entries.
274    pub fn len(&self) -> usize {
275        self.entries
276            .len()
277    }
278
279    /// Returns `true` if there are no Via entries.
280    pub fn is_empty(&self) -> bool {
281        self.entries
282            .is_empty()
283    }
284}
285
286impl fmt::Display for SipVia {
287    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288        crate::fmt_joined(f, &self.entries, ", ")
289    }
290}
291
292impl_from_str_via_parse!(SipVia, SipViaError);
293
294impl IntoIterator for SipVia {
295    type Item = SipViaEntry;
296    type IntoIter = std::vec::IntoIter<SipViaEntry>;
297
298    fn into_iter(self) -> Self::IntoIter {
299        self.entries
300            .into_iter()
301    }
302}
303
304impl<'a> IntoIterator for &'a SipVia {
305    type Item = &'a SipViaEntry;
306    type IntoIter = std::slice::Iter<'a, SipViaEntry>;
307
308    fn into_iter(self) -> Self::IntoIter {
309        self.entries
310            .iter()
311    }
312}
313
314fn parse_host_port(sent_by: &str) -> Result<(String, Option<u16>), SipViaError> {
315    // Handle IPv6 bracket notation [::1]:port
316    if sent_by.starts_with('[') {
317        // Find the closing bracket
318        if let Some(close_bracket) = sent_by.find(']') {
319            let host = sent_by[1..close_bracket].to_string();
320            let remainder = &sent_by[close_bracket + 1..];
321
322            if remainder.is_empty() {
323                return Ok((host, None));
324            }
325
326            if let Some(port_str) = remainder.strip_prefix(':') {
327                let port = port_str
328                    .parse::<u16>()
329                    .map_err(|_| {
330                        SipViaError::InvalidFormat(format!("invalid port: {}", port_str))
331                    })?;
332                return Ok((host, Some(port)));
333            }
334
335            return Err(SipViaError::InvalidFormat(format!(
336                "unexpected characters after IPv6 address: {}",
337                remainder
338            )));
339        } else {
340            return Err(SipViaError::InvalidFormat(
341                "unclosed IPv6 bracket".to_string(),
342            ));
343        }
344    }
345
346    // IPv4 or hostname with optional port
347    // Find the last colon (to handle IPv6 without brackets, though that's not valid in Via)
348    if let Some(colon_idx) = sent_by.rfind(':') {
349        let host = sent_by[..colon_idx].to_string();
350        let port_str = &sent_by[colon_idx + 1..];
351
352        // Check if this looks like an IPv6 address without brackets (invalid but handle gracefully)
353        if host.contains(':') {
354            // This is likely a bare IPv6 address, return as-is without port
355            return Ok((sent_by.to_string(), None));
356        }
357
358        let port = port_str
359            .parse::<u16>()
360            .map_err(|_| SipViaError::InvalidFormat(format!("invalid port: {}", port_str)))?;
361        Ok((host, Some(port)))
362    } else {
363        Ok((sent_by.to_string(), None))
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    #[test]
372    fn test_single_via() {
373        let via = SipVia::parse("SIP/2.0/UDP 198.51.100.1:5060").unwrap();
374        assert_eq!(via.len(), 1);
375
376        let entry = &via.entries()[0];
377        assert_eq!(entry.protocol(), "SIP");
378        assert_eq!(entry.version(), "2.0");
379        assert_eq!(entry.transport(), "UDP");
380        assert_eq!(entry.host(), "198.51.100.1");
381        assert_eq!(entry.port(), Some(5060));
382        assert!(entry
383            .params()
384            .is_empty());
385    }
386
387    #[test]
388    fn test_multiple_vias() {
389        let via = SipVia::parse("SIP/2.0/UDP 198.51.100.1:5060, SIP/2.0/TCP 203.0.113.5").unwrap();
390        assert_eq!(via.len(), 2);
391
392        let entry1 = &via.entries()[0];
393        assert_eq!(entry1.host(), "198.51.100.1");
394        assert_eq!(entry1.port(), Some(5060));
395        assert_eq!(entry1.transport(), "UDP");
396
397        let entry2 = &via.entries()[1];
398        assert_eq!(entry2.host(), "203.0.113.5");
399        assert_eq!(entry2.port(), None);
400        assert_eq!(entry2.transport(), "TCP");
401    }
402
403    #[test]
404    fn test_via_with_params() {
405        let via = SipVia::parse(
406            "SIP/2.0/UDP 198.51.100.1:5060;branch=z9hG4bKnashds8;received=203.0.113.10;rport=5061",
407        )
408        .unwrap();
409
410        let entry = &via.entries()[0];
411        assert_eq!(entry.branch(), Some("z9hG4bKnashds8"));
412        assert_eq!(entry.received(), Some("203.0.113.10"));
413        assert_eq!(entry.rport(), Some(Some(5061)));
414    }
415
416    #[test]
417    fn test_via_with_rport_no_value() {
418        let via = SipVia::parse("SIP/2.0/UDP 198.51.100.1:5060;rport").unwrap();
419
420        let entry = &via.entries()[0];
421        assert_eq!(entry.rport(), Some(None));
422    }
423
424    #[test]
425    fn test_via_without_rport() {
426        let via = SipVia::parse("SIP/2.0/UDP 198.51.100.1:5060").unwrap();
427
428        let entry = &via.entries()[0];
429        assert_eq!(entry.rport(), None);
430    }
431
432    #[test]
433    fn test_via_ipv6() {
434        let via = SipVia::parse("SIP/2.0/UDP [2001:db8::1]:5060").unwrap();
435
436        let entry = &via.entries()[0];
437        assert_eq!(entry.host(), "2001:db8::1");
438        assert_eq!(entry.port(), Some(5060));
439    }
440
441    #[test]
442    fn test_via_ipv6_no_port() {
443        let via = SipVia::parse("SIP/2.0/UDP [2001:db8::1]").unwrap();
444
445        let entry = &via.entries()[0];
446        assert_eq!(entry.host(), "2001:db8::1");
447        assert_eq!(entry.port(), None);
448    }
449
450    #[test]
451    fn test_via_hostname() {
452        let via = SipVia::parse("SIP/2.0/TLS example.com:5061").unwrap();
453
454        let entry = &via.entries()[0];
455        assert_eq!(entry.host(), "example.com");
456        assert_eq!(entry.port(), Some(5061));
457        assert_eq!(entry.transport(), "TLS");
458    }
459
460    #[test]
461    fn test_empty_via() {
462        let result = SipVia::parse("");
463        assert!(matches!(result, Err(SipViaError::Empty)));
464    }
465
466    #[test]
467    fn test_empty_via_whitespace() {
468        let result = SipVia::parse("   ");
469        assert!(matches!(result, Err(SipViaError::Empty)));
470    }
471
472    #[test]
473    fn test_invalid_format() {
474        let result = SipVia::parse("invalid");
475        assert!(matches!(result, Err(SipViaError::InvalidFormat(_))));
476    }
477
478    #[test]
479    fn test_rport_invalid_value_is_error() {
480        let result = SipVia::parse("SIP/2.0/UDP 198.51.100.1:5060;rport=garbage");
481        assert!(result.is_err());
482    }
483
484    #[test]
485    fn test_display_roundtrip() {
486        let original =
487            "SIP/2.0/UDP 198.51.100.1:5060;branch=z9hG4bKnashds8;received=203.0.113.10;rport";
488        let via = SipVia::parse(original).unwrap();
489        let displayed = via.to_string();
490
491        let reparsed = SipVia::parse(&displayed).unwrap();
492        assert_eq!(via, reparsed);
493    }
494
495    #[test]
496    fn test_display_multiple_vias() {
497        let via = SipVia::parse("SIP/2.0/UDP 198.51.100.1:5060, SIP/2.0/TCP 203.0.113.5").unwrap();
498        let displayed = via.to_string();
499        assert!(displayed.contains("198.51.100.1"));
500        assert!(displayed.contains("203.0.113.5"));
501    }
502
503    #[test]
504    fn test_into_iterator() {
505        let via = SipVia::parse("SIP/2.0/UDP 198.51.100.1:5060, SIP/2.0/TCP 203.0.113.5").unwrap();
506
507        let mut count = 0;
508        for entry in &via {
509            assert!(entry.host() == "198.51.100.1" || entry.host() == "203.0.113.5");
510            count += 1;
511        }
512        assert_eq!(count, 2);
513
514        let entries: Vec<_> = via
515            .into_iter()
516            .collect();
517        assert_eq!(entries.len(), 2);
518    }
519
520    #[test]
521    fn test_into_entries() {
522        let via = SipVia::parse("SIP/2.0/UDP 198.51.100.1:5060, SIP/2.0/TCP 203.0.113.5").unwrap();
523        let entries = via.into_entries();
524        assert_eq!(entries.len(), 2);
525        assert_eq!(entries[0].host(), "198.51.100.1");
526        assert_eq!(entries[1].host(), "203.0.113.5");
527    }
528
529    #[test]
530    fn test_from_str() {
531        let via: SipVia = "SIP/2.0/UDP 198.51.100.1:5060"
532            .parse()
533            .unwrap();
534        assert_eq!(via.len(), 1);
535    }
536
537    #[test]
538    fn test_param_case_insensitive() {
539        let via = SipVia::parse("SIP/2.0/UDP 198.51.100.1:5060;Branch=test").unwrap();
540        let entry = &via.entries()[0];
541        assert_eq!(entry.param("branch"), Some(Some("test")));
542        assert_eq!(entry.param("BRANCH"), Some(Some("test")));
543    }
544
545    #[test]
546    fn test_display_ipv6() {
547        let via = SipVia::parse("SIP/2.0/UDP [2001:db8::1]:5060").unwrap();
548        let displayed = via.to_string();
549        assert!(displayed.contains("[2001:db8::1]"));
550    }
551}