use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use ssh_agent_lib::agent::{listen, Session};
use ssh_agent_lib::error::AgentError;
use ssh_agent_lib::proto::{
signature as proto_signature, AddIdentity, AddIdentityConstrained, Credential, Identity,
KeyConstraint, RemoveIdentity, SignRequest,
};
use ssh_key::private::KeypairData;
use ssh_key::{Algorithm, HashAlg, PrivateKey, Signature};
use tokio::sync::Mutex;
use crate::GitwayError;
#[derive(Debug, Clone)]
pub struct AgentDaemonConfig {
pub socket_path: PathBuf,
pub pid_file: Option<PathBuf>,
pub default_ttl: Option<Duration>,
}
#[derive(Debug, Clone)]
struct StoredKey {
key: PrivateKey,
expires_at: Option<Instant>,
confirm: bool,
}
#[derive(Debug, Default)]
struct KeyStore {
keys: HashMap<String, StoredKey>,
lock: Option<String>,
}
impl KeyStore {
fn new() -> Self {
Self::default()
}
fn is_locked(&self) -> bool {
self.lock.is_some()
}
fn evict_expired(&mut self, now: Instant) {
self.keys.retain(|_fp, k| match k.expires_at {
Some(t) => t > now,
None => true,
});
}
}
#[derive(Debug, Clone)]
struct AgentSession {
store: Arc<Mutex<KeyStore>>,
default_ttl: Option<Duration>,
}
#[async_trait]
impl Session for AgentSession {
async fn request_identities(&mut self) -> Result<Vec<Identity>, AgentError> {
let store = self.store.lock().await;
if store.is_locked() {
return Err(AgentError::Failure);
}
Ok(store
.keys
.values()
.map(|s| Identity {
pubkey: s.key.public_key().key_data().clone(),
comment: s.key.comment().to_owned(),
})
.collect())
}
async fn add_identity(&mut self, req: AddIdentity) -> Result<(), AgentError> {
self.add_inner(req, Vec::new()).await
}
async fn add_identity_constrained(
&mut self,
req: AddIdentityConstrained,
) -> Result<(), AgentError> {
self.add_inner(req.identity, req.constraints).await
}
async fn remove_identity(&mut self, req: RemoveIdentity) -> Result<(), AgentError> {
let mut store = self.store.lock().await;
if store.is_locked() {
return Err(AgentError::Failure);
}
let pk = ssh_key::PublicKey::from(req.pubkey);
let fp = pk.fingerprint(HashAlg::Sha256).to_string();
if store.keys.remove(&fp).is_none() {
return Err(AgentError::Failure);
}
Ok(())
}
async fn remove_all_identities(&mut self) -> Result<(), AgentError> {
let mut store = self.store.lock().await;
if store.is_locked() {
return Err(AgentError::Failure);
}
store.keys.clear();
Ok(())
}
async fn sign(&mut self, req: SignRequest) -> Result<Signature, AgentError> {
let pk = ssh_key::PublicKey::from(req.pubkey.clone());
let fp = pk.fingerprint(HashAlg::Sha256).to_string();
let stored = {
let store = self.store.lock().await;
if store.is_locked() {
return Err(AgentError::Failure);
}
store.keys.get(&fp).ok_or(AgentError::Failure)?.clone()
};
if stored.confirm {
let prompt = format!("Allow use of SSH key {fp} ({})?", stored.key.comment());
if !super::askpass::confirm(&prompt).await {
return Err(AgentError::Failure);
}
let store = self.store.lock().await;
if !store.keys.contains_key(&fp) {
return Err(AgentError::Failure);
}
}
sign_with_key(&stored.key, &req.data, req.flags).map_err(|e| {
log::warn!("gitway-agent: sign failed for {fp}: {e}");
AgentError::Failure
})
}
async fn lock(&mut self, key: String) -> Result<(), AgentError> {
let mut store = self.store.lock().await;
if store.is_locked() {
return Err(AgentError::Failure);
}
store.lock = Some(key);
Ok(())
}
async fn unlock(&mut self, key: String) -> Result<(), AgentError> {
let mut store = self.store.lock().await;
match &store.lock {
Some(current) if *current == key => {
store.lock = None;
Ok(())
}
_ => Err(AgentError::Failure),
}
}
}
impl AgentSession {
async fn add_inner(
&mut self,
req: AddIdentity,
constraints: Vec<KeyConstraint>,
) -> Result<(), AgentError> {
let mut store = self.store.lock().await;
if store.is_locked() {
return Err(AgentError::Failure);
}
let key = match req.credential {
Credential::Key { privkey, comment } => {
let mut pk = PrivateKey::try_from(privkey).map_err(|e| {
log::warn!("gitway-agent: add failed to parse credential: {e}");
AgentError::Failure
})?;
pk.set_comment(&comment);
pk
}
Credential::Cert { .. } => {
return Err(AgentError::Failure);
}
};
let mut expires_at = self.default_ttl.map(|d| Instant::now() + d);
let mut confirm = false;
for c in constraints {
match c {
KeyConstraint::Lifetime(secs) => {
expires_at = Some(Instant::now() + Duration::from_secs(u64::from(secs)));
}
KeyConstraint::Confirm => {
confirm = true;
}
KeyConstraint::Extension(_) => {
}
}
}
let fp = key.public_key().fingerprint(HashAlg::Sha256).to_string();
store.keys.insert(
fp,
StoredKey {
key,
expires_at,
confirm,
},
);
Ok(())
}
}
fn sign_with_key(key: &PrivateKey, data: &[u8], flags: u32) -> Result<Signature, GitwayError> {
use signature::Signer;
match key.algorithm() {
Algorithm::Ed25519 | Algorithm::Ecdsa { .. } => key
.try_sign(data)
.map_err(|e| GitwayError::signing(format!("sign failed: {e}"))),
Algorithm::Rsa { .. } => sign_rsa(key, data, flags),
other => Err(GitwayError::invalid_config(format!(
"agent daemon sign: algorithm {} not supported",
other.as_str()
))),
}
}
fn sign_rsa(key: &PrivateKey, data: &[u8], flags: u32) -> Result<Signature, GitwayError> {
use rsa::pkcs1v15::SigningKey;
use rsa::signature::{RandomizedSigner, SignatureEncoding};
use sha2::{Sha256, Sha512};
let KeypairData::Rsa(rsa_keypair) = key.key_data() else {
return Err(GitwayError::signing(
"sign_rsa invoked on non-RSA key".to_string(),
));
};
let private = rsa::RsaPrivateKey::from_components(
rsa::BigUint::try_from(&rsa_keypair.public.n)
.map_err(|e| GitwayError::signing(format!("rsa modulus parse: {e}")))?,
rsa::BigUint::try_from(&rsa_keypair.public.e)
.map_err(|e| GitwayError::signing(format!("rsa exponent parse: {e}")))?,
rsa::BigUint::try_from(&rsa_keypair.private.d)
.map_err(|e| GitwayError::signing(format!("rsa private exponent parse: {e}")))?,
vec![
rsa::BigUint::try_from(&rsa_keypair.private.p)
.map_err(|e| GitwayError::signing(format!("rsa prime p parse: {e}")))?,
rsa::BigUint::try_from(&rsa_keypair.private.q)
.map_err(|e| GitwayError::signing(format!("rsa prime q parse: {e}")))?,
],
)
.map_err(|e| GitwayError::signing(format!("rsa from_components: {e}")))?;
let mut rng = rand_core::OsRng;
let (algorithm, sig_bytes) = if flags & proto_signature::RSA_SHA2_512 != 0 {
let signing = SigningKey::<Sha512>::new(private);
let sig = signing.sign_with_rng(&mut rng, data);
(
Algorithm::Rsa {
hash: Some(HashAlg::Sha512),
},
sig.to_bytes().into_vec(),
)
} else if flags & proto_signature::RSA_SHA2_256 != 0 {
let signing = SigningKey::<Sha256>::new(private);
let sig = signing.sign_with_rng(&mut rng, data);
(
Algorithm::Rsa {
hash: Some(HashAlg::Sha256),
},
sig.to_bytes().into_vec(),
)
} else {
return Err(GitwayError::signing(
"rsa sign: SHA-1 `ssh-rsa` requested but not supported — \
client must request rsa-sha2-256 or rsa-sha2-512 \
(OpenSSH has done so since 8.2)"
.to_string(),
));
};
Signature::new(algorithm, sig_bytes)
.map_err(|e| GitwayError::signing(format!("ssh signature encode: {e}")))
}
pub async fn run(config: AgentDaemonConfig) -> Result<(), GitwayError> {
write_pid_file(config.pid_file.as_deref())?;
let store = Arc::new(Mutex::new(KeyStore::new()));
let session = AgentSession {
store: Arc::clone(&store),
default_ttl: config.default_ttl,
};
let evict_store = Arc::clone(&store);
let evict_handle = tokio::spawn(async move {
let mut ticker = tokio::time::interval(Duration::from_secs(1));
loop {
ticker.tick().await;
let now = Instant::now();
let mut s = evict_store.lock().await;
s.evict_expired(now);
}
});
accept_until_shutdown(&config.socket_path, session).await;
evict_handle.abort();
cleanup(&config);
Ok(())
}
#[cfg(unix)]
async fn accept_until_shutdown(socket_path: &Path, session: AgentSession) {
let listener = match bind_unix_socket(socket_path) {
Ok(l) => l,
Err(e) => {
log::warn!("gitway-agent: bind failed: {e}");
return;
}
};
let ctrl_c = tokio::signal::ctrl_c();
let sigterm = async {
let mut term = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
term.recv().await;
Ok::<_, std::io::Error>(())
};
let accept_loop = listen(listener, session);
tokio::select! {
res = accept_loop => {
if let Err(e) = res {
log::warn!("gitway-agent: accept loop ended with error: {e}");
}
}
_ = ctrl_c => {
log::info!("gitway-agent: SIGINT received, shutting down");
}
_ = sigterm => {
log::info!("gitway-agent: SIGTERM received, shutting down");
}
}
}
#[cfg(windows)]
async fn accept_until_shutdown(socket_path: &Path, session: AgentSession) {
use ssh_agent_lib::agent::NamedPipeListener;
let listener = match NamedPipeListener::bind(socket_path.as_os_str()) {
Ok(l) => l,
Err(e) => {
log::warn!(
"gitway-agent: named-pipe bind failed for {}: {e}",
socket_path.display()
);
return;
}
};
let ctrl_c = tokio::signal::ctrl_c();
let accept_loop = listen(listener, session);
tokio::select! {
res = accept_loop => {
if let Err(e) = res {
log::warn!("gitway-agent: accept loop ended with error: {e}");
}
}
_ = ctrl_c => {
log::info!("gitway-agent: Ctrl+C received, shutting down");
}
}
}
#[cfg(unix)]
fn bind_unix_socket(path: &Path) -> Result<tokio::net::UnixListener, GitwayError> {
use std::os::unix::fs::PermissionsExt as _;
let _ = std::fs::remove_file(path);
let listener = tokio::net::UnixListener::bind(path)?;
let mut perms = std::fs::metadata(path)?.permissions();
perms.set_mode(SOCKET_MODE);
std::fs::set_permissions(path, perms)?;
Ok(listener)
}
fn write_pid_file(path: Option<&Path>) -> Result<(), GitwayError> {
let Some(p) = path else {
return Ok(());
};
let pid = std::process::id();
std::fs::write(p, format!("{pid}\n"))?;
Ok(())
}
fn cleanup(config: &AgentDaemonConfig) {
#[cfg(unix)]
{
let _ = std::fs::remove_file(&config.socket_path);
}
if let Some(ref p) = config.pid_file {
let _ = std::fs::remove_file(p);
}
}
#[cfg(unix)]
const SOCKET_MODE: u32 = 0o600;
#[cfg(test)]
mod tests {
use super::*;
use crate::keygen::{generate, KeyType};
#[test]
fn evict_expired_drops_past_keys_only() {
let key_now = generate(KeyType::Ed25519, None, "now").unwrap();
let key_later = generate(KeyType::Ed25519, None, "later").unwrap();
let fp_now = key_now
.public_key()
.fingerprint(HashAlg::Sha256)
.to_string();
let fp_later = key_later
.public_key()
.fingerprint(HashAlg::Sha256)
.to_string();
let mut store = KeyStore::new();
let past = Instant::now()
.checked_sub(Duration::from_secs(1))
.expect("test runs after process start; Instant never underflows");
store.keys.insert(
fp_now.clone(),
StoredKey {
key: key_now,
expires_at: Some(past),
confirm: false,
},
);
store.keys.insert(
fp_later.clone(),
StoredKey {
key: key_later,
expires_at: Some(Instant::now() + Duration::from_secs(60)),
confirm: false,
},
);
store.evict_expired(Instant::now());
assert!(!store.keys.contains_key(&fp_now));
assert!(store.keys.contains_key(&fp_later));
}
#[test]
fn sign_ed25519_roundtrip_verifies_with_public_key() {
use ed25519_dalek::Verifier as _;
let key = generate(KeyType::Ed25519, None, "roundtrip").unwrap();
let data = b"hello gitway agent";
let sig = sign_with_key(&key, data, 0).unwrap();
assert_eq!(sig.algorithm(), ssh_key::Algorithm::Ed25519);
let ssh_key::public::KeyData::Ed25519(pk) = key.public_key().key_data() else {
unreachable!()
};
let verifying = ed25519_dalek::VerifyingKey::from_bytes(&pk.0).unwrap();
let bytes: [u8; 64] = sig.as_bytes().try_into().unwrap();
let dalek_sig = ed25519_dalek::Signature::from_bytes(&bytes);
verifying.verify(data, &dalek_sig).unwrap();
}
fn sign_verify_roundtrip(kind: KeyType) {
use signature::Verifier;
let key = generate(kind, None, "roundtrip").unwrap();
let data = b"hello gitway agent";
let sig = sign_with_key(&key, data, 0).unwrap();
key.public_key()
.key_data()
.verify(data, &sig)
.unwrap_or_else(|e| panic!("verify failed for {kind:?}: {e}"));
}
#[test]
fn sign_ecdsa_p256_roundtrip() {
sign_verify_roundtrip(KeyType::EcdsaP256);
}
#[test]
fn sign_ecdsa_p384_roundtrip() {
sign_verify_roundtrip(KeyType::EcdsaP384);
}
#[test]
fn sign_ecdsa_p521_roundtrip() {
sign_verify_roundtrip(KeyType::EcdsaP521);
}
fn sign_rsa_roundtrip(flags: u32, expected_hash: HashAlg) {
use signature::Verifier;
let key = generate(KeyType::Rsa, Some(2048), "rsa-roundtrip").unwrap();
let data = b"hello gitway agent";
let sig = sign_with_key(&key, data, flags).unwrap();
assert_eq!(
sig.algorithm(),
Algorithm::Rsa {
hash: Some(expected_hash)
}
);
key.public_key()
.key_data()
.verify(data, &sig)
.expect("rsa roundtrip verify");
}
#[test]
fn sign_rsa_sha256_roundtrip() {
sign_rsa_roundtrip(proto_signature::RSA_SHA2_256, HashAlg::Sha256);
}
#[test]
fn sign_rsa_sha512_roundtrip() {
sign_rsa_roundtrip(proto_signature::RSA_SHA2_512, HashAlg::Sha512);
}
#[test]
fn sign_rsa_prefers_sha512_when_both_flags_set() {
sign_rsa_roundtrip(
proto_signature::RSA_SHA2_256 | proto_signature::RSA_SHA2_512,
HashAlg::Sha512,
);
}
#[test]
fn sign_rsa_rejects_sha1_request() {
let key = generate(KeyType::Rsa, Some(2048), "rsa-sha1").unwrap();
let err = sign_with_key(&key, b"data", 0).unwrap_err();
assert!(err.to_string().contains("SHA-1"), "unexpected error: {err}");
}
}