use std::{
fs::File,
io,
net::TcpStream,
path::{Path, PathBuf},
};
use chrono::Utc;
use ssh2::{RenameFlags, Session, Sftp};
use tracing::{debug, info, instrument};
use crate::{
progress::{ProgressDisplay, ProgressReader},
secret::Secret,
task::{Mode, Status},
};
pub struct ClientBuilder {
host: String,
port: u16,
username: String,
key_path: Option<PathBuf>,
key_password: Option<Secret>,
two_factor_callback: Option<Box<dyn Fn() -> String + Send + Sync>>,
}
impl std::fmt::Debug for ClientBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ClientBuilder")
.field("host", &self.host)
.field("port", &self.port)
.field("username", &self.username)
.field("key_path", &self.key_path)
.field("key_password", &self.key_password)
.finish()
}
}
impl Default for ClientBuilder {
fn default() -> Self {
Self {
host: "localhost".into(),
port: 22,
username: "user".into(),
key_path: None,
key_password: None,
two_factor_callback: None,
}
}
}
impl ClientBuilder {
pub fn new() -> Self {
Default::default()
}
pub fn build(self) -> Client {
Client::new(self)
}
pub fn host(mut self, host: impl Into<String>) -> Self {
self.host = host.into();
self
}
pub fn port(mut self, port: u16) -> Self {
self.port = port;
self
}
pub fn username(mut self, username: impl Into<String>) -> Self {
self.username = username.into();
self
}
pub fn key_path(mut self, key_path: Option<impl Into<PathBuf>>) -> Self {
self.key_path = key_path.map(Into::into);
self
}
pub fn key_password(mut self, key_password: Option<impl Into<Secret>>) -> Self {
self.key_password = key_password.map(Into::into);
self
}
pub fn two_factor_callback<F: Fn() -> String + Send + Sync + 'static>(
mut self,
two_factor_callback: Option<F>,
) -> Self {
self.two_factor_callback =
two_factor_callback.map(|f| -> Box<dyn Fn() -> String + Send + Sync> { Box::new(f) });
self
}
}
pub struct Client {
builder: ClientBuilder,
}
impl std::fmt::Debug for Client {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Client")
.field(
"host_url",
&format!(
"sftp://{}@{}:{}",
self.builder.username, self.builder.host, self.builder.port
),
)
.finish()
}
}
impl Client {
pub fn builder() -> ClientBuilder {
ClientBuilder::new()
}
fn new(builder: ClientBuilder) -> Self {
Self { builder }
}
pub(crate) fn connect(&self) -> Result<ClientConnected, error::ConnectionError> {
Ok(ClientConnected {
inner: connect(&self.builder)?,
host_url: format!("sftp://{}:{}", self.builder.host, self.builder.port),
})
}
}
pub(crate) struct ClientConnected {
pub(crate) inner: Sftp,
host_url: String,
}
impl ClientConnected {
pub(crate) fn get_url(&self, path: &Path) -> String {
format!("{}/{}", self.host_url, path.to_string_lossy())
}
}
fn make_session(host: &str, port: u16) -> Result<Session, error::ConnectionError> {
let tcp = TcpStream::connect(format!("{host}:{port}"))?;
let mut session = Session::new()?;
session.set_tcp_stream(tcp);
session.handshake()?;
Ok(session)
}
pub mod error {
#[derive(Debug)]
pub struct CredentialsError(pub std::str::Utf8Error);
impl std::fmt::Display for CredentialsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "password conversion to utf-8 failed")
}
}
impl std::error::Error for CredentialsError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&self.0)
}
}
#[derive(Debug)]
pub enum ConnectionError {
Credentials(CredentialsError),
Io(std::io::Error),
Session(ssh2::Error),
AgentAuth,
UnsupportedMethods(String),
Missing2ndFactor,
}
impl std::fmt::Display for ConnectionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "sftp connection error")?;
match self {
Self::AgentAuth => {
write!(f, ": Agent authentication failed")
}
Self::UnsupportedMethods(methods) => write!(
f,
": The following method(s) are not supported (client side) during multi factor authentication: {methods}"
),
Self::Missing2ndFactor => write!(
f,
": A second factor was requested but no two_factor_callback available"
),
_ => Ok(()),
}
}
}
impl std::error::Error for ConnectionError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Credentials(source) => Some(source),
Self::Io(source) => Some(source),
Self::Session(source) => Some(source),
_ => None,
}
}
}
impl From<CredentialsError> for ConnectionError {
fn from(value: CredentialsError) -> Self {
Self::Credentials(value)
}
}
impl From<std::io::Error> for ConnectionError {
fn from(value: std::io::Error) -> Self {
Self::Io(value)
}
}
impl From<ssh2::Error> for ConnectionError {
fn from(value: ssh2::Error) -> Self {
Self::Session(value)
}
}
#[derive(Debug)]
pub enum DeleteError {
Session(ssh2::Error),
Denied(String),
}
impl std::fmt::Display for DeleteError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "sftp delete error")?;
match self {
Self::Denied(file) => {
write!(
f,
": Cannot delete '{file}'. Only '*.part' files can be deleted"
)
}
_ => Ok(()),
}
}
}
impl std::error::Error for DeleteError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Session(source) => Some(source),
_ => None,
}
}
}
impl From<ssh2::Error> for DeleteError {
fn from(value: ssh2::Error) -> Self {
Self::Session(value)
}
}
type MetadataError = crate::package::error::MetadataError<
<std::path::PathBuf as crate::package::source::PackageStream>::Error,
>;
#[derive(Debug)]
pub enum UploadError {
Io(std::io::Error),
Join(tokio::task::JoinError),
Session(ssh2::Error),
Connection(ConnectionError),
Metadata(MetadataError),
}
impl std::fmt::Display for UploadError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "sftp upload error")
}
}
impl std::error::Error for UploadError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io(source) => Some(source),
Self::Join(source) => Some(source),
Self::Session(source) => Some(source),
Self::Connection(source) => Some(source),
Self::Metadata(source) => Some(source),
}
}
}
impl From<std::io::Error> for UploadError {
fn from(value: std::io::Error) -> Self {
Self::Io(value)
}
}
impl From<tokio::task::JoinError> for UploadError {
fn from(value: tokio::task::JoinError) -> Self {
Self::Join(value)
}
}
impl From<ssh2::Error> for UploadError {
fn from(value: ssh2::Error) -> Self {
Self::Session(value)
}
}
impl From<ConnectionError> for UploadError {
fn from(value: ConnectionError) -> Self {
Self::Connection(value)
}
}
impl From<MetadataError> for UploadError {
fn from(value: MetadataError) -> Self {
Self::Metadata(value)
}
}
}
fn connect(sftp_opts: &ClientBuilder) -> Result<Sftp, error::ConnectionError> {
let mut session = make_session(sftp_opts.host.as_ref(), sftp_opts.port)?;
if let Some(key) = sftp_opts.key_path.as_deref() {
if let Some(password) = &sftp_opts.key_password {
password
.as_inner()
.map(|p| -> Result<(), error::ConnectionError> {
session.userauth_pubkey_file(
sftp_opts.username.as_ref(),
None,
Path::new(key),
Some(std::str::from_utf8(p.as_ref()).map_err(error::CredentialsError)?),
)?;
Ok(())
})?;
} else {
session.userauth_pubkey_file(
sftp_opts.username.as_ref(),
None,
Path::new(key),
None,
)?;
}
} else {
debug!("No SSH key used. Using SSH Agent");
connect_with_agent(sftp_opts.username.as_ref(), &mut session)?;
}
if !session.authenticated() {
let methods = session
.auth_methods(sftp_opts.username.as_ref())
.unwrap_or("none");
if methods != "keyboard-interactive" {
return Err(error::ConnectionError::UnsupportedMethods(methods.into()));
}
debug!(
"Partially connected. Trying second factor. Allowed methods: {}",
methods
);
if let Some(cb) = &sftp_opts.two_factor_callback {
let mut prompt = Prompt { cb };
session.userauth_keyboard_interactive(sftp_opts.username.as_ref(), &mut prompt)?;
} else {
return Err(error::ConnectionError::Missing2ndFactor);
}
}
Ok(session.sftp()?)
}
fn connect_with_agent(username: &str, session: &mut Session) -> Result<(), error::ConnectionError> {
let mut agent = session.agent()?;
agent.connect()?;
agent.list_identities()?;
let identities = agent.identities()?;
let key = &identities.iter().find(|i| {
agent
.userauth(username, i)
.or_else(|e| {
if e.code() == ssh2::ErrorCode::Session(-19) {
Ok(())
} else {
Err(e)
}
})
.map_err(|e| {
debug!("{:?}", e);
e
})
.is_ok()
});
agent.disconnect()?;
if key.is_none() {
return Err(error::ConnectionError::AgentAuth);
}
Ok(())
}
pub(crate) struct UploadDir<'a> {
pub(crate) path: PathBuf,
client: &'a ClientConnected,
}
impl<'a> UploadDir<'a> {
pub(crate) fn new(base_path: &Path, client: &'a ClientConnected) -> Self {
const DATETIME_FORMAT: &str = "%Y%m%dT%H%M%S_%f";
Self {
path: base_path.join(Utc::now().format(DATETIME_FORMAT).to_string()),
client,
}
}
pub(crate) fn create(&self, mode: Option<i32>) -> Result<(), ssh2::Error> {
self.client.inner.mkdir(&self.path, mode.unwrap_or(0o755))?;
Ok(())
}
pub(crate) fn finalize(self) -> Result<(), ssh2::Error> {
const UPLOAD_FINISHED_MARKER_NAME: &str = "done.txt";
self.client
.inner
.create(&self.path.join(UPLOAD_FINISHED_MARKER_NAME))?;
Ok(())
}
pub(crate) fn delete(&self) -> Result<(), error::DeleteError> {
for (file, _) in self.client.inner.readdir(&self.path)? {
if file.extension().is_some_and(|e| e != "part") {
return Err(error::DeleteError::Denied(
file.to_string_lossy().to_string(),
));
}
self.client.inner.unlink(&file)?
}
self.client.inner.rmdir(&self.path)?;
Ok(())
}
}
pub(crate) struct DpkgPath<'a> {
pub(crate) tmp: PathBuf,
pub(crate) path: PathBuf,
client: &'a ClientConnected,
}
impl<'a> DpkgPath<'a> {
pub(crate) fn new<P: AsRef<Path>, S: AsRef<str>>(
base: P,
name: S,
client: &'a ClientConnected,
) -> Self {
const UPLOAD_TMP_SUFFIX: &str = ".part";
let p: PathBuf = base.as_ref().into();
Self {
tmp: p.join(format!("{}.{}", name.as_ref(), UPLOAD_TMP_SUFFIX)),
path: p.join(name.as_ref()),
client,
}
}
pub(crate) fn finalize(&self) -> Result<(), ssh2::Error> {
self.client.inner.rename(
&self.tmp,
&self.path,
Some(RenameFlags::ATOMIC | RenameFlags::NATIVE),
)?;
Ok(())
}
}
#[instrument(skip(progress), err(Debug, level=tracing::Level::ERROR))]
pub async fn upload(
package: &crate::package::Package<PathBuf, crate::package::state::Verified>,
client: &Client,
base_path: &Path,
mode: Mode,
progress: Option<impl ProgressDisplay + Send + 'static>,
) -> Result<Status, error::UploadError> {
let metadata = package.metadata().await?;
let path = package.source().to_path_buf();
let name = package.name.clone();
let base_path = base_path.to_path_buf();
let parent_span = tracing::Span::current();
let client = client.connect()?;
let handle = tokio::task::spawn_blocking(move || -> Result<Status, error::UploadError> {
let thread_span = tracing::info_span!(parent: &parent_span, "sftp upload thread");
let _enter = thread_span.enter();
let source_size = path.metadata()?.len();
let upload_dir = UploadDir::new(&base_path, &client);
let dpkg_path = DpkgPath::new(&upload_dir.path, &name, &client);
let destination = client.get_url(&dpkg_path.path);
if let Mode::Check = mode {
debug!(
destination,
source_size, "Checked {name} for transfer into {destination}"
);
return Ok(Status::Checked {
destination,
source_size,
});
}
upload_dir.create(None)?;
const BUF_SIZE: usize = 1 << 22;
let mut reader = io::BufReader::with_capacity(BUF_SIZE, File::open(path)?);
let mut fout = io::BufWriter::with_capacity(BUF_SIZE, client.inner.create(&dpkg_path.tmp)?);
if let Some(p) = progress {
let mut reader = ProgressReader::new(reader, p.start(source_size));
io::copy(&mut reader, &mut fout)?;
} else {
io::copy(&mut reader, &mut fout)?;
}
dpkg_path.finalize()?;
upload_dir.finalize()?;
info!(
destination,
source_size,
destination_size = source_size,
"Successfully transferred {name} into {destination}"
);
Ok(Status::Completed {
destination,
source_size,
destination_size: source_size,
metadata,
})
});
handle.await?
}
struct Prompt<Cb: Fn() -> String> {
cb: Cb,
}
impl<Cb: Fn() -> String> ssh2::KeyboardInteractivePrompt for Prompt<Cb> {
fn prompt(
&mut self,
username: &str,
instructions: &str,
prompts: &[ssh2::Prompt],
) -> Vec<String> {
debug!(
"prompt: username='{}', instructions='{}', prompts={:?}",
username, instructions, prompts
);
prompts
.iter()
.map(|p| {
debug!("prompting for '{}'", p.text);
let response = (self.cb)();
debug!("Returning '{}'", response);
response
})
.collect()
}
}