1use std::collections::HashMap;
2use std::sync::atomic::{AtomicU8, AtomicU64, Ordering};
3use std::sync::{Arc, Mutex};
4use std::time::{Duration, Instant};
5
6use bytes::{Bytes, BytesMut};
7use log::{debug, trace, warn};
8use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
9use tokio_stream::StreamExt;
10use tokio_util::codec::{Encoder, FramedRead};
11
12use crate::codec::{FramingMode, NetconfCodec};
13use crate::config::Config;
14use crate::error::TransportError;
15use crate::hello::ServerHello;
16use crate::message::{self, DataPayload, RpcReply, RpcReplyBody, ServerMessage};
17use crate::stream::NetconfStream;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20#[repr(u8)]
21pub enum SessionState {
22 Ready = 0,
24 Closing = 1,
26 Closed = 2,
28}
29
30impl SessionState {
31 fn from_u8(v: u8) -> Self {
32 match v {
33 0 => Self::Ready,
34 1 => Self::Closing,
35 _ => Self::Closed,
36 }
37 }
38}
39
40impl std::fmt::Display for SessionState {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 match self {
43 Self::Ready => write!(f, "Ready"),
44 Self::Closing => write!(f, "Closing"),
45 Self::Closed => write!(f, "Closed"),
46 }
47 }
48}
49
50#[derive(Debug, Clone)]
55pub enum DisconnectReason {
56 Eof,
58 TransportError(String),
62 Dropped,
64}
65
66impl std::fmt::Display for DisconnectReason {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 match self {
69 Self::Eof => write!(f, "connection closed by remote"),
70 Self::TransportError(e) => write!(f, "transport error: {e}"),
71 Self::Dropped => write!(f, "session dropped"),
72 }
73 }
74}
75
76#[derive(Debug, Clone, Copy)]
77pub enum Datastore {
78 Running,
79 Candidate,
80 Startup,
81}
82
83impl Datastore {
84 fn as_xml(&self) -> &'static str {
85 match self {
86 Datastore::Running => "<running/>",
87 Datastore::Candidate => "<candidate/>",
88 Datastore::Startup => "<startup/>",
89 }
90 }
91}
92
93pub struct RpcFuture {
98 rx: tokio::sync::oneshot::Receiver<crate::Result<RpcReply>>,
99 msg_id: u32,
100 rpc_timeout: Option<Duration>,
101}
102
103impl RpcFuture {
104 pub fn message_id(&self) -> u32 {
106 self.msg_id
107 }
108
109 pub async fn response(self) -> crate::Result<RpcReply> {
115 let result = match self.rpc_timeout {
116 Some(duration) => tokio::time::timeout(duration, self.rx)
117 .await
118 .map_err(|_| crate::Error::Transport(TransportError::Timeout(duration)))?,
119 None => self.rx.await,
120 };
121 result.map_err(|_| crate::Error::SessionClosed)?
122 }
123
124 pub async fn response_with_timeout(self, timeout: Duration) -> crate::Result<RpcReply> {
130 let result = tokio::time::timeout(timeout, self.rx)
131 .await
132 .map_err(|_| crate::Error::Transport(TransportError::Timeout(timeout)))?;
133 result.map_err(|_| crate::Error::SessionClosed)?
134 }
135}
136
137struct SessionInner {
138 pending: Mutex<HashMap<u32, tokio::sync::oneshot::Sender<crate::Result<RpcReply>>>>,
140 state: AtomicU8,
151 disconnect_tx: tokio::sync::watch::Sender<Option<DisconnectReason>>,
155 created_at: Instant,
157 last_rpc_nanos: AtomicU64,
160}
161
162impl SessionInner {
163 fn state(&self) -> SessionState {
164 SessionState::from_u8(self.state.load(Ordering::Acquire))
165 }
166
167 fn set_state(&self, state: SessionState) {
168 self.state.store(state as u8, Ordering::Release);
169 }
170
171 fn drain_pending(&self) -> usize {
172 let mut pending = self.pending.lock().unwrap();
173 let count = pending.len();
174 for (_, tx) in pending.drain() {
175 let _ = tx.send(Err(crate::Error::SessionClosed));
176 }
177 count
178 }
179}
180
181struct WriterState {
184 writer: WriteHalf<NetconfStream>,
185 codec: NetconfCodec,
186}
187
188pub struct Session {
200 writer: tokio::sync::Mutex<WriterState>,
203
204 inner: Arc<SessionInner>,
207
208 server_hello: ServerHello,
210
211 framing: FramingMode,
213
214 rpc_timeout: Option<Duration>,
216
217 disconnect_rx: tokio::sync::watch::Receiver<Option<DisconnectReason>>,
220
221 connected_since: Instant,
223
224 _reader_handle: tokio::task::JoinHandle<()>,
229
230 _keep_alive: Option<Box<dyn std::any::Any + Send + Sync>>,
233}
234
235impl Drop for Session {
236 fn drop(&mut self) {
237 let drained = self.inner.drain_pending();
239 if drained > 0 {
240 debug!(
241 "session {}: drop: drained {drained} pending RPCs",
242 self.server_hello.session_id
243 );
244 }
245 self.inner.set_state(SessionState::Closed);
247 self.inner.disconnect_tx.send_if_modified(|current| {
249 if current.is_none() {
250 *current = Some(DisconnectReason::Dropped);
251 true
252 } else {
253 false
254 }
255 });
256 self._reader_handle.abort();
258 }
259}
260
261impl Session {
262 pub async fn connect(
264 host: &str,
265 port: u16,
266 username: &str,
267 password: &str,
268 ) -> crate::Result<Self> {
269 Self::connect_with_config(host, port, username, password, Config::default()).await
270 }
271
272 pub async fn connect_with_config(
274 host: &str,
275 port: u16,
276 username: &str,
277 password: &str,
278 config: Config,
279 ) -> crate::Result<Self> {
280 let (mut stream, keep_alive) =
281 crate::transport::connect(host, port, username, password, &config).await?;
282 let (server_hello, framing) = exchange_hello(&mut stream, &config).await?;
283 Self::build(stream, Some(keep_alive), server_hello, framing, config)
284 }
285
286 pub async fn from_stream<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
288 stream: S,
289 ) -> crate::Result<Self> {
290 Self::from_stream_with_config(stream, Config::default()).await
291 }
292
293 pub async fn from_stream_with_config<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
295 mut stream: S,
296 config: Config,
297 ) -> crate::Result<Self> {
298 let (server_hello, framing) = exchange_hello(&mut stream, &config).await?;
299 let boxed: NetconfStream = Box::new(stream);
300 Self::build(boxed, None, server_hello, framing, config)
301 }
302
303 fn build(
304 stream: NetconfStream,
305 keep_alive: Option<Box<dyn std::any::Any + Send + Sync>>,
306 server_hello: ServerHello,
307 framing: FramingMode,
308 config: Config,
309 ) -> crate::Result<Self> {
310 debug!(
311 "session {}: building (framing={:?}, capabilities={})",
312 server_hello.session_id,
313 framing,
314 server_hello.capabilities.len()
315 );
316 let (read_half, write_half) = tokio::io::split(stream);
317
318 let read_codec = NetconfCodec::new(framing, config.codec);
319 let write_codec = NetconfCodec::new(framing, config.codec);
320 let reader = FramedRead::new(read_half, read_codec);
321
322 let (disconnect_tx, disconnect_rx) = tokio::sync::watch::channel(None);
323
324 let inner = Arc::new(SessionInner {
325 pending: Mutex::new(HashMap::new()),
326 state: AtomicU8::new(SessionState::Ready as u8),
327 disconnect_tx,
328 created_at: Instant::now(),
329 last_rpc_nanos: AtomicU64::new(0),
330 });
331
332 let reader_inner = Arc::clone(&inner);
333 let session_id = server_hello.session_id;
334 let reader_handle = tokio::spawn(async move {
335 reader_loop(reader, reader_inner, session_id).await;
336 });
337
338 Ok(Self {
339 writer: tokio::sync::Mutex::new(WriterState {
340 writer: write_half,
341 codec: write_codec,
342 }),
343 inner,
344 server_hello,
345 framing,
346 rpc_timeout: config.rpc_timeout,
347 disconnect_rx,
348 connected_since: Instant::now(),
349 _reader_handle: reader_handle,
350 _keep_alive: keep_alive,
351 })
352 }
353
354 pub fn session_id(&self) -> u32 {
355 self.server_hello.session_id
356 }
357
358 pub fn server_capabilities(&self) -> &[String] {
359 &self.server_hello.capabilities
360 }
361
362 pub fn framing_mode(&self) -> FramingMode {
363 self.framing
364 }
365
366 pub fn state(&self) -> SessionState {
367 self.inner.state()
368 }
369
370 pub async fn disconnected(&self) -> DisconnectReason {
387 let mut rx = self.disconnect_rx.clone();
388 if let Some(reason) = rx.borrow_and_update().clone() {
390 return reason;
391 }
392 loop {
395 if rx.changed().await.is_err() {
396 return DisconnectReason::Dropped;
397 }
398 if let Some(reason) = rx.borrow_and_update().clone() {
399 return reason;
400 }
401 }
402 }
403
404 fn check_state(&self) -> crate::Result<()> {
405 let state = self.inner.state();
406 if state != SessionState::Ready {
407 return Err(crate::Error::InvalidState(state.to_string()));
408 }
409 Ok(())
410 }
411
412 async fn send_encoded(&self, xml: &str) -> crate::Result<()> {
423 let mut buf = BytesMut::new();
424 let mut state = self.writer.lock().await;
425 state.codec.encode(Bytes::from(xml.to_string()), &mut buf)?;
426 trace!(
427 "session {}: writing {} bytes to stream",
428 self.server_hello.session_id,
429 buf.len()
430 );
431 state.writer.write_all(&buf).await?;
432 state.writer.flush().await?;
433 Ok(())
434 }
435
436 pub async fn rpc_send(&self, inner_xml: &str) -> crate::Result<RpcFuture> {
443 self.check_state()?;
444 let (msg_id, xml) = message::build_rpc(inner_xml);
445 debug!(
446 "session {}: sending rpc message-id={} ({} bytes)",
447 self.server_hello.session_id,
448 msg_id,
449 xml.len()
450 );
451 trace!(
452 "session {}: rpc content: {}",
453 self.server_hello.session_id, inner_xml
454 );
455 let (tx, rx) = tokio::sync::oneshot::channel();
456
457 self.inner.pending.lock().unwrap().insert(msg_id, tx);
458
459 if let Err(e) = self.send_encoded(&xml).await {
460 debug!(
461 "session {}: send failed for message-id={}: {}",
462 self.server_hello.session_id, msg_id, e
463 );
464 self.inner.pending.lock().unwrap().remove(&msg_id);
465 return Err(e);
466 }
467 Ok(RpcFuture {
468 rx,
469 msg_id,
470 rpc_timeout: self.rpc_timeout,
471 })
472 }
473
474 pub async fn rpc_raw(&self, inner_xml: &str) -> crate::Result<RpcReply> {
476 let future = self.rpc_send(inner_xml).await?;
477 future.response().await
478 }
479
480 async fn rpc_send_unchecked(&self, inner_xml: &str) -> crate::Result<RpcFuture> {
482 let (msg_id, xml) = message::build_rpc(inner_xml);
483 let (tx, rx) = tokio::sync::oneshot::channel();
484
485 self.inner.pending.lock().unwrap().insert(msg_id, tx);
486
487 if let Err(e) = self.send_encoded(&xml).await {
488 self.inner.pending.lock().unwrap().remove(&msg_id);
489 return Err(e);
490 }
491
492 Ok(RpcFuture {
493 rx,
494 msg_id,
495 rpc_timeout: self.rpc_timeout,
496 })
497 }
498
499 pub async fn get_config(
501 &self,
502 source: Datastore,
503 filter: Option<&str>,
504 ) -> crate::Result<String> {
505 let filter_xml = match filter {
506 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
507 None => String::new(),
508 };
509 let inner = format!(
510 "<get-config><source>{}</source>{filter_xml}</get-config>",
511 source.as_xml()
512 );
513 let reply = self.rpc_raw(&inner).await?;
514 reply_to_data(reply)
515 }
516
517 pub async fn get_config_payload(
523 &self,
524 source: Datastore,
525 filter: Option<&str>,
526 ) -> crate::Result<DataPayload> {
527 let filter_xml = match filter {
528 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
529 None => String::new(),
530 };
531 let inner = format!(
532 "<get-config><source>{}</source>{filter_xml}</get-config>",
533 source.as_xml()
534 );
535 let reply = self.rpc_raw(&inner).await?;
536 reply.into_data()
537 }
538
539 pub async fn get(&self, filter: Option<&str>) -> crate::Result<String> {
541 let filter_xml = match filter {
542 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
543 None => String::new(),
544 };
545 let inner = format!("<get>{filter_xml}</get>");
546 let reply = self.rpc_raw(&inner).await?;
547 reply_to_data(reply)
548 }
549
550 pub async fn get_payload(&self, filter: Option<&str>) -> crate::Result<DataPayload> {
555 let filter_xml = match filter {
556 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
557 None => String::new(),
558 };
559 let inner = format!("<get>{filter_xml}</get>");
560 let reply = self.rpc_raw(&inner).await?;
561 reply.into_data()
562 }
563
564 pub async fn edit_config(&self, target: Datastore, config: &str) -> crate::Result<()> {
566 let inner = format!(
567 "<edit-config><target>{}</target><config>{config}</config></edit-config>",
568 target.as_xml()
569 );
570 let reply = self.rpc_raw(&inner).await?;
571 reply_to_ok(reply)
572 }
573
574 pub async fn lock(&self, target: Datastore) -> crate::Result<()> {
576 let inner = format!("<lock><target>{}</target></lock>", target.as_xml());
577 let reply = self.rpc_raw(&inner).await?;
578 reply_to_ok(reply)
579 }
580
581 pub async fn unlock(&self, target: Datastore) -> crate::Result<()> {
583 let inner = format!("<unlock><target>{}</target></unlock>", target.as_xml());
584 let reply = self.rpc_raw(&inner).await?;
585 reply_to_ok(reply)
586 }
587
588 pub async fn commit(&self) -> crate::Result<()> {
590 let reply = self.rpc_raw("<commit/>").await?;
591 reply_to_ok(reply)
592 }
593
594 pub async fn close_session(&self) -> crate::Result<()> {
596 let prev = self.inner.state.compare_exchange(
599 SessionState::Ready as u8,
600 SessionState::Closing as u8,
601 Ordering::AcqRel,
602 Ordering::Acquire,
603 );
604 if let Err(current) = prev {
605 let state = SessionState::from_u8(current);
606 return Err(crate::Error::InvalidState(state.to_string()));
607 }
608 debug!("session {}: closing", self.server_hello.session_id);
609 let result = self.rpc_send_unchecked("<close-session/>").await;
610 match result {
611 Ok(future) => {
612 let reply = future.response().await;
613 self.inner.set_state(SessionState::Closed);
614 debug!(
615 "session {}: closed gracefully",
616 self.server_hello.session_id
617 );
618 reply_to_ok(reply?)
619 }
620 Err(e) => {
621 self.inner.set_state(SessionState::Closed);
622 debug!(
623 "session {}: close failed: {}",
624 self.server_hello.session_id, e
625 );
626 Err(e)
627 }
628 }
629 }
630
631 pub async fn kill_session(&self, session_id: u32) -> crate::Result<()> {
633 let inner = format!("<kill-session><session-id>{session_id}</session-id></kill-session>");
634 let reply = self.rpc_raw(&inner).await?;
635 reply_to_ok(reply)
636 }
637
638 pub fn with_timeout(&self, timeout: Duration) -> SessionWithTimeout<'_> {
641 SessionWithTimeout {
642 session: self,
643 timeout,
644 }
645 }
646
647 pub fn pending_rpc_count(&self) -> usize {
649 self.inner.pending.lock().unwrap().len()
650 }
651
652 pub fn last_rpc_at(&self) -> Option<Instant> {
655 let nanos = self.inner.last_rpc_nanos.load(Ordering::Acquire);
656 if nanos == 0 {
657 None
658 } else {
659 Some(self.inner.created_at + Duration::from_nanos(nanos))
660 }
661 }
662
663 pub fn connected_since(&self) -> Instant {
666 self.connected_since
667 }
668}
669
670pub struct SessionWithTimeout<'a> {
675 session: &'a Session,
676 timeout: Duration,
677}
678
679impl SessionWithTimeout<'_> {
680 pub async fn rpc_raw(&self, inner_xml: &str) -> crate::Result<RpcReply> {
682 let future = self.session.rpc_send(inner_xml).await?;
683 future.response_with_timeout(self.timeout).await
684 }
685
686 pub async fn get_config(
688 &self,
689 source: Datastore,
690 filter: Option<&str>,
691 ) -> crate::Result<String> {
692 let filter_xml = match filter {
693 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
694 None => String::new(),
695 };
696 let inner = format!(
697 "<get-config><source>{}</source>{filter_xml}</get-config>",
698 source.as_xml()
699 );
700 let reply = self.rpc_raw(&inner).await?;
701 reply_to_data(reply)
702 }
703
704 pub async fn get_config_payload(
706 &self,
707 source: Datastore,
708 filter: Option<&str>,
709 ) -> crate::Result<DataPayload> {
710 let filter_xml = match filter {
711 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
712 None => String::new(),
713 };
714 let inner = format!(
715 "<get-config><source>{}</source>{filter_xml}</get-config>",
716 source.as_xml()
717 );
718 let reply = self.rpc_raw(&inner).await?;
719 reply.into_data()
720 }
721
722 pub async fn get(&self, filter: Option<&str>) -> crate::Result<String> {
724 let filter_xml = match filter {
725 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
726 None => String::new(),
727 };
728 let inner = format!("<get>{filter_xml}</get>");
729 let reply = self.rpc_raw(&inner).await?;
730 reply_to_data(reply)
731 }
732
733 pub async fn get_payload(&self, filter: Option<&str>) -> crate::Result<DataPayload> {
735 let filter_xml = match filter {
736 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
737 None => String::new(),
738 };
739 let inner = format!("<get>{filter_xml}</get>");
740 let reply = self.rpc_raw(&inner).await?;
741 reply.into_data()
742 }
743
744 pub async fn edit_config(&self, target: Datastore, config: &str) -> crate::Result<()> {
746 let inner = format!(
747 "<edit-config><target>{}</target><config>{config}</config></edit-config>",
748 target.as_xml()
749 );
750 let reply = self.rpc_raw(&inner).await?;
751 reply_to_ok(reply)
752 }
753
754 pub async fn lock(&self, target: Datastore) -> crate::Result<()> {
756 let inner = format!("<lock><target>{}</target></lock>", target.as_xml());
757 let reply = self.rpc_raw(&inner).await?;
758 reply_to_ok(reply)
759 }
760
761 pub async fn unlock(&self, target: Datastore) -> crate::Result<()> {
763 let inner = format!("<unlock><target>{}</target></unlock>", target.as_xml());
764 let reply = self.rpc_raw(&inner).await?;
765 reply_to_ok(reply)
766 }
767
768 pub async fn commit(&self) -> crate::Result<()> {
770 let reply = self.rpc_raw("<commit/>").await?;
771 reply_to_ok(reply)
772 }
773}
774
775async fn exchange_hello<S: AsyncRead + AsyncWrite + Unpin>(
777 stream: &mut S,
778 config: &Config,
779) -> crate::Result<(ServerHello, FramingMode)> {
780 let fut = crate::hello::exchange(stream, config.codec.max_message_size);
781 match config.hello_timeout {
782 Some(duration) => tokio::time::timeout(duration, fut)
783 .await
784 .map_err(|_| crate::Error::Transport(TransportError::Timeout(duration)))?,
785 None => fut.await,
786 }
787}
788
789async fn reader_loop(
793 mut reader: FramedRead<ReadHalf<NetconfStream>, NetconfCodec>,
794 inner: Arc<SessionInner>,
795 session_id: u32,
796) {
797 debug!("session {}: reader loop started", session_id);
798 let mut disconnect_reason = DisconnectReason::Eof;
800 while let Some(result) = reader.next().await {
804 match result {
805 Ok(bytes) => {
807 trace!(
808 "session {}: received frame ({} bytes)",
809 session_id,
810 bytes.len()
811 );
812 match message::classify_message(bytes) {
813 Ok(ServerMessage::RpcReply(reply)) => {
814 debug!(
815 "session {}: received rpc-reply message-id={}",
816 session_id, reply.message_id
817 );
818 let tx = {
822 let mut pending = inner.pending.lock().unwrap();
823 pending.remove(&reply.message_id)
824 };
825 if let Some(tx) = tx {
826 let nanos = inner.created_at.elapsed().as_nanos() as u64;
827 inner.last_rpc_nanos.store(nanos, Ordering::Release);
828 let _ = tx.send(Ok(reply));
830 } else {
831 warn!(
832 "session {}: received reply for unknown message-id {}",
833 session_id, reply.message_id
834 );
835 }
836 }
837 Err(e) => {
838 warn!("session {}: failed to classify message: {e}", session_id);
839 }
840 }
841 }
842 Err(e) => {
845 debug!("session {}: reader error: {e}", session_id);
846 disconnect_reason = DisconnectReason::TransportError(e.to_string());
847 let drained = inner.drain_pending();
848 if drained > 0 {
849 debug!(
850 "session {}: drained {} pending RPCs after error",
851 session_id, drained
852 );
853 }
854 break;
855 }
856 }
857 }
858 {
861 let drained = inner.drain_pending();
862 if drained > 0 {
863 debug!(
864 "session {}: drained {} pending RPCs on stream close",
865 session_id, drained
866 );
867 }
868 }
869 inner.set_state(SessionState::Closed);
872 let _ = inner.disconnect_tx.send(Some(disconnect_reason));
875 debug!("session {}: reader loop ended", session_id);
876}
877
878fn reply_to_data(reply: RpcReply) -> crate::Result<String> {
879 match reply.body {
880 RpcReplyBody::Data(payload) => Ok(payload.into_string()),
881 RpcReplyBody::Ok => Ok(String::new()),
882 RpcReplyBody::Error(errors) => Err(crate::Error::Rpc {
883 message_id: reply.message_id,
884 error: errors
885 .first()
886 .map(|e| e.error_message.clone())
887 .unwrap_or_default(),
888 }),
889 }
890}
891
892fn reply_to_ok(reply: RpcReply) -> crate::Result<()> {
893 match reply.body {
894 RpcReplyBody::Ok => Ok(()),
895 RpcReplyBody::Data(_) => Ok(()),
896 RpcReplyBody::Error(errors) => Err(crate::Error::Rpc {
897 message_id: reply.message_id,
898 error: errors
899 .first()
900 .map(|e| e.error_message.clone())
901 .unwrap_or_default(),
902 }),
903 }
904}