Skip to main content

netconf_rust/
session.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicU8, Ordering};
3use std::sync::{Arc, Mutex};
4
5use bytes::{Bytes, BytesMut};
6use log::{debug, trace, warn};
7use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
8use tokio_stream::StreamExt;
9use tokio_util::codec::{Encoder, FramedRead};
10
11use crate::codec::{CodecConfig, FramingMode, NetconfCodec};
12use crate::hello::ServerHello;
13use crate::message::{self, DataPayload, RpcReply, RpcReplyBody, ServerMessage};
14use crate::stream::NetconfStream;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17#[repr(u8)]
18pub enum SessionState {
19    /// Hello exchange complete, ready for RPCs
20    Ready = 0,
21    /// A 'close-session' RPC has been sent, awaiting reply
22    Closing = 1,
23    /// Session terminated gracefully or with error
24    Closed = 2,
25}
26
27impl SessionState {
28    fn from_u8(v: u8) -> Self {
29        match v {
30            0 => Self::Ready,
31            1 => Self::Closing,
32            _ => Self::Closed,
33        }
34    }
35}
36
37impl std::fmt::Display for SessionState {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        match self {
40            Self::Ready => write!(f, "Ready"),
41            Self::Closing => write!(f, "Closing"),
42            Self::Closed => write!(f, "Closed"),
43        }
44    }
45}
46
47/// Configuration for a NETCONF session.
48#[derive(Debug, Clone, Default)]
49pub struct SessionConfig {
50    /// Codec configuration (message size limits, etc.).
51    pub codec: CodecConfig,
52}
53
54#[derive(Debug, Clone, Copy)]
55pub enum Datastore {
56    Running,
57    Candidate,
58    Startup,
59}
60
61impl Datastore {
62    fn as_xml(&self) -> &'static str {
63        match self {
64            Datastore::Running => "<running/>",
65            Datastore::Candidate => "<candidate/>",
66            Datastore::Startup => "<startup/>",
67        }
68    }
69}
70
71/// A handle to a pending RPC reply.
72///
73/// Created by [`Session::rpc_send()`], this allows sending multiple RPCs
74/// before awaiting any replies (pipelining).
75pub struct RpcFuture {
76    rx: tokio::sync::oneshot::Receiver<crate::Result<RpcReply>>,
77    msg_id: u32,
78}
79
80impl RpcFuture {
81    /// The message-id of this RPC.
82    pub fn message_id(&self) -> u32 {
83        self.msg_id
84    }
85
86    /// Await the RPC reply.
87    pub async fn response(self) -> crate::Result<RpcReply> {
88        self.rx.await.map_err(|_| crate::Error::SessionClosed)?
89    }
90}
91
92struct SessionInner {
93    /// Pending RPC replies: message-id → oneshot sender.
94    pending: Mutex<HashMap<u32, tokio::sync::oneshot::Sender<crate::Result<RpcReply>>>>,
95    /// Session lifecycle state, stored as AtomicU8 for lock-free access
96    /// across the session and reader task.
97    ///
98    /// State transitions:
99    ///   Ready → Closing → Closed   (graceful: user calls close_session)
100    ///   Ready → Closed             (abrupt: reader hits EOF or error)
101    ///
102    /// - Ready:   normal operation, RPCs can be sent
103    /// - Closing: close-session RPC in flight, new RPCs rejected
104    /// - Closed:  session terminated, all operations fail
105    state: AtomicU8,
106}
107
108impl SessionInner {
109    fn state(&self) -> SessionState {
110        SessionState::from_u8(self.state.load(Ordering::Acquire))
111    }
112
113    fn set_state(&self, state: SessionState) {
114        self.state.store(state as u8, Ordering::Release);
115    }
116}
117
118/// A NETCONF session with pipelining support.
119///
120/// Architecture:
121/// - **Writer**: the session holds the write half of the stream and a codec
122///   for encoding outgoing RPCs.
123/// - **Reader task**: a background tokio task reads framed messages from
124///   the read half, classifies them (`<rpc-reply>` vs `<notification>`),
125///   and routes them to the correct handler.
126/// - **Pipelining**: [`rpc_send()`](Session::rpc_send) writes an RPC and
127///   returns an [`RpcFuture`] without waiting for the reply. Multiple RPCs
128///   can be in flight simultaneously.
129pub struct Session {
130    /// Write half of the split stream, used to send RPCs to the server.
131    writer: WriteHalf<NetconfStream>,
132
133    /// Codec used to frame outgoing messages (EOM or chunked).
134    /// Kept separate from FramedRead because the write half isn't wrapped
135    /// in FramedWrite — we encode manually and write_all to the writer.
136    write_codec: NetconfCodec,
137
138    /// Shared state between this session and the background reader task.
139    /// Contains the pending RPC map, notification channel, and session state.
140    inner: Arc<SessionInner>,
141
142    /// The server's hello response, containing its capabilities and session ID.
143    server_hello: ServerHello,
144
145    /// Negotiated framing mode (EOM for 1.0-only servers, chunked for 1.1).
146    framing: FramingMode,
147
148    /// Holds the SSH handle alive. Dropping this tears down the SSH connection,
149    /// which would invalidate the stream. Never accessed, just kept alive.
150    _keep_alive: Option<Box<dyn std::any::Any + Send>>,
151
152    /// Handle to the background reader task. Aborted on drop to ensure
153    /// the task doesn't outlive the session.
154    _reader_handle: tokio::task::JoinHandle<()>,
155}
156
157impl Drop for Session {
158    fn drop(&mut self) {
159        self._reader_handle.abort();
160    }
161}
162
163impl Session {
164    /// Connect to a NETCONF server over SSH with password authentication.
165    pub async fn connect(
166        host: &str,
167        port: u16,
168        username: &str,
169        password: &str,
170    ) -> crate::Result<Self> {
171        Self::connect_with_config(host, port, username, password, SessionConfig::default()).await
172    }
173
174    /// Connect with custom configuration.
175    pub async fn connect_with_config(
176        host: &str,
177        port: u16,
178        username: &str,
179        password: &str,
180        config: SessionConfig,
181    ) -> crate::Result<Self> {
182        let (mut stream, keep_alive) =
183            crate::transport::connect(host, port, username, password).await?;
184        let (server_hello, framing) =
185            crate::hello::exchange(&mut stream, config.codec.max_message_size).await?;
186        Self::build(stream, Some(keep_alive), server_hello, framing, config)
187    }
188
189    /// Create a session from an existing stream (useful for testing).
190    pub async fn from_stream<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
191        stream: S,
192    ) -> crate::Result<Self> {
193        Self::from_stream_with_config(stream, SessionConfig::default()).await
194    }
195
196    /// Create a session from an existing stream with custom configuration.
197    pub async fn from_stream_with_config<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
198        mut stream: S,
199        config: SessionConfig,
200    ) -> crate::Result<Self> {
201        let (server_hello, framing) =
202            crate::hello::exchange(&mut stream, config.codec.max_message_size).await?;
203        let boxed: NetconfStream = Box::new(stream);
204        Self::build(boxed, None, server_hello, framing, config)
205    }
206
207    fn build(
208        stream: NetconfStream,
209        keep_alive: Option<Box<dyn std::any::Any + Send>>,
210        server_hello: ServerHello,
211        framing: FramingMode,
212        config: SessionConfig,
213    ) -> crate::Result<Self> {
214        debug!(
215            "session {}: building (framing={:?}, capabilities={})",
216            server_hello.session_id,
217            framing,
218            server_hello.capabilities.len()
219        );
220        let (read_half, write_half) = tokio::io::split(stream);
221
222        let read_codec = NetconfCodec::new(framing, config.codec);
223        let write_codec = NetconfCodec::new(framing, config.codec);
224        let reader = FramedRead::new(read_half, read_codec);
225
226        let inner = Arc::new(SessionInner {
227            pending: Mutex::new(HashMap::new()),
228            state: AtomicU8::new(SessionState::Ready as u8),
229        });
230
231        let reader_inner = Arc::clone(&inner);
232        let session_id = server_hello.session_id;
233        let reader_handle = tokio::spawn(async move {
234            reader_loop(reader, reader_inner, session_id).await;
235        });
236
237        Ok(Self {
238            writer: write_half,
239            write_codec,
240            inner,
241            server_hello,
242            framing,
243            _keep_alive: keep_alive,
244            _reader_handle: reader_handle,
245        })
246    }
247
248    pub fn session_id(&self) -> u32 {
249        self.server_hello.session_id
250    }
251
252    pub fn server_capabilities(&self) -> &[String] {
253        &self.server_hello.capabilities
254    }
255
256    pub fn framing_mode(&self) -> FramingMode {
257        self.framing
258    }
259
260    pub fn state(&self) -> SessionState {
261        self.inner.state()
262    }
263
264    fn check_state(&self) -> crate::Result<()> {
265        let state = self.inner.state();
266        if state != SessionState::Ready {
267            return Err(crate::Error::InvalidState(state.to_string()));
268        }
269        Ok(())
270    }
271
272    /// Encode a message with the negotiated framing (EOM or chunked) and
273    /// write it to the stream. We encode manually rather than using
274    /// FramedWrite because NETCONF requires complete XML documents —
275    /// the server doesn't process anything until the message delimiter
276    /// arrives. A simple encode + write_all + flush is sufficient since
277    /// we always write whole messages. The kernel's TCP send buffer
278    /// handles chunking large writes over the wire, and the single
279    /// Session owner means there's no need for Sink-based backpressure.
280    async fn send_encoded(&mut self, xml: &str) -> crate::Result<()> {
281        let mut buf = BytesMut::new();
282        self.write_codec
283            .encode(Bytes::from(xml.to_string()), &mut buf)?;
284        trace!(
285            "session {}: writing {} bytes to stream",
286            self.server_hello.session_id,
287            buf.len()
288        );
289        self.writer.write_all(&buf).await?;
290        self.writer.flush().await?;
291        Ok(())
292    }
293
294    /// Send a raw RPC and return a future for the reply (pipelining).
295    ///
296    /// This writes the RPC to the server immediately but does not wait
297    /// for the reply. Call [`RpcFuture::response()`] to await the reply.
298    /// Multiple RPCs can be pipelined by calling this repeatedly before
299    /// awaiting any of them.
300    pub async fn rpc_send(&mut self, inner_xml: &str) -> crate::Result<RpcFuture> {
301        self.check_state()?;
302        let (msg_id, xml) = message::build_rpc(inner_xml);
303        debug!(
304            "session {}: sending rpc message-id={} ({} bytes)",
305            self.server_hello.session_id,
306            msg_id,
307            xml.len()
308        );
309        trace!(
310            "session {}: rpc content: {}",
311            self.server_hello.session_id, inner_xml
312        );
313        let (tx, rx) = tokio::sync::oneshot::channel();
314
315        self.inner.pending.lock().unwrap().insert(msg_id, tx);
316
317        if let Err(e) = self.send_encoded(&xml).await {
318            debug!(
319                "session {}: send failed for message-id={}: {}",
320                self.server_hello.session_id, msg_id, e
321            );
322            self.inner.pending.lock().unwrap().remove(&msg_id);
323            return Err(e);
324        }
325        Ok(RpcFuture { rx, msg_id })
326    }
327
328    /// Send a raw RPC and wait for the reply.
329    pub async fn rpc_raw(&mut self, inner_xml: &str) -> crate::Result<RpcReply> {
330        let future = self.rpc_send(inner_xml).await?;
331        future.response().await
332    }
333
334    /// Internal rpc send that skips state check. Only used for sending close-session.
335    async fn rpc_send_unchecked(&mut self, inner_xml: &str) -> crate::Result<RpcFuture> {
336        let (msg_id, xml) = message::build_rpc(inner_xml);
337        let (tx, rx) = tokio::sync::oneshot::channel();
338
339        self.inner.pending.lock().unwrap().insert(msg_id, tx);
340
341        if let Err(e) = self.send_encoded(&xml).await {
342            self.inner.pending.lock().unwrap().remove(&msg_id);
343            return Err(e);
344        }
345
346        Ok(RpcFuture { rx, msg_id })
347    }
348
349    /// Retrieve configuration from a datastore.
350    pub async fn get_config(
351        &mut self,
352        source: Datastore,
353        filter: Option<&str>,
354    ) -> crate::Result<String> {
355        let filter_xml = match filter {
356            Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
357            None => String::new(),
358        };
359        let inner = format!(
360            "<get-config><source>{}</source>{filter_xml}</get-config>",
361            source.as_xml()
362        );
363        let reply = self.rpc_raw(&inner).await?;
364        reply_to_data(reply)
365    }
366
367    /// Retrieve configuration as a zero-copy `DataPayload`.
368    ///
369    /// Same as `get_config()` but returns a `DataPayload` instead of `String`,
370    /// avoiding a copy of the response body. Use `payload.as_str()` for a
371    /// zero-copy `&str` view, or `payload.reader()` for streaming XML events.
372    pub async fn get_config_payload(
373        &mut self,
374        source: Datastore,
375        filter: Option<&str>,
376    ) -> crate::Result<DataPayload> {
377        let filter_xml = match filter {
378            Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
379            None => String::new(),
380        };
381        let inner = format!(
382            "<get-config><source>{}</source>{filter_xml}</get-config>",
383            source.as_xml()
384        );
385        let reply = self.rpc_raw(&inner).await?;
386        reply.into_data()
387    }
388
389    /// Retrieve running configuration and state data.
390    pub async fn get(&mut self, filter: Option<&str>) -> crate::Result<String> {
391        let filter_xml = match filter {
392            Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
393            None => String::new(),
394        };
395        let inner = format!("<get>{filter_xml}</get>");
396        let reply = self.rpc_raw(&inner).await?;
397        reply_to_data(reply)
398    }
399
400    /// Retrieve running configuration and state data as a zero-copy `DataPayload`.
401    ///
402    /// Same as `get()` but returns a `DataPayload` instead of `String`,
403    /// avoiding a copy of the response body.
404    pub async fn get_payload(&mut self, filter: Option<&str>) -> crate::Result<DataPayload> {
405        let filter_xml = match filter {
406            Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
407            None => String::new(),
408        };
409        let inner = format!("<get>{filter_xml}</get>");
410        let reply = self.rpc_raw(&inner).await?;
411        reply.into_data()
412    }
413
414    /// Edit the configuration of a target datastore.
415    pub async fn edit_config(&mut self, target: Datastore, config: &str) -> crate::Result<()> {
416        let inner = format!(
417            "<edit-config><target>{}</target><config>{config}</config></edit-config>",
418            target.as_xml()
419        );
420        let reply = self.rpc_raw(&inner).await?;
421        reply_to_ok(reply)
422    }
423
424    /// Lock a datastore
425    pub async fn lock(&mut self, target: Datastore) -> crate::Result<()> {
426        let inner = format!("<lock><target>{}</target></lock>", target.as_xml());
427        let reply = self.rpc_raw(&inner).await?;
428        reply_to_ok(reply)
429    }
430
431    /// Unlock a datastore.
432    pub async fn unlock(&mut self, target: Datastore) -> crate::Result<()> {
433        let inner = format!("<unlock><target>{}</target></unlock>", target.as_xml());
434        let reply = self.rpc_raw(&inner).await?;
435        reply_to_ok(reply)
436    }
437
438    /// Commit the candidate configuration to running.
439    pub async fn commit(&mut self) -> crate::Result<()> {
440        let reply = self.rpc_raw("<commit/>").await?;
441        reply_to_ok(reply)
442    }
443
444    /// Gracefully close the NETCONF session.
445    pub async fn close_session(&mut self) -> crate::Result<()> {
446        self.check_state()?;
447        debug!("session {}: closing", self.server_hello.session_id);
448        self.inner.set_state(SessionState::Closing);
449        let result = self.rpc_send_unchecked("<close-session/>").await;
450        match result {
451            Ok(future) => {
452                let reply = future.response().await;
453                self.inner.set_state(SessionState::Closed);
454                debug!(
455                    "session {}: closed gracefully",
456                    self.server_hello.session_id
457                );
458                reply_to_ok(reply?)
459            }
460            Err(e) => {
461                self.inner.set_state(SessionState::Closed);
462                debug!(
463                    "session {}: close failed: {}",
464                    self.server_hello.session_id, e
465                );
466                Err(e)
467            }
468        }
469    }
470
471    /// Force-close another NETCONF session.
472    pub async fn kill_session(&mut self, session_id: u32) -> crate::Result<()> {
473        let inner = format!("<kill-session><session-id>{session_id}</session-id></kill-session>");
474        let reply = self.rpc_raw(&inner).await?;
475        reply_to_ok(reply)
476    }
477}
478
479/// This loop is the only thing reading from the SSH stream. It runs in a background tokio task.
480/// The session's main API (the writer side) never reads — it only writes RPCs and waits on oneshot channels.
481/// This separation is what makes pipelining work.
482async fn reader_loop(
483    mut reader: FramedRead<ReadHalf<NetconfStream>, NetconfCodec>,
484    inner: Arc<SessionInner>,
485    session_id: u32,
486) {
487    debug!("session {}: reader loop started", session_id);
488    // FramedRead wraps the ReadHalf with the NetconfCodec decoder.
489    // Each .next() reads bytes from SSH, feeds them to decode(), and yields a
490    // complete framed message as Bytes.
491    while let Some(result) = reader.next().await {
492        match result {
493            // Takes the raw Bytes, validates UTF-8, finds the root element, parses it.
494            Ok(bytes) => {
495                trace!(
496                    "session {}: received frame ({} bytes)",
497                    session_id,
498                    bytes.len()
499                );
500                match message::classify_message(bytes) {
501                    Ok(ServerMessage::RpcReply(reply)) => {
502                        debug!(
503                            "session {}: received rpc-reply message-id={}",
504                            session_id, reply.message_id
505                        );
506                        // When rpc_send sends an RPC, it inserts a oneshot sender into the pending map
507                        // keyed by message-id. Here, it is removed and used to send the reply
508                        // through the channel. The caller who's .awaiting the RpcFuture receives it.
509                        let tx = {
510                            let mut pending = inner.pending.lock().unwrap();
511                            pending.remove(&reply.message_id)
512                        };
513                        if let Some(tx) = tx {
514                            // we can ignore if the receiver was dropped. Nothing to do about it.
515                            let _ = tx.send(Ok(reply));
516                        } else {
517                            warn!(
518                                "session {}: received reply for unknown message-id {}",
519                                session_id, reply.message_id
520                            );
521                        }
522                    }
523                    Err(e) => {
524                        warn!("session {}: failed to classify message: {e}", session_id);
525                    }
526                }
527            }
528            // The stream broke (EOF or Error). Every pending RPC would hang forever waiting for a reply that's never coming.
529            // So we drain the map and send SessionClosed to each one, unblocking all waiters.
530            Err(e) => {
531                debug!("session {}: reader error: {e}", session_id);
532                let mut pending = inner.pending.lock().unwrap();
533                let drained = pending.len();
534                for (_, tx) in pending.drain() {
535                    let _ = tx.send(Err(crate::Error::SessionClosed));
536                }
537                if drained > 0 {
538                    debug!(
539                        "session {}: drained {} pending RPCs after error",
540                        session_id, drained
541                    );
542                }
543                break;
544            }
545        }
546    }
547    // Drain any remaining pending RPCs — this handles clean EOF where the while loop exits
548    // via None (stream closed) rather than Err (which drains inline above).
549    {
550        let mut pending = inner.pending.lock().unwrap();
551        let drained = pending.len();
552        for (_, tx) in pending.drain() {
553            let _ = tx.send(Err(crate::Error::SessionClosed));
554        }
555        if drained > 0 {
556            debug!(
557                "session {}: drained {} pending RPCs on stream close",
558                session_id, drained
559            );
560        }
561    }
562    // Mark the session as closed so future rpc_send calls fail immediately with InvalidState
563    // instead of inserting into the pending map and hanging.
564    inner.set_state(SessionState::Closed);
565    debug!("session {}: reader loop ended", session_id);
566}
567
568fn reply_to_data(reply: RpcReply) -> crate::Result<String> {
569    match reply.body {
570        RpcReplyBody::Data(payload) => Ok(payload.into_string()),
571        RpcReplyBody::Ok => Ok(String::new()),
572        RpcReplyBody::Error(errors) => Err(crate::Error::Rpc {
573            message_id: reply.message_id,
574            error: errors
575                .first()
576                .map(|e| e.error_message.clone())
577                .unwrap_or_default(),
578        }),
579    }
580}
581
582fn reply_to_ok(reply: RpcReply) -> crate::Result<()> {
583    match reply.body {
584        RpcReplyBody::Ok => Ok(()),
585        RpcReplyBody::Data(_) => Ok(()),
586        RpcReplyBody::Error(errors) => Err(crate::Error::Rpc {
587            message_id: reply.message_id,
588            error: errors
589                .first()
590                .map(|e| e.error_message.clone())
591                .unwrap_or_default(),
592        }),
593    }
594}