1use std::collections::HashMap;
2use std::future::Future;
3use std::io;
4use std::pin::Pin;
5use std::sync::atomic::{AtomicU8, AtomicU64, AtomicUsize, Ordering};
6use std::sync::{Arc, Mutex};
7use std::task::{Context, Poll};
8use std::time::{Duration, Instant};
9
10use bytes::{Buf, Bytes, BytesMut};
11use log::{debug, trace, warn};
12use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, ReadHalf, WriteHalf};
13use tokio_stream::StreamExt;
14use tokio_util::codec::{Encoder, FramedRead};
15
16use crate::codec::{DecodedFrame, FramingMode, NetconfCodec, extract_message_id_from_bytes};
17use crate::config::Config;
18use crate::error::TransportError;
19use crate::hello::ServerHello;
20use crate::message::{self, DataPayload, RpcReply, RpcReplyBody, ServerMessage};
21use crate::stream::NetconfStream;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24#[repr(u8)]
25pub enum SessionState {
26 Ready = 0,
28 Closing = 1,
30 Closed = 2,
32}
33
34impl SessionState {
35 fn from_u8(v: u8) -> Self {
36 match v {
37 0 => Self::Ready,
38 1 => Self::Closing,
39 _ => Self::Closed,
40 }
41 }
42}
43
44impl std::fmt::Display for SessionState {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 match self {
47 Self::Ready => write!(f, "Ready"),
48 Self::Closing => write!(f, "Closing"),
49 Self::Closed => write!(f, "Closed"),
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
59pub enum DisconnectReason {
60 Eof,
62 TransportError(String),
66 Dropped,
68}
69
70impl std::fmt::Display for DisconnectReason {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 match self {
73 Self::Eof => write!(f, "connection closed by remote"),
74 Self::TransportError(e) => write!(f, "transport error: {e}"),
75 Self::Dropped => write!(f, "session dropped"),
76 }
77 }
78}
79
80#[derive(Debug, Clone, Copy)]
81pub enum Datastore {
82 Running,
83 Candidate,
84 Startup,
85}
86
87impl Datastore {
88 fn as_xml(&self) -> &'static str {
89 match self {
90 Datastore::Running => "<running/>",
91 Datastore::Candidate => "<candidate/>",
92 Datastore::Startup => "<startup/>",
93 }
94 }
95}
96
97enum PendingRpc {
107 Normal(tokio::sync::oneshot::Sender<crate::Result<RpcReply>>),
108 Stream(tokio::sync::mpsc::Sender<crate::Result<Bytes>>),
109}
110
111pub struct RpcFuture {
116 rx: tokio::sync::oneshot::Receiver<crate::Result<RpcReply>>,
117 msg_id: u32,
118 rpc_timeout: Option<Duration>,
119}
120
121impl RpcFuture {
122 pub fn message_id(&self) -> u32 {
124 self.msg_id
125 }
126
127 pub async fn response(self) -> crate::Result<RpcReply> {
133 let result = match self.rpc_timeout {
134 Some(duration) => tokio::time::timeout(duration, self.rx)
135 .await
136 .map_err(|_| crate::Error::Transport(TransportError::Timeout(duration)))?,
137 None => self.rx.await,
138 };
139 result.map_err(|_| crate::Error::SessionClosed)?
140 }
141
142 pub async fn response_with_timeout(self, timeout: Duration) -> crate::Result<RpcReply> {
148 let result = tokio::time::timeout(timeout, self.rx)
149 .await
150 .map_err(|_| crate::Error::Transport(TransportError::Timeout(timeout)))?;
151 result.map_err(|_| crate::Error::SessionClosed)?
152 }
153}
154
155pub struct RpcStream {
184 rx: tokio::sync::mpsc::Receiver<crate::Result<Bytes>>,
185 current: Bytes,
187 msg_id: u32,
188 done: bool,
189}
190
191impl RpcStream {
192 pub fn message_id(&self) -> u32 {
194 self.msg_id
195 }
196
197 pub fn is_done(&self) -> bool {
199 self.done
200 }
201}
202
203impl AsyncRead for RpcStream {
204 fn poll_read(
205 mut self: Pin<&mut Self>,
206 cx: &mut Context<'_>,
207 buf: &mut ReadBuf<'_>,
208 ) -> Poll<io::Result<()>> {
209 if !self.current.is_empty() {
211 let n = std::cmp::min(buf.remaining(), self.current.len());
212 buf.put_slice(&self.current[..n]);
213 self.current.advance(n);
214 return Poll::Ready(Ok(()));
215 }
216
217 if self.done {
219 return Poll::Ready(Ok(()));
220 }
221
222 match self.rx.poll_recv(cx) {
224 Poll::Ready(Some(Ok(chunk))) => {
225 let n = std::cmp::min(buf.remaining(), chunk.len());
226 buf.put_slice(&chunk[..n]);
227 if n < chunk.len() {
228 self.current = chunk.slice(n..);
229 }
230 Poll::Ready(Ok(()))
231 }
232 Poll::Ready(Some(Err(e))) => {
233 self.done = true;
234 Poll::Ready(Err(io::Error::other(e.to_string())))
235 }
236 Poll::Ready(None) => {
237 self.done = true;
239 Poll::Ready(Ok(()))
240 }
241 Poll::Pending => Poll::Pending,
242 }
243 }
244}
245
246struct SessionInner {
251 pending: Mutex<HashMap<u32, PendingRpc>>,
253 state: AtomicU8,
264 disconnect_tx: tokio::sync::watch::Sender<Option<DisconnectReason>>,
268 created_at: Instant,
270 last_rpc_nanos: AtomicU64,
273 active_streams: AtomicUsize,
276}
277
278impl SessionInner {
279 fn state(&self) -> SessionState {
280 SessionState::from_u8(self.state.load(Ordering::Acquire))
281 }
282
283 fn set_state(&self, state: SessionState) {
284 self.state.store(state as u8, Ordering::Release);
285 }
286
287 fn drain_pending(&self) -> usize {
288 let mut pending = self.pending.lock().unwrap();
289 let count = pending.len();
290 for (_, rpc) in pending.drain() {
291 match rpc {
292 PendingRpc::Normal(tx) => {
293 let _ = tx.send(Err(crate::Error::SessionClosed));
294 }
295 PendingRpc::Stream(tx) => {
296 let _ = tx.try_send(Err(crate::Error::SessionClosed));
297 }
298 }
299 }
300 count
301 }
302}
303
304struct WriterState {
307 writer: WriteHalf<NetconfStream>,
308 codec: NetconfCodec,
309}
310
311pub struct Session {
326 writer: tokio::sync::Mutex<WriterState>,
329
330 inner: Arc<SessionInner>,
333
334 server_hello: ServerHello,
336
337 framing: FramingMode,
339
340 rpc_timeout: Option<Duration>,
342
343 disconnect_rx: tokio::sync::watch::Receiver<Option<DisconnectReason>>,
346
347 connected_since: Instant,
349
350 _reader_handle: tokio::task::JoinHandle<()>,
355
356 _keep_alive: Option<Box<dyn std::any::Any + Send + Sync>>,
359}
360
361impl Drop for Session {
362 fn drop(&mut self) {
363 let drained = self.inner.drain_pending();
365 if drained > 0 {
366 debug!(
367 "session {}: drop: drained {drained} pending RPCs",
368 self.server_hello.session_id
369 );
370 }
371 self.inner.set_state(SessionState::Closed);
373 self.inner.disconnect_tx.send_if_modified(|current| {
375 if current.is_none() {
376 *current = Some(DisconnectReason::Dropped);
377 true
378 } else {
379 false
380 }
381 });
382 self._reader_handle.abort();
384 }
385}
386
387impl Session {
388 pub async fn connect(
390 host: &str,
391 port: u16,
392 username: &str,
393 password: &str,
394 ) -> crate::Result<Self> {
395 Self::connect_with_config(host, port, username, password, Config::default()).await
396 }
397
398 pub async fn connect_with_config(
400 host: &str,
401 port: u16,
402 username: &str,
403 password: &str,
404 config: Config,
405 ) -> crate::Result<Self> {
406 let (mut stream, keep_alive) =
407 crate::transport::connect(host, port, username, password, &config).await?;
408 let (server_hello, framing) = exchange_hello(&mut stream, &config).await?;
409 Self::build(stream, Some(keep_alive), server_hello, framing, config)
410 }
411
412 pub async fn from_stream<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
414 stream: S,
415 ) -> crate::Result<Self> {
416 Self::from_stream_with_config(stream, Config::default()).await
417 }
418
419 pub async fn from_stream_with_config<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
421 mut stream: S,
422 config: Config,
423 ) -> crate::Result<Self> {
424 let (server_hello, framing) = exchange_hello(&mut stream, &config).await?;
425 let boxed: NetconfStream = Box::new(stream);
426 Self::build(boxed, None, server_hello, framing, config)
427 }
428
429 fn build(
430 stream: NetconfStream,
431 keep_alive: Option<Box<dyn std::any::Any + Send + Sync>>,
432 server_hello: ServerHello,
433 framing: FramingMode,
434 config: Config,
435 ) -> crate::Result<Self> {
436 debug!(
437 "session {}: building (framing={:?}, capabilities={})",
438 server_hello.session_id,
439 framing,
440 server_hello.capabilities.len()
441 );
442 let (read_half, write_half) = tokio::io::split(stream);
443
444 let read_codec = NetconfCodec::new(framing, config.codec);
445 let write_codec = NetconfCodec::new(framing, config.codec);
446 let reader = FramedRead::new(read_half, read_codec);
447
448 let (disconnect_tx, disconnect_rx) = tokio::sync::watch::channel(None);
449
450 let inner = Arc::new(SessionInner {
451 pending: Mutex::new(HashMap::new()),
452 state: AtomicU8::new(SessionState::Ready as u8),
453 disconnect_tx,
454 created_at: Instant::now(),
455 last_rpc_nanos: AtomicU64::new(0),
456 active_streams: AtomicUsize::new(0),
457 });
458
459 let reader_inner = Arc::clone(&inner);
460 let session_id = server_hello.session_id;
461 let reader_handle = tokio::spawn(async move {
462 reader_loop(reader, reader_inner, session_id).await;
463 });
464
465 Ok(Self {
466 writer: tokio::sync::Mutex::new(WriterState {
467 writer: write_half,
468 codec: write_codec,
469 }),
470 inner,
471 server_hello,
472 framing,
473 rpc_timeout: config.rpc_timeout,
474 disconnect_rx,
475 connected_since: Instant::now(),
476 _reader_handle: reader_handle,
477 _keep_alive: keep_alive,
478 })
479 }
480
481 pub fn session_id(&self) -> u32 {
482 self.server_hello.session_id
483 }
484
485 pub fn server_capabilities(&self) -> &[String] {
486 &self.server_hello.capabilities
487 }
488
489 pub fn framing_mode(&self) -> FramingMode {
490 self.framing
491 }
492
493 pub fn state(&self) -> SessionState {
494 self.inner.state()
495 }
496
497 pub fn disconnected(&self) -> impl Future<Output = DisconnectReason> + Send + 'static {
514 let mut rx = self.disconnect_rx.clone();
515 async move {
516 if let Some(reason) = rx.borrow_and_update().clone() {
518 return reason;
519 }
520 loop {
523 if rx.changed().await.is_err() {
524 return DisconnectReason::Dropped;
525 }
526 if let Some(reason) = rx.borrow_and_update().clone() {
527 return reason;
528 }
529 }
530 }
531 }
532
533 fn check_state(&self) -> crate::Result<()> {
534 let state = self.inner.state();
535 if state != SessionState::Ready {
536 return Err(crate::Error::InvalidState(state.to_string()));
537 }
538 Ok(())
539 }
540
541 async fn send_encoded(&self, xml: &str) -> crate::Result<()> {
544 let mut buf = BytesMut::new();
545 let mut state = self.writer.lock().await;
546 state.codec.encode(Bytes::from(xml.to_string()), &mut buf)?;
547 trace!(
548 "session {}: writing {} bytes to stream",
549 self.server_hello.session_id,
550 buf.len()
551 );
552 state.writer.write_all(&buf).await?;
553 state.writer.flush().await?;
554 Ok(())
555 }
556
557 pub async fn rpc_send(&self, inner_xml: &str) -> crate::Result<RpcFuture> {
564 self.check_state()?;
565 let (msg_id, xml) = message::build_rpc(inner_xml);
566 debug!(
567 "session {}: sending rpc message-id={} ({} bytes)",
568 self.server_hello.session_id,
569 msg_id,
570 xml.len()
571 );
572 trace!(
573 "session {}: rpc content: {}",
574 self.server_hello.session_id, inner_xml
575 );
576 let (tx, rx) = tokio::sync::oneshot::channel();
577
578 self.inner
579 .pending
580 .lock()
581 .unwrap()
582 .insert(msg_id, PendingRpc::Normal(tx));
583
584 if let Err(e) = self.send_encoded(&xml).await {
585 debug!(
586 "session {}: send failed for message-id={}: {}",
587 self.server_hello.session_id, msg_id, e
588 );
589 self.inner.pending.lock().unwrap().remove(&msg_id);
590 return Err(e);
591 }
592 Ok(RpcFuture {
593 rx,
594 msg_id,
595 rpc_timeout: self.rpc_timeout,
596 })
597 }
598
599 pub async fn rpc_raw(&self, inner_xml: &str) -> crate::Result<RpcReply> {
601 let future = self.rpc_send(inner_xml).await?;
602 future.response().await
603 }
604
605 pub async fn rpc_stream(&self, inner_xml: &str) -> crate::Result<RpcStream> {
641 self.check_state()?;
642 let (msg_id, xml) = message::build_rpc(inner_xml);
643 debug!(
644 "session {}: sending streaming rpc message-id={} ({} bytes)",
645 self.server_hello.session_id,
646 msg_id,
647 xml.len()
648 );
649
650 let (tx, rx) = tokio::sync::mpsc::channel(32);
651
652 self.inner
653 .pending
654 .lock()
655 .unwrap()
656 .insert(msg_id, PendingRpc::Stream(tx));
657
658 if let Err(e) = self.send_encoded(&xml).await {
659 debug!(
660 "session {}: send failed for streaming message-id={}: {}",
661 self.server_hello.session_id, msg_id, e
662 );
663 self.inner.pending.lock().unwrap().remove(&msg_id);
664 return Err(e);
665 }
666
667 Ok(RpcStream {
668 rx,
669 current: Bytes::new(),
670 msg_id,
671 done: false,
672 })
673 }
674
675 async fn rpc_send_unchecked(&self, inner_xml: &str) -> crate::Result<RpcFuture> {
677 let (msg_id, xml) = message::build_rpc(inner_xml);
678 let (tx, rx) = tokio::sync::oneshot::channel();
679
680 self.inner
681 .pending
682 .lock()
683 .unwrap()
684 .insert(msg_id, PendingRpc::Normal(tx));
685
686 if let Err(e) = self.send_encoded(&xml).await {
687 self.inner.pending.lock().unwrap().remove(&msg_id);
688 return Err(e);
689 }
690
691 Ok(RpcFuture {
692 rx,
693 msg_id,
694 rpc_timeout: self.rpc_timeout,
695 })
696 }
697
698 pub async fn get_config(
700 &self,
701 source: Datastore,
702 filter: Option<&str>,
703 ) -> crate::Result<String> {
704 let filter_xml = match filter {
705 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
706 None => String::new(),
707 };
708 let inner = format!(
709 "<get-config><source>{}</source>{filter_xml}</get-config>",
710 source.as_xml()
711 );
712 let reply = self.rpc_raw(&inner).await?;
713 reply_to_data(reply)
714 }
715
716 pub async fn get_config_payload(
722 &self,
723 source: Datastore,
724 filter: Option<&str>,
725 ) -> crate::Result<DataPayload> {
726 let filter_xml = match filter {
727 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
728 None => String::new(),
729 };
730 let inner = format!(
731 "<get-config><source>{}</source>{filter_xml}</get-config>",
732 source.as_xml()
733 );
734 let reply = self.rpc_raw(&inner).await?;
735 reply.into_data()
736 }
737
738 pub async fn get_config_stream(
743 &self,
744 source: Datastore,
745 filter: Option<&str>,
746 ) -> crate::Result<RpcStream> {
747 let filter_xml = match filter {
748 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
749 None => String::new(),
750 };
751 let inner = format!(
752 "<get-config><source>{}</source>{filter_xml}</get-config>",
753 source.as_xml()
754 );
755 self.rpc_stream(&inner).await
756 }
757
758 pub async fn get(&self, filter: Option<&str>) -> crate::Result<String> {
760 let filter_xml = match filter {
761 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
762 None => String::new(),
763 };
764 let inner = format!("<get>{filter_xml}</get>");
765 let reply = self.rpc_raw(&inner).await?;
766 reply_to_data(reply)
767 }
768
769 pub async fn get_payload(&self, filter: Option<&str>) -> crate::Result<DataPayload> {
774 let filter_xml = match filter {
775 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
776 None => String::new(),
777 };
778 let inner = format!("<get>{filter_xml}</get>");
779 let reply = self.rpc_raw(&inner).await?;
780 reply.into_data()
781 }
782
783 pub async fn get_stream(&self, filter: Option<&str>) -> crate::Result<RpcStream> {
788 let filter_xml = match filter {
789 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
790 None => String::new(),
791 };
792 let inner = format!("<get>{filter_xml}</get>");
793 self.rpc_stream(&inner).await
794 }
795
796 pub async fn edit_config(&self, target: Datastore, config: &str) -> crate::Result<()> {
798 let inner = format!(
799 "<edit-config><target>{}</target><config>{config}</config></edit-config>",
800 target.as_xml()
801 );
802 let reply = self.rpc_raw(&inner).await?;
803 reply_to_ok(reply)
804 }
805
806 pub async fn lock(&self, target: Datastore) -> crate::Result<()> {
808 let inner = format!("<lock><target>{}</target></lock>", target.as_xml());
809 let reply = self.rpc_raw(&inner).await?;
810 reply_to_ok(reply)
811 }
812
813 pub async fn unlock(&self, target: Datastore) -> crate::Result<()> {
815 let inner = format!("<unlock><target>{}</target></unlock>", target.as_xml());
816 let reply = self.rpc_raw(&inner).await?;
817 reply_to_ok(reply)
818 }
819
820 pub async fn commit(&self) -> crate::Result<()> {
822 let reply = self.rpc_raw("<commit/>").await?;
823 reply_to_ok(reply)
824 }
825
826 pub async fn close_session(&self) -> crate::Result<()> {
828 let prev = self.inner.state.compare_exchange(
831 SessionState::Ready as u8,
832 SessionState::Closing as u8,
833 Ordering::AcqRel,
834 Ordering::Acquire,
835 );
836 if let Err(current) = prev {
837 let state = SessionState::from_u8(current);
838 return Err(crate::Error::InvalidState(state.to_string()));
839 }
840 debug!("session {}: closing", self.server_hello.session_id);
841 let result = self.rpc_send_unchecked("<close-session/>").await;
842 match result {
843 Ok(future) => {
844 let reply = future.response().await;
845 self.inner.set_state(SessionState::Closed);
846 debug!(
847 "session {}: closed gracefully",
848 self.server_hello.session_id
849 );
850 reply_to_ok(reply?)
851 }
852 Err(e) => {
853 self.inner.set_state(SessionState::Closed);
854 debug!(
855 "session {}: close failed: {}",
856 self.server_hello.session_id, e
857 );
858 Err(e)
859 }
860 }
861 }
862
863 pub async fn close(self) -> crate::Result<()> {
870 let result = self.close_session().await;
871 self.writer.lock().await.writer.shutdown().await.ok();
872 result
874 }
875
876 pub async fn kill_session(&self, session_id: u32) -> crate::Result<()> {
878 let inner = format!("<kill-session><session-id>{session_id}</session-id></kill-session>");
879 let reply = self.rpc_raw(&inner).await?;
880 reply_to_ok(reply)
881 }
882
883 pub fn with_timeout(&self, timeout: Duration) -> SessionWithTimeout<'_> {
886 SessionWithTimeout {
887 session: self,
888 timeout,
889 }
890 }
891
892 pub fn pending_rpc_count(&self) -> usize {
898 self.inner.pending.lock().unwrap().len() + self.inner.active_streams.load(Ordering::Acquire)
899 }
900
901 pub fn last_rpc_at(&self) -> Option<Instant> {
904 let nanos = self.inner.last_rpc_nanos.load(Ordering::Acquire);
905 if nanos == 0 {
906 None
907 } else {
908 Some(self.inner.created_at + Duration::from_nanos(nanos))
909 }
910 }
911
912 pub fn connected_since(&self) -> Instant {
915 self.connected_since
916 }
917}
918
919pub struct SessionWithTimeout<'a> {
924 session: &'a Session,
925 timeout: Duration,
926}
927
928impl SessionWithTimeout<'_> {
929 pub async fn rpc_raw(&self, inner_xml: &str) -> crate::Result<RpcReply> {
931 let future = self.session.rpc_send(inner_xml).await?;
932 future.response_with_timeout(self.timeout).await
933 }
934
935 pub async fn get_config(
937 &self,
938 source: Datastore,
939 filter: Option<&str>,
940 ) -> crate::Result<String> {
941 let filter_xml = match filter {
942 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
943 None => String::new(),
944 };
945 let inner = format!(
946 "<get-config><source>{}</source>{filter_xml}</get-config>",
947 source.as_xml()
948 );
949 let reply = self.rpc_raw(&inner).await?;
950 reply_to_data(reply)
951 }
952
953 pub async fn get_config_payload(
955 &self,
956 source: Datastore,
957 filter: Option<&str>,
958 ) -> crate::Result<DataPayload> {
959 let filter_xml = match filter {
960 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
961 None => String::new(),
962 };
963 let inner = format!(
964 "<get-config><source>{}</source>{filter_xml}</get-config>",
965 source.as_xml()
966 );
967 let reply = self.rpc_raw(&inner).await?;
968 reply.into_data()
969 }
970
971 pub async fn get(&self, filter: Option<&str>) -> crate::Result<String> {
973 let filter_xml = match filter {
974 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
975 None => String::new(),
976 };
977 let inner = format!("<get>{filter_xml}</get>");
978 let reply = self.rpc_raw(&inner).await?;
979 reply_to_data(reply)
980 }
981
982 pub async fn get_payload(&self, filter: Option<&str>) -> crate::Result<DataPayload> {
984 let filter_xml = match filter {
985 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
986 None => String::new(),
987 };
988 let inner = format!("<get>{filter_xml}</get>");
989 let reply = self.rpc_raw(&inner).await?;
990 reply.into_data()
991 }
992
993 pub async fn edit_config(&self, target: Datastore, config: &str) -> crate::Result<()> {
995 let inner = format!(
996 "<edit-config><target>{}</target><config>{config}</config></edit-config>",
997 target.as_xml()
998 );
999 let reply = self.rpc_raw(&inner).await?;
1000 reply_to_ok(reply)
1001 }
1002
1003 pub async fn lock(&self, target: Datastore) -> crate::Result<()> {
1005 let inner = format!("<lock><target>{}</target></lock>", target.as_xml());
1006 let reply = self.rpc_raw(&inner).await?;
1007 reply_to_ok(reply)
1008 }
1009
1010 pub async fn unlock(&self, target: Datastore) -> crate::Result<()> {
1012 let inner = format!("<unlock><target>{}</target></unlock>", target.as_xml());
1013 let reply = self.rpc_raw(&inner).await?;
1014 reply_to_ok(reply)
1015 }
1016
1017 pub async fn commit(&self) -> crate::Result<()> {
1019 let reply = self.rpc_raw("<commit/>").await?;
1020 reply_to_ok(reply)
1021 }
1022}
1023
1024async fn exchange_hello<S: AsyncRead + AsyncWrite + Unpin>(
1026 stream: &mut S,
1027 config: &Config,
1028) -> crate::Result<(ServerHello, FramingMode)> {
1029 let fut = crate::hello::exchange(stream, config.codec.max_message_size);
1030 match config.hello_timeout {
1031 Some(duration) => tokio::time::timeout(duration, fut)
1032 .await
1033 .map_err(|_| crate::Error::Transport(TransportError::Timeout(duration)))?,
1034 None => fut.await,
1035 }
1036}
1037
1038enum ReaderMessageState {
1048 AwaitingHeader { buf: BytesMut },
1051 Accumulating { msg_id: u32, buf: BytesMut },
1053 Streaming {
1055 msg_id: u32,
1056 tx: tokio::sync::mpsc::Sender<crate::Result<Bytes>>,
1057 },
1058}
1059
1060async fn reader_loop(
1071 mut reader: FramedRead<ReadHalf<NetconfStream>, NetconfCodec>,
1072 inner: Arc<SessionInner>,
1073 session_id: u32,
1074) {
1075 debug!("session {}: reader loop started", session_id);
1076 let mut disconnect_reason = DisconnectReason::Eof;
1077 let mut state = ReaderMessageState::AwaitingHeader {
1078 buf: BytesMut::new(),
1079 };
1080
1081 loop {
1082 if inner.state() == SessionState::Closing {
1084 reader.decoder_mut().set_closing();
1085 }
1086 let Some(result) = reader.next().await else {
1087 break;
1088 };
1089 match result {
1090 Ok(frame) => {
1091 state = process_frame(frame, state, &inner, session_id).await;
1092 }
1093 Err(e) => {
1094 debug!("session {}: reader error: {e}", session_id);
1095 disconnect_reason = DisconnectReason::TransportError(e.to_string());
1096
1097 if let ReaderMessageState::Streaming { tx, .. } = &state {
1099 let _ = tx.try_send(Err(crate::Error::SessionClosed));
1100 }
1101
1102 let drained = inner.drain_pending();
1103 if drained > 0 {
1104 debug!(
1105 "session {}: drained {} pending RPCs after error",
1106 session_id, drained
1107 );
1108 }
1109 break;
1110 }
1111 }
1112 }
1113
1114 if let ReaderMessageState::Streaming { tx, .. } = &state {
1116 let _ = tx.try_send(Err(crate::Error::SessionClosed));
1117 }
1118
1119 {
1121 let drained = inner.drain_pending();
1122 if drained > 0 {
1123 debug!(
1124 "session {}: drained {} pending RPCs on stream close",
1125 session_id, drained
1126 );
1127 }
1128 }
1129
1130 inner.set_state(SessionState::Closed);
1131 let _ = inner.disconnect_tx.send(Some(disconnect_reason));
1132 debug!("session {}: reader loop ended", session_id);
1133}
1134
1135async fn process_frame(
1137 frame: DecodedFrame,
1138 state: ReaderMessageState,
1139 inner: &SessionInner,
1140 session_id: u32,
1141) -> ReaderMessageState {
1142 match frame {
1143 DecodedFrame::Chunk(chunk) => match state {
1144 ReaderMessageState::AwaitingHeader { mut buf } => {
1145 buf.extend_from_slice(&chunk);
1146
1147 if let Some(msg_id) = extract_message_id_from_bytes(&buf) {
1149 let is_stream = {
1151 let pending = inner.pending.lock().unwrap();
1152 matches!(pending.get(&msg_id), Some(PendingRpc::Stream(_)))
1153 };
1154
1155 if is_stream {
1156 let tx = {
1159 let mut pending = inner.pending.lock().unwrap();
1160 match pending.remove(&msg_id) {
1161 Some(PendingRpc::Stream(tx)) => tx,
1162 _ => {
1165 return ReaderMessageState::Accumulating { msg_id, buf };
1166 }
1167 }
1168 };
1169 inner.active_streams.fetch_add(1, Ordering::Release);
1170 let _ = tx.send(Ok(buf.freeze())).await;
1172 debug!(
1173 "session {}: streaming rpc message-id={}",
1174 session_id, msg_id
1175 );
1176 ReaderMessageState::Streaming { msg_id, tx }
1177 } else {
1178 ReaderMessageState::Accumulating { msg_id, buf }
1180 }
1181 } else {
1182 ReaderMessageState::AwaitingHeader { buf }
1184 }
1185 }
1186 ReaderMessageState::Accumulating { msg_id, mut buf } => {
1187 buf.extend_from_slice(&chunk);
1188 ReaderMessageState::Accumulating { msg_id, buf }
1189 }
1190 ReaderMessageState::Streaming { msg_id, tx } => {
1191 let _ = tx.send(Ok(chunk)).await;
1195 ReaderMessageState::Streaming { msg_id, tx }
1196 }
1197 },
1198
1199 DecodedFrame::EndOfMessage => match state {
1200 ReaderMessageState::AwaitingHeader { .. } => {
1201 trace!("session {}: empty or unparseable message", session_id);
1203 ReaderMessageState::AwaitingHeader {
1204 buf: BytesMut::new(),
1205 }
1206 }
1207 ReaderMessageState::Accumulating { msg_id, buf } => {
1208 let bytes = buf.freeze();
1210 trace!(
1211 "session {}: complete message for msg-id={} ({} bytes)",
1212 session_id,
1213 msg_id,
1214 bytes.len()
1215 );
1216
1217 match message::classify_message(bytes) {
1218 Ok(ServerMessage::RpcReply(reply)) => {
1219 debug!(
1220 "session {}: received rpc-reply message-id={}",
1221 session_id, reply.message_id
1222 );
1223 let tx = {
1224 let mut pending = inner.pending.lock().unwrap();
1225 pending.remove(&reply.message_id)
1226 };
1227 if let Some(PendingRpc::Normal(tx)) = tx {
1228 let nanos = inner.created_at.elapsed().as_nanos() as u64;
1229 inner.last_rpc_nanos.store(nanos, Ordering::Release);
1230 let _ = tx.send(Ok(reply));
1231 } else {
1232 warn!(
1233 "session {}: received reply for unknown message-id {}",
1234 session_id, reply.message_id
1235 );
1236 }
1237 }
1238 Err(e) => {
1239 warn!("session {}: failed to classify message: {e}", session_id);
1240 }
1241 }
1242
1243 ReaderMessageState::AwaitingHeader {
1244 buf: BytesMut::new(),
1245 }
1246 }
1247 ReaderMessageState::Streaming { msg_id, tx } => {
1248 drop(tx);
1250 inner.active_streams.fetch_sub(1, Ordering::Release);
1251 let nanos = inner.created_at.elapsed().as_nanos() as u64;
1252 inner.last_rpc_nanos.store(nanos, Ordering::Release);
1253 debug!(
1254 "session {}: streaming message complete for msg-id={}",
1255 session_id, msg_id
1256 );
1257 ReaderMessageState::AwaitingHeader {
1258 buf: BytesMut::new(),
1259 }
1260 }
1261 },
1262 }
1263}
1264
1265fn reply_to_data(reply: RpcReply) -> crate::Result<String> {
1266 match reply.body {
1267 RpcReplyBody::Data(payload) => Ok(payload.into_string()),
1268 RpcReplyBody::Ok => Ok(String::new()),
1269 RpcReplyBody::Error(errors) => Err(crate::Error::Rpc {
1270 message_id: reply.message_id,
1271 error: errors
1272 .first()
1273 .map(|e| e.error_message.clone())
1274 .unwrap_or_default(),
1275 }),
1276 }
1277}
1278
1279fn reply_to_ok(reply: RpcReply) -> crate::Result<()> {
1280 match reply.body {
1281 RpcReplyBody::Ok => Ok(()),
1282 RpcReplyBody::Data(_) => Ok(()),
1283 RpcReplyBody::Error(errors) => Err(crate::Error::Rpc {
1284 message_id: reply.message_id,
1285 error: errors
1286 .first()
1287 .map(|e| e.error_message.clone())
1288 .unwrap_or_default(),
1289 }),
1290 }
1291}