netconf/transport/
ssh.rs

1use std::{
2    convert::Infallible,
3    fmt::{self, Debug},
4    str::FromStr,
5};
6
7use async_trait::async_trait;
8use bytes::{Bytes, BytesMut};
9use memchr::memmem::Finder;
10use russh::{
11    client::{connect, Config},
12    ChannelMsg,
13};
14use russh_keys::key::PublicKey;
15use tokio::{net::ToSocketAddrs, sync::mpsc, task::JoinHandle};
16
17use super::{RecvHandle, SendHandle, Transport};
18use crate::{message::MARKER, Error};
19
20#[derive(Debug)]
21pub struct Ssh {
22    _task: JoinHandle<Result<(), Error>>,
23    in_queue: Receiver,
24    out_queue: Sender,
25}
26
27impl Ssh {
28    #[tracing::instrument]
29    pub(crate) async fn connect<A>(
30        addr: A,
31        username: String,
32        password: Password,
33    ) -> Result<Self, Error>
34    where
35        A: ToSocketAddrs + Debug + Send,
36    {
37        tracing::info!("attempting to establish SSH session");
38        let config = Config::default().into();
39        let handler = Handler::new();
40        let session = {
41            let mut session = connect(config, addr, handler).await?;
42            tracing::info!("ssh session established");
43            if !session
44                .authenticate_password(username.clone(), password.into_inner())
45                .await?
46            {
47                return Err(Error::Authentication { username });
48            };
49            tracing::info!("ssh authentication sucessful");
50            session
51        };
52        tracing::info!("attempting to open ssh channel");
53        let mut channel = session.channel_open_session().await?;
54        tracing::info!("ssh channel opened");
55        tracing::info!("requesting netconf ssh subsystem");
56        channel.request_subsystem(true, "netconf").await?;
57        tracing::info!("netconf ssh subsystem activated");
58        let (out_queue_tx, mut out_queue_rx) = mpsc::channel::<Bytes>(32);
59        let (in_queue_tx, in_queue_rx) = mpsc::channel(32);
60        let out_queue = Sender {
61            inner: out_queue_tx,
62        };
63        let in_queue = Receiver { inner: in_queue_rx };
64        let task = tokio::spawn(async move {
65            let mut in_buf = BytesMut::new();
66            let message_break = Finder::new(MARKER);
67            loop {
68                tokio::select! {
69                    to_send = out_queue_rx.recv() => {
70                        tracing::debug!("attempting to send message");
71                        tracing::trace!(?to_send);
72                        // TODO:
73                        // we should probably handle this error?
74                        if let Some(data) = to_send {
75                            channel.data(data.as_ref()).await?;
76                        } else {
77                            break;
78                        };
79                        tracing::trace!("message sent");
80                    }
81                    msg = channel.wait() => {
82                        if let Some(msg) = msg {
83                            tracing::debug!("processing received msg");
84                            tracing::trace!(?msg);
85                            match msg {
86                                ChannelMsg::Data{ data } => {
87                                    tracing::debug!("got data on channel");
88                                    in_buf.extend_from_slice(&data);
89                                    tracing::debug!("checking for message break marker");
90                                    if let Some(index) = message_break.find(&in_buf) {
91                                        let end  = index + MARKER.len();
92                                        tracing::info!("splitting {end} message bytes from input buffer");
93                                        let message = in_buf.split_to(end).freeze();
94                                        in_queue_tx.send(message).await?;
95                                        tracing::debug!("message data enqueued sucessfully");
96                                    };
97                                }
98                                ChannelMsg::Eof => {
99                                    tracing::info!("got eof, hanging up");
100                                    break;
101                                }
102                                _ => {
103                                    tracing::debug!("ignoring msg {msg:?}");
104                                }
105                            }
106                        } else {
107                            // TODO: what should we do if it's None?
108                        }
109                    }
110                }
111            }
112            Ok(())
113        });
114        Ok(Self {
115            _task: task,
116            in_queue,
117            out_queue,
118        })
119    }
120}
121
122impl Transport for Ssh {
123    type SendHandle = Sender;
124    type RecvHandle = Receiver;
125
126    #[tracing::instrument(level = "debug")]
127    fn split(self) -> (Self::SendHandle, Self::RecvHandle) {
128        (self.out_queue, self.in_queue)
129    }
130}
131
132#[derive(Debug)]
133pub struct Sender {
134    inner: mpsc::Sender<Bytes>,
135}
136
137#[async_trait]
138impl SendHandle for Sender {
139    #[tracing::instrument(level = "trace")]
140    async fn send(&mut self, data: Bytes) -> Result<(), Error> {
141        Ok(self.inner.send(data).await?)
142    }
143}
144
145#[derive(Debug)]
146pub struct Receiver {
147    inner: mpsc::Receiver<Bytes>,
148}
149
150#[async_trait]
151impl RecvHandle for Receiver {
152    #[tracing::instrument(level = "trace")]
153    async fn recv(&mut self) -> Result<Bytes, Error> {
154        self.inner.recv().await.ok_or(Error::DequeueMessage)
155    }
156}
157
158#[derive(Debug)]
159struct Handler {}
160
161impl Handler {
162    const fn new() -> Self {
163        Self {}
164    }
165}
166
167#[async_trait]
168impl russh::client::Handler for Handler {
169    type Error = Error;
170
171    // TODO
172    #[tracing::instrument(skip_all)]
173    async fn check_server_key(self, _: &PublicKey) -> Result<(Self, bool), Self::Error> {
174        tracing::info!("NOT checking server public key");
175        Ok((self, true))
176    }
177}
178
179#[derive(Clone)]
180pub struct Password(String);
181
182impl Password {
183    #[must_use]
184    pub fn into_inner(self) -> String {
185        self.0
186    }
187}
188
189impl Debug for Password {
190    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191        f.debug_tuple("Password").field(&"****").finish()
192    }
193}
194
195impl FromStr for Password {
196    type Err = Infallible;
197    fn from_str(s: &str) -> Result<Self, Self::Err> {
198        Ok(Self(s.to_string()))
199    }
200}