1use arc_swap::ArcSwap;
2use rsip::{headers::ToTypedHeader, prelude::HeadersExt};
3use rsipstack::{
4 rsip_ext::RsipResponseExt, transaction::endpoint::MessageInspector, transport::SipAddr,
5};
6use std::{net::IpAddr, sync::Arc};
7
8pub type SharedPublicAddress = Arc<ArcSwap<rsip::HostWithPort>>;
9
10pub fn normalize_transport(transport: Option<&rsip::Transport>) -> rsip::Transport {
11 transport.cloned().unwrap_or(rsip::Transport::Udp)
12}
13
14pub fn transport_for_uri(uri: &rsip::Uri) -> rsip::Transport {
15 if matches!(uri.scheme, Some(rsip::Scheme::Sips)) {
16 return rsip::Transport::Tls;
17 }
18
19 uri.params
20 .iter()
21 .find_map(|param| match param {
22 rsip::Param::Transport(transport) => Some(transport.clone()),
23 _ => None,
24 })
25 .unwrap_or(rsip::Transport::Udp)
26}
27
28pub fn find_local_addr_for_uri(addrs: &[SipAddr], uri: &rsip::Uri) -> Option<SipAddr> {
29 let transport = transport_for_uri(uri);
30 addrs.iter()
31 .find(|addr| normalize_transport(addr.r#type.as_ref()) == transport)
32 .cloned()
33}
34
35pub fn contact_needs_public_resolution(contact: &rsip::Uri) -> bool {
36 if contact.scheme.is_none() {
37 return true;
38 }
39
40 match &contact.host_with_port.host {
41 rsip::Host::Domain(domain) => {
42 let host = domain.to_string();
43 host.eq_ignore_ascii_case("localhost")
44 }
45 rsip::Host::IpAddr(ip) => is_local_or_unspecified(ip),
46 }
47}
48
49pub fn build_contact_uri(
50 local_addr: &SipAddr,
51 learned_addr: Option<rsip::HostWithPort>,
52 username: Option<&str>,
53 template: Option<&rsip::Uri>,
54) -> rsip::Uri {
55 let mut uri = template
56 .cloned()
57 .unwrap_or_else(|| rsip::Uri::from(local_addr));
58
59 uri.host_with_port = learned_addr.unwrap_or_else(|| local_addr.addr.clone());
60 if uri.scheme.is_none() {
61 uri.scheme = Some(match local_addr.r#type {
62 Some(rsip::Transport::Tls)
63 | Some(rsip::Transport::Wss)
64 | Some(rsip::Transport::TlsSctp) => rsip::Scheme::Sips,
65 _ => rsip::Scheme::Sip,
66 });
67 }
68
69 if uri.auth.is_none() {
70 if let Some(username) = username.filter(|value| !value.is_empty()) {
71 uri.auth = Some(rsip::Auth {
72 user: username.to_string(),
73 password: None,
74 });
75 }
76 }
77
78 uri
79}
80
81pub fn build_contact(
82 local_addr: &SipAddr,
83 contact_address: Option<rsip::HostWithPort>,
84 username: Option<&str>,
85 template: Option<&rsip::Uri>,
86) -> rsip::typed::Contact {
87 let contact_uri = build_contact_uri(local_addr, contact_address, username, template);
88 rsip::typed::Contact {
89 display_name: None,
90 uri: contact_uri,
91 params: vec![],
92 }
93}
94
95pub fn build_public_contact_uri(
96 learned_public_address: &SharedPublicAddress,
97 auto_learn_public_address: bool,
98 local_addr: &SipAddr,
99 username: Option<&str>,
100 template: Option<&rsip::Uri>,
101) -> rsip::Uri {
102 let selected_addr = if auto_learn_public_address
103 && normalize_transport(local_addr.r#type.as_ref()) == rsip::Transport::Udp
104 {
105 Some(learned_public_address.load_full().as_ref().clone())
106 } else {
107 Some(local_addr.addr.clone())
108 };
109 build_contact_uri(local_addr, selected_addr, username, template)
110}
111
112pub struct LearningMessageInspector {
113 learned_public_address: SharedPublicAddress,
114 next: Option<Box<dyn MessageInspector>>,
115}
116
117impl LearningMessageInspector {
118 pub fn new(
119 initial_address: rsip::HostWithPort,
120 next: Option<Box<dyn MessageInspector>>,
121 ) -> Self {
122 Self {
123 learned_public_address: Arc::new(ArcSwap::from_pointee(initial_address)),
124 next,
125 }
126 }
127
128 pub fn shared_public_address(&self) -> SharedPublicAddress {
129 self.learned_public_address.clone()
130 }
131}
132
133impl MessageInspector for LearningMessageInspector {
134 fn before_send(&self, msg: rsip::SipMessage, dest: Option<&SipAddr>) -> rsip::SipMessage {
135 if let Some(next) = &self.next {
136 next.before_send(msg, dest)
137 } else {
138 msg
139 }
140 }
141
142 fn after_received(&self, msg: rsip::SipMessage, from: &SipAddr) -> rsip::SipMessage {
143 if let rsip::SipMessage::Response(response) = &msg
144 && let Ok(via) = response.via_header()
145 && let Ok(via) = via.typed()
146 && via.transport == rsip::Transport::Udp
147 && let Some(host_with_port) = response.via_received()
148 {
149 self.learned_public_address
150 .rcu(|previous: &Arc<rsip::HostWithPort>| {
151 if should_update_address(previous.as_ref(), &host_with_port) {
152 Arc::new(host_with_port.clone())
153 } else {
154 previous.clone()
155 }
156 });
157 }
158
159 if let Some(next) = &self.next {
160 next.after_received(msg, from)
161 } else {
162 msg
163 }
164 }
165}
166
167pub fn should_update_address(
168 previous: &rsip::HostWithPort,
169 current: &rsip::HostWithPort,
170) -> bool {
171 if previous == current {
172 return false;
173 }
174
175 let previous_is_public = is_public_address(previous);
176 let current_is_public = is_public_address(current);
177 (!previous_is_public && current_is_public)
178 || (previous_is_public && current_is_public && previous != current)
179}
180
181fn is_public_address(host_with_port: &rsip::HostWithPort) -> bool {
182 match &host_with_port.host {
183 rsip::Host::Domain(domain) => !domain.to_string().eq_ignore_ascii_case("localhost"),
184 rsip::Host::IpAddr(ip) => !is_local_or_unspecified(ip),
185 }
186}
187
188fn is_local_or_unspecified(ip: &IpAddr) -> bool {
189 ip.is_loopback() || ip.is_unspecified()
190}
191
192#[cfg(test)]
193mod tests {
194 use super::{
195 SharedPublicAddress, build_contact, build_contact_uri, build_public_contact_uri,
196 contact_needs_public_resolution, find_local_addr_for_uri, should_update_address,
197 transport_for_uri,
198 };
199 use arc_swap::ArcSwap;
200 use rsip::transport::Transport;
201 use rsipstack::transaction::endpoint::MessageInspector;
202 use rsipstack::transport::SipAddr;
203 use std::sync::Arc;
204
205 #[test]
206 fn learns_public_address_from_response_via() {
207 let response: rsip::Response = concat!(
208 "SIP/2.0 401 Unauthorized\r\n",
209 "Via: SIP/2.0/UDP 10.0.0.1:5060;branch=z9hG4bK-1;received=203.0.113.10;rport=62000\r\n",
210 "Content-Length: 0\r\n",
211 "\r\n"
212 )
213 .try_into()
214 .unwrap();
215
216 let inspector = super::LearningMessageInspector::new(
217 "127.0.0.1:5060"
218 .parse::<std::net::SocketAddr>()
219 .unwrap()
220 .into(),
221 None,
222 );
223 let cache = inspector.shared_public_address();
224 inspector.after_received(
225 rsip::SipMessage::Response(response),
226 &SipAddr {
227 r#type: Some(Transport::Udp),
228 addr: "10.0.0.1:5060"
229 .parse::<std::net::SocketAddr>()
230 .unwrap()
231 .into(),
232 },
233 );
234 assert_eq!(cache.load_full().as_ref().to_string(), "203.0.113.10:62000");
235 }
236
237 #[test]
238 fn builds_contact_using_learned_public_address() {
239 let local_addr = SipAddr {
240 r#type: Some(Transport::Udp),
241 addr: "10.0.0.5:5060"
242 .parse::<std::net::SocketAddr>()
243 .unwrap()
244 .into(),
245 };
246 let template: rsip::Uri = "sip:alice@127.0.0.1:5060".try_into().unwrap();
247 let learned_addr = Some(
248 "203.0.113.10:62000"
249 .parse::<std::net::SocketAddr>()
250 .unwrap()
251 .into(),
252 );
253
254 let contact = build_contact_uri(&local_addr, learned_addr, Some("alice"), Some(&template));
255 assert_eq!(contact.to_string(), "sip:alice@203.0.113.10:62000");
256 }
257
258 #[test]
259 fn identifies_contacts_that_need_resolution() {
260 let local_contact: rsip::Uri = "sip:alice@127.0.0.1:5060".try_into().unwrap();
261 let remote_contact: rsip::Uri = "sip:alice@203.0.113.10:62000".try_into().unwrap();
262 assert!(contact_needs_public_resolution(&local_contact));
263 assert!(!contact_needs_public_resolution(&remote_contact));
264 }
265
266 #[test]
267 fn selects_local_addr_for_uri_transport() {
268 let addrs = vec![
269 SipAddr {
270 r#type: Some(Transport::Udp),
271 addr: "10.0.0.5:5060"
272 .parse::<std::net::SocketAddr>()
273 .unwrap()
274 .into(),
275 },
276 SipAddr {
277 r#type: Some(Transport::Tls),
278 addr: "10.0.0.5:5061"
279 .parse::<std::net::SocketAddr>()
280 .unwrap()
281 .into(),
282 },
283 ];
284
285 let uri: rsip::Uri = "sips:alice@example.com".try_into().unwrap();
286 let selected = find_local_addr_for_uri(&addrs, &uri).unwrap();
287
288 assert_eq!(selected.to_string(), "TLS 10.0.0.5:5061");
289 }
290
291 #[test]
292 fn builds_public_contact_from_shared_cache() {
293 let cache: SharedPublicAddress = Arc::new(ArcSwap::from_pointee(
294 "203.0.113.20:62000"
295 .parse::<std::net::SocketAddr>()
296 .unwrap()
297 .into(),
298 ));
299 let local_addr = SipAddr {
300 r#type: Some(Transport::Udp),
301 addr: "10.0.0.5:5060"
302 .parse::<std::net::SocketAddr>()
303 .unwrap()
304 .into(),
305 };
306
307 let contact = build_public_contact_uri(&cache, true, &local_addr, Some("alice"), None);
308 assert_eq!(contact.to_string(), "sip:alice@203.0.113.20:62000");
309 }
310
311 #[test]
312 fn builds_typed_contact() {
313 let local_addr = SipAddr {
314 r#type: Some(Transport::Udp),
315 addr: "10.0.0.5:5060"
316 .parse::<std::net::SocketAddr>()
317 .unwrap()
318 .into(),
319 };
320 let contact = build_contact(
321 &local_addr,
322 Some(
323 "203.0.113.20:62000"
324 .parse::<std::net::SocketAddr>()
325 .unwrap()
326 .into(),
327 ),
328 Some("alice"),
329 None,
330 );
331 assert_eq!(contact.to_string(), "<sip:alice@203.0.113.20:62000>");
332 }
333
334 #[test]
335 fn keeps_configured_contact_for_tls() {
336 let cache: SharedPublicAddress = Arc::new(ArcSwap::from_pointee(
337 "203.0.113.20:62000"
338 .parse::<std::net::SocketAddr>()
339 .unwrap()
340 .into(),
341 ));
342 let local_addr = SipAddr {
343 r#type: Some(Transport::Tls),
344 addr: "10.0.0.5:5061"
345 .parse::<std::net::SocketAddr>()
346 .unwrap()
347 .into(),
348 };
349
350 let contact = build_public_contact_uri(&cache, true, &local_addr, Some("alice"), None);
351 assert_eq!(contact.to_string(), "sips:alice@10.0.0.5:5061;transport=TLS");
352 }
353
354 #[test]
355 fn infers_transport_from_uri() {
356 let sips_uri: rsip::Uri = "sips:alice@example.com".try_into().unwrap();
357 let tcp_uri: rsip::Uri = "sip:alice@example.com;transport=tcp".try_into().unwrap();
358 assert_eq!(transport_for_uri(&sips_uri), Transport::Tls);
359 assert_eq!(transport_for_uri(&tcp_uri), Transport::Tcp);
360 }
361
362 #[test]
363 fn updates_learned_address_from_local_to_public() {
364 let previous: rsip::HostWithPort = "127.0.0.1:5060"
365 .parse::<std::net::SocketAddr>()
366 .unwrap()
367 .into();
368 let current: rsip::HostWithPort = "203.0.113.10:62000"
369 .parse::<std::net::SocketAddr>()
370 .unwrap()
371 .into();
372
373 assert!(should_update_address(&previous, ¤t,));
374 }
375
376 #[test]
377 fn does_not_update_learned_address_when_unchanged() {
378 let current: rsip::HostWithPort = "203.0.113.10:62000"
379 .parse::<std::net::SocketAddr>()
380 .unwrap()
381 .into();
382
383 assert!(!should_update_address(¤t, ¤t,));
384 }
385}