Skip to main content

netconf_rust/
session.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicU8, Ordering};
3use std::sync::{Arc, Mutex};
4
5use bytes::{Bytes, BytesMut};
6use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
7use tokio_stream::StreamExt;
8use tokio_util::codec::{Encoder, FramedRead};
9
10use crate::codec::{CodecConfig, FramingMode, NetconfCodec};
11use crate::hello::ServerHello;
12use crate::message::{self, DataPayload, RpcReply, RpcReplyBody, ServerMessage};
13use crate::stream::NetconfStream;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16#[repr(u8)]
17pub enum SessionState {
18    /// Hello exchange complete, ready for RPCs
19    Ready = 0,
20    /// A 'close-session' RPC has been sent, awaiting reply
21    Closing = 1,
22    /// Session terminated gracefully or with error
23    Closed = 2,
24}
25
26impl SessionState {
27    fn from_u8(v: u8) -> Self {
28        match v {
29            0 => Self::Ready,
30            1 => Self::Closing,
31            _ => Self::Closed,
32        }
33    }
34}
35
36impl std::fmt::Display for SessionState {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        match self {
39            Self::Ready => write!(f, "Ready"),
40            Self::Closing => write!(f, "Closing"),
41            Self::Closed => write!(f, "Closed"),
42        }
43    }
44}
45
46/// Configuration for a NETCONF session.
47#[derive(Debug, Clone, Default)]
48pub struct SessionConfig {
49    /// Codec configuration (message size limits, etc.).
50    pub codec: CodecConfig,
51}
52
53#[derive(Debug, Clone, Copy)]
54pub enum Datastore {
55    Running,
56    Candidate,
57    Startup,
58}
59
60impl Datastore {
61    fn as_xml(&self) -> &'static str {
62        match self {
63            Datastore::Running => "<running/>",
64            Datastore::Candidate => "<candidate/>",
65            Datastore::Startup => "<startup/>",
66        }
67    }
68}
69
70/// A handle to a pending RPC reply.
71///
72/// Created by [`Session::rpc_send()`], this allows sending multiple RPCs
73/// before awaiting any replies (pipelining).
74pub struct RpcFuture {
75    rx: tokio::sync::oneshot::Receiver<crate::Result<RpcReply>>,
76    msg_id: u32,
77}
78
79impl RpcFuture {
80    /// The message-id of this RPC.
81    pub fn message_id(&self) -> u32 {
82        self.msg_id
83    }
84
85    /// Await the RPC reply.
86    pub async fn response(self) -> crate::Result<RpcReply> {
87        self.rx.await.map_err(|_| crate::Error::SessionClosed)?
88    }
89}
90
91struct SessionInner {
92    /// Pending RPC replies: message-id → oneshot sender.
93    pending: Mutex<HashMap<u32, tokio::sync::oneshot::Sender<crate::Result<RpcReply>>>>,
94    /// Session lifecycle state, stored as AtomicU8 for lock-free access
95    /// across the session and reader task.
96    ///
97    /// State transitions:
98    ///   Ready → Closing → Closed   (graceful: user calls close_session)
99    ///   Ready → Closed             (abrupt: reader hits EOF or error)
100    ///
101    /// - Ready:   normal operation, RPCs can be sent
102    /// - Closing: close-session RPC in flight, new RPCs rejected
103    /// - Closed:  session terminated, all operations fail
104    state: AtomicU8,
105}
106
107impl SessionInner {
108    fn state(&self) -> SessionState {
109        SessionState::from_u8(self.state.load(Ordering::Acquire))
110    }
111
112    fn set_state(&self, state: SessionState) {
113        self.state.store(state as u8, Ordering::Release);
114    }
115}
116
117/// A NETCONF session with pipelining support.
118///
119/// Architecture:
120/// - **Writer**: the session holds the write half of the stream and a codec
121///   for encoding outgoing RPCs.
122/// - **Reader task**: a background tokio task reads framed messages from
123///   the read half, classifies them (`<rpc-reply>` vs `<notification>`),
124///   and routes them to the correct handler.
125/// - **Pipelining**: [`rpc_send()`](Session::rpc_send) writes an RPC and
126///   returns an [`RpcFuture`] without waiting for the reply. Multiple RPCs
127///   can be in flight simultaneously.
128pub struct Session {
129    /// Write half of the split stream, used to send RPCs to the server.
130    writer: WriteHalf<NetconfStream>,
131
132    /// Codec used to frame outgoing messages (EOM or chunked).
133    /// Kept separate from FramedRead because the write half isn't wrapped
134    /// in FramedWrite — we encode manually and write_all to the writer.
135    write_codec: NetconfCodec,
136
137    /// Shared state between this session and the background reader task.
138    /// Contains the pending RPC map, notification channel, and session state.
139    inner: Arc<SessionInner>,
140
141    /// The server's hello response, containing its capabilities and session ID.
142    server_hello: ServerHello,
143
144    /// Negotiated framing mode (EOM for 1.0-only servers, chunked for 1.1).
145    framing: FramingMode,
146
147    /// Holds the SSH handle alive. Dropping this tears down the SSH connection,
148    /// which would invalidate the stream. Never accessed, just kept alive.
149    _keep_alive: Option<Box<dyn std::any::Any + Send>>,
150
151    /// Handle to the background reader task. Aborted on drop to ensure
152    /// the task doesn't outlive the session.
153    _reader_handle: tokio::task::JoinHandle<()>,
154}
155
156impl Drop for Session {
157    fn drop(&mut self) {
158        self._reader_handle.abort();
159    }
160}
161
162impl Session {
163    /// Connect to a NETCONF server over SSH with password authentication.
164    pub async fn connect(
165        host: &str,
166        port: u16,
167        username: &str,
168        password: &str,
169    ) -> crate::Result<Self> {
170        Self::connect_with_config(host, port, username, password, SessionConfig::default()).await
171    }
172
173    /// Connect with custom configuration.
174    pub async fn connect_with_config(
175        host: &str,
176        port: u16,
177        username: &str,
178        password: &str,
179        config: SessionConfig,
180    ) -> crate::Result<Self> {
181        let (mut stream, keep_alive) =
182            crate::transport::connect(host, port, username, password).await?;
183        let (server_hello, framing) =
184            crate::hello::exchange(&mut stream, config.codec.max_message_size).await?;
185        Self::build(stream, Some(keep_alive), server_hello, framing, config)
186    }
187
188    /// Create a session from an existing stream (useful for testing).
189    pub async fn from_stream<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
190        stream: S,
191    ) -> crate::Result<Self> {
192        Self::from_stream_with_config(stream, SessionConfig::default()).await
193    }
194
195    /// Create a session from an existing stream with custom configuration.
196    pub async fn from_stream_with_config<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
197        mut stream: S,
198        config: SessionConfig,
199    ) -> crate::Result<Self> {
200        let (server_hello, framing) =
201            crate::hello::exchange(&mut stream, config.codec.max_message_size).await?;
202        let boxed: NetconfStream = Box::new(stream);
203        Self::build(boxed, None, server_hello, framing, config)
204    }
205
206    fn build(
207        stream: NetconfStream,
208        keep_alive: Option<Box<dyn std::any::Any + Send>>,
209        server_hello: ServerHello,
210        framing: FramingMode,
211        config: SessionConfig,
212    ) -> crate::Result<Self> {
213        let (read_half, write_half) = tokio::io::split(stream);
214
215        let read_codec = NetconfCodec::new(framing, config.codec);
216        let write_codec = NetconfCodec::new(framing, config.codec);
217        let reader = FramedRead::new(read_half, read_codec);
218
219        let inner = Arc::new(SessionInner {
220            pending: Mutex::new(HashMap::new()),
221            state: AtomicU8::new(SessionState::Ready as u8),
222        });
223
224        let reader_inner = Arc::clone(&inner);
225        let reader_handle = tokio::spawn(reader_loop(reader, reader_inner));
226
227        Ok(Self {
228            writer: write_half,
229            write_codec,
230            inner,
231            server_hello,
232            framing,
233            _keep_alive: keep_alive,
234            _reader_handle: reader_handle,
235        })
236    }
237
238    pub fn session_id(&self) -> u32 {
239        self.server_hello.session_id
240    }
241
242    pub fn server_capabilities(&self) -> &[String] {
243        &self.server_hello.capabilities
244    }
245
246    pub fn framing_mode(&self) -> FramingMode {
247        self.framing
248    }
249
250    pub fn state(&self) -> SessionState {
251        self.inner.state()
252    }
253
254    fn check_state(&self) -> crate::Result<()> {
255        let state = self.inner.state();
256        if state != SessionState::Ready {
257            return Err(crate::Error::InvalidState(state.to_string()));
258        }
259        Ok(())
260    }
261
262    /// Encode a message with the negotiated framing (EOM or chunked) and
263    /// write it to the stream. We encode manually rather than using
264    /// FramedWrite because NETCONF requires complete XML documents —
265    /// the server doesn't process anything until the message delimiter
266    /// arrives. A simple encode + write_all + flush is sufficient since
267    /// we always write whole messages. The kernel's TCP send buffer
268    /// handles chunking large writes over the wire, and the single
269    /// Session owner means there's no need for Sink-based backpressure.
270    async fn send_encoded(&mut self, xml: &str) -> crate::Result<()> {
271        let mut buf = BytesMut::new();
272        self.write_codec
273            .encode(Bytes::from(xml.to_string()), &mut buf)?;
274        self.writer.write_all(&buf).await?;
275        self.writer.flush().await?;
276        Ok(())
277    }
278
279    /// Send a raw RPC and return a future for the reply (pipelining).
280    ///
281    /// This writes the RPC to the server immediately but does not wait
282    /// for the reply. Call [`RpcFuture::response()`] to await the reply.
283    /// Multiple RPCs can be pipelined by calling this repeatedly before
284    /// awaiting any of them.
285    pub async fn rpc_send(&mut self, inner_xml: &str) -> crate::Result<RpcFuture> {
286        self.check_state()?;
287        let (msg_id, xml) = message::build_rpc(inner_xml);
288        let (tx, rx) = tokio::sync::oneshot::channel();
289
290        self.inner.pending.lock().unwrap().insert(msg_id, tx);
291
292        if let Err(e) = self.send_encoded(&xml).await {
293            self.inner.pending.lock().unwrap().remove(&msg_id);
294            return Err(e);
295        }
296        Ok(RpcFuture { rx, msg_id })
297    }
298
299    /// Send a raw RPC and wait for the reply.
300    pub async fn rpc_raw(&mut self, inner_xml: &str) -> crate::Result<RpcReply> {
301        let future = self.rpc_send(inner_xml).await?;
302        future.response().await
303    }
304
305    /// Internal rpc send that skips state check. Only used for sending close-session.
306    async fn rpc_send_unchecked(&mut self, inner_xml: &str) -> crate::Result<RpcFuture> {
307        let (msg_id, xml) = message::build_rpc(inner_xml);
308        let (tx, rx) = tokio::sync::oneshot::channel();
309
310        self.inner.pending.lock().unwrap().insert(msg_id, tx);
311
312        if let Err(e) = self.send_encoded(&xml).await {
313            self.inner.pending.lock().unwrap().remove(&msg_id);
314            return Err(e);
315        }
316
317        Ok(RpcFuture { rx, msg_id })
318    }
319
320    /// Retrieve configuration from a datastore.
321    pub async fn get_config(
322        &mut self,
323        source: Datastore,
324        filter: Option<&str>,
325    ) -> crate::Result<String> {
326        let filter_xml = match filter {
327            Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
328            None => String::new(),
329        };
330        let inner = format!(
331            "<get-config><source>{}</source>{filter_xml}</get-config>",
332            source.as_xml()
333        );
334        let reply = self.rpc_raw(&inner).await?;
335        reply_to_data(reply)
336    }
337
338    /// Retrieve configuration as a zero-copy `DataPayload`.
339    ///
340    /// Same as `get_config()` but returns a `DataPayload` instead of `String`,
341    /// avoiding a copy of the response body. Use `payload.as_str()` for a
342    /// zero-copy `&str` view, or `payload.reader()` for streaming XML events.
343    pub async fn get_config_payload(
344        &mut self,
345        source: Datastore,
346        filter: Option<&str>,
347    ) -> crate::Result<DataPayload> {
348        let filter_xml = match filter {
349            Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
350            None => String::new(),
351        };
352        let inner = format!(
353            "<get-config><source>{}</source>{filter_xml}</get-config>",
354            source.as_xml()
355        );
356        let reply = self.rpc_raw(&inner).await?;
357        reply.into_data()
358    }
359
360    /// Retrieve running configuration and state data.
361    pub async fn get(&mut self, filter: Option<&str>) -> crate::Result<String> {
362        let filter_xml = match filter {
363            Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
364            None => String::new(),
365        };
366        let inner = format!("<get>{filter_xml}</get>");
367        let reply = self.rpc_raw(&inner).await?;
368        reply_to_data(reply)
369    }
370
371    /// Retrieve running configuration and state data as a zero-copy `DataPayload`.
372    ///
373    /// Same as `get()` but returns a `DataPayload` instead of `String`,
374    /// avoiding a copy of the response body.
375    pub async fn get_payload(&mut self, filter: Option<&str>) -> crate::Result<DataPayload> {
376        let filter_xml = match filter {
377            Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
378            None => String::new(),
379        };
380        let inner = format!("<get>{filter_xml}</get>");
381        let reply = self.rpc_raw(&inner).await?;
382        reply.into_data()
383    }
384
385    /// Edit the configuration of a target datastore.
386    pub async fn edit_config(&mut self, target: Datastore, config: &str) -> crate::Result<()> {
387        let inner = format!(
388            "<edit-config><target>{}</target><config>{config}</config></edit-config>",
389            target.as_xml()
390        );
391        let reply = self.rpc_raw(&inner).await?;
392        reply_to_ok(reply)
393    }
394
395    /// Lock a datastore
396    pub async fn lock(&mut self, target: Datastore) -> crate::Result<()> {
397        let inner = format!("<lock><target>{}</target></lock>", target.as_xml());
398        let reply = self.rpc_raw(&inner).await?;
399        reply_to_ok(reply)
400    }
401
402    /// Unlock a datastore.
403    pub async fn unlock(&mut self, target: Datastore) -> crate::Result<()> {
404        let inner = format!("<unlock><target>{}</target></unlock>", target.as_xml());
405        let reply = self.rpc_raw(&inner).await?;
406        reply_to_ok(reply)
407    }
408
409    /// Commit the candidate configuration to running.
410    pub async fn commit(&mut self) -> crate::Result<()> {
411        let reply = self.rpc_raw("<commit/>").await?;
412        reply_to_ok(reply)
413    }
414
415    /// Gracefully close the NETCONF session.
416    pub async fn close_session(&mut self) -> crate::Result<()> {
417        self.check_state()?;
418        self.inner.set_state(SessionState::Closing);
419        let result = self.rpc_send_unchecked("<close-session/>").await;
420        match result {
421            Ok(future) => {
422                let reply = future.response().await;
423                self.inner.set_state(SessionState::Closed);
424                reply_to_ok(reply?)
425            }
426            Err(e) => {
427                self.inner.set_state(SessionState::Closed);
428                Err(e)
429            }
430        }
431    }
432
433    /// Force-close another NETCONF session.
434    pub async fn kill_session(&mut self, session_id: u32) -> crate::Result<()> {
435        let inner = format!("<kill-session><session-id>{session_id}</session-id></kill-session>");
436        let reply = self.rpc_raw(&inner).await?;
437        reply_to_ok(reply)
438    }
439}
440
441/// This loop is the only thing reading from the SSH stream. It runs in a background tokio task.
442/// The session's main API (the writer side) never reads — it only writes RPCs and waits on oneshot channels.
443/// This separation is what makes pipelining work.
444async fn reader_loop(
445    mut reader: FramedRead<ReadHalf<NetconfStream>, NetconfCodec>,
446    inner: Arc<SessionInner>,
447) {
448    // FramedRead wraps the ReadHalf with the NetconfCodec decoder.
449    // Each .next() reads bytes from SSH, feeds them to decode(), and yields a
450    // complete framed message as Bytes.
451    while let Some(result) = reader.next().await {
452        match result {
453            // Takes the raw Bytes, validates UTF-8, finds the root element, parses it.
454            Ok(bytes) => match message::classify_message(bytes) {
455                Ok(ServerMessage::RpcReply(reply)) => {
456                    // When rpc_send sends an RPC, it inserts a oneshot sender into the pending map
457                    // keyed by message-id. Here, it is removed and used to send the reply
458                    // through the channel. The caller who's .awaiting the RpcFuture receives it.
459                    let tx = {
460                        let mut pending = inner.pending.lock().unwrap();
461                        pending.remove(&reply.message_id)
462                    };
463                    if let Some(tx) = tx {
464                        // we can ignore if the receiver was dropped. Nothing to do about it.
465                        let _ = tx.send(Ok(reply));
466                    } else {
467                        log::warn!("received reply for unknown message-id {}", reply.message_id);
468                    }
469                }
470                Err(e) => {
471                    log::warn!("failed to classify message: {e}");
472                }
473            },
474            // The stream broke (EOF or Error). Every pending RPC would hang forever waiting for a reply that's never coming.
475            // So we drain the map and send SessionClosed to each one, unblocking all waiters.
476            Err(e) => {
477                log::error!("reader error: {e}");
478                let mut pending = inner.pending.lock().unwrap();
479                for (_, tx) in pending.drain() {
480                    let _ = tx.send(Err(crate::Error::SessionClosed));
481                }
482                break;
483            }
484        }
485    }
486    // Drain any remaining pending RPCs — this handles clean EOF where the while loop exits
487    // via None (stream closed) rather than Err (which drains inline above).
488    {
489        let mut pending = inner.pending.lock().unwrap();
490        for (_, tx) in pending.drain() {
491            let _ = tx.send(Err(crate::Error::SessionClosed));
492        }
493    }
494    // Mark the session as closed so future rpc_send calls fail immediately with InvalidState
495    // instead of inserting into the pending map and hanging.
496    inner.set_state(SessionState::Closed);
497}
498
499fn reply_to_data(reply: RpcReply) -> crate::Result<String> {
500    match reply.body {
501        RpcReplyBody::Data(payload) => Ok(payload.into_string()),
502        RpcReplyBody::Ok => Ok(String::new()),
503        RpcReplyBody::Error(errors) => Err(crate::Error::Rpc {
504            message_id: reply.message_id,
505            error: errors
506                .first()
507                .map(|e| e.error_message.clone())
508                .unwrap_or_default(),
509        }),
510    }
511}
512
513fn reply_to_ok(reply: RpcReply) -> crate::Result<()> {
514    match reply.body {
515        RpcReplyBody::Ok => Ok(()),
516        RpcReplyBody::Data(_) => Ok(()),
517        RpcReplyBody::Error(errors) => Err(crate::Error::Rpc {
518            message_id: reply.message_id,
519            error: errors
520                .first()
521                .map(|e| e.error_message.clone())
522                .unwrap_or_default(),
523        }),
524    }
525}