use russh::keys::{ssh_key, PrivateKey};
use russh::server::{Auth, Handler, Msg, Session};
use russh::{Channel, ChannelId, CryptoVec, MethodKind, MethodSet};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::process::Child;
use tokio::sync::watch;
use crate::auth::check_credentials;
use crate::git;
use crate::Config;
pub async fn serve(
config: Arc<Config>,
shutdown: watch::Receiver<bool>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let key_path = config.root.join(".host_key");
let key = if key_path.exists() {
let seed = tokio::fs::read(&key_path).await?;
let seed: [u8; 32] = seed.try_into().map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Invalid host key file (expected 32 bytes). Delete .host_key and restart.",
)
})?;
let sk = ed25519_dalek::SigningKey::from_bytes(&seed);
let vk = ed25519_dalek::VerifyingKey::from(&sk);
PrivateKey::from(ssh_key::private::Ed25519Keypair {
public: ssh_key::public::Ed25519PublicKey::from(vk),
private: ssh_key::private::Ed25519PrivateKey::from(&sk),
})
} else {
let key = PrivateKey::random(
&mut rand::rngs::OsRng,
ssh_key::Algorithm::Ed25519,
)
.expect("Failed to generate Ed25519 key");
if let ssh_key::private::KeypairData::Ed25519(kp) = key.key_data() {
tokio::fs::write(&key_path, kp.private.as_ref()).await?;
}
key
};
let ssh_config = russh::server::Config {
keys: vec![key],
..Default::default()
};
let addr = format!("{}:{}", config.host, config.ssh_port);
let listener = TcpListener::bind(&addr).await?;
let ssh_config = Arc::new(ssh_config);
let mut shutdown = shutdown;
loop {
tokio::select! {
accept = listener.accept() => {
match accept {
Ok((stream, peer)) => {
let handler = SshSession {
config: config.clone(),
channels: HashMap::new(),
processes: HashMap::new(),
};
let cfg = ssh_config.clone();
tokio::spawn(async move {
let session = match russh::server::run_stream(cfg, stream, handler).await {
Ok(s) => s,
Err(e) => {
eprintln!("SSH connection setup failed from {:?}: {}", peer, e);
return;
}
};
if let Err(_e) = session.await {
}
});
}
Err(e) => {
eprintln!("SSH accept error: {}", e);
}
}
}
_ = shutdown.changed() => {
break;
}
}
}
Ok(())
}
struct SshSession {
config: Arc<Config>,
channels: HashMap<ChannelId, Channel<Msg>>,
processes: HashMap<ChannelId, Child>,
}
impl Handler for SshSession {
type Error = russh::Error;
async fn auth_none(&mut self, _user: &str) -> Result<Auth, Self::Error> {
if self.config.user.is_none() {
Ok(Auth::Accept)
} else {
Ok(Auth::Reject {
proceed_with_methods: Some(MethodSet::from(
&[MethodKind::Password][..],
)),
partial_success: false,
})
}
}
async fn auth_password(&mut self, user: &str, password: &str) -> Result<Auth, Self::Error> {
if check_credentials(user, password, &self.config) {
Ok(Auth::Accept)
} else {
Ok(Auth::Reject {
proceed_with_methods: None,
partial_success: false,
})
}
}
async fn channel_open_session(
&mut self,
channel: Channel<Msg>,
_session: &mut Session,
) -> Result<bool, Self::Error> {
self.channels.insert(channel.id(), channel);
Ok(true)
}
async fn exec_request(
&mut self,
channel_id: ChannelId,
data: &[u8],
session: &mut Session,
) -> Result<(), Self::Error> {
let command = String::from_utf8_lossy(data).to_string();
let (git_cmd, repo_path) = match parse_ssh_command(&command) {
Some(v) => v,
None => {
let _ = session.data(
channel_id,
CryptoVec::from_slice(
format!("Error: unsupported command: {}\n", command).as_bytes(),
),
);
let _ = session.close(channel_id);
return Ok(());
}
};
let repo_name = repo_path
.trim_matches('\'')
.trim_matches('"')
.trim_matches('/')
.trim_end_matches(".git");
if repo_name.contains("..") {
let _ = session.data(
channel_id,
CryptoVec::from_slice(b"Error: invalid repo path\n"),
);
let _ = session.close(channel_id);
return Ok(());
}
let full_path = self.config.root.join(repo_name);
if git_cmd == "git-receive-pack" && !git::is_git_repo(&full_path) {
if let Some(parent) = full_path.parent() {
tokio::fs::create_dir_all(parent).await.ok();
}
if let Err(e) = git::init_repo(&full_path, &self.config).await {
let _ = session.data(
channel_id,
CryptoVec::from_slice(format!("Error: {}\n", e).as_bytes()),
);
let _ = session.close(channel_id);
return Ok(());
}
}
if !git::is_git_repo(&full_path) {
let _ = session.data(
channel_id,
CryptoVec::from_slice(
format!("Error: repository '{}' not found\n", repo_name).as_bytes(),
),
);
let _ = session.close(channel_id);
return Ok(());
}
let activity_kind = match git_cmd {
"git-receive-pack" => "push",
"git-upload-pack" => "pull",
_ => "access",
};
let activity_path = full_path.clone();
tokio::spawn(async move {
git::record_activity(&activity_path, activity_kind).await;
});
let child = if git_cmd == "git-upload-archive" {
git::spawn_upload_archive(&full_path).await
} else {
git::spawn_git(&full_path, git_cmd).await
};
let mut child = match child {
Ok(c) => c,
Err(e) => {
let _ = session.data(
channel_id,
CryptoVec::from_slice(format!("Error: {}\n", e).as_bytes()),
);
let _ = session.close(channel_id);
return Ok(());
}
};
let stdout = child.stdout.take();
let stderr = child.stderr.take();
let handle = session.handle();
if let Some(mut stdout) = stdout {
let handle = handle.clone();
tokio::spawn(async move {
let mut buf = vec![0u8; 32768];
loop {
match stdout.read(&mut buf).await {
Ok(0) => break,
Ok(n) => {
if handle
.data(channel_id, CryptoVec::from_slice(&buf[..n]))
.await
.is_err()
{
break;
}
}
Err(_) => break,
}
}
});
}
if let Some(mut stderr) = stderr {
let handle = handle.clone();
tokio::spawn(async move {
let mut buf = vec![0u8; 8192];
loop {
match stderr.read(&mut buf).await {
Ok(0) => break,
Ok(n) => {
if handle
.extended_data(channel_id, 1, CryptoVec::from_slice(&buf[..n]))
.await
.is_err()
{
break;
}
}
Err(_) => break,
}
}
});
}
self.processes.insert(channel_id, child);
Ok(())
}
async fn data(
&mut self,
channel_id: ChannelId,
data: &[u8],
_session: &mut Session,
) -> Result<(), Self::Error> {
if let Some(child) = self.processes.get_mut(&channel_id) {
if let Some(stdin) = child.stdin.as_mut() {
let _ = stdin.write_all(data).await;
}
}
Ok(())
}
async fn channel_eof(
&mut self,
channel_id: ChannelId,
session: &mut Session,
) -> Result<(), Self::Error> {
if let Some(mut child) = self.processes.remove(&channel_id) {
drop(child.stdin.take());
let handle = session.handle();
tokio::spawn(async move {
let status = child.wait().await;
let code = status
.map(|s| s.code().unwrap_or(1) as u32)
.unwrap_or(1);
let _ = handle.exit_status_request(channel_id, code).await;
let _ = handle.eof(channel_id).await;
let _ = handle.close(channel_id).await;
});
}
Ok(())
}
}
fn parse_ssh_command(cmd: &str) -> Option<(&str, &str)> {
let cmd = cmd.trim();
for prefix in [
"git-upload-pack",
"git-receive-pack",
"git-upload-archive",
"git upload-pack",
"git receive-pack",
"git upload-archive",
] {
if let Some(rest) = cmd.strip_prefix(prefix) {
let path = rest.trim();
if !path.is_empty() {
let normalized = match prefix {
"git upload-pack" => "git-upload-pack",
"git receive-pack" => "git-receive-pack",
"git upload-archive" => "git-upload-archive",
other => other,
};
return Some((normalized, path));
}
}
}
None
}