1use std::collections::HashMap;
2use std::sync::atomic::{AtomicU8, Ordering};
3use std::sync::{Arc, Mutex};
4
5use bytes::{Bytes, BytesMut};
6use log::{debug, trace, warn};
7use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
8use tokio_stream::StreamExt;
9use tokio_util::codec::{Encoder, FramedRead};
10
11use crate::codec::{CodecConfig, FramingMode, NetconfCodec};
12use crate::hello::ServerHello;
13use crate::message::{self, DataPayload, RpcReply, RpcReplyBody, ServerMessage};
14use crate::stream::NetconfStream;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17#[repr(u8)]
18pub enum SessionState {
19 Ready = 0,
21 Closing = 1,
23 Closed = 2,
25}
26
27impl SessionState {
28 fn from_u8(v: u8) -> Self {
29 match v {
30 0 => Self::Ready,
31 1 => Self::Closing,
32 _ => Self::Closed,
33 }
34 }
35}
36
37impl std::fmt::Display for SessionState {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 match self {
40 Self::Ready => write!(f, "Ready"),
41 Self::Closing => write!(f, "Closing"),
42 Self::Closed => write!(f, "Closed"),
43 }
44 }
45}
46
47#[derive(Debug, Clone, Default)]
49pub struct SessionConfig {
50 pub codec: CodecConfig,
52}
53
54#[derive(Debug, Clone, Copy)]
55pub enum Datastore {
56 Running,
57 Candidate,
58 Startup,
59}
60
61impl Datastore {
62 fn as_xml(&self) -> &'static str {
63 match self {
64 Datastore::Running => "<running/>",
65 Datastore::Candidate => "<candidate/>",
66 Datastore::Startup => "<startup/>",
67 }
68 }
69}
70
71pub struct RpcFuture {
76 rx: tokio::sync::oneshot::Receiver<crate::Result<RpcReply>>,
77 msg_id: u32,
78}
79
80impl RpcFuture {
81 pub fn message_id(&self) -> u32 {
83 self.msg_id
84 }
85
86 pub async fn response(self) -> crate::Result<RpcReply> {
88 self.rx.await.map_err(|_| crate::Error::SessionClosed)?
89 }
90}
91
92struct SessionInner {
93 pending: Mutex<HashMap<u32, tokio::sync::oneshot::Sender<crate::Result<RpcReply>>>>,
95 state: AtomicU8,
106}
107
108impl SessionInner {
109 fn state(&self) -> SessionState {
110 SessionState::from_u8(self.state.load(Ordering::Acquire))
111 }
112
113 fn set_state(&self, state: SessionState) {
114 self.state.store(state as u8, Ordering::Release);
115 }
116}
117
118pub struct Session {
130 writer: WriteHalf<NetconfStream>,
132
133 write_codec: NetconfCodec,
137
138 inner: Arc<SessionInner>,
141
142 server_hello: ServerHello,
144
145 framing: FramingMode,
147
148 _keep_alive: Option<Box<dyn std::any::Any + Send>>,
151
152 _reader_handle: tokio::task::JoinHandle<()>,
155}
156
157impl Drop for Session {
158 fn drop(&mut self) {
159 self._reader_handle.abort();
160 }
161}
162
163impl Session {
164 pub async fn connect(
166 host: &str,
167 port: u16,
168 username: &str,
169 password: &str,
170 ) -> crate::Result<Self> {
171 Self::connect_with_config(host, port, username, password, SessionConfig::default()).await
172 }
173
174 pub async fn connect_with_config(
176 host: &str,
177 port: u16,
178 username: &str,
179 password: &str,
180 config: SessionConfig,
181 ) -> crate::Result<Self> {
182 let (mut stream, keep_alive) =
183 crate::transport::connect(host, port, username, password).await?;
184 let (server_hello, framing) =
185 crate::hello::exchange(&mut stream, config.codec.max_message_size).await?;
186 Self::build(stream, Some(keep_alive), server_hello, framing, config)
187 }
188
189 pub async fn from_stream<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
191 stream: S,
192 ) -> crate::Result<Self> {
193 Self::from_stream_with_config(stream, SessionConfig::default()).await
194 }
195
196 pub async fn from_stream_with_config<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
198 mut stream: S,
199 config: SessionConfig,
200 ) -> crate::Result<Self> {
201 let (server_hello, framing) =
202 crate::hello::exchange(&mut stream, config.codec.max_message_size).await?;
203 let boxed: NetconfStream = Box::new(stream);
204 Self::build(boxed, None, server_hello, framing, config)
205 }
206
207 fn build(
208 stream: NetconfStream,
209 keep_alive: Option<Box<dyn std::any::Any + Send>>,
210 server_hello: ServerHello,
211 framing: FramingMode,
212 config: SessionConfig,
213 ) -> crate::Result<Self> {
214 debug!(
215 "session {}: building (framing={:?}, capabilities={})",
216 server_hello.session_id,
217 framing,
218 server_hello.capabilities.len()
219 );
220 let (read_half, write_half) = tokio::io::split(stream);
221
222 let read_codec = NetconfCodec::new(framing, config.codec);
223 let write_codec = NetconfCodec::new(framing, config.codec);
224 let reader = FramedRead::new(read_half, read_codec);
225
226 let inner = Arc::new(SessionInner {
227 pending: Mutex::new(HashMap::new()),
228 state: AtomicU8::new(SessionState::Ready as u8),
229 });
230
231 let reader_inner = Arc::clone(&inner);
232 let session_id = server_hello.session_id;
233 let reader_handle = tokio::spawn(async move {
234 reader_loop(reader, reader_inner, session_id).await;
235 });
236
237 Ok(Self {
238 writer: write_half,
239 write_codec,
240 inner,
241 server_hello,
242 framing,
243 _keep_alive: keep_alive,
244 _reader_handle: reader_handle,
245 })
246 }
247
248 pub fn session_id(&self) -> u32 {
249 self.server_hello.session_id
250 }
251
252 pub fn server_capabilities(&self) -> &[String] {
253 &self.server_hello.capabilities
254 }
255
256 pub fn framing_mode(&self) -> FramingMode {
257 self.framing
258 }
259
260 pub fn state(&self) -> SessionState {
261 self.inner.state()
262 }
263
264 fn check_state(&self) -> crate::Result<()> {
265 let state = self.inner.state();
266 if state != SessionState::Ready {
267 return Err(crate::Error::InvalidState(state.to_string()));
268 }
269 Ok(())
270 }
271
272 async fn send_encoded(&mut self, xml: &str) -> crate::Result<()> {
281 let mut buf = BytesMut::new();
282 self.write_codec
283 .encode(Bytes::from(xml.to_string()), &mut buf)?;
284 trace!(
285 "session {}: writing {} bytes to stream",
286 self.server_hello.session_id,
287 buf.len()
288 );
289 self.writer.write_all(&buf).await?;
290 self.writer.flush().await?;
291 Ok(())
292 }
293
294 pub async fn rpc_send(&mut self, inner_xml: &str) -> crate::Result<RpcFuture> {
301 self.check_state()?;
302 let (msg_id, xml) = message::build_rpc(inner_xml);
303 debug!(
304 "session {}: sending rpc message-id={} ({} bytes)",
305 self.server_hello.session_id,
306 msg_id,
307 xml.len()
308 );
309 trace!(
310 "session {}: rpc content: {}",
311 self.server_hello.session_id, inner_xml
312 );
313 let (tx, rx) = tokio::sync::oneshot::channel();
314
315 self.inner.pending.lock().unwrap().insert(msg_id, tx);
316
317 if let Err(e) = self.send_encoded(&xml).await {
318 debug!(
319 "session {}: send failed for message-id={}: {}",
320 self.server_hello.session_id, msg_id, e
321 );
322 self.inner.pending.lock().unwrap().remove(&msg_id);
323 return Err(e);
324 }
325 Ok(RpcFuture { rx, msg_id })
326 }
327
328 pub async fn rpc_raw(&mut self, inner_xml: &str) -> crate::Result<RpcReply> {
330 let future = self.rpc_send(inner_xml).await?;
331 future.response().await
332 }
333
334 async fn rpc_send_unchecked(&mut self, inner_xml: &str) -> crate::Result<RpcFuture> {
336 let (msg_id, xml) = message::build_rpc(inner_xml);
337 let (tx, rx) = tokio::sync::oneshot::channel();
338
339 self.inner.pending.lock().unwrap().insert(msg_id, tx);
340
341 if let Err(e) = self.send_encoded(&xml).await {
342 self.inner.pending.lock().unwrap().remove(&msg_id);
343 return Err(e);
344 }
345
346 Ok(RpcFuture { rx, msg_id })
347 }
348
349 pub async fn get_config(
351 &mut self,
352 source: Datastore,
353 filter: Option<&str>,
354 ) -> crate::Result<String> {
355 let filter_xml = match filter {
356 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
357 None => String::new(),
358 };
359 let inner = format!(
360 "<get-config><source>{}</source>{filter_xml}</get-config>",
361 source.as_xml()
362 );
363 let reply = self.rpc_raw(&inner).await?;
364 reply_to_data(reply)
365 }
366
367 pub async fn get_config_payload(
373 &mut self,
374 source: Datastore,
375 filter: Option<&str>,
376 ) -> crate::Result<DataPayload> {
377 let filter_xml = match filter {
378 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
379 None => String::new(),
380 };
381 let inner = format!(
382 "<get-config><source>{}</source>{filter_xml}</get-config>",
383 source.as_xml()
384 );
385 let reply = self.rpc_raw(&inner).await?;
386 reply.into_data()
387 }
388
389 pub async fn get(&mut self, filter: Option<&str>) -> crate::Result<String> {
391 let filter_xml = match filter {
392 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
393 None => String::new(),
394 };
395 let inner = format!("<get>{filter_xml}</get>");
396 let reply = self.rpc_raw(&inner).await?;
397 reply_to_data(reply)
398 }
399
400 pub async fn get_payload(&mut self, filter: Option<&str>) -> crate::Result<DataPayload> {
405 let filter_xml = match filter {
406 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
407 None => String::new(),
408 };
409 let inner = format!("<get>{filter_xml}</get>");
410 let reply = self.rpc_raw(&inner).await?;
411 reply.into_data()
412 }
413
414 pub async fn edit_config(&mut self, target: Datastore, config: &str) -> crate::Result<()> {
416 let inner = format!(
417 "<edit-config><target>{}</target><config>{config}</config></edit-config>",
418 target.as_xml()
419 );
420 let reply = self.rpc_raw(&inner).await?;
421 reply_to_ok(reply)
422 }
423
424 pub async fn lock(&mut self, target: Datastore) -> crate::Result<()> {
426 let inner = format!("<lock><target>{}</target></lock>", target.as_xml());
427 let reply = self.rpc_raw(&inner).await?;
428 reply_to_ok(reply)
429 }
430
431 pub async fn unlock(&mut self, target: Datastore) -> crate::Result<()> {
433 let inner = format!("<unlock><target>{}</target></unlock>", target.as_xml());
434 let reply = self.rpc_raw(&inner).await?;
435 reply_to_ok(reply)
436 }
437
438 pub async fn commit(&mut self) -> crate::Result<()> {
440 let reply = self.rpc_raw("<commit/>").await?;
441 reply_to_ok(reply)
442 }
443
444 pub async fn close_session(&mut self) -> crate::Result<()> {
446 self.check_state()?;
447 debug!("session {}: closing", self.server_hello.session_id);
448 self.inner.set_state(SessionState::Closing);
449 let result = self.rpc_send_unchecked("<close-session/>").await;
450 match result {
451 Ok(future) => {
452 let reply = future.response().await;
453 self.inner.set_state(SessionState::Closed);
454 debug!(
455 "session {}: closed gracefully",
456 self.server_hello.session_id
457 );
458 reply_to_ok(reply?)
459 }
460 Err(e) => {
461 self.inner.set_state(SessionState::Closed);
462 debug!(
463 "session {}: close failed: {}",
464 self.server_hello.session_id, e
465 );
466 Err(e)
467 }
468 }
469 }
470
471 pub async fn kill_session(&mut self, session_id: u32) -> crate::Result<()> {
473 let inner = format!("<kill-session><session-id>{session_id}</session-id></kill-session>");
474 let reply = self.rpc_raw(&inner).await?;
475 reply_to_ok(reply)
476 }
477}
478
479async fn reader_loop(
483 mut reader: FramedRead<ReadHalf<NetconfStream>, NetconfCodec>,
484 inner: Arc<SessionInner>,
485 session_id: u32,
486) {
487 debug!("session {}: reader loop started", session_id);
488 while let Some(result) = reader.next().await {
492 match result {
493 Ok(bytes) => {
495 trace!(
496 "session {}: received frame ({} bytes)",
497 session_id,
498 bytes.len()
499 );
500 match message::classify_message(bytes) {
501 Ok(ServerMessage::RpcReply(reply)) => {
502 debug!(
503 "session {}: received rpc-reply message-id={}",
504 session_id, reply.message_id
505 );
506 let tx = {
510 let mut pending = inner.pending.lock().unwrap();
511 pending.remove(&reply.message_id)
512 };
513 if let Some(tx) = tx {
514 let _ = tx.send(Ok(reply));
516 } else {
517 warn!(
518 "session {}: received reply for unknown message-id {}",
519 session_id, reply.message_id
520 );
521 }
522 }
523 Err(e) => {
524 warn!("session {}: failed to classify message: {e}", session_id);
525 }
526 }
527 }
528 Err(e) => {
531 debug!("session {}: reader error: {e}", session_id);
532 let mut pending = inner.pending.lock().unwrap();
533 let drained = pending.len();
534 for (_, tx) in pending.drain() {
535 let _ = tx.send(Err(crate::Error::SessionClosed));
536 }
537 if drained > 0 {
538 debug!(
539 "session {}: drained {} pending RPCs after error",
540 session_id, drained
541 );
542 }
543 break;
544 }
545 }
546 }
547 {
550 let mut pending = inner.pending.lock().unwrap();
551 let drained = pending.len();
552 for (_, tx) in pending.drain() {
553 let _ = tx.send(Err(crate::Error::SessionClosed));
554 }
555 if drained > 0 {
556 debug!(
557 "session {}: drained {} pending RPCs on stream close",
558 session_id, drained
559 );
560 }
561 }
562 inner.set_state(SessionState::Closed);
565 debug!("session {}: reader loop ended", session_id);
566}
567
568fn reply_to_data(reply: RpcReply) -> crate::Result<String> {
569 match reply.body {
570 RpcReplyBody::Data(payload) => Ok(payload.into_string()),
571 RpcReplyBody::Ok => Ok(String::new()),
572 RpcReplyBody::Error(errors) => Err(crate::Error::Rpc {
573 message_id: reply.message_id,
574 error: errors
575 .first()
576 .map(|e| e.error_message.clone())
577 .unwrap_or_default(),
578 }),
579 }
580}
581
582fn reply_to_ok(reply: RpcReply) -> crate::Result<()> {
583 match reply.body {
584 RpcReplyBody::Ok => Ok(()),
585 RpcReplyBody::Data(_) => Ok(()),
586 RpcReplyBody::Error(errors) => Err(crate::Error::Rpc {
587 message_id: reply.message_id,
588 error: errors
589 .first()
590 .map(|e| e.error_message.clone())
591 .unwrap_or_default(),
592 }),
593 }
594}