1use rsip::{headers::ToTypedHeader, prelude::HeadersExt};
2use rsipstack::{
3 rsip_ext::RsipResponseExt, transaction::endpoint::MessageInspector, transport::SipAddr,
4};
5use std::{
6 collections::HashMap,
7 net::IpAddr,
8 sync::{Arc, RwLock},
9};
10
11#[derive(Clone, Debug, Eq, PartialEq)]
12pub struct LearnedPublicAddress {
13 pub transport: rsip::Transport,
14 pub host_with_port: rsip::HostWithPort,
15}
16
17#[derive(Clone, Default)]
18pub struct LearnedPublicAddresses {
19 inner: Arc<RwLock<HashMap<rsip::Transport, rsip::HostWithPort>>>,
20}
21
22impl LearnedPublicAddresses {
23 pub fn store(
24 &self,
25 transport: Option<&rsip::Transport>,
26 host_with_port: rsip::HostWithPort,
27 ) -> bool {
28 let transport = normalize_transport(transport);
29 let mut guard = self.inner.write().unwrap();
30 if guard.get(&transport) == Some(&host_with_port) {
31 return false;
32 }
33 guard.insert(transport, host_with_port);
34 true
35 }
36
37 pub fn get(&self, transport: Option<&rsip::Transport>) -> Option<rsip::HostWithPort> {
38 let transport = normalize_transport(transport);
39 self.inner.read().unwrap().get(&transport).cloned()
40 }
41
42 pub fn learn_from_response(&self, response: &rsip::Response) -> Option<LearnedPublicAddress> {
43 let host_with_port = response.via_received()?;
44 let transport = response
45 .via_header()
46 .ok()
47 .and_then(|via| via.typed().ok())
48 .map(|via| via.transport)
49 .unwrap_or(rsip::Transport::Udp);
50 self.store(Some(&transport), host_with_port.clone());
51 Some(LearnedPublicAddress {
52 transport,
53 host_with_port,
54 })
55 }
56}
57
58pub fn normalize_transport(transport: Option<&rsip::Transport>) -> rsip::Transport {
59 transport.cloned().unwrap_or(rsip::Transport::Udp)
60}
61
62pub fn transport_for_uri(uri: &rsip::Uri) -> rsip::Transport {
63 if matches!(uri.scheme, Some(rsip::Scheme::Sips)) {
64 return rsip::Transport::Tls;
65 }
66
67 uri.params
68 .iter()
69 .find_map(|param| match param {
70 rsip::Param::Transport(transport) => Some(transport.clone()),
71 _ => None,
72 })
73 .unwrap_or(rsip::Transport::Udp)
74}
75
76pub fn contact_needs_public_resolution(contact: &rsip::Uri) -> bool {
77 if contact.scheme.is_none() {
78 return true;
79 }
80
81 match &contact.host_with_port.host {
82 rsip::Host::Domain(domain) => {
83 let host = domain.to_string();
84 host.eq_ignore_ascii_case("localhost")
85 }
86 rsip::Host::IpAddr(ip) => is_local_or_unspecified(ip),
87 }
88}
89
90pub fn build_contact_uri(
91 local_addr: &SipAddr,
92 learned_addr: Option<rsip::HostWithPort>,
93 username: Option<&str>,
94 template: Option<&rsip::Uri>,
95) -> rsip::Uri {
96 let mut uri = template
97 .cloned()
98 .unwrap_or_else(|| rsip::Uri::from(local_addr));
99
100 uri.host_with_port = learned_addr.unwrap_or_else(|| local_addr.addr.clone());
101 if uri.scheme.is_none() {
102 uri.scheme = Some(match local_addr.r#type {
103 Some(rsip::Transport::Tls)
104 | Some(rsip::Transport::Wss)
105 | Some(rsip::Transport::TlsSctp) => rsip::Scheme::Sips,
106 _ => rsip::Scheme::Sip,
107 });
108 }
109
110 if uri.auth.is_none() {
111 if let Some(username) = username.filter(|value| !value.is_empty()) {
112 uri.auth = Some(rsip::Auth {
113 user: username.to_string(),
114 password: None,
115 });
116 }
117 }
118
119 uri
120}
121
122pub fn build_public_contact_uri(
123 learned_public_addresses: &LearnedPublicAddresses,
124 auto_learn_public_address: bool,
125 local_addr: &SipAddr,
126 username: Option<&str>,
127 template: Option<&rsip::Uri>,
128) -> rsip::Uri {
129 let learned_addr = if auto_learn_public_address {
130 learned_public_addresses.get(local_addr.r#type.as_ref())
131 } else {
132 None
133 };
134 build_contact_uri(local_addr, learned_addr, username, template)
135}
136
137pub struct LearningMessageInspector {
138 learned_public_addresses: LearnedPublicAddresses,
139 next: Option<Box<dyn MessageInspector>>,
140}
141
142impl LearningMessageInspector {
143 pub fn new(
144 learned_public_addresses: LearnedPublicAddresses,
145 next: Option<Box<dyn MessageInspector>>,
146 ) -> Self {
147 Self {
148 learned_public_addresses,
149 next,
150 }
151 }
152}
153
154impl MessageInspector for LearningMessageInspector {
155 fn before_send(&self, msg: rsip::SipMessage, dest: Option<&SipAddr>) -> rsip::SipMessage {
156 if let Some(next) = &self.next {
157 next.before_send(msg, dest)
158 } else {
159 msg
160 }
161 }
162
163 fn after_received(&self, msg: rsip::SipMessage, from: &SipAddr) -> rsip::SipMessage {
164 let msg = if let Some(next) = &self.next {
165 next.after_received(msg, from)
166 } else {
167 msg
168 };
169
170 if let rsip::SipMessage::Response(response) = &msg {
171 self.learned_public_addresses.learn_from_response(response);
172 }
173
174 msg
175 }
176}
177
178fn is_local_or_unspecified(ip: &IpAddr) -> bool {
179 ip.is_loopback() || ip.is_unspecified()
180}
181
182#[cfg(test)]
183mod tests {
184 use super::{
185 LearnedPublicAddresses, build_contact_uri, build_public_contact_uri,
186 contact_needs_public_resolution, transport_for_uri,
187 };
188 use rsip::transport::Transport;
189 use rsipstack::transport::SipAddr;
190
191 #[test]
192 fn learns_public_address_from_response_via() {
193 let cache = LearnedPublicAddresses::default();
194 let response: rsip::Response = concat!(
195 "SIP/2.0 401 Unauthorized\r\n",
196 "Via: SIP/2.0/UDP 10.0.0.1:5060;branch=z9hG4bK-1;received=203.0.113.10;rport=62000\r\n",
197 "Content-Length: 0\r\n",
198 "\r\n"
199 )
200 .try_into()
201 .unwrap();
202
203 let learned = cache.learn_from_response(&response).unwrap();
204 assert_eq!(learned.transport, Transport::Udp);
205 assert_eq!(learned.host_with_port.to_string(), "203.0.113.10:62000");
206 assert_eq!(
207 cache.get(Some(&Transport::Udp)).unwrap().to_string(),
208 "203.0.113.10:62000"
209 );
210 }
211
212 #[test]
213 fn builds_contact_using_learned_public_address() {
214 let local_addr = SipAddr {
215 r#type: Some(Transport::Udp),
216 addr: "10.0.0.5:5060"
217 .parse::<std::net::SocketAddr>()
218 .unwrap()
219 .into(),
220 };
221 let template: rsip::Uri = "sip:alice@127.0.0.1:5060".try_into().unwrap();
222 let learned_addr = Some(
223 "203.0.113.10:62000"
224 .parse::<std::net::SocketAddr>()
225 .unwrap()
226 .into(),
227 );
228
229 let contact = build_contact_uri(&local_addr, learned_addr, Some("alice"), Some(&template));
230 assert_eq!(contact.to_string(), "sip:alice@203.0.113.10:62000");
231 }
232
233 #[test]
234 fn identifies_contacts_that_need_resolution() {
235 let local_contact: rsip::Uri = "sip:alice@127.0.0.1:5060".try_into().unwrap();
236 let remote_contact: rsip::Uri = "sip:alice@203.0.113.10:62000".try_into().unwrap();
237 assert!(contact_needs_public_resolution(&local_contact));
238 assert!(!contact_needs_public_resolution(&remote_contact));
239 }
240
241 #[test]
242 fn builds_public_contact_from_shared_cache() {
243 let cache = LearnedPublicAddresses::default();
244 cache.store(
245 Some(&Transport::Udp),
246 "203.0.113.20:62000"
247 .parse::<std::net::SocketAddr>()
248 .unwrap()
249 .into(),
250 );
251 let local_addr = SipAddr {
252 r#type: Some(Transport::Udp),
253 addr: "10.0.0.5:5060"
254 .parse::<std::net::SocketAddr>()
255 .unwrap()
256 .into(),
257 };
258
259 let contact = build_public_contact_uri(&cache, true, &local_addr, Some("alice"), None);
260 assert_eq!(contact.to_string(), "sip:alice@203.0.113.20:62000");
261 }
262
263 #[test]
264 fn infers_transport_from_uri() {
265 let sips_uri: rsip::Uri = "sips:alice@example.com".try_into().unwrap();
266 let tcp_uri: rsip::Uri = "sip:alice@example.com;transport=tcp".try_into().unwrap();
267 assert_eq!(transport_for_uri(&sips_uri), Transport::Tls);
268 assert_eq!(transport_for_uri(&tcp_uri), Transport::Tcp);
269 }
270}