Skip to main content

opcua_client/transport/
channel.rs

1use std::{str::FromStr, sync::Arc, time::Duration};
2
3use crate::{
4    session::{process_unexpected_response, EndpointInfo},
5    transport::core::TransportPollResult,
6};
7use arc_swap::{ArcSwap, ArcSwapOption};
8use opcua_core::{
9    comms::secure_channel::{Role, SecureChannel},
10    sync::RwLock,
11    trace_read_lock, trace_write_lock, RequestMessage, ResponseMessage,
12};
13use opcua_crypto::{CertificateStore, PrivateKey, SecurityPolicy, X509};
14use opcua_types::{
15    ByteString, CloseSecureChannelRequest, ContextOwned, IntegerId, NodeId, RequestHeader,
16    SecurityTokenRequestType, StatusCode,
17};
18use tracing::{debug, error};
19
20use super::{
21    connect::{Connector, Transport},
22    state::{Request, RequestSend, SecureChannelState},
23};
24
25use crate::{
26    retry::SessionRetryPolicy,
27    transport::{tcp::TransportConfiguration, OutgoingMessage},
28};
29
30// This is an arbitrary limit which should never be reached in practice,
31// it's just a safety net to prevent the client from consuming too much
32// memory if it gets into an unexpected (bad) state.
33const MAX_INFLIGHT_MESSAGES: usize = 1_000_000;
34
35/// Wrapper around an open secure channel
36pub struct AsyncSecureChannel {
37    endpoint_info: EndpointInfo,
38    session_retry_policy: SessionRetryPolicy,
39    pub(crate) secure_channel: Arc<RwLock<SecureChannel>>,
40    certificate_store: Arc<RwLock<CertificateStore>>,
41    transport_config: TransportConfiguration,
42    state: Arc<SecureChannelState>,
43    issue_channel_lock: tokio::sync::Mutex<()>,
44    channel_lifetime: u32,
45
46    request_send: ArcSwapOption<RequestSend>,
47    encoding_context: Arc<RwLock<ContextOwned>>,
48}
49
50/// Event loop for a secure channel. This must be polled to make progress.
51pub struct SecureChannelEventLoop<T> {
52    transport: T,
53}
54
55impl<T: Transport + Send + Sync + 'static> SecureChannelEventLoop<T> {
56    /// Poll the channel, processing any pending incoming or outgoing messages and returning the
57    /// action that was taken.
58    pub async fn poll(&mut self) -> TransportPollResult {
59        self.transport.poll().await
60    }
61
62    /// Get the URL of the connected server.
63    /// This was either the URL used to establish the connection, or the URL
64    /// reported by the server in ReverseHello.
65    pub fn connected_url(&self) -> &str {
66        self.transport.connected_url()
67    }
68}
69
70impl AsyncSecureChannel {
71    pub(crate) fn make_request_header(&self, timeout: Duration) -> RequestHeader {
72        self.state.make_request_header(timeout)
73    }
74
75    /// Get the next request handle on the channel.
76    pub fn request_handle(&self) -> IntegerId {
77        self.state.request_handle()
78    }
79
80    pub(crate) fn update_from_created_session(
81        &self,
82        nonce: &ByteString,
83        certificate: &ByteString,
84        auth_token: &NodeId,
85    ) -> Result<(), StatusCode> {
86        let mut secure_channel = trace_write_lock!(self.secure_channel);
87        secure_channel.set_remote_nonce_from_byte_string(nonce)?;
88        secure_channel.set_remote_cert_from_byte_string(certificate)?;
89        self.set_auth_token(auth_token.clone());
90        Ok(())
91    }
92
93    pub(crate) fn security_policy(&self) -> SecurityPolicy {
94        let secure_channel = trace_read_lock!(self.secure_channel);
95        secure_channel.security_policy()
96    }
97
98    /// Get the target endpoint of the secure channel.
99    pub fn endpoint_info(&self) -> &EndpointInfo {
100        &self.endpoint_info
101    }
102
103    /// Get the current global encoding context in use by this channel.
104    pub fn encoding_context(&self) -> &RwLock<ContextOwned> {
105        &self.encoding_context
106    }
107
108    /// Set the active authentication token for this channel.
109    pub fn set_auth_token(&self, token: NodeId) {
110        self.state.set_auth_token(token);
111    }
112
113    pub(crate) fn read_own_private_key(&self) -> Option<PrivateKey> {
114        let cert_store = trace_read_lock!(self.certificate_store);
115        cert_store.read_own_pkey().ok()
116    }
117
118    pub(crate) fn read_own_certificate(&self) -> Option<X509> {
119        let cert_store = trace_read_lock!(self.certificate_store);
120        cert_store.read_own_cert().ok()
121    }
122
123    pub(crate) fn certificate_store(&self) -> &RwLock<CertificateStore> {
124        &self.certificate_store
125    }
126}
127
128impl AsyncSecureChannel {
129    /// Create a new client secure channel.
130    #[allow(clippy::too_many_arguments)]
131    pub fn new(
132        certificate_store: Arc<RwLock<CertificateStore>>,
133        endpoint_info: EndpointInfo,
134        session_retry_policy: SessionRetryPolicy,
135        ignore_clock_skew: bool,
136        auth_token: Arc<ArcSwap<NodeId>>,
137        transport_config: TransportConfiguration,
138        channel_lifetime: u32,
139        encoding_context: Arc<RwLock<ContextOwned>>,
140    ) -> Self {
141        let secure_channel = Arc::new(RwLock::new(SecureChannel::new(
142            certificate_store.clone(),
143            Role::Client,
144            encoding_context.clone(),
145        )));
146
147        Self {
148            transport_config,
149            issue_channel_lock: tokio::sync::Mutex::new(()),
150            state: Arc::new(SecureChannelState::new(
151                ignore_clock_skew,
152                secure_channel.clone(),
153                auth_token,
154            )),
155            endpoint_info,
156            secure_channel,
157            certificate_store,
158            session_retry_policy,
159            request_send: Default::default(),
160            channel_lifetime,
161            encoding_context,
162        }
163    }
164
165    /// Send a message on the secure channel, and wait for a response.
166    pub async fn send(
167        &self,
168        request: impl Into<RequestMessage>,
169        timeout: Duration,
170    ) -> Result<ResponseMessage, StatusCode> {
171        let sender = self.request_send.load().as_deref().cloned();
172        let Some(send) = sender else {
173            return Err(StatusCode::BadNotConnected);
174        };
175
176        let should_renew_security_token = {
177            let secure_channel = trace_read_lock!(self.secure_channel);
178            secure_channel.should_renew_security_token()
179        };
180
181        if should_renew_security_token {
182            // Grab the lock, then check again whether we should renew the secure channel,
183            // this avoids renewing it multiple times if the client sends many requests in quick
184            // succession.
185            // Also, if the channel is currently being renewed, we need to wait for the new security token.
186            let guard = self.issue_channel_lock.lock().await;
187            let should_renew_security_token = {
188                let secure_channel = trace_read_lock!(self.secure_channel);
189                secure_channel.should_renew_security_token()
190            };
191
192            if should_renew_security_token {
193                let request = self.state.begin_issue_or_renew_secure_channel(
194                    SecurityTokenRequestType::Renew,
195                    self.channel_lifetime,
196                    Duration::from_secs(30),
197                    send.clone(),
198                );
199
200                let resp = request.send().await?;
201
202                if !matches!(resp, ResponseMessage::OpenSecureChannel(_)) {
203                    return Err(process_unexpected_response(resp));
204                }
205            }
206
207            drop(guard);
208        }
209
210        Request::new(request, send, timeout).send().await
211    }
212
213    /// Attempt to establish a connection using this channel, returning an event loop
214    /// for polling the connection.
215    pub async fn connect<T: Connector>(
216        &self,
217        connector: &T,
218    ) -> Result<SecureChannelEventLoop<T::Transport>, StatusCode> {
219        self.request_send.store(None);
220        let mut backoff = self.session_retry_policy.new_backoff();
221        loop {
222            match self.connect_no_retry(connector).await {
223                Ok(event_loop) => {
224                    break Ok(event_loop);
225                }
226                Err(s) => {
227                    let Some(delay) = backoff.next() else {
228                        break Err(s);
229                    };
230
231                    tokio::time::sleep(delay).await
232                }
233            }
234        }
235    }
236
237    /// Connect to the server without attempting to retry if it fails.
238    pub async fn connect_no_retry<T: Connector>(
239        &self,
240        connector: &T,
241    ) -> Result<SecureChannelEventLoop<T::Transport>, StatusCode> {
242        {
243            let mut secure_channel = trace_write_lock!(self.secure_channel);
244            secure_channel.clear_security_token();
245        }
246
247        let (mut transport, send) = self.create_transport(connector).await?;
248
249        let request = self.state.begin_issue_or_renew_secure_channel(
250            SecurityTokenRequestType::Issue,
251            self.channel_lifetime,
252            Duration::from_secs(30),
253            send.clone(),
254        );
255
256        let request_fut = request.send();
257        tokio::pin!(request_fut);
258
259        // Temporarily poll the transport task while we're waiting for a response.
260        let resp = loop {
261            tokio::select! {
262                r = &mut request_fut => break r?,
263                r = transport.poll() => {
264                    if let TransportPollResult::Closed(e) = r {
265                        return Err(e);
266                    }
267                }
268            }
269        };
270
271        self.request_send.store(Some(Arc::new(send)));
272        if !matches!(resp, ResponseMessage::OpenSecureChannel(_)) {
273            return Err(process_unexpected_response(resp));
274        }
275
276        Ok(SecureChannelEventLoop { transport })
277    }
278
279    async fn create_transport<T: Connector>(
280        &self,
281        connector: &T,
282    ) -> Result<(T::Transport, tokio::sync::mpsc::Sender<OutgoingMessage>), StatusCode> {
283        debug!("Connect");
284        let security_policy =
285            SecurityPolicy::from_str(self.endpoint_info.endpoint.security_policy_uri.as_ref())
286                .map_err(|_| StatusCode::BadSecurityPolicyRejected)?;
287
288        if security_policy == SecurityPolicy::Unknown {
289            error!(
290                "connect, security policy \"{}\" is unknown",
291                self.endpoint_info.endpoint.security_policy_uri.as_ref()
292            );
293            Err(StatusCode::BadSecurityPolicyRejected)
294        } else {
295            let (cert, key) = {
296                let certificate_store = trace_write_lock!(self.certificate_store);
297                (
298                    certificate_store.read_own_cert().ok(),
299                    certificate_store.read_own_pkey().ok(),
300                )
301            };
302
303            {
304                let mut secure_channel = trace_write_lock!(self.secure_channel);
305                secure_channel.set_private_key(key);
306                secure_channel.set_cert(cert);
307                secure_channel.set_security_policy(security_policy);
308                secure_channel.set_security_mode(self.endpoint_info.endpoint.security_mode);
309                secure_channel.set_remote_cert_from_byte_string(
310                    &self.endpoint_info.endpoint.server_certificate,
311                )?;
312                debug!("Security policy = {:?}", security_policy);
313                debug!(
314                    "Security mode = {:?}",
315                    self.endpoint_info.endpoint.security_mode
316                );
317            }
318
319            let (send, recv) = tokio::sync::mpsc::channel(MAX_INFLIGHT_MESSAGES);
320            let transport = connector
321                .connect(self.state.clone(), recv, self.transport_config.clone())
322                .await?;
323
324            Ok((transport, send))
325        }
326    }
327
328    /// Close the secure channel, optionally wait for the channel to close.
329    pub async fn close_channel(&self) {
330        let msg = CloseSecureChannelRequest {
331            request_header: self.state.make_request_header(Duration::from_secs(60)),
332        };
333
334        let sender = self.request_send.load().as_deref().cloned();
335        let request = sender.map(|s| Request::new(msg, s, Duration::from_secs(60)));
336
337        // Instruct the channel to not attempt to reopen.
338        if let Some(request) = request {
339            if let Err(e) = request.send_no_response().await {
340                error!("Failed to send disconnect message, queue full: {e}");
341            }
342        }
343    }
344}