use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use anvil_ssh::proxy::JumpHost;
use anvil_ssh::{AnvilConfig, AnvilSession, StrictHostKeyChecking};
use russh::keys::ssh_key::rand_core::OsRng;
use russh::keys::{Algorithm, HashAlg, PrivateKey};
use russh::server::{Auth, Msg, Server as _, Session};
use russh::{server, ChannelId};
use tokio::net::{TcpListener, TcpStream};
#[derive(Clone)]
struct TestServer;
impl server::Server for TestServer {
type Handler = TestSession;
fn new_client(&mut self, _: Option<SocketAddr>) -> Self::Handler {
TestSession
}
fn handle_session_error(&mut self, _error: <Self::Handler as server::Handler>::Error) {
}
}
struct TestSession;
impl server::Handler for TestSession {
type Error = russh::Error;
async fn auth_password(&mut self, _user: &str, _password: &str) -> Result<Auth, Self::Error> {
Ok(Auth::Accept)
}
async fn auth_publickey(
&mut self,
_user: &str,
_public_key: &russh::keys::ssh_key::PublicKey,
) -> Result<Auth, Self::Error> {
Ok(Auth::Accept)
}
async fn auth_publickey_offered(
&mut self,
_user: &str,
_public_key: &russh::keys::ssh_key::PublicKey,
) -> Result<Auth, Self::Error> {
Ok(Auth::Accept)
}
async fn channel_open_session(
&mut self,
_channel: russh::Channel<Msg>,
_session: &mut Session,
) -> Result<bool, Self::Error> {
Ok(true)
}
async fn channel_open_direct_tcpip(
&mut self,
channel: russh::Channel<Msg>,
host_to_connect: &str,
port_to_connect: u32,
_originator_address: &str,
_originator_port: u32,
session: &mut Session,
) -> Result<bool, Self::Error> {
let port_u16 = u16::try_from(port_to_connect).map_err(|_truncated| {
russh::Error::from(std::io::Error::other(format!(
"test fixture: direct-tcpip port {port_to_connect} out of u16 range",
)))
})?;
let upstream = TcpStream::connect((host_to_connect, port_u16))
.await
.map_err(russh::Error::from)?;
let session_handle = session.handle();
let channel_id = channel.id();
tokio::spawn(async move {
relay_channel_to_tcp(channel, upstream, session_handle, channel_id).await;
});
Ok(true)
}
}
async fn relay_channel_to_tcp(
channel: russh::Channel<Msg>,
tcp: TcpStream,
session: server::Handle,
channel_id: ChannelId,
) {
use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
let (mut tcp_read, mut tcp_write) = tcp.into_split();
let mut writer = channel.make_writer();
let (mut read_half, _write_half) = channel.split();
let read_to_tcp = async {
loop {
let Some(msg) = read_half.wait().await else {
break;
};
if let russh::ChannelMsg::Data { data } = msg {
if tcp_write.write_all(&data).await.is_err() {
break;
}
} else if matches!(msg, russh::ChannelMsg::Eof | russh::ChannelMsg::Close) {
break;
}
}
let _ = session.eof(channel_id).await;
};
let tcp_to_channel = async {
let mut buf = vec![0_u8; 32 * 1024];
loop {
let Ok(n) = tcp_read.read(&mut buf).await else {
break;
};
if n == 0 {
break;
}
if writer.write_all(&buf[..n]).await.is_err() {
break;
}
}
};
tokio::join!(read_to_tcp, tcp_to_channel);
}
async fn spawn_server() -> (PrivateKey, u16) {
let host_key = PrivateKey::random(&mut OsRng, Algorithm::Ed25519).expect("ed25519 key");
let config = Arc::new(server::Config {
inactivity_timeout: Some(Duration::from_secs(30)),
auth_rejection_time: Duration::from_millis(50),
auth_rejection_time_initial: Some(Duration::from_millis(50)),
keys: vec![host_key.clone()],
..Default::default()
});
let listener = TcpListener::bind(("127.0.0.1", 0))
.await
.expect("bind loopback");
let port = listener.local_addr().expect("local_addr").port();
tokio::spawn(async move {
let mut server = TestServer;
let fut = server.run_on_socket(config, &listener);
let _ = fut.await;
});
(host_key, port)
}
fn fingerprint_for(key: &PrivateKey) -> String {
key.public_key().fingerprint(HashAlg::Sha256).to_string()
}
fn write_known_hosts(
path: &std::path::Path,
bastion_host: &str,
bastion_fp: &str,
target_host: &str,
target_fp: &str,
) {
let body = format!("{bastion_host} {bastion_fp}\n{target_host} {target_fp}\n",);
std::fs::write(path, body).expect("write known_hosts");
}
fn integration_enabled() -> bool {
std::env::var("GITWAY_INTEGRATION_TESTS").is_ok_and(|v| !v.is_empty())
}
#[tokio::test]
#[ignore = "GITWAY_INTEGRATION_TESTS=1 + --ignored required; spins up two russh::server instances"]
async fn two_hop_chain_succeeds() {
if !integration_enabled() {
return;
}
let (bastion_key, bastion_port) = spawn_server().await;
let (target_key, target_port) = spawn_server().await;
let bastion_host = format!("127.0.0.1:{bastion_port}");
let target_host = format!("127.0.0.1:{target_port}");
let tmp = tempfile::NamedTempFile::new().expect("temp file");
write_known_hosts(
tmp.path(),
&bastion_host,
&fingerprint_for(&bastion_key),
&target_host,
&fingerprint_for(&target_key),
);
let target_config = AnvilConfig::builder("127.0.0.1")
.port(target_port)
.username("user")
.strict_host_key_checking(StrictHostKeyChecking::No)
.custom_known_hosts(tmp.path().to_path_buf())
.build();
let jumps = vec![JumpHost {
host: "127.0.0.1".to_owned(),
port: bastion_port,
user: Some("user".to_owned()),
identity_files: Vec::new(),
}];
let session = AnvilSession::connect_via_jump_hosts(&target_config, &jumps).await;
assert!(
session.is_ok(),
"2-hop chain should succeed; err = {:?}",
session.err(),
);
let _ = session.expect("session").close().await;
}
#[tokio::test]
#[ignore = "GITWAY_INTEGRATION_TESTS=1 + --ignored required"]
async fn empty_jump_chain_is_rejected() {
if !integration_enabled() {
return;
}
let cfg = AnvilConfig::builder("127.0.0.1")
.port(22)
.strict_host_key_checking(StrictHostKeyChecking::No)
.build();
let err = AnvilSession::connect_via_jump_hosts(&cfg, &[])
.await
.expect_err("empty jumps should error");
let msg = format!("{err}");
assert!(
msg.contains("empty jump-host list"),
"expected empty-list message, got: {msg}",
);
}
#[allow(dead_code, reason = "helper used only when integration_enabled()")]
fn _silence_unused_paths(_: PathBuf) {}