netconf/
session.rs

1use std::{
2    collections::{hash_map::Entry, HashMap},
3    fmt::{self, Debug, Display},
4    future::Future,
5    mem,
6    num::NonZeroU32,
7    str::FromStr,
8    sync::Arc,
9};
10
11#[cfg(feature = "tls")]
12use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
13
14use tokio::{net::ToSocketAddrs, sync::Mutex};
15
16use crate::{
17    capabilities::{Base, Capabilities},
18    message::{
19        rpc::{
20            self,
21            operation::{Builder, CloseSession},
22            IntoResult,
23        },
24        ClientHello, ClientMsg, ReadError, ServerHello, ServerMsg,
25    },
26    transport::Transport,
27    Error,
28};
29
30#[cfg(feature = "tls")]
31use crate::transport::Tls;
32
33#[cfg(feature = "ssh")]
34use crate::transport::{Password, Ssh};
35
36#[cfg(feature = "junos")]
37use crate::transport::JunosLocal;
38
39/// An identifier used by a NETCONF server to uniquely identify a session.
40#[allow(clippy::module_name_repetitions)]
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub struct SessionId(NonZeroU32);
43
44impl SessionId {
45    pub(crate) fn new(n: u32) -> Result<Self, Error> {
46        NonZeroU32::new(n)
47            .ok_or(Error::InvalidSessionId { session_id: n })
48            .map(Self)
49    }
50}
51
52impl FromStr for SessionId {
53    type Err = ReadError;
54
55    fn from_str(s: &str) -> Result<Self, Self::Err> {
56        Ok(Self(s.parse().map_err(Self::Err::SessionIdParse)?))
57    }
58}
59
60impl Display for SessionId {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62        Display::fmt(&self.0, f)
63    }
64}
65
66/// A NETCONF client session over a secure transport `T`.
67///
68/// [`Session`] instances provide direct access to asynchronous NETCONF protocol operations. The
69/// library user is responsible for ensuring the correct ordering of operations to ensure, for
70/// example, safe config modification. See [RFC6241] for additional guidance.
71///
72/// [RFC6241]: https://datatracker.ietf.org/doc/html/rfc6241#appendix-E
73#[derive(Debug)]
74pub struct Session<T: Transport> {
75    transport_tx: Arc<Mutex<T::SendHandle>>,
76    transport_rx: Arc<Mutex<T::RecvHandle>>,
77    context: Context,
78    last_message_id: rpc::MessageId,
79    requests: Arc<Mutex<HashMap<rpc::MessageId, OutstandingRequest>>>,
80}
81
82/// NETCONF session state container.
83#[derive(Debug)]
84pub struct Context {
85    session_id: SessionId,
86    protocol_version: Base,
87    client_capabilities: Capabilities,
88    server_capabilities: Capabilities,
89}
90
91impl Context {
92    const fn new(
93        session_id: SessionId,
94        protocol_version: Base,
95        client_capabilities: Capabilities,
96        server_capabilities: Capabilities,
97    ) -> Self {
98        Self {
99            session_id,
100            protocol_version,
101            client_capabilities,
102            server_capabilities,
103        }
104    }
105
106    /// The NETCONF `session-id` of the current session.
107    #[must_use]
108    pub const fn session_id(&self) -> SessionId {
109        self.session_id
110    }
111
112    /// The base NETCONF protocol version negotiated on the current session.
113    #[must_use]
114    pub const fn protocol_version(&self) -> Base {
115        self.protocol_version
116    }
117
118    /// The set of NETCONF capabilities advertised by the client during `<hello>` message exchange.
119    #[must_use]
120    pub const fn client_capabilities(&self) -> &Capabilities {
121        &self.client_capabilities
122    }
123
124    /// The set of NETCONF capabilities advertised by the client during `<hello>` message exchange.
125    #[must_use]
126    pub const fn server_capabilities(&self) -> &Capabilities {
127        &self.server_capabilities
128    }
129}
130
131#[derive(Debug)]
132enum OutstandingRequest {
133    Pending,
134    Ready(rpc::PartialReply),
135    Complete,
136}
137
138impl OutstandingRequest {
139    #[tracing::instrument(level = "trace")]
140    fn take(&mut self) -> Result<Option<rpc::PartialReply>, Error> {
141        match mem::replace(self, Self::Complete) {
142            mut pending @ Self::Pending => {
143                mem::swap(self, &mut pending);
144                Ok(None)
145            }
146            Self::Complete => Err(Error::RequestComplete),
147            Self::Ready(reply) => Ok(Some(reply)),
148        }
149    }
150}
151
152#[cfg(feature = "ssh")]
153impl Session<Ssh> {
154    /// Establish a new NETCONF session over an SSH transport.
155    #[tracing::instrument(level = "debug")]
156    pub async fn ssh<A>(addr: A, username: String, password: Password) -> Result<Self, Error>
157    where
158        A: ToSocketAddrs + Send + Debug,
159    {
160        tracing::info!("starting ssh transport");
161        let transport = Ssh::connect(addr, username, password).await?;
162        Self::new(transport).await
163    }
164}
165
166#[cfg(feature = "tls")]
167impl Session<Tls> {
168    /// Establish a new NETCONF session over a TLS transport.
169    #[tracing::instrument(skip(ca_cert, client_cert, client_key), level = "debug")]
170    pub async fn tls<A, S>(
171        addr: A,
172        server_name: S,
173        ca_cert: CertificateDer<'_>,
174        client_cert: CertificateDer<'static>,
175        client_key: PrivateKeyDer<'static>,
176    ) -> Result<Self, Error>
177    where
178        A: ToSocketAddrs + Debug + Send,
179        S: TryInto<ServerName<'static>> + Debug + Send,
180        Error: From<S::Error>,
181    {
182        tracing::info!("starting tls transport");
183        let transport = Tls::connect(addr, server_name, ca_cert, client_cert, client_key).await?;
184        Self::new(transport).await
185    }
186}
187
188#[cfg(feature = "junos")]
189impl Session<JunosLocal> {
190    /// Establish a new NETCONF session via the local Junos `cli` binary.
191    #[tracing::instrument(level = "debug")]
192    pub async fn junos_local() -> Result<Self, Error> {
193        tracing::info!("starting local junos transport");
194        let transport = JunosLocal::connect().await?;
195        Self::new(transport).await
196    }
197}
198
199impl<T: Transport> Session<T> {
200    #[tracing::instrument(skip(transport), level = "trace")]
201    async fn new(transport: T) -> Result<Self, Error> {
202        let client_hello = ClientHello::default();
203        let (mut tx, mut rx) = transport.split();
204        let ((), server_hello) =
205            tokio::try_join!(client_hello.send(&mut tx), ServerHello::recv(&mut rx))?;
206        let transport_tx = Arc::new(Mutex::new(tx));
207        let transport_rx = Arc::new(Mutex::new(rx));
208        let session_id = server_hello.session_id();
209        let server_capabilities = server_hello.capabilities();
210        let client_capabilities = client_hello.capabilities();
211        let protocol_version = client_capabilities.highest_common_version(&server_capabilities)?;
212        let context = Context::new(
213            session_id,
214            protocol_version,
215            client_capabilities,
216            server_capabilities,
217        );
218        let requests = Arc::new(Mutex::new(HashMap::default()));
219        Ok(Self {
220            transport_tx,
221            transport_rx,
222            context,
223            requests,
224            last_message_id: rpc::MessageId::default(),
225        })
226    }
227
228    /// Get the session state [`Context`] of this session.
229    #[must_use]
230    pub const fn context(&self) -> &Context {
231        &self.context
232    }
233
234    /// Execute a NETCONF RPC operation on the current session.
235    ///
236    /// See the [`rpc::operation`] module for available operations and their request builder APIs.
237    ///
238    /// RPC requests are built and validated against the [`Context`] of the current session - in
239    /// particular, against the list of capabilities advertised by the NETCONF server in the
240    /// `<hello>` message exchange.
241    ///
242    /// The `build_fn` closure must accept an instance of the operation request
243    /// [`Builder`][rpc::Operation::Builder], configure the builder, and then convert it to a
244    /// validated request by calling [`Builder::finish()`][rpc::operation::Builder::finish].
245    ///
246    /// This method returns a nested [`Future`], reflecting the fact that the request is sent to
247    /// the NETCONF server asynchronously and then the response is later received asynchronously.
248    ///
249    /// The `Output` of both the outer and inner `Future` are of type `Result`.
250    ///
251    /// An [`Err`] variant returned by awaiting the outer future indicates either a request validation
252    /// error or a session/transport error encountered while sending the RPC request.
253    ///
254    /// An [`Err`] variant returned by awaiting the inner future indicates either a
255    /// session/transport error while receiving the `<rpc-reply>` message, an error parsing the
256    /// received XML, or one-or-more application layer errors returned by the NETCONF server. The
257    /// latter case may be identified by matching on the [`Error::RpcError`] variant.
258    #[tracing::instrument(skip(self, build_fn), level = "debug")]
259    pub async fn rpc<O, F>(
260        &mut self,
261        build_fn: F,
262    ) -> Result<impl Future<Output = Result<<O::Reply as IntoResult>::Ok, Error>>, Error>
263    where
264        O: rpc::Operation,
265        F: FnOnce(O::Builder<'_>) -> Result<O, Error> + Send,
266    {
267        let message_id = self.last_message_id.increment();
268        let request = O::new(&self.context, build_fn)
269            .map(|operation| rpc::Request::new(message_id, operation))?;
270        #[allow(clippy::significant_drop_in_scrutinee)]
271        match self.requests.lock().await.entry(message_id) {
272            Entry::Occupied(_) => return Err(Error::MessageIdCollision { message_id }),
273            Entry::Vacant(entry) => {
274                request.send(&mut *self.transport_tx.lock().await).await?;
275                _ = entry.insert(OutstandingRequest::Pending);
276            }
277        };
278        let requests = self.requests.clone();
279        let rx = self.transport_rx.clone();
280        Ok(Self::recv::<O>(message_id, requests, rx))
281    }
282
283    #[tracing::instrument(skip(requests, rx), level = "debug")]
284    async fn recv<O>(
285        message_id: rpc::MessageId,
286        requests: Arc<Mutex<HashMap<rpc::MessageId, OutstandingRequest>>>,
287        rx: Arc<Mutex<<T as Transport>::RecvHandle>>,
288    ) -> Result<<O::Reply as IntoResult>::Ok, Error>
289    where
290        O: rpc::Operation,
291    {
292        // TODO:
293        // Try using a background task to read from the transport, and then just check that task
294        // and `take()` from requests in a `select!` here.
295        loop {
296            let mut rx_guard = rx.lock().await;
297            tracing::trace!(?requests);
298            tracing::debug!("checking for ready response");
299            if let Some(partial) = requests
300                .lock()
301                .await
302                .get_mut(&message_id)
303                .ok_or(Error::RequestNotFound { message_id })?
304                .take()?
305            {
306                tracing::debug!("found ready response");
307                let reply: rpc::Reply<O> = partial.try_into()?;
308                break reply.into_result();
309            };
310            tracing::debug!("response to {message_id:?} not yet ready");
311            let reply = rpc::PartialReply::recv(&mut *rx_guard).await?;
312            #[allow(clippy::significant_drop_in_scrutinee)]
313            match requests
314                .lock()
315                .await
316                .get_mut(&reply.message_id())
317                .ok_or_else(|| Error::RequestNotFound {
318                    message_id: reply.message_id(),
319                })? {
320                OutstandingRequest::Complete => break Err(Error::RequestComplete),
321                OutstandingRequest::Ready(_) => {
322                    break Err(Error::MessageIdCollision {
323                        message_id: reply.message_id(),
324                    })
325                }
326                pending @ OutstandingRequest::Pending => {
327                    tracing::debug!("storing response to {:?}", reply.message_id());
328                    _ = mem::replace(pending, OutstandingRequest::Ready(reply));
329                }
330            };
331            drop(rx_guard);
332        }
333    }
334
335    /// Close the NETCONF session gracefully using the `<close-session>` RPC operation.
336    #[tracing::instrument(skip(self), level = "debug")]
337    pub async fn close(mut self) -> Result<impl Future<Output = Result<(), Error>>, Error> {
338        self.rpc::<CloseSession, _>(Builder::finish)
339            .await
340            .map(|fut| async move { fut.await.map(|()| drop(self)) })
341    }
342}