1use std::{str::FromStr, sync::Arc, time::Duration};
2
3use crate::{
4 session::{process_unexpected_response, EndpointInfo},
5 transport::core::TransportPollResult,
6};
7use arc_swap::{ArcSwap, ArcSwapOption};
8use opcua_core::{
9 comms::secure_channel::{Role, SecureChannel},
10 sync::RwLock,
11 trace_read_lock, trace_write_lock, RequestMessage, ResponseMessage,
12};
13use opcua_crypto::{CertificateStore, PrivateKey, SecurityPolicy, X509};
14use opcua_types::{
15 ByteString, CloseSecureChannelRequest, ContextOwned, IntegerId, NodeId, RequestHeader,
16 SecurityTokenRequestType, StatusCode,
17};
18use tracing::{debug, error};
19
20use super::{
21 connect::{Connector, Transport},
22 state::{Request, RequestSend, SecureChannelState},
23};
24
25use crate::{
26 retry::SessionRetryPolicy,
27 transport::{tcp::TransportConfiguration, OutgoingMessage},
28};
29
30const MAX_INFLIGHT_MESSAGES: usize = 1_000_000;
34
35pub struct AsyncSecureChannel {
37 endpoint_info: EndpointInfo,
38 session_retry_policy: SessionRetryPolicy,
39 pub(crate) secure_channel: Arc<RwLock<SecureChannel>>,
40 certificate_store: Arc<RwLock<CertificateStore>>,
41 transport_config: TransportConfiguration,
42 state: Arc<SecureChannelState>,
43 issue_channel_lock: tokio::sync::Mutex<()>,
44 channel_lifetime: u32,
45
46 request_send: ArcSwapOption<RequestSend>,
47 encoding_context: Arc<RwLock<ContextOwned>>,
48}
49
50pub struct SecureChannelEventLoop<T> {
52 transport: T,
53}
54
55impl<T: Transport + Send + Sync + 'static> SecureChannelEventLoop<T> {
56 pub async fn poll(&mut self) -> TransportPollResult {
59 self.transport.poll().await
60 }
61
62 pub fn connected_url(&self) -> &str {
66 self.transport.connected_url()
67 }
68}
69
70impl AsyncSecureChannel {
71 pub(crate) fn make_request_header(&self, timeout: Duration) -> RequestHeader {
72 self.state.make_request_header(timeout)
73 }
74
75 pub fn request_handle(&self) -> IntegerId {
77 self.state.request_handle()
78 }
79
80 pub(crate) fn update_from_created_session(
81 &self,
82 nonce: &ByteString,
83 certificate: &ByteString,
84 auth_token: &NodeId,
85 ) -> Result<(), StatusCode> {
86 let mut secure_channel = trace_write_lock!(self.secure_channel);
87 secure_channel.set_remote_nonce_from_byte_string(nonce)?;
88 secure_channel.set_remote_cert_from_byte_string(certificate)?;
89 self.set_auth_token(auth_token.clone());
90 Ok(())
91 }
92
93 pub(crate) fn security_policy(&self) -> SecurityPolicy {
94 let secure_channel = trace_read_lock!(self.secure_channel);
95 secure_channel.security_policy()
96 }
97
98 pub fn endpoint_info(&self) -> &EndpointInfo {
100 &self.endpoint_info
101 }
102
103 pub fn encoding_context(&self) -> &RwLock<ContextOwned> {
105 &self.encoding_context
106 }
107
108 pub fn set_auth_token(&self, token: NodeId) {
110 self.state.set_auth_token(token);
111 }
112
113 pub(crate) fn read_own_private_key(&self) -> Option<PrivateKey> {
114 let cert_store = trace_read_lock!(self.certificate_store);
115 cert_store.read_own_pkey().ok()
116 }
117
118 pub(crate) fn read_own_certificate(&self) -> Option<X509> {
119 let cert_store = trace_read_lock!(self.certificate_store);
120 cert_store.read_own_cert().ok()
121 }
122
123 pub(crate) fn certificate_store(&self) -> &RwLock<CertificateStore> {
124 &self.certificate_store
125 }
126}
127
128impl AsyncSecureChannel {
129 #[allow(clippy::too_many_arguments)]
131 pub fn new(
132 certificate_store: Arc<RwLock<CertificateStore>>,
133 endpoint_info: EndpointInfo,
134 session_retry_policy: SessionRetryPolicy,
135 ignore_clock_skew: bool,
136 auth_token: Arc<ArcSwap<NodeId>>,
137 transport_config: TransportConfiguration,
138 channel_lifetime: u32,
139 encoding_context: Arc<RwLock<ContextOwned>>,
140 ) -> Self {
141 let secure_channel = Arc::new(RwLock::new(SecureChannel::new(
142 certificate_store.clone(),
143 Role::Client,
144 encoding_context.clone(),
145 )));
146
147 Self {
148 transport_config,
149 issue_channel_lock: tokio::sync::Mutex::new(()),
150 state: Arc::new(SecureChannelState::new(
151 ignore_clock_skew,
152 secure_channel.clone(),
153 auth_token,
154 )),
155 endpoint_info,
156 secure_channel,
157 certificate_store,
158 session_retry_policy,
159 request_send: Default::default(),
160 channel_lifetime,
161 encoding_context,
162 }
163 }
164
165 pub async fn send(
167 &self,
168 request: impl Into<RequestMessage>,
169 timeout: Duration,
170 ) -> Result<ResponseMessage, StatusCode> {
171 let sender = self.request_send.load().as_deref().cloned();
172 let Some(send) = sender else {
173 return Err(StatusCode::BadNotConnected);
174 };
175
176 let should_renew_security_token = {
177 let secure_channel = trace_read_lock!(self.secure_channel);
178 secure_channel.should_renew_security_token()
179 };
180
181 if should_renew_security_token {
182 let guard = self.issue_channel_lock.lock().await;
187 let should_renew_security_token = {
188 let secure_channel = trace_read_lock!(self.secure_channel);
189 secure_channel.should_renew_security_token()
190 };
191
192 if should_renew_security_token {
193 let request = self.state.begin_issue_or_renew_secure_channel(
194 SecurityTokenRequestType::Renew,
195 self.channel_lifetime,
196 Duration::from_secs(30),
197 send.clone(),
198 );
199
200 let resp = request.send().await?;
201
202 if !matches!(resp, ResponseMessage::OpenSecureChannel(_)) {
203 return Err(process_unexpected_response(resp));
204 }
205 }
206
207 drop(guard);
208 }
209
210 Request::new(request, send, timeout).send().await
211 }
212
213 pub async fn connect<T: Connector>(
216 &self,
217 connector: &T,
218 ) -> Result<SecureChannelEventLoop<T::Transport>, StatusCode> {
219 self.request_send.store(None);
220 let mut backoff = self.session_retry_policy.new_backoff();
221 loop {
222 match self.connect_no_retry(connector).await {
223 Ok(event_loop) => {
224 break Ok(event_loop);
225 }
226 Err(s) => {
227 let Some(delay) = backoff.next() else {
228 break Err(s);
229 };
230
231 tokio::time::sleep(delay).await
232 }
233 }
234 }
235 }
236
237 pub async fn connect_no_retry<T: Connector>(
239 &self,
240 connector: &T,
241 ) -> Result<SecureChannelEventLoop<T::Transport>, StatusCode> {
242 {
243 let mut secure_channel = trace_write_lock!(self.secure_channel);
244 secure_channel.clear_security_token();
245 }
246
247 let (mut transport, send) = self.create_transport(connector).await?;
248
249 let request = self.state.begin_issue_or_renew_secure_channel(
250 SecurityTokenRequestType::Issue,
251 self.channel_lifetime,
252 Duration::from_secs(30),
253 send.clone(),
254 );
255
256 let request_fut = request.send();
257 tokio::pin!(request_fut);
258
259 let resp = loop {
261 tokio::select! {
262 r = &mut request_fut => break r?,
263 r = transport.poll() => {
264 if let TransportPollResult::Closed(e) = r {
265 return Err(e);
266 }
267 }
268 }
269 };
270
271 self.request_send.store(Some(Arc::new(send)));
272 if !matches!(resp, ResponseMessage::OpenSecureChannel(_)) {
273 return Err(process_unexpected_response(resp));
274 }
275
276 Ok(SecureChannelEventLoop { transport })
277 }
278
279 async fn create_transport<T: Connector>(
280 &self,
281 connector: &T,
282 ) -> Result<(T::Transport, tokio::sync::mpsc::Sender<OutgoingMessage>), StatusCode> {
283 debug!("Connect");
284 let security_policy =
285 SecurityPolicy::from_str(self.endpoint_info.endpoint.security_policy_uri.as_ref())
286 .map_err(|_| StatusCode::BadSecurityPolicyRejected)?;
287
288 if security_policy == SecurityPolicy::Unknown {
289 error!(
290 "connect, security policy \"{}\" is unknown",
291 self.endpoint_info.endpoint.security_policy_uri.as_ref()
292 );
293 Err(StatusCode::BadSecurityPolicyRejected)
294 } else {
295 let (cert, key) = {
296 let certificate_store = trace_write_lock!(self.certificate_store);
297 (
298 certificate_store.read_own_cert().ok(),
299 certificate_store.read_own_pkey().ok(),
300 )
301 };
302
303 {
304 let mut secure_channel = trace_write_lock!(self.secure_channel);
305 secure_channel.set_private_key(key);
306 secure_channel.set_cert(cert);
307 secure_channel.set_security_policy(security_policy);
308 secure_channel.set_security_mode(self.endpoint_info.endpoint.security_mode);
309 secure_channel.set_remote_cert_from_byte_string(
310 &self.endpoint_info.endpoint.server_certificate,
311 )?;
312 debug!("Security policy = {:?}", security_policy);
313 debug!(
314 "Security mode = {:?}",
315 self.endpoint_info.endpoint.security_mode
316 );
317 }
318
319 let (send, recv) = tokio::sync::mpsc::channel(MAX_INFLIGHT_MESSAGES);
320 let transport = connector
321 .connect(self.state.clone(), recv, self.transport_config.clone())
322 .await?;
323
324 Ok((transport, send))
325 }
326 }
327
328 pub async fn close_channel(&self) {
330 let msg = CloseSecureChannelRequest {
331 request_header: self.state.make_request_header(Duration::from_secs(60)),
332 };
333
334 let sender = self.request_send.load().as_deref().cloned();
335 let request = sender.map(|s| Request::new(msg, s, Duration::from_secs(60)));
336
337 if let Some(request) = request {
339 if let Err(e) = request.send_no_response().await {
340 error!("Failed to send disconnect message, queue full: {e}");
341 }
342 }
343 }
344}