Skip to main content

netconf_rust/
session.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::io;
4use std::pin::Pin;
5use std::sync::atomic::{AtomicU8, AtomicU64, AtomicUsize, Ordering};
6use std::sync::{Arc, Mutex};
7use std::task::{Context, Poll};
8use std::time::{Duration, Instant};
9
10use bytes::{Buf, Bytes, BytesMut};
11use log::{debug, trace, warn};
12use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, ReadHalf, WriteHalf};
13use tokio_stream::StreamExt;
14use tokio_util::codec::{Encoder, FramedRead};
15
16use crate::codec::{DecodedFrame, FramingMode, NetconfCodec, extract_message_id_from_bytes};
17use crate::config::Config;
18use crate::error::TransportError;
19use crate::hello::ServerHello;
20use crate::message::{self, DataPayload, RpcReply, RpcReplyBody, ServerMessage};
21use crate::stream::NetconfStream;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24#[repr(u8)]
25pub enum SessionState {
26    /// Hello exchange complete, ready for RPCs
27    Ready = 0,
28    /// A 'close-session' RPC has been sent, awaiting reply
29    Closing = 1,
30    /// Session terminated gracefully or with error
31    Closed = 2,
32}
33
34impl SessionState {
35    fn from_u8(v: u8) -> Self {
36        match v {
37            0 => Self::Ready,
38            1 => Self::Closing,
39            _ => Self::Closed,
40        }
41    }
42}
43
44impl std::fmt::Display for SessionState {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        match self {
47            Self::Ready => write!(f, "Ready"),
48            Self::Closing => write!(f, "Closing"),
49            Self::Closed => write!(f, "Closed"),
50        }
51    }
52}
53
54/// Reason the session disconnected.
55///
56/// Delivered via [`Session::disconnected()`] when the background reader
57/// task detects that the connection is no longer alive.
58#[derive(Debug, Clone)]
59pub enum DisconnectReason {
60    /// The remote end closed the connection cleanly (TCP FIN / EOF).
61    Eof,
62    /// A transport error severed the connection.
63    ///
64    /// Contains the error's display string.
65    TransportError(String),
66    /// The [`Session`] was dropped without calling [`Session::close_session()`].
67    Dropped,
68}
69
70impl std::fmt::Display for DisconnectReason {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        match self {
73            Self::Eof => write!(f, "connection closed by remote"),
74            Self::TransportError(e) => write!(f, "transport error: {e}"),
75            Self::Dropped => write!(f, "session dropped"),
76        }
77    }
78}
79
80#[derive(Debug, Clone, Copy)]
81pub enum Datastore {
82    Running,
83    Candidate,
84    Startup,
85}
86
87impl Datastore {
88    fn as_xml(&self) -> &'static str {
89        match self {
90            Datastore::Running => "<running/>",
91            Datastore::Candidate => "<candidate/>",
92            Datastore::Startup => "<startup/>",
93        }
94    }
95}
96
97// =============================================================================
98// PendingRpc — supports both normal (buffered) and streaming RPCs
99// =============================================================================
100
101/// A pending RPC awaiting its response from the server.
102///
103/// Normal RPCs accumulate the full response and deliver it as an `RpcReply`
104/// via a oneshot channel. Streaming RPCs forward individual chunks via an
105/// mpsc channel, enabling the consumer to process data incrementally.
106enum PendingRpc {
107    Normal(tokio::sync::oneshot::Sender<crate::Result<RpcReply>>),
108    Stream(tokio::sync::mpsc::Sender<crate::Result<Bytes>>),
109}
110
111/// A handle to a pending RPC reply.
112///
113/// Created by [`Session::rpc_send()`], this allows sending multiple RPCs
114/// before awaiting any replies (pipelining).
115pub struct RpcFuture {
116    rx: tokio::sync::oneshot::Receiver<crate::Result<RpcReply>>,
117    msg_id: u32,
118    rpc_timeout: Option<Duration>,
119}
120
121impl RpcFuture {
122    /// The message-id of this RPC.
123    pub fn message_id(&self) -> u32 {
124        self.msg_id
125    }
126
127    /// Await the RPC reply.
128    ///
129    /// If an `rpc_timeout` was configured, this will fail with
130    /// [`TransportError::Timeout`]
131    /// if the server does not reply in time.
132    pub async fn response(self) -> crate::Result<RpcReply> {
133        let result = match self.rpc_timeout {
134            Some(duration) => tokio::time::timeout(duration, self.rx)
135                .await
136                .map_err(|_| crate::Error::Transport(TransportError::Timeout(duration)))?,
137            None => self.rx.await,
138        };
139        result.map_err(|_| crate::Error::SessionClosed)?
140    }
141
142    /// Await the RPC reply with an explicit timeout, ignoring the
143    /// session-level `rpc_timeout`.
144    ///
145    /// Fails with [`TransportError::Timeout`]
146    /// if the server does not reply within `timeout`.
147    pub async fn response_with_timeout(self, timeout: Duration) -> crate::Result<RpcReply> {
148        let result = tokio::time::timeout(timeout, self.rx)
149            .await
150            .map_err(|_| crate::Error::Transport(TransportError::Timeout(timeout)))?;
151        result.map_err(|_| crate::Error::SessionClosed)?
152    }
153}
154
155// =============================================================================
156// RpcStream — streaming RPC response with AsyncRead
157// =============================================================================
158
159/// A streaming RPC response that yields raw XML bytes as they arrive.
160///
161/// Implements [`AsyncRead`] so it can be plugged directly into compression
162/// encoders, file writers, or any byte-oriented consumer without buffering
163/// the entire response.
164///
165/// Created by [`Session::rpc_stream()`].
166///
167/// # Example
168///
169/// ```no_run
170/// # use netconf_rust::Session;
171/// # use tokio::io::AsyncReadExt;
172/// # async fn example(session: &Session) -> netconf_rust::Result<()> {
173/// let mut stream = session.rpc_stream("<get-config><source><running/></source></get-config>").await?;
174/// let mut buf = [0u8; 8192];
175/// loop {
176///     let n = stream.read(&mut buf).await?;
177///     if n == 0 { break; }
178///     // process buf[..n] — e.g. feed to a compressor
179/// }
180/// # Ok(())
181/// # }
182/// ```
183pub struct RpcStream {
184    rx: tokio::sync::mpsc::Receiver<crate::Result<Bytes>>,
185    /// Partially consumed chunk from the last `poll_read`.
186    current: Bytes,
187    msg_id: u32,
188    done: bool,
189}
190
191impl RpcStream {
192    /// The message-id of this streaming RPC.
193    pub fn message_id(&self) -> u32 {
194        self.msg_id
195    }
196
197    /// Whether the stream has finished (all chunks received).
198    pub fn is_done(&self) -> bool {
199        self.done
200    }
201}
202
203impl AsyncRead for RpcStream {
204    fn poll_read(
205        mut self: Pin<&mut Self>,
206        cx: &mut Context<'_>,
207        buf: &mut ReadBuf<'_>,
208    ) -> Poll<io::Result<()>> {
209        // 1. Serve from the partially-consumed current chunk.
210        if !self.current.is_empty() {
211            let n = std::cmp::min(buf.remaining(), self.current.len());
212            buf.put_slice(&self.current[..n]);
213            self.current.advance(n);
214            return Poll::Ready(Ok(()));
215        }
216
217        // 2. If done, return EOF.
218        if self.done {
219            return Poll::Ready(Ok(()));
220        }
221
222        // 3. Poll the channel for the next chunk.
223        match self.rx.poll_recv(cx) {
224            Poll::Ready(Some(Ok(chunk))) => {
225                let n = std::cmp::min(buf.remaining(), chunk.len());
226                buf.put_slice(&chunk[..n]);
227                if n < chunk.len() {
228                    self.current = chunk.slice(n..);
229                }
230                Poll::Ready(Ok(()))
231            }
232            Poll::Ready(Some(Err(e))) => {
233                self.done = true;
234                Poll::Ready(Err(io::Error::other(e.to_string())))
235            }
236            Poll::Ready(None) => {
237                // Channel closed — stream complete.
238                self.done = true;
239                Poll::Ready(Ok(()))
240            }
241            Poll::Pending => Poll::Pending,
242        }
243    }
244}
245
246// =============================================================================
247// SessionInner
248// =============================================================================
249
250struct SessionInner {
251    /// Pending RPC replies: message-id → pending RPC (normal or stream).
252    pending: Mutex<HashMap<u32, PendingRpc>>,
253    /// Session lifecycle state, stored as AtomicU8 for lock-free access
254    /// across the session and reader task.
255    ///
256    /// State transitions:
257    ///   Ready → Closing → Closed   (graceful: user calls close_session)
258    ///   Ready → Closed             (abrupt: reader hits EOF or error)
259    ///
260    /// - Ready:   normal operation, RPCs can be sent
261    /// - Closing: close-session RPC in flight, new RPCs rejected
262    /// - Closed:  session terminated, all operations fail
263    state: AtomicU8,
264    /// Sender for the disconnect notification. The reader_loop sends the
265    /// reason just before exiting. Receivers (from `Session::disconnected()`)
266    /// wake up immediately.
267    disconnect_tx: tokio::sync::watch::Sender<Option<DisconnectReason>>,
268    /// Anchor instant for computing `last_rpc_at` from `last_rpc_nanos`.
269    created_at: Instant,
270    /// Nanoseconds elapsed since `created_at` when the last RPC reply was
271    /// successfully routed. 0 means no reply has been received yet.
272    last_rpc_nanos: AtomicU64,
273    /// Number of streaming RPCs currently in flight (removed from `pending`
274    /// but not yet completed with `EndOfMessage`).
275    active_streams: AtomicUsize,
276}
277
278impl SessionInner {
279    fn state(&self) -> SessionState {
280        SessionState::from_u8(self.state.load(Ordering::Acquire))
281    }
282
283    fn set_state(&self, state: SessionState) {
284        self.state.store(state as u8, Ordering::Release);
285    }
286
287    fn drain_pending(&self) -> usize {
288        let mut pending = self.pending.lock().unwrap();
289        let count = pending.len();
290        for (_, rpc) in pending.drain() {
291            match rpc {
292                PendingRpc::Normal(tx) => {
293                    let _ = tx.send(Err(crate::Error::SessionClosed));
294                }
295                PendingRpc::Stream(tx) => {
296                    let _ = tx.try_send(Err(crate::Error::SessionClosed));
297                }
298            }
299        }
300        count
301    }
302}
303
304/// Write half + codec, behind a `tokio::sync::Mutex` so Session methods
305/// can take `&self`. The Mutex is held only during encode + write + flush.
306struct WriterState {
307    writer: WriteHalf<NetconfStream>,
308    codec: NetconfCodec,
309}
310
311/// A NETCONF session with pipelining and streaming support.
312///
313/// Architecture:
314/// - **Writer**: the session holds the write half of the stream and a codec
315///   for encoding outgoing RPCs.
316/// - **Reader task**: a background tokio task reads framed chunks from
317///   the read half, classifies them (`<rpc-reply>` vs `<notification>`),
318///   and routes them to the correct handler.
319/// - **Pipelining**: [`rpc_send()`](Session::rpc_send) writes an RPC and
320///   returns an [`RpcFuture`] without waiting for the reply. Multiple RPCs
321///   can be in flight simultaneously.
322/// - **Streaming**: [`rpc_stream()`](Session::rpc_stream) writes an RPC and
323///   returns an [`RpcStream`] that yields raw bytes as they arrive. This
324///   avoids buffering the entire response for large payloads.
325pub struct Session {
326    /// Write state behind an async Mutex. Held only for the duration of
327    /// encoding and flushing a single message.
328    writer: tokio::sync::Mutex<WriterState>,
329
330    /// Shared state between this session and the background reader task.
331    /// Contains the pending RPC map, notification channel, and session state.
332    inner: Arc<SessionInner>,
333
334    /// The server's hello response, containing its capabilities and session ID.
335    server_hello: ServerHello,
336
337    /// Negotiated framing mode (EOM for 1.0-only servers, chunked for 1.1).
338    framing: FramingMode,
339
340    /// Timeout applied to each RPC response wait.
341    rpc_timeout: Option<Duration>,
342
343    /// Receiver for disconnect notifications. Cloned for each call to
344    /// `disconnected()`, allowing multiple independent subscribers.
345    disconnect_rx: tokio::sync::watch::Receiver<Option<DisconnectReason>>,
346
347    /// Instant when the session was established (hello exchange complete).
348    connected_since: Instant,
349
350    /// Handle to the background reader task. Aborted on drop to ensure
351    /// the task doesn't outlive the session. Declared before `_keep_alive`
352    /// so it is dropped first — the reader task must be aborted before
353    /// the SSH connection is torn down.
354    _reader_handle: tokio::task::JoinHandle<()>,
355
356    /// Holds the SSH handle alive. Dropping this tears down the SSH connection,
357    /// which would invalidate the stream. Never accessed, just kept alive.
358    _keep_alive: Option<Box<dyn std::any::Any + Send + Sync>>,
359}
360
361impl Drop for Session {
362    fn drop(&mut self) {
363        // 1. Drain pending RPCs → waiters get deterministic SessionClosed
364        let drained = self.inner.drain_pending();
365        if drained > 0 {
366            debug!(
367                "session {}: drop: drained {drained} pending RPCs",
368                self.server_hello.session_id
369            );
370        }
371        // 2. Mark session closed
372        self.inner.set_state(SessionState::Closed);
373        // 3. Notify disconnect subscribers (only if reader_loop hasn't already)
374        self.inner.disconnect_tx.send_if_modified(|current| {
375            if current.is_none() {
376                *current = Some(DisconnectReason::Dropped);
377                true
378            } else {
379                false
380            }
381        });
382        // 4. Abort reader task (its cleanup is now a no-op)
383        self._reader_handle.abort();
384    }
385}
386
387impl Session {
388    /// Connect to a NETCONF server over SSH with password authentication.
389    pub async fn connect(
390        host: &str,
391        port: u16,
392        username: &str,
393        password: &str,
394    ) -> crate::Result<Self> {
395        Self::connect_with_config(host, port, username, password, Config::default()).await
396    }
397
398    /// Connect with custom configuration.
399    pub async fn connect_with_config(
400        host: &str,
401        port: u16,
402        username: &str,
403        password: &str,
404        config: Config,
405    ) -> crate::Result<Self> {
406        let (mut stream, keep_alive) =
407            crate::transport::connect(host, port, username, password, &config).await?;
408        let (server_hello, framing) = exchange_hello(&mut stream, &config).await?;
409        Self::build(stream, Some(keep_alive), server_hello, framing, config)
410    }
411
412    /// Create a session from an existing stream (useful for testing).
413    pub async fn from_stream<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
414        stream: S,
415    ) -> crate::Result<Self> {
416        Self::from_stream_with_config(stream, Config::default()).await
417    }
418
419    /// Create a session from an existing stream with custom configuration.
420    pub async fn from_stream_with_config<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
421        mut stream: S,
422        config: Config,
423    ) -> crate::Result<Self> {
424        let (server_hello, framing) = exchange_hello(&mut stream, &config).await?;
425        let boxed: NetconfStream = Box::new(stream);
426        Self::build(boxed, None, server_hello, framing, config)
427    }
428
429    fn build(
430        stream: NetconfStream,
431        keep_alive: Option<Box<dyn std::any::Any + Send + Sync>>,
432        server_hello: ServerHello,
433        framing: FramingMode,
434        config: Config,
435    ) -> crate::Result<Self> {
436        debug!(
437            "session {}: building (framing={:?}, capabilities={})",
438            server_hello.session_id,
439            framing,
440            server_hello.capabilities.len()
441        );
442        let (read_half, write_half) = tokio::io::split(stream);
443
444        let read_codec = NetconfCodec::new(framing, config.codec);
445        let write_codec = NetconfCodec::new(framing, config.codec);
446        let reader = FramedRead::new(read_half, read_codec);
447
448        let (disconnect_tx, disconnect_rx) = tokio::sync::watch::channel(None);
449
450        let inner = Arc::new(SessionInner {
451            pending: Mutex::new(HashMap::new()),
452            state: AtomicU8::new(SessionState::Ready as u8),
453            disconnect_tx,
454            created_at: Instant::now(),
455            last_rpc_nanos: AtomicU64::new(0),
456            active_streams: AtomicUsize::new(0),
457        });
458
459        let reader_inner = Arc::clone(&inner);
460        let session_id = server_hello.session_id;
461        let reader_handle = tokio::spawn(async move {
462            reader_loop(reader, reader_inner, session_id).await;
463        });
464
465        Ok(Self {
466            writer: tokio::sync::Mutex::new(WriterState {
467                writer: write_half,
468                codec: write_codec,
469            }),
470            inner,
471            server_hello,
472            framing,
473            rpc_timeout: config.rpc_timeout,
474            disconnect_rx,
475            connected_since: Instant::now(),
476            _reader_handle: reader_handle,
477            _keep_alive: keep_alive,
478        })
479    }
480
481    pub fn session_id(&self) -> u32 {
482        self.server_hello.session_id
483    }
484
485    pub fn server_capabilities(&self) -> &[String] {
486        &self.server_hello.capabilities
487    }
488
489    pub fn framing_mode(&self) -> FramingMode {
490        self.framing
491    }
492
493    pub fn state(&self) -> SessionState {
494        self.inner.state()
495    }
496
497    /// Returns a future that completes when the session disconnects.
498    ///
499    /// Can be called multiple times — each call clones an internal
500    /// `watch::Receiver`, so multiple tasks can independently await
501    /// the same disconnect event. If the session is already disconnected
502    /// when called, returns immediately.
503    ///
504    /// # Example
505    ///
506    /// ```no_run
507    /// # use netconf_rust::Session;
508    /// # async fn example(session: &Session) {
509    /// let reason = session.disconnected().await;
510    /// println!("session died: {reason}");
511    /// # }
512    /// ```
513    pub fn disconnected(&self) -> impl Future<Output = DisconnectReason> + Send + 'static {
514        let mut rx = self.disconnect_rx.clone();
515        async move {
516            // Check if already disconnected (late subscriber).
517            if let Some(reason) = rx.borrow_and_update().clone() {
518                return reason;
519            }
520            // Wait for the reader_loop to send the reason. If the sender
521            // is dropped (Session dropped, reader aborted), treat as Dropped.
522            loop {
523                if rx.changed().await.is_err() {
524                    return DisconnectReason::Dropped;
525                }
526                if let Some(reason) = rx.borrow_and_update().clone() {
527                    return reason;
528                }
529            }
530        }
531    }
532
533    fn check_state(&self) -> crate::Result<()> {
534        let state = self.inner.state();
535        if state != SessionState::Ready {
536            return Err(crate::Error::InvalidState(state.to_string()));
537        }
538        Ok(())
539    }
540
541    /// Encode a message with the negotiated framing (EOM or chunked) and
542    /// write it to the stream.
543    async fn send_encoded(&self, xml: &str) -> crate::Result<()> {
544        let mut buf = BytesMut::new();
545        let mut state = self.writer.lock().await;
546        state.codec.encode(Bytes::from(xml.to_string()), &mut buf)?;
547        trace!(
548            "session {}: writing {} bytes to stream",
549            self.server_hello.session_id,
550            buf.len()
551        );
552        state.writer.write_all(&buf).await?;
553        state.writer.flush().await?;
554        Ok(())
555    }
556
557    /// Send a raw RPC and return a future for the reply (pipelining).
558    ///
559    /// This writes the RPC to the server immediately but does not wait
560    /// for the reply. Call [`RpcFuture::response()`] to await the reply.
561    /// Multiple RPCs can be pipelined by calling this repeatedly before
562    /// awaiting any of them.
563    pub async fn rpc_send(&self, inner_xml: &str) -> crate::Result<RpcFuture> {
564        self.check_state()?;
565        let (msg_id, xml) = message::build_rpc(inner_xml);
566        debug!(
567            "session {}: sending rpc message-id={} ({} bytes)",
568            self.server_hello.session_id,
569            msg_id,
570            xml.len()
571        );
572        trace!(
573            "session {}: rpc content: {}",
574            self.server_hello.session_id, inner_xml
575        );
576        let (tx, rx) = tokio::sync::oneshot::channel();
577
578        self.inner
579            .pending
580            .lock()
581            .unwrap()
582            .insert(msg_id, PendingRpc::Normal(tx));
583
584        if let Err(e) = self.send_encoded(&xml).await {
585            debug!(
586                "session {}: send failed for message-id={}: {}",
587                self.server_hello.session_id, msg_id, e
588            );
589            self.inner.pending.lock().unwrap().remove(&msg_id);
590            return Err(e);
591        }
592        Ok(RpcFuture {
593            rx,
594            msg_id,
595            rpc_timeout: self.rpc_timeout,
596        })
597    }
598
599    /// Send a raw RPC and wait for the reply.
600    pub async fn rpc_raw(&self, inner_xml: &str) -> crate::Result<RpcReply> {
601        let future = self.rpc_send(inner_xml).await?;
602        future.response().await
603    }
604
605    /// Send a raw RPC and return a streaming response.
606    ///
607    /// Unlike [`rpc_send()`](Session::rpc_send) which buffers the entire
608    /// response, this returns an [`RpcStream`] that yields raw XML bytes
609    /// as individual chunks arrive from the server. The stream implements
610    /// [`AsyncRead`], so it can be piped directly into compression encoders,
611    /// file writers, or any byte-oriented consumer.
612    ///
613    /// Normal (buffered) and streaming RPCs can be freely interleaved on
614    /// the same session — the reader task routes each message independently
615    /// based on its `message-id`.
616    ///
617    /// # Example
618    ///
619    /// ```no_run
620    /// # use netconf_rust::Session;
621    /// # use tokio::io::AsyncReadExt;
622    /// # async fn example(session: &Session) -> netconf_rust::Result<()> {
623    /// // Pipeline a normal RPC alongside a streaming one
624    /// let edit_future = session.rpc_send("<edit-config>...</edit-config>").await?;
625    /// let mut config_stream = session.rpc_stream("<get-config><source><running/></source></get-config>").await?;
626    ///
627    /// // Stream the large response
628    /// let mut buf = [0u8; 8192];
629    /// loop {
630    ///     let n = config_stream.read(&mut buf).await?;
631    ///     if n == 0 { break; }
632    ///     // process buf[..n]
633    /// }
634    ///
635    /// // Collect the normal RPC reply
636    /// let edit_reply = edit_future.response().await?;
637    /// # Ok(())
638    /// # }
639    /// ```
640    pub async fn rpc_stream(&self, inner_xml: &str) -> crate::Result<RpcStream> {
641        self.check_state()?;
642        let (msg_id, xml) = message::build_rpc(inner_xml);
643        debug!(
644            "session {}: sending streaming rpc message-id={} ({} bytes)",
645            self.server_hello.session_id,
646            msg_id,
647            xml.len()
648        );
649
650        let (tx, rx) = tokio::sync::mpsc::channel(32);
651
652        self.inner
653            .pending
654            .lock()
655            .unwrap()
656            .insert(msg_id, PendingRpc::Stream(tx));
657
658        if let Err(e) = self.send_encoded(&xml).await {
659            debug!(
660                "session {}: send failed for streaming message-id={}: {}",
661                self.server_hello.session_id, msg_id, e
662            );
663            self.inner.pending.lock().unwrap().remove(&msg_id);
664            return Err(e);
665        }
666
667        Ok(RpcStream {
668            rx,
669            current: Bytes::new(),
670            msg_id,
671            done: false,
672        })
673    }
674
675    /// Internal rpc send that skips state check. Only used for sending close-session.
676    async fn rpc_send_unchecked(&self, inner_xml: &str) -> crate::Result<RpcFuture> {
677        let (msg_id, xml) = message::build_rpc(inner_xml);
678        let (tx, rx) = tokio::sync::oneshot::channel();
679
680        self.inner
681            .pending
682            .lock()
683            .unwrap()
684            .insert(msg_id, PendingRpc::Normal(tx));
685
686        if let Err(e) = self.send_encoded(&xml).await {
687            self.inner.pending.lock().unwrap().remove(&msg_id);
688            return Err(e);
689        }
690
691        Ok(RpcFuture {
692            rx,
693            msg_id,
694            rpc_timeout: self.rpc_timeout,
695        })
696    }
697
698    /// Retrieve configuration from a datastore.
699    pub async fn get_config(
700        &self,
701        source: Datastore,
702        filter: Option<&str>,
703    ) -> crate::Result<String> {
704        let filter_xml = match filter {
705            Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
706            None => String::new(),
707        };
708        let inner = format!(
709            "<get-config><source>{}</source>{filter_xml}</get-config>",
710            source.as_xml()
711        );
712        let reply = self.rpc_raw(&inner).await?;
713        reply_to_data(reply)
714    }
715
716    /// Retrieve configuration as a zero-copy `DataPayload`.
717    ///
718    /// Same as `get_config()` but returns a `DataPayload` instead of `String`,
719    /// avoiding a copy of the response body. Use `payload.as_str()` for a
720    /// zero-copy `&str` view, or `payload.reader()` for streaming XML events.
721    pub async fn get_config_payload(
722        &self,
723        source: Datastore,
724        filter: Option<&str>,
725    ) -> crate::Result<DataPayload> {
726        let filter_xml = match filter {
727            Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
728            None => String::new(),
729        };
730        let inner = format!(
731            "<get-config><source>{}</source>{filter_xml}</get-config>",
732            source.as_xml()
733        );
734        let reply = self.rpc_raw(&inner).await?;
735        reply.into_data()
736    }
737
738    /// Retrieve configuration from a datastore as a streaming response.
739    ///
740    /// Same as `get_config()` but returns an [`RpcStream`] instead of
741    /// buffering the entire response.
742    pub async fn get_config_stream(
743        &self,
744        source: Datastore,
745        filter: Option<&str>,
746    ) -> crate::Result<RpcStream> {
747        let filter_xml = match filter {
748            Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
749            None => String::new(),
750        };
751        let inner = format!(
752            "<get-config><source>{}</source>{filter_xml}</get-config>",
753            source.as_xml()
754        );
755        self.rpc_stream(&inner).await
756    }
757
758    /// Retrieve running configuration and state data.
759    pub async fn get(&self, filter: Option<&str>) -> crate::Result<String> {
760        let filter_xml = match filter {
761            Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
762            None => String::new(),
763        };
764        let inner = format!("<get>{filter_xml}</get>");
765        let reply = self.rpc_raw(&inner).await?;
766        reply_to_data(reply)
767    }
768
769    /// Retrieve running configuration and state data as a zero-copy `DataPayload`.
770    ///
771    /// Same as `get()` but returns a `DataPayload` instead of `String`,
772    /// avoiding a copy of the response body.
773    pub async fn get_payload(&self, filter: Option<&str>) -> crate::Result<DataPayload> {
774        let filter_xml = match filter {
775            Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
776            None => String::new(),
777        };
778        let inner = format!("<get>{filter_xml}</get>");
779        let reply = self.rpc_raw(&inner).await?;
780        reply.into_data()
781    }
782
783    /// Retrieve running configuration and state data as a streaming response.
784    ///
785    /// Same as `get()` but returns an [`RpcStream`] instead of buffering
786    /// the entire response.
787    pub async fn get_stream(&self, filter: Option<&str>) -> crate::Result<RpcStream> {
788        let filter_xml = match filter {
789            Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
790            None => String::new(),
791        };
792        let inner = format!("<get>{filter_xml}</get>");
793        self.rpc_stream(&inner).await
794    }
795
796    /// Edit the configuration of a target datastore.
797    pub async fn edit_config(&self, target: Datastore, config: &str) -> crate::Result<()> {
798        let inner = format!(
799            "<edit-config><target>{}</target><config>{config}</config></edit-config>",
800            target.as_xml()
801        );
802        let reply = self.rpc_raw(&inner).await?;
803        reply_to_ok(reply)
804    }
805
806    /// Lock a datastore
807    pub async fn lock(&self, target: Datastore) -> crate::Result<()> {
808        let inner = format!("<lock><target>{}</target></lock>", target.as_xml());
809        let reply = self.rpc_raw(&inner).await?;
810        reply_to_ok(reply)
811    }
812
813    /// Unlock a datastore.
814    pub async fn unlock(&self, target: Datastore) -> crate::Result<()> {
815        let inner = format!("<unlock><target>{}</target></unlock>", target.as_xml());
816        let reply = self.rpc_raw(&inner).await?;
817        reply_to_ok(reply)
818    }
819
820    /// Commit the candidate configuration to running.
821    pub async fn commit(&self) -> crate::Result<()> {
822        let reply = self.rpc_raw("<commit/>").await?;
823        reply_to_ok(reply)
824    }
825
826    /// Gracefully close the NETCONF session.
827    pub async fn close_session(&self) -> crate::Result<()> {
828        // Atomically transition Ready → Closing. If another caller already
829        // moved us out of Ready, we fail immediately.
830        let prev = self.inner.state.compare_exchange(
831            SessionState::Ready as u8,
832            SessionState::Closing as u8,
833            Ordering::AcqRel,
834            Ordering::Acquire,
835        );
836        if let Err(current) = prev {
837            let state = SessionState::from_u8(current);
838            return Err(crate::Error::InvalidState(state.to_string()));
839        }
840        debug!("session {}: closing", self.server_hello.session_id);
841        let result = self.rpc_send_unchecked("<close-session/>").await;
842        match result {
843            Ok(future) => {
844                let reply = future.response().await;
845                self.inner.set_state(SessionState::Closed);
846                debug!(
847                    "session {}: closed gracefully",
848                    self.server_hello.session_id
849                );
850                reply_to_ok(reply?)
851            }
852            Err(e) => {
853                self.inner.set_state(SessionState::Closed);
854                debug!(
855                    "session {}: close failed: {}",
856                    self.server_hello.session_id, e
857                );
858                Err(e)
859            }
860        }
861    }
862
863    /// Gracefully close the NETCONF session and shut down the transport.
864    ///
865    /// Sends a `<close-session/>` RPC, shuts down the write half of the
866    /// stream, and drops the session (aborting the reader task and releasing
867    /// the SSH handle). Prefer this over [`close_session()`](Self::close_session)
868    /// when you are done with the session entirely.
869    pub async fn close(self) -> crate::Result<()> {
870        let result = self.close_session().await;
871        self.writer.lock().await.writer.shutdown().await.ok();
872        // `self` is dropped here — reader task aborted, SSH handle released
873        result
874    }
875
876    /// Force-close another NETCONF session.
877    pub async fn kill_session(&self, session_id: u32) -> crate::Result<()> {
878        let inner = format!("<kill-session><session-id>{session_id}</session-id></kill-session>");
879        let reply = self.rpc_raw(&inner).await?;
880        reply_to_ok(reply)
881    }
882
883    /// Return a wrapper that applies `timeout` to every RPC sent through it,
884    /// overriding the session-level `rpc_timeout`.
885    pub fn with_timeout(&self, timeout: Duration) -> SessionWithTimeout<'_> {
886        SessionWithTimeout {
887            session: self,
888            timeout,
889        }
890    }
891
892    /// Number of RPCs that have been sent but not yet fully replied to.
893    ///
894    /// This includes both RPCs awaiting their first reply byte (in the
895    /// pending map) and streaming RPCs whose response is in flight but
896    /// not yet complete.
897    pub fn pending_rpc_count(&self) -> usize {
898        self.inner.pending.lock().unwrap().len() + self.inner.active_streams.load(Ordering::Acquire)
899    }
900
901    /// The [`Instant`] when the most recent RPC reply was received, or
902    /// `None` if no reply has been received yet.
903    pub fn last_rpc_at(&self) -> Option<Instant> {
904        let nanos = self.inner.last_rpc_nanos.load(Ordering::Acquire);
905        if nanos == 0 {
906            None
907        } else {
908            Some(self.inner.created_at + Duration::from_nanos(nanos))
909        }
910    }
911
912    /// The [`Instant`] when the session was established (hello exchange
913    /// complete and reader task started).
914    pub fn connected_since(&self) -> Instant {
915        self.connected_since
916    }
917}
918
919/// A wrapper around [`Session`] that applies a per-call timeout to every RPC.
920///
921/// Created by [`Session::with_timeout()`]. Each method sends the RPC via
922/// the underlying session and awaits the reply with the configured timeout.
923pub struct SessionWithTimeout<'a> {
924    session: &'a Session,
925    timeout: Duration,
926}
927
928impl SessionWithTimeout<'_> {
929    /// Send a raw RPC and wait for the reply with the configured timeout.
930    pub async fn rpc_raw(&self, inner_xml: &str) -> crate::Result<RpcReply> {
931        let future = self.session.rpc_send(inner_xml).await?;
932        future.response_with_timeout(self.timeout).await
933    }
934
935    /// Retrieve configuration from a datastore.
936    pub async fn get_config(
937        &self,
938        source: Datastore,
939        filter: Option<&str>,
940    ) -> crate::Result<String> {
941        let filter_xml = match filter {
942            Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
943            None => String::new(),
944        };
945        let inner = format!(
946            "<get-config><source>{}</source>{filter_xml}</get-config>",
947            source.as_xml()
948        );
949        let reply = self.rpc_raw(&inner).await?;
950        reply_to_data(reply)
951    }
952
953    /// Retrieve configuration as a zero-copy `DataPayload`.
954    pub async fn get_config_payload(
955        &self,
956        source: Datastore,
957        filter: Option<&str>,
958    ) -> crate::Result<DataPayload> {
959        let filter_xml = match filter {
960            Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
961            None => String::new(),
962        };
963        let inner = format!(
964            "<get-config><source>{}</source>{filter_xml}</get-config>",
965            source.as_xml()
966        );
967        let reply = self.rpc_raw(&inner).await?;
968        reply.into_data()
969    }
970
971    /// Retrieve running configuration and state data.
972    pub async fn get(&self, filter: Option<&str>) -> crate::Result<String> {
973        let filter_xml = match filter {
974            Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
975            None => String::new(),
976        };
977        let inner = format!("<get>{filter_xml}</get>");
978        let reply = self.rpc_raw(&inner).await?;
979        reply_to_data(reply)
980    }
981
982    /// Retrieve running configuration and state data as a zero-copy `DataPayload`.
983    pub async fn get_payload(&self, filter: Option<&str>) -> crate::Result<DataPayload> {
984        let filter_xml = match filter {
985            Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
986            None => String::new(),
987        };
988        let inner = format!("<get>{filter_xml}</get>");
989        let reply = self.rpc_raw(&inner).await?;
990        reply.into_data()
991    }
992
993    /// Edit the configuration of a target datastore.
994    pub async fn edit_config(&self, target: Datastore, config: &str) -> crate::Result<()> {
995        let inner = format!(
996            "<edit-config><target>{}</target><config>{config}</config></edit-config>",
997            target.as_xml()
998        );
999        let reply = self.rpc_raw(&inner).await?;
1000        reply_to_ok(reply)
1001    }
1002
1003    /// Lock a datastore.
1004    pub async fn lock(&self, target: Datastore) -> crate::Result<()> {
1005        let inner = format!("<lock><target>{}</target></lock>", target.as_xml());
1006        let reply = self.rpc_raw(&inner).await?;
1007        reply_to_ok(reply)
1008    }
1009
1010    /// Unlock a datastore.
1011    pub async fn unlock(&self, target: Datastore) -> crate::Result<()> {
1012        let inner = format!("<unlock><target>{}</target></unlock>", target.as_xml());
1013        let reply = self.rpc_raw(&inner).await?;
1014        reply_to_ok(reply)
1015    }
1016
1017    /// Commit the candidate configuration to running.
1018    pub async fn commit(&self) -> crate::Result<()> {
1019        let reply = self.rpc_raw("<commit/>").await?;
1020        reply_to_ok(reply)
1021    }
1022}
1023
1024/// Perform the NETCONF hello exchange, optionally with a timeout.
1025async fn exchange_hello<S: AsyncRead + AsyncWrite + Unpin>(
1026    stream: &mut S,
1027    config: &Config,
1028) -> crate::Result<(ServerHello, FramingMode)> {
1029    let fut = crate::hello::exchange(stream, config.codec.max_message_size);
1030    match config.hello_timeout {
1031        Some(duration) => tokio::time::timeout(duration, fut)
1032            .await
1033            .map_err(|_| crate::Error::Transport(TransportError::Timeout(duration)))?,
1034        None => fut.await,
1035    }
1036}
1037
1038// =============================================================================
1039// Reader task — state machine for chunk-level routing
1040// =============================================================================
1041
1042/// Per-message state machine for the reader task.
1043///
1044/// Tracks whether the current message is being accumulated (normal RPC)
1045/// or streamed (streaming RPC). The state transitions on each
1046/// [`DecodedFrame`] from the codec.
1047enum ReaderMessageState {
1048    /// Waiting for the first chunk of a new message.
1049    /// Accumulates bytes until the `message-id` can be extracted.
1050    AwaitingHeader { buf: BytesMut },
1051    /// Accumulating chunks for a normal (non-streaming) RPC.
1052    Accumulating { msg_id: u32, buf: BytesMut },
1053    /// Forwarding chunks for a streaming RPC.
1054    Streaming {
1055        msg_id: u32,
1056        tx: tokio::sync::mpsc::Sender<crate::Result<Bytes>>,
1057    },
1058}
1059
1060/// This loop is the only thing reading from the SSH stream. It runs in a
1061/// background tokio task. The session's main API (the writer side) never
1062/// reads — it only writes RPCs and waits on oneshot/mpsc channels.
1063/// This separation is what makes pipelining and streaming work.
1064///
1065/// The codec now yields [`DecodedFrame`] items (individual chunks and
1066/// end-of-message markers) instead of complete messages. The reader task
1067/// uses a [`ReaderMessageState`] state machine to decide per-message
1068/// whether to accumulate chunks (normal RPCs) or forward them (streaming
1069/// RPCs). The `message-id` from the first chunk determines the routing.
1070async fn reader_loop(
1071    mut reader: FramedRead<ReadHalf<NetconfStream>, NetconfCodec>,
1072    inner: Arc<SessionInner>,
1073    session_id: u32,
1074) {
1075    debug!("session {}: reader loop started", session_id);
1076    let mut disconnect_reason = DisconnectReason::Eof;
1077    let mut state = ReaderMessageState::AwaitingHeader {
1078        buf: BytesMut::new(),
1079    };
1080
1081    loop {
1082        // Propagate the closing flag so decode_eof can discard holdback bytes.
1083        if inner.state() == SessionState::Closing {
1084            reader.decoder_mut().set_closing();
1085        }
1086        let Some(result) = reader.next().await else {
1087            break;
1088        };
1089        match result {
1090            Ok(frame) => {
1091                state = process_frame(frame, state, &inner, session_id).await;
1092            }
1093            Err(e) => {
1094                debug!("session {}: reader error: {e}", session_id);
1095                disconnect_reason = DisconnectReason::TransportError(e.to_string());
1096
1097                // Notify any in-flight streaming RPC about the error.
1098                if let ReaderMessageState::Streaming { tx, .. } = &state {
1099                    let _ = tx.try_send(Err(crate::Error::SessionClosed));
1100                }
1101
1102                let drained = inner.drain_pending();
1103                if drained > 0 {
1104                    debug!(
1105                        "session {}: drained {} pending RPCs after error",
1106                        session_id, drained
1107                    );
1108                }
1109                break;
1110            }
1111        }
1112    }
1113
1114    // Notify any in-flight streaming RPC about session close.
1115    if let ReaderMessageState::Streaming { tx, .. } = &state {
1116        let _ = tx.try_send(Err(crate::Error::SessionClosed));
1117    }
1118
1119    // Drain any remaining pending RPCs.
1120    {
1121        let drained = inner.drain_pending();
1122        if drained > 0 {
1123            debug!(
1124                "session {}: drained {} pending RPCs on stream close",
1125                session_id, drained
1126            );
1127        }
1128    }
1129
1130    inner.set_state(SessionState::Closed);
1131    let _ = inner.disconnect_tx.send(Some(disconnect_reason));
1132    debug!("session {}: reader loop ended", session_id);
1133}
1134
1135/// Process a single decoded frame, advancing the per-message state machine.
1136async fn process_frame(
1137    frame: DecodedFrame,
1138    state: ReaderMessageState,
1139    inner: &SessionInner,
1140    session_id: u32,
1141) -> ReaderMessageState {
1142    match frame {
1143        DecodedFrame::Chunk(chunk) => match state {
1144            ReaderMessageState::AwaitingHeader { mut buf } => {
1145                buf.extend_from_slice(&chunk);
1146
1147                // Try to extract the message-id from what we have so far.
1148                if let Some(msg_id) = extract_message_id_from_bytes(&buf) {
1149                    // Look up the pending RPC type to decide routing.
1150                    let is_stream = {
1151                        let pending = inner.pending.lock().unwrap();
1152                        matches!(pending.get(&msg_id), Some(PendingRpc::Stream(_)))
1153                    };
1154
1155                    if is_stream {
1156                        // Remove the stream sender from pending and transition
1157                        // to Streaming state.
1158                        let tx = {
1159                            let mut pending = inner.pending.lock().unwrap();
1160                            match pending.remove(&msg_id) {
1161                                Some(PendingRpc::Stream(tx)) => tx,
1162                                // Race: might have been drained between the
1163                                // check above and here. Fall back to accumulate.
1164                                _ => {
1165                                    return ReaderMessageState::Accumulating { msg_id, buf };
1166                                }
1167                            }
1168                        };
1169                        inner.active_streams.fetch_add(1, Ordering::Release);
1170                        // Forward the buffered header as the first chunk.
1171                        let _ = tx.send(Ok(buf.freeze())).await;
1172                        debug!(
1173                            "session {}: streaming rpc message-id={}",
1174                            session_id, msg_id
1175                        );
1176                        ReaderMessageState::Streaming { msg_id, tx }
1177                    } else {
1178                        // Normal RPC or unknown — accumulate.
1179                        ReaderMessageState::Accumulating { msg_id, buf }
1180                    }
1181                } else {
1182                    // Need more data to find message-id.
1183                    ReaderMessageState::AwaitingHeader { buf }
1184                }
1185            }
1186            ReaderMessageState::Accumulating { msg_id, mut buf } => {
1187                buf.extend_from_slice(&chunk);
1188                ReaderMessageState::Accumulating { msg_id, buf }
1189            }
1190            ReaderMessageState::Streaming { msg_id, tx } => {
1191                // Forward chunk to the streaming consumer. If the consumer
1192                // dropped the receiver, just discard — we still need to
1193                // consume until EndOfMessage so we can start the next message.
1194                let _ = tx.send(Ok(chunk)).await;
1195                ReaderMessageState::Streaming { msg_id, tx }
1196            }
1197        },
1198
1199        DecodedFrame::EndOfMessage => match state {
1200            ReaderMessageState::AwaitingHeader { .. } => {
1201                // Empty message or we never found a message-id — reset.
1202                trace!("session {}: empty or unparseable message", session_id);
1203                ReaderMessageState::AwaitingHeader {
1204                    buf: BytesMut::new(),
1205                }
1206            }
1207            ReaderMessageState::Accumulating { msg_id, buf } => {
1208                // Complete normal message — classify and route.
1209                let bytes = buf.freeze();
1210                trace!(
1211                    "session {}: complete message for msg-id={} ({} bytes)",
1212                    session_id,
1213                    msg_id,
1214                    bytes.len()
1215                );
1216
1217                match message::classify_message(bytes) {
1218                    Ok(ServerMessage::RpcReply(reply)) => {
1219                        debug!(
1220                            "session {}: received rpc-reply message-id={}",
1221                            session_id, reply.message_id
1222                        );
1223                        let tx = {
1224                            let mut pending = inner.pending.lock().unwrap();
1225                            pending.remove(&reply.message_id)
1226                        };
1227                        if let Some(PendingRpc::Normal(tx)) = tx {
1228                            let nanos = inner.created_at.elapsed().as_nanos() as u64;
1229                            inner.last_rpc_nanos.store(nanos, Ordering::Release);
1230                            let _ = tx.send(Ok(reply));
1231                        } else {
1232                            warn!(
1233                                "session {}: received reply for unknown message-id {}",
1234                                session_id, reply.message_id
1235                            );
1236                        }
1237                    }
1238                    Err(e) => {
1239                        warn!("session {}: failed to classify message: {e}", session_id);
1240                    }
1241                }
1242
1243                ReaderMessageState::AwaitingHeader {
1244                    buf: BytesMut::new(),
1245                }
1246            }
1247            ReaderMessageState::Streaming { msg_id, tx } => {
1248                // End of streaming message — drop the sender to signal EOF.
1249                drop(tx);
1250                inner.active_streams.fetch_sub(1, Ordering::Release);
1251                let nanos = inner.created_at.elapsed().as_nanos() as u64;
1252                inner.last_rpc_nanos.store(nanos, Ordering::Release);
1253                debug!(
1254                    "session {}: streaming message complete for msg-id={}",
1255                    session_id, msg_id
1256                );
1257                ReaderMessageState::AwaitingHeader {
1258                    buf: BytesMut::new(),
1259                }
1260            }
1261        },
1262    }
1263}
1264
1265fn reply_to_data(reply: RpcReply) -> crate::Result<String> {
1266    match reply.body {
1267        RpcReplyBody::Data(payload) => Ok(payload.into_string()),
1268        RpcReplyBody::Ok => Ok(String::new()),
1269        RpcReplyBody::Error(errors) => Err(crate::Error::Rpc {
1270            message_id: reply.message_id,
1271            error: errors
1272                .first()
1273                .map(|e| e.error_message.clone())
1274                .unwrap_or_default(),
1275        }),
1276    }
1277}
1278
1279fn reply_to_ok(reply: RpcReply) -> crate::Result<()> {
1280    match reply.body {
1281        RpcReplyBody::Ok => Ok(()),
1282        RpcReplyBody::Data(_) => Ok(()),
1283        RpcReplyBody::Error(errors) => Err(crate::Error::Rpc {
1284            message_id: reply.message_id,
1285            error: errors
1286                .first()
1287                .map(|e| e.error_message.clone())
1288                .unwrap_or_default(),
1289        }),
1290    }
1291}