#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use log::debug;
use ssh_key::PrivateKey;
use tokio::net::TcpListener;
use crate::client::{Config, Handler, connect};
use crate::keys::PrivateKeyWithHashAlg;
use rand::rng;
use crate::server::{self, Auth, Handler as ServerHandler, Server, Session};
use crate::{ChannelId, SshId}; use crate::Error;
#[derive(Clone)]
struct TestServer {
clients: Arc<Mutex<HashMap<(usize, ChannelId), server::Handle>>>,
id: usize,
}
impl server::Server for TestServer {
type Handler = Self;
fn new_client(&mut self, _: Option<std::net::SocketAddr>) -> Self {
let s = self.clone();
self.id += 1;
s
}
}
impl ServerHandler for TestServer {
type Error = Error;
async fn channel_open_session(
&mut self,
channel: crate::channels::Channel<server::Msg>,
session: &mut Session,
) -> Result<bool, Self::Error> {
{
let mut clients = self.clients.lock().unwrap();
clients.insert((self.id, channel.id()), session.handle());
}
Ok(true)
}
async fn auth_publickey(
&mut self,
_: &str,
_: &ssh_key::PublicKey,
) -> Result<Auth, Self::Error> {
debug!("auth_publickey");
Ok(Auth::Accept)
}
async fn data(
&mut self,
channel: ChannelId,
data: &[u8],
session: &mut Session,
) -> Result<(), Self::Error> {
debug!("server received data: {:?}", std::str::from_utf8(data));
session.data(channel, data.to_vec())?;
Ok(())
}
}
struct Client {}
impl Handler for Client {
type Error = Error;
async fn check_server_key(&mut self, _: &ssh_key::PublicKey) -> Result<bool, Self::Error> {
Ok(true)
}
}
#[tokio::test]
async fn test_client_connects_to_protocol_1_99() {
let _ = env_logger::try_init();
let client_key = PrivateKey::random(&mut rng(), ssh_key::Algorithm::Ed25519).unwrap();
let mut config = server::Config::default();
config.auth_rejection_time = std::time::Duration::from_secs(1);
config.server_id = SshId::Standard("SSH-1.99-CustomServer_1.0".into());
config.inactivity_timeout = None;
config
.keys
.push(PrivateKey::random(&mut rng(), ssh_key::Algorithm::Ed25519).unwrap());
let config = Arc::new(config);
let mut server = TestServer {
clients: Arc::new(Mutex::new(HashMap::new())),
id: 0,
};
let socket = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = socket.local_addr().unwrap();
tokio::spawn(async move {
let (socket, _) = socket.accept().await.unwrap();
let server_handler = server.new_client(None);
server::run_stream(config, socket, server_handler)
.await
.unwrap();
});
println!("Server listening on {addr}");
let client_config = Arc::new(Config::default());
let mut session = connect(client_config, addr, Client {}).await.unwrap();
let auth_result = session
.authenticate_publickey(
std::env::var("USER").unwrap_or("user".to_string()),
PrivateKeyWithHashAlg::new(
Arc::new(client_key),
session.best_supported_rsa_hash().await.unwrap().flatten(),
),
)
.await
.unwrap();
assert!(auth_result.success());
let mut channel = session.channel_open_session().await.unwrap();
let test_data = b"Hello, 1.99 protocol server!";
channel.data(&test_data[..]).await.unwrap();
let msg = channel.wait().await.unwrap();
match msg {
crate::channels::ChannelMsg::Data { data: msg_data } => {
assert_eq!(test_data.as_slice(), &msg_data[..]);
}
msg => panic!("Unexpected message {msg:?}"),
}
}
}