1use std::{
2 convert::Infallible,
3 fmt::{self, Debug},
4 str::FromStr,
5};
6
7use async_trait::async_trait;
8use bytes::{Bytes, BytesMut};
9use memchr::memmem::Finder;
10use russh::{
11 client::{connect, Config},
12 ChannelMsg,
13};
14use russh_keys::key::PublicKey;
15use tokio::{net::ToSocketAddrs, sync::mpsc, task::JoinHandle};
16
17use super::{RecvHandle, SendHandle, Transport};
18use crate::{message::MARKER, Error};
19
20#[derive(Debug)]
21pub struct Ssh {
22 _task: JoinHandle<Result<(), Error>>,
23 in_queue: Receiver,
24 out_queue: Sender,
25}
26
27impl Ssh {
28 #[tracing::instrument]
29 pub(crate) async fn connect<A>(
30 addr: A,
31 username: String,
32 password: Password,
33 ) -> Result<Self, Error>
34 where
35 A: ToSocketAddrs + Debug + Send,
36 {
37 tracing::info!("attempting to establish SSH session");
38 let config = Config::default().into();
39 let handler = Handler::new();
40 let session = {
41 let mut session = connect(config, addr, handler).await?;
42 tracing::info!("ssh session established");
43 if !session
44 .authenticate_password(username.clone(), password.into_inner())
45 .await?
46 {
47 return Err(Error::Authentication { username });
48 };
49 tracing::info!("ssh authentication sucessful");
50 session
51 };
52 tracing::info!("attempting to open ssh channel");
53 let mut channel = session.channel_open_session().await?;
54 tracing::info!("ssh channel opened");
55 tracing::info!("requesting netconf ssh subsystem");
56 channel.request_subsystem(true, "netconf").await?;
57 tracing::info!("netconf ssh subsystem activated");
58 let (out_queue_tx, mut out_queue_rx) = mpsc::channel::<Bytes>(32);
59 let (in_queue_tx, in_queue_rx) = mpsc::channel(32);
60 let out_queue = Sender {
61 inner: out_queue_tx,
62 };
63 let in_queue = Receiver { inner: in_queue_rx };
64 let task = tokio::spawn(async move {
65 let mut in_buf = BytesMut::new();
66 let message_break = Finder::new(MARKER);
67 loop {
68 tokio::select! {
69 to_send = out_queue_rx.recv() => {
70 tracing::debug!("attempting to send message");
71 tracing::trace!(?to_send);
72 if let Some(data) = to_send {
75 channel.data(data.as_ref()).await?;
76 } else {
77 break;
78 };
79 tracing::trace!("message sent");
80 }
81 msg = channel.wait() => {
82 if let Some(msg) = msg {
83 tracing::debug!("processing received msg");
84 tracing::trace!(?msg);
85 match msg {
86 ChannelMsg::Data{ data } => {
87 tracing::debug!("got data on channel");
88 in_buf.extend_from_slice(&data);
89 tracing::debug!("checking for message break marker");
90 if let Some(index) = message_break.find(&in_buf) {
91 let end = index + MARKER.len();
92 tracing::info!("splitting {end} message bytes from input buffer");
93 let message = in_buf.split_to(end).freeze();
94 in_queue_tx.send(message).await?;
95 tracing::debug!("message data enqueued sucessfully");
96 };
97 }
98 ChannelMsg::Eof => {
99 tracing::info!("got eof, hanging up");
100 break;
101 }
102 _ => {
103 tracing::debug!("ignoring msg {msg:?}");
104 }
105 }
106 } else {
107 }
109 }
110 }
111 }
112 Ok(())
113 });
114 Ok(Self {
115 _task: task,
116 in_queue,
117 out_queue,
118 })
119 }
120}
121
122impl Transport for Ssh {
123 type SendHandle = Sender;
124 type RecvHandle = Receiver;
125
126 #[tracing::instrument(level = "debug")]
127 fn split(self) -> (Self::SendHandle, Self::RecvHandle) {
128 (self.out_queue, self.in_queue)
129 }
130}
131
132#[derive(Debug)]
133pub struct Sender {
134 inner: mpsc::Sender<Bytes>,
135}
136
137#[async_trait]
138impl SendHandle for Sender {
139 #[tracing::instrument(level = "trace")]
140 async fn send(&mut self, data: Bytes) -> Result<(), Error> {
141 Ok(self.inner.send(data).await?)
142 }
143}
144
145#[derive(Debug)]
146pub struct Receiver {
147 inner: mpsc::Receiver<Bytes>,
148}
149
150#[async_trait]
151impl RecvHandle for Receiver {
152 #[tracing::instrument(level = "trace")]
153 async fn recv(&mut self) -> Result<Bytes, Error> {
154 self.inner.recv().await.ok_or(Error::DequeueMessage)
155 }
156}
157
158#[derive(Debug)]
159struct Handler {}
160
161impl Handler {
162 const fn new() -> Self {
163 Self {}
164 }
165}
166
167#[async_trait]
168impl russh::client::Handler for Handler {
169 type Error = Error;
170
171 #[tracing::instrument(skip_all)]
173 async fn check_server_key(self, _: &PublicKey) -> Result<(Self, bool), Self::Error> {
174 tracing::info!("NOT checking server public key");
175 Ok((self, true))
176 }
177}
178
179#[derive(Clone)]
180pub struct Password(String);
181
182impl Password {
183 #[must_use]
184 pub fn into_inner(self) -> String {
185 self.0
186 }
187}
188
189impl Debug for Password {
190 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191 f.debug_tuple("Password").field(&"****").finish()
192 }
193}
194
195impl FromStr for Password {
196 type Err = Infallible;
197 fn from_str(s: &str) -> Result<Self, Self::Err> {
198 Ok(Self(s.to_string()))
199 }
200}