use std::path::Path;
use std::sync::Arc;
use anyhow::{anyhow, bail, Context, Result};
use tokio::io::{ReadHalf, WriteHalf};
use tokio_util::io::SyncIoBridge;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SshLocation {
pub user: String,
pub host: String,
pub port: u16,
pub path: String,
}
pub fn parse_ssh_url(url: &str) -> Result<SshLocation> {
if let Some(rest) = url.strip_prefix("ssh://") {
let (authority, path) = rest
.split_once('/')
.ok_or_else(|| anyhow!("ssh url `{url}` has no path"))?;
let (user, hostport) = match authority.split_once('@') {
Some((u, hp)) => (u.to_string(), hp),
None => ("git".to_string(), authority),
};
let (host, port) = match hostport.split_once(':') {
Some((h, p)) => (
h.to_string(),
p.parse().with_context(|| format!("ssh url `{url}` bad port"))?,
),
None => (hostport.to_string(), 22),
};
if host.is_empty() {
bail!("ssh url `{url}` has empty host");
}
return Ok(SshLocation {
user,
host,
port,
path: format!("/{path}"),
});
}
if url.contains("://") {
bail!("`{url}` is not an SSH url (expected host:path or ssh://…)");
}
let (userhost, path) = url
.split_once(':')
.ok_or_else(|| anyhow!("`{url}` is not an SSH url (expected host:path or ssh://…)"))?;
let (user, host) = match userhost.split_once('@') {
Some((u, h)) => (u.to_string(), h.to_string()),
None => ("git".to_string(), userhost.to_string()),
};
if host.is_empty() || path.is_empty() {
bail!("ssh url `{url}` has empty host or path");
}
Ok(SshLocation {
user,
host,
port: 22,
path: path.to_string(),
})
}
pub fn parse_ref_advertisement(mut buf: &[u8]) -> Result<Vec<(String, String)>> {
let mut refs = Vec::new();
loop {
if buf.len() < 4 {
break;
}
let len_hex = std::str::from_utf8(&buf[..4]).context("pkt-line length not utf8")?;
let len = usize::from_str_radix(len_hex, 16)
.with_context(|| format!("pkt-line length `{len_hex}` not hex"))?;
if len == 0 {
break;
}
if len < 4 || len > buf.len() {
bail!("pkt-line length {len} out of range (have {} bytes)", buf.len());
}
let payload = &buf[4..len];
buf = &buf[len..];
let line = payload.strip_suffix(b"\n").unwrap_or(payload);
let line = line.split(|&b| b == 0).next().unwrap_or(line);
let text = std::str::from_utf8(line).context("ref line not utf8")?;
if text.starts_with("# service=") {
continue;
}
if let Some((sha, name)) = text.split_once(' ') {
if sha.len() == 40 && sha.bytes().all(|b| b.is_ascii_hexdigit()) {
refs.push((sha.to_string(), name.to_string()));
}
}
}
Ok(refs)
}
struct Client;
impl russh::client::Handler for Client {
type Error = russh::Error;
async fn check_server_key(
&mut self,
_server_public_key: &russh::keys::ssh_key::PublicKey,
) -> Result<bool, Self::Error> {
Ok(true)
}
}
async fn connect_session(
loc: &SshLocation,
key_path: &Path,
) -> Result<russh::client::Handle<Client>> {
use russh::keys::{PrivateKeyWithHashAlg, ssh_key::PrivateKey};
let key_pem = std::fs::read_to_string(key_path)
.with_context(|| format!("read ssh key {}", key_path.display()))?;
let key = PrivateKey::from_openssh(&key_pem)
.with_context(|| format!("parse OpenSSH key {}", key_path.display()))?;
let config = Arc::new(russh::client::Config::default());
let mut session = russh::client::connect(config, (loc.host.as_str(), loc.port), Client)
.await
.with_context(|| format!("ssh connect {}:{}", loc.host, loc.port))?;
let auth = session
.authenticate_publickey(&loc.user, PrivateKeyWithHashAlg::new(Arc::new(key), None))
.await
.context("ssh publickey auth")?;
if !auth.success() {
bail!("ssh publickey auth rejected for {}@{}", loc.user, loc.host);
}
Ok(session)
}
pub async fn ls_remote(loc: &SshLocation, key_path: &Path) -> Result<Vec<(String, String)>> {
let session = connect_session(loc, key_path).await?;
let mut channel = session
.channel_open_session()
.await
.context("ssh open session channel")?;
let cmd = format!("git-upload-pack '{}'", loc.path);
channel.exec(true, cmd.as_bytes()).await.context("ssh exec git-upload-pack")?;
channel
.data(&b"0000"[..])
.await
.context("ssh send flush-pkt")?;
let mut out: Vec<u8> = Vec::new();
while let Some(msg) = channel.wait().await {
match msg {
russh::ChannelMsg::Data { ref data } => out.extend_from_slice(data),
russh::ChannelMsg::Eof | russh::ChannelMsg::ExitStatus { .. } => break,
_ => {}
}
}
parse_ref_advertisement(&out)
}
pub fn ls_remote_blocking(url: &str, key_path: &Path) -> Result<Vec<(String, String)>> {
let loc = parse_ssh_url(url)?;
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.context("build runtime for ssh ls-remote")?;
rt.block_on(ls_remote(&loc, key_path))
}
pub struct UploadPack {
_rt: tokio::runtime::Runtime,
_session: russh::client::Handle<Client>,
pub reader: SyncIoBridge<ReadHalf<russh::ChannelStream<russh::client::Msg>>>,
pub writer: SyncIoBridge<WriteHalf<russh::ChannelStream<russh::client::Msg>>>,
}
pub fn connect_upload_pack(loc: &SshLocation, key_path: &Path) -> Result<UploadPack> {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.context("build runtime for ssh upload-pack")?;
let (session, read_half, write_half) = rt.block_on(async {
let session = connect_session(loc, key_path).await?;
let channel = session
.channel_open_session()
.await
.context("ssh open session channel")?;
channel.set_env(false, "GIT_PROTOCOL", "version=2").await.ok();
let cmd = format!("git-upload-pack '{}'", loc.path);
channel
.exec(true, cmd.as_bytes())
.await
.context("ssh exec git-upload-pack")?;
let (r, w) = tokio::io::split(channel.into_stream());
Ok::<_, anyhow::Error>((session, r, w))
})?;
let handle = rt.handle().clone();
let reader = SyncIoBridge::new_with_handle(read_half, handle.clone());
let writer = SyncIoBridge::new_with_handle(write_half, handle);
Ok(UploadPack { _rt: rt, _session: session, reader, writer })
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_scp_like() {
let l = parse_ssh_url("git@github.com:octocat/Hello-World.git").unwrap();
assert_eq!(l.user, "git");
assert_eq!(l.host, "github.com");
assert_eq!(l.port, 22);
assert_eq!(l.path, "octocat/Hello-World.git");
}
#[test]
fn parse_ssh_scheme_with_port_and_user() {
let l = parse_ssh_url("ssh://deploy@git.example.com:2222/srv/repos/foo.git").unwrap();
assert_eq!(l.user, "deploy");
assert_eq!(l.host, "git.example.com");
assert_eq!(l.port, 2222);
assert_eq!(l.path, "/srv/repos/foo.git");
}
#[test]
fn reject_https() {
assert!(parse_ssh_url("https://github.com/octocat/Hello-World").is_err());
}
#[test]
fn parse_advertisement() {
let sha1 = "7fd1a60b01f91b314f59955a4e4d4e80d8edf11d";
let sha2 = "1111111111111111111111111111111111111111";
let line1_body = format!("{sha1} HEAD\0multi_ack symref=HEAD:refs/heads/main\n");
let line2_body = format!("{sha2} refs/heads/main\n");
let mut buf = Vec::new();
for body in [line1_body, line2_body] {
let len = body.len() + 4;
buf.extend_from_slice(format!("{len:04x}").as_bytes());
buf.extend_from_slice(body.as_bytes());
}
buf.extend_from_slice(b"0000");
let refs = parse_ref_advertisement(&buf).unwrap();
assert_eq!(refs.len(), 2);
assert_eq!(refs[0], (sha1.to_string(), "HEAD".to_string()));
assert_eq!(refs[1], (sha2.to_string(), "refs/heads/main".to_string()));
}
}