1use std::collections::HashMap;
2use std::sync::atomic::{AtomicU8, Ordering};
3use std::sync::{Arc, Mutex};
4
5use bytes::{Bytes, BytesMut};
6use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
7use tokio_stream::StreamExt;
8use tokio_util::codec::{Encoder, FramedRead};
9
10use crate::codec::{CodecConfig, FramingMode, NetconfCodec};
11use crate::hello::ServerHello;
12use crate::message::{self, DataPayload, RpcReply, RpcReplyBody, ServerMessage};
13use crate::stream::NetconfStream;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16#[repr(u8)]
17pub enum SessionState {
18 Ready = 0,
20 Closing = 1,
22 Closed = 2,
24}
25
26impl SessionState {
27 fn from_u8(v: u8) -> Self {
28 match v {
29 0 => Self::Ready,
30 1 => Self::Closing,
31 _ => Self::Closed,
32 }
33 }
34}
35
36impl std::fmt::Display for SessionState {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 match self {
39 Self::Ready => write!(f, "Ready"),
40 Self::Closing => write!(f, "Closing"),
41 Self::Closed => write!(f, "Closed"),
42 }
43 }
44}
45
46#[derive(Debug, Clone, Default)]
48pub struct SessionConfig {
49 pub codec: CodecConfig,
51}
52
53#[derive(Debug, Clone, Copy)]
54pub enum Datastore {
55 Running,
56 Candidate,
57 Startup,
58}
59
60impl Datastore {
61 fn as_xml(&self) -> &'static str {
62 match self {
63 Datastore::Running => "<running/>",
64 Datastore::Candidate => "<candidate/>",
65 Datastore::Startup => "<startup/>",
66 }
67 }
68}
69
70pub struct RpcFuture {
75 rx: tokio::sync::oneshot::Receiver<crate::Result<RpcReply>>,
76 msg_id: u32,
77}
78
79impl RpcFuture {
80 pub fn message_id(&self) -> u32 {
82 self.msg_id
83 }
84
85 pub async fn response(self) -> crate::Result<RpcReply> {
87 self.rx.await.map_err(|_| crate::Error::SessionClosed)?
88 }
89}
90
91struct SessionInner {
92 pending: Mutex<HashMap<u32, tokio::sync::oneshot::Sender<crate::Result<RpcReply>>>>,
94 state: AtomicU8,
105}
106
107impl SessionInner {
108 fn state(&self) -> SessionState {
109 SessionState::from_u8(self.state.load(Ordering::Acquire))
110 }
111
112 fn set_state(&self, state: SessionState) {
113 self.state.store(state as u8, Ordering::Release);
114 }
115}
116
117pub struct Session {
129 writer: WriteHalf<NetconfStream>,
131
132 write_codec: NetconfCodec,
136
137 inner: Arc<SessionInner>,
140
141 server_hello: ServerHello,
143
144 framing: FramingMode,
146
147 _keep_alive: Option<Box<dyn std::any::Any + Send>>,
150
151 _reader_handle: tokio::task::JoinHandle<()>,
154}
155
156impl Drop for Session {
157 fn drop(&mut self) {
158 self._reader_handle.abort();
159 }
160}
161
162impl Session {
163 pub async fn connect(
165 host: &str,
166 port: u16,
167 username: &str,
168 password: &str,
169 ) -> crate::Result<Self> {
170 Self::connect_with_config(host, port, username, password, SessionConfig::default()).await
171 }
172
173 pub async fn connect_with_config(
175 host: &str,
176 port: u16,
177 username: &str,
178 password: &str,
179 config: SessionConfig,
180 ) -> crate::Result<Self> {
181 let (mut stream, keep_alive) =
182 crate::transport::connect(host, port, username, password).await?;
183 let (server_hello, framing) =
184 crate::hello::exchange(&mut stream, config.codec.max_message_size).await?;
185 Self::build(stream, Some(keep_alive), server_hello, framing, config)
186 }
187
188 pub async fn from_stream<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
190 stream: S,
191 ) -> crate::Result<Self> {
192 Self::from_stream_with_config(stream, SessionConfig::default()).await
193 }
194
195 pub async fn from_stream_with_config<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
197 mut stream: S,
198 config: SessionConfig,
199 ) -> crate::Result<Self> {
200 let (server_hello, framing) =
201 crate::hello::exchange(&mut stream, config.codec.max_message_size).await?;
202 let boxed: NetconfStream = Box::new(stream);
203 Self::build(boxed, None, server_hello, framing, config)
204 }
205
206 fn build(
207 stream: NetconfStream,
208 keep_alive: Option<Box<dyn std::any::Any + Send>>,
209 server_hello: ServerHello,
210 framing: FramingMode,
211 config: SessionConfig,
212 ) -> crate::Result<Self> {
213 let (read_half, write_half) = tokio::io::split(stream);
214
215 let read_codec = NetconfCodec::new(framing, config.codec);
216 let write_codec = NetconfCodec::new(framing, config.codec);
217 let reader = FramedRead::new(read_half, read_codec);
218
219 let inner = Arc::new(SessionInner {
220 pending: Mutex::new(HashMap::new()),
221 state: AtomicU8::new(SessionState::Ready as u8),
222 });
223
224 let reader_inner = Arc::clone(&inner);
225 let reader_handle = tokio::spawn(reader_loop(reader, reader_inner));
226
227 Ok(Self {
228 writer: write_half,
229 write_codec,
230 inner,
231 server_hello,
232 framing,
233 _keep_alive: keep_alive,
234 _reader_handle: reader_handle,
235 })
236 }
237
238 pub fn session_id(&self) -> u32 {
239 self.server_hello.session_id
240 }
241
242 pub fn server_capabilities(&self) -> &[String] {
243 &self.server_hello.capabilities
244 }
245
246 pub fn framing_mode(&self) -> FramingMode {
247 self.framing
248 }
249
250 pub fn state(&self) -> SessionState {
251 self.inner.state()
252 }
253
254 fn check_state(&self) -> crate::Result<()> {
255 let state = self.inner.state();
256 if state != SessionState::Ready {
257 return Err(crate::Error::InvalidState(state.to_string()));
258 }
259 Ok(())
260 }
261
262 async fn send_encoded(&mut self, xml: &str) -> crate::Result<()> {
271 let mut buf = BytesMut::new();
272 self.write_codec
273 .encode(Bytes::from(xml.to_string()), &mut buf)?;
274 self.writer.write_all(&buf).await?;
275 self.writer.flush().await?;
276 Ok(())
277 }
278
279 pub async fn rpc_send(&mut self, inner_xml: &str) -> crate::Result<RpcFuture> {
286 self.check_state()?;
287 let (msg_id, xml) = message::build_rpc(inner_xml);
288 let (tx, rx) = tokio::sync::oneshot::channel();
289
290 self.inner.pending.lock().unwrap().insert(msg_id, tx);
291
292 if let Err(e) = self.send_encoded(&xml).await {
293 self.inner.pending.lock().unwrap().remove(&msg_id);
294 return Err(e);
295 }
296 Ok(RpcFuture { rx, msg_id })
297 }
298
299 pub async fn rpc_raw(&mut self, inner_xml: &str) -> crate::Result<RpcReply> {
301 let future = self.rpc_send(inner_xml).await?;
302 future.response().await
303 }
304
305 async fn rpc_send_unchecked(&mut self, inner_xml: &str) -> crate::Result<RpcFuture> {
307 let (msg_id, xml) = message::build_rpc(inner_xml);
308 let (tx, rx) = tokio::sync::oneshot::channel();
309
310 self.inner.pending.lock().unwrap().insert(msg_id, tx);
311
312 if let Err(e) = self.send_encoded(&xml).await {
313 self.inner.pending.lock().unwrap().remove(&msg_id);
314 return Err(e);
315 }
316
317 Ok(RpcFuture { rx, msg_id })
318 }
319
320 pub async fn get_config(
322 &mut self,
323 source: Datastore,
324 filter: Option<&str>,
325 ) -> crate::Result<String> {
326 let filter_xml = match filter {
327 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
328 None => String::new(),
329 };
330 let inner = format!(
331 "<get-config><source>{}</source>{filter_xml}</get-config>",
332 source.as_xml()
333 );
334 let reply = self.rpc_raw(&inner).await?;
335 reply_to_data(reply)
336 }
337
338 pub async fn get_config_payload(
344 &mut self,
345 source: Datastore,
346 filter: Option<&str>,
347 ) -> crate::Result<DataPayload> {
348 let filter_xml = match filter {
349 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
350 None => String::new(),
351 };
352 let inner = format!(
353 "<get-config><source>{}</source>{filter_xml}</get-config>",
354 source.as_xml()
355 );
356 let reply = self.rpc_raw(&inner).await?;
357 reply.into_data()
358 }
359
360 pub async fn get(&mut self, filter: Option<&str>) -> crate::Result<String> {
362 let filter_xml = match filter {
363 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
364 None => String::new(),
365 };
366 let inner = format!("<get>{filter_xml}</get>");
367 let reply = self.rpc_raw(&inner).await?;
368 reply_to_data(reply)
369 }
370
371 pub async fn get_payload(&mut self, filter: Option<&str>) -> crate::Result<DataPayload> {
376 let filter_xml = match filter {
377 Some(f) => format!(r#"<filter type="subtree">{f}</filter>"#),
378 None => String::new(),
379 };
380 let inner = format!("<get>{filter_xml}</get>");
381 let reply = self.rpc_raw(&inner).await?;
382 reply.into_data()
383 }
384
385 pub async fn edit_config(&mut self, target: Datastore, config: &str) -> crate::Result<()> {
387 let inner = format!(
388 "<edit-config><target>{}</target><config>{config}</config></edit-config>",
389 target.as_xml()
390 );
391 let reply = self.rpc_raw(&inner).await?;
392 reply_to_ok(reply)
393 }
394
395 pub async fn lock(&mut self, target: Datastore) -> crate::Result<()> {
397 let inner = format!("<lock><target>{}</target></lock>", target.as_xml());
398 let reply = self.rpc_raw(&inner).await?;
399 reply_to_ok(reply)
400 }
401
402 pub async fn unlock(&mut self, target: Datastore) -> crate::Result<()> {
404 let inner = format!("<unlock><target>{}</target></unlock>", target.as_xml());
405 let reply = self.rpc_raw(&inner).await?;
406 reply_to_ok(reply)
407 }
408
409 pub async fn commit(&mut self) -> crate::Result<()> {
411 let reply = self.rpc_raw("<commit/>").await?;
412 reply_to_ok(reply)
413 }
414
415 pub async fn close_session(&mut self) -> crate::Result<()> {
417 self.check_state()?;
418 self.inner.set_state(SessionState::Closing);
419 let result = self.rpc_send_unchecked("<close-session/>").await;
420 match result {
421 Ok(future) => {
422 let reply = future.response().await;
423 self.inner.set_state(SessionState::Closed);
424 reply_to_ok(reply?)
425 }
426 Err(e) => {
427 self.inner.set_state(SessionState::Closed);
428 Err(e)
429 }
430 }
431 }
432
433 pub async fn kill_session(&mut self, session_id: u32) -> crate::Result<()> {
435 let inner = format!("<kill-session><session-id>{session_id}</session-id></kill-session>");
436 let reply = self.rpc_raw(&inner).await?;
437 reply_to_ok(reply)
438 }
439}
440
441async fn reader_loop(
445 mut reader: FramedRead<ReadHalf<NetconfStream>, NetconfCodec>,
446 inner: Arc<SessionInner>,
447) {
448 while let Some(result) = reader.next().await {
452 match result {
453 Ok(bytes) => match message::classify_message(bytes) {
455 Ok(ServerMessage::RpcReply(reply)) => {
456 let tx = {
460 let mut pending = inner.pending.lock().unwrap();
461 pending.remove(&reply.message_id)
462 };
463 if let Some(tx) = tx {
464 let _ = tx.send(Ok(reply));
466 } else {
467 log::warn!("received reply for unknown message-id {}", reply.message_id);
468 }
469 }
470 Err(e) => {
471 log::warn!("failed to classify message: {e}");
472 }
473 },
474 Err(e) => {
477 log::error!("reader error: {e}");
478 let mut pending = inner.pending.lock().unwrap();
479 for (_, tx) in pending.drain() {
480 let _ = tx.send(Err(crate::Error::SessionClosed));
481 }
482 break;
483 }
484 }
485 }
486 {
489 let mut pending = inner.pending.lock().unwrap();
490 for (_, tx) in pending.drain() {
491 let _ = tx.send(Err(crate::Error::SessionClosed));
492 }
493 }
494 inner.set_state(SessionState::Closed);
497}
498
499fn reply_to_data(reply: RpcReply) -> crate::Result<String> {
500 match reply.body {
501 RpcReplyBody::Data(payload) => Ok(payload.into_string()),
502 RpcReplyBody::Ok => Ok(String::new()),
503 RpcReplyBody::Error(errors) => Err(crate::Error::Rpc {
504 message_id: reply.message_id,
505 error: errors
506 .first()
507 .map(|e| e.error_message.clone())
508 .unwrap_or_default(),
509 }),
510 }
511}
512
513fn reply_to_ok(reply: RpcReply) -> crate::Result<()> {
514 match reply.body {
515 RpcReplyBody::Ok => Ok(()),
516 RpcReplyBody::Data(_) => Ok(()),
517 RpcReplyBody::Error(errors) => Err(crate::Error::Rpc {
518 message_id: reply.message_id,
519 error: errors
520 .first()
521 .map(|e| e.error_message.clone())
522 .unwrap_or_default(),
523 }),
524 }
525}