use crate::config::{ConnectionMode, ResolvedServer};
use anyhow::Result;
use std::sync::mpsc;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ScpDirection {
Upload,
Download,
}
impl ScpDirection {
pub fn label(&self) -> &'static str {
match self {
Self::Upload => "Upload",
Self::Download => "Download",
}
}
}
#[derive(Debug)]
pub enum ScpEvent {
FileSize(u64),
Progress(u8),
Done(bool),
Error(String),
}
#[cfg(unix)]
pub fn spawn_sftp(
server: &ResolvedServer,
mode: ConnectionMode,
direction: ScpDirection,
local: &str,
remote: &str,
) -> Result<mpsc::Receiver<ScpEvent>> {
if mode == ConnectionMode::Wallix {
anyhow::bail!("SFTP non disponible en mode Wallix");
}
let server = server.clone();
let local = shellexpand::tilde(local).into_owned();
let remote = remote.to_string();
let (tx, rx) = mpsc::channel::<ScpEvent>();
std::thread::spawn(move || {
let result = transfer_inner(&server, mode, &direction, &local, &remote, &tx);
match result {
Ok(()) => {
let _ = tx.send(ScpEvent::Done(true));
}
Err(e) => {
let _ = tx.send(ScpEvent::Error(e.to_string()));
}
}
});
Ok(rx)
}
#[cfg(not(unix))]
pub fn spawn_sftp(
_server: &ResolvedServer,
_mode: ConnectionMode,
_direction: ScpDirection,
_local: &str,
_remote: &str,
) -> Result<mpsc::Receiver<ScpEvent>> {
anyhow::bail!("SFTP non disponible sur cette plateforme")
}
#[cfg(unix)]
fn transfer_inner(
server: &ResolvedServer,
mode: ConnectionMode,
direction: &ScpDirection,
local: &str,
remote: &str,
tx: &mpsc::Sender<ScpEvent>,
) -> Result<()> {
use std::fs::File;
use std::path::{Path, PathBuf};
let sess = open_session(server, mode)?;
let remote_path: PathBuf = {
let sftp = sess.sftp()?;
let raw = resolve_remote_path(&sftp, remote);
if remote.ends_with('/') || remote.ends_with('\\') || raw.as_os_str().is_empty() {
let filename = Path::new(local)
.file_name()
.unwrap_or_else(|| std::ffi::OsStr::new("transfer"));
raw.join(filename)
} else {
match sftp.stat(&raw) {
Ok(stat) if stat.is_dir() => {
let filename = Path::new(local)
.file_name()
.unwrap_or_else(|| std::ffi::OsStr::new("transfer"));
raw.join(filename)
}
_ => raw,
}
}
};
match direction {
ScpDirection::Upload => {
let local_path = Path::new(local);
let file_size = std::fs::metadata(local_path).map(|m| m.len()).unwrap_or(0);
let _ = tx.send(ScpEvent::FileSize(file_size));
let mut src = File::open(local_path)?;
let mut channel = sess.scp_send(&remote_path, 0o644, file_size, None)?;
copy_with_progress(&mut src, &mut channel, file_size, tx)?;
channel.send_eof()?;
channel.wait_eof()?;
channel.close()?;
channel.wait_close()?;
}
ScpDirection::Download => {
let local_path = {
let p = Path::new(local);
if local.ends_with('/') || local.ends_with('\\') || p.is_dir() {
let filename = remote_path
.file_name()
.unwrap_or_else(|| std::ffi::OsStr::new("download"));
p.join(filename)
} else {
p.to_path_buf()
}
};
let (mut channel, scp_stat) = sess.scp_recv(&remote_path)?;
let file_size = scp_stat.size();
let _ = tx.send(ScpEvent::FileSize(file_size));
let mut dst = File::create(&local_path)?;
copy_with_progress(&mut channel, &mut dst, file_size, tx)?;
}
}
Ok(())
}
#[cfg(unix)]
fn open_session(server: &ResolvedServer, mode: ConnectionMode) -> Result<ssh2::Session> {
use std::net::TcpStream;
match mode {
ConnectionMode::Direct => {
let (host, port) = resolve_host_port(server);
let tcp = TcpStream::connect(format!("{}:{}", host, port))?;
let mut sess = ssh2::Session::new()?;
sess.set_tcp_stream(tcp);
sess.handshake()?;
auth_session(&mut sess, &server.user, &server.ssh_key)?;
Ok(sess)
}
ConnectionMode::Jump => {
let jump_str = server
.jump_host
.as_deref()
.filter(|s| !s.is_empty())
.ok_or_else(|| anyhow::anyhow!("Jump host non configuré pour ce serveur"))?;
open_session_via_jump(jump_str, server)
}
ConnectionMode::Wallix => anyhow::bail!("SFTP non disponible en mode Wallix"),
}
}
#[cfg(unix)]
fn resolve_host_port(server: &ResolvedServer) -> (String, u16) {
if let Some((h, p)) = server.host.split_once(':') {
(h.to_string(), p.parse::<u16>().unwrap_or(server.port))
} else {
(server.host.clone(), server.port)
}
}
#[cfg(unix)]
fn auth_session(sess: &mut ssh2::Session, username: &str, ssh_key: &str) -> Result<()> {
use std::path::PathBuf;
if let Ok(mut agent) = sess.agent()
&& agent.connect().is_ok()
&& agent.list_identities().is_ok()
{
let identities = agent.identities().unwrap_or_default();
for identity in &identities {
if agent.userauth(username, identity).is_ok() && sess.authenticated() {
return Ok(());
}
}
}
if !ssh_key.is_empty() {
let expanded = shellexpand::tilde(ssh_key).to_string();
let key_path = PathBuf::from(&expanded);
if sess
.userauth_pubkey_file(username, None, &key_path, None)
.is_ok()
&& sess.authenticated()
{
return Ok(());
}
}
for key in &["~/.ssh/id_ed25519", "~/.ssh/id_rsa", "~/.ssh/id_ecdsa"] {
let expanded = shellexpand::tilde(key).to_string();
let key_path = PathBuf::from(&expanded);
if key_path.exists()
&& sess
.userauth_pubkey_file(username, None, &key_path, None)
.is_ok()
&& sess.authenticated()
{
return Ok(());
}
}
anyhow::bail!(
"Authentification SSH échouée pour {} (agent SSH + clés par défaut épuisés)",
username
)
}
#[cfg(unix)]
fn open_session_via_jump(jump_str: &str, server: &ResolvedServer) -> Result<ssh2::Session> {
use std::net::TcpStream;
use std::os::unix::io::FromRawFd;
use std::os::unix::net::UnixStream;
let first = jump_str.split(',').next().unwrap_or(jump_str);
let (jump_user, jump_host_port) = match first.split_once('@') {
Some((u, hp)) => (u, hp),
None => (server.user.as_str(), first),
};
let (jump_host, jump_port) = match jump_host_port.split_once(':') {
Some((h, p)) => (h, p.parse::<u16>().unwrap_or(22)),
None => (jump_host_port, 22u16),
};
let jump_tcp = TcpStream::connect(format!("{}:{}", jump_host, jump_port))?;
let jump_raw_fd = {
use std::os::unix::io::AsRawFd;
jump_tcp.as_raw_fd()
};
let mut jump_sess = ssh2::Session::new()?;
jump_sess.set_tcp_stream(jump_tcp);
jump_sess.set_compress(true);
jump_sess.handshake()?;
auth_session(&mut jump_sess, jump_user, &server.ssh_key)?;
let (target_host, target_port) = resolve_host_port(server);
let channel = jump_sess.channel_direct_tcpip(&target_host, target_port, None)?;
let mut pair_fds: [libc::c_int; 2] = [-1; 2];
if unsafe { libc::socketpair(libc::AF_UNIX, libc::SOCK_STREAM, 0, pair_fds.as_mut_ptr()) } != 0
{
anyhow::bail!("socketpair: {}", std::io::Error::last_os_error());
}
for &fd in &pair_fds {
unsafe { libc::fcntl(fd, libc::F_SETFD, libc::FD_CLOEXEC) };
}
let local_stream = unsafe { UnixStream::from_raw_fd(pair_fds[0]) };
let bridge_stream = unsafe { UnixStream::from_raw_fd(pair_fds[1]) };
std::thread::spawn(move || {
bridge_bidirectional(jump_sess, channel, bridge_stream, jump_raw_fd);
});
let mut target_sess = ssh2::Session::new()?;
target_sess.set_tcp_stream(local_stream);
target_sess.handshake()?;
auth_session(&mut target_sess, &server.user, &server.ssh_key)?;
Ok(target_sess)
}
#[cfg(unix)]
fn bridge_bidirectional(
sess: ssh2::Session,
mut channel: ssh2::Channel,
mut stream: std::os::unix::net::UnixStream,
jump_raw_fd: libc::c_int,
) {
use std::io::{ErrorKind, Read, Write};
use std::os::unix::io::AsRawFd;
let stream_fd = stream.as_raw_fd();
sess.set_blocking(false);
let mut buf = vec![0u8; 32 * 1024];
'outer: loop {
let mut fds = [
libc::pollfd {
fd: jump_raw_fd,
events: libc::POLLIN,
revents: 0,
},
libc::pollfd {
fd: stream_fd,
events: libc::POLLIN,
revents: 0,
},
];
let ret = unsafe { libc::poll(fds.as_mut_ptr(), 2, 50) };
if ret < 0 {
break; }
loop {
match channel.read(&mut buf) {
Ok(0) => break 'outer, Ok(n) => {
if stream.write_all(&buf[..n]).is_err() {
break 'outer;
}
}
Err(ref e) if e.kind() == ErrorKind::WouldBlock => break,
Err(_) => break 'outer,
}
}
if fds[1].revents & libc::POLLIN != 0 {
match stream.read(&mut buf) {
Ok(0) => break, Ok(n) => {
let mut pos = 0;
while pos < n {
match channel.write(&buf[pos..n]) {
Ok(k) if k > 0 => pos += k,
Ok(_) => std::thread::yield_now(),
Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
std::thread::yield_now();
}
Err(_) => break 'outer,
}
}
}
Err(ref e) if e.kind() == ErrorKind::WouldBlock => {} Err(_) => break,
}
}
if channel.eof() {
break;
}
}
let _ = channel.close();
}
#[cfg(unix)]
fn resolve_remote_path(sftp: &ssh2::Sftp, remote: &str) -> std::path::PathBuf {
use std::path::PathBuf;
if remote.starts_with('~') {
let home = sftp
.realpath(std::path::Path::new("."))
.unwrap_or_else(|_| PathBuf::from(""));
if remote == "~" {
home
} else {
let tail = remote
.trim_start_matches("~/")
.trim_start_matches('~')
.trim_start_matches('/');
home.join(tail)
}
} else {
PathBuf::from(remote)
}
}
#[cfg(unix)]
fn copy_with_progress(
src: &mut dyn std::io::Read,
dst: &mut dyn std::io::Write,
total: u64,
tx: &mpsc::Sender<ScpEvent>,
) -> Result<()> {
let mut buf = vec![0u8; 262144];
let mut transferred: u64 = 0;
let mut last_pct: u8 = 0;
loop {
let n = src.read(&mut buf)?;
if n == 0 {
break;
}
dst.write_all(&buf[..n])?;
transferred += n as u64;
let pct = if total > 0 {
((transferred * 100) / total).min(100) as u8
} else {
0
};
if pct != last_pct {
if tx.send(ScpEvent::Progress(pct)).is_err() {
anyhow::bail!("transfert annulé");
}
last_pct = pct;
}
}
if last_pct < 100 {
let _ = tx.send(ScpEvent::Progress(100));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::ConnectionMode;
#[test]
#[cfg(unix)]
fn wallix_returns_error_immediately() {
let server = base_server();
let result = spawn_sftp(
&server,
ConnectionMode::Wallix,
ScpDirection::Upload,
"/tmp/file",
"/remote/file",
);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Wallix"));
}
#[test]
#[cfg(unix)]
fn unreachable_server_emits_error_event() {
let mut server = base_server();
server.host = "127.0.0.1".into();
server.port = 1;
let rx = spawn_sftp(
&server,
ConnectionMode::Direct,
ScpDirection::Upload,
"/tmp/file",
"/remote/file",
)
.expect("spawn_sftp ne doit pas échouer immédiatement");
let event = rx.recv_timeout(std::time::Duration::from_secs(5));
assert!(
matches!(event, Ok(ScpEvent::Error(_))),
"attendu ScpEvent::Error, obtenu {:?}",
event
);
}
fn base_server() -> ResolvedServer {
ResolvedServer {
namespace: String::new(),
group_name: String::new(),
env_name: String::new(),
name: "test".into(),
host: "127.0.0.1".into(),
user: "test".into(),
port: 22,
ssh_key: String::new(),
ssh_options: vec![],
default_mode: ConnectionMode::Direct,
jump_host: None,
bastion_host: None,
bastion_user: None,
bastion_template: String::new(),
use_system_ssh_config: false,
probe_filesystems: vec![],
tunnels: vec![],
tags: vec![],
control_master: false,
control_path: String::new(),
control_persist: "10m".to_string(),
pre_connect_hook: None,
post_disconnect_hook: None,
hook_timeout_secs: 5,
wallix_group: None,
wallix_account: "default".to_string(),
wallix_protocol: "SSH".to_string(),
wallix_auto_select: true,
wallix_fail_if_menu_match_error: true,
wallix_selection_timeout_secs: 8,
}
}
}