ssh-agent-lib 0.5.2

A collection of types for writing custom SSH agents
Documentation
use std::sync::{Arc, Mutex};

use async_trait::async_trait;
use log::info;
use rsa::pkcs1v15::SigningKey;
use rsa::sha2::{Sha256, Sha512};
use rsa::signature::{RandomizedSigner, SignatureEncoding};
use sha1::Sha1;
#[cfg(windows)]
use ssh_agent_lib::agent::NamedPipeListener as Listener;
use ssh_agent_lib::agent::{listen, Session};
use ssh_agent_lib::error::AgentError;
use ssh_agent_lib::proto::extension::{QueryResponse, RestrictDestination, SessionBind};
use ssh_agent_lib::proto::{
    message, signature, AddIdentity, AddIdentityConstrained, AddSmartcardKeyConstrained,
    Credential, Extension, KeyConstraint, RemoveIdentity, SignRequest, SmartcardKey,
};
use ssh_key::{
    private::{KeypairData, PrivateKey},
    public::PublicKey,
    Algorithm, Signature,
};
#[cfg(not(windows))]
use tokio::net::UnixListener as Listener;

#[derive(Clone, PartialEq, Debug)]
struct Identity {
    pubkey: PublicKey,
    privkey: PrivateKey,
    comment: String,
}

#[derive(Default, Clone)]
struct KeyStorage {
    identities: Arc<Mutex<Vec<Identity>>>,
}

impl KeyStorage {
    fn identity_index_from_pubkey(identities: &[Identity], pubkey: &PublicKey) -> Option<usize> {
        for (index, identity) in identities.iter().enumerate() {
            if &identity.pubkey == pubkey {
                return Some(index);
            }
        }
        None
    }

    fn identity_from_pubkey(&self, pubkey: &PublicKey) -> Option<Identity> {
        let identities = self.identities.lock().unwrap();

        let index = Self::identity_index_from_pubkey(&identities, pubkey)?;
        Some(identities[index].clone())
    }

    fn identity_add(&self, identity: Identity) {
        let mut identities = self.identities.lock().unwrap();
        if Self::identity_index_from_pubkey(&identities, &identity.pubkey).is_none() {
            identities.push(identity);
        }
    }

    fn identity_remove(&self, pubkey: &PublicKey) -> Result<(), AgentError> {
        let mut identities = self.identities.lock().unwrap();

        if let Some(index) = Self::identity_index_from_pubkey(&identities, pubkey) {
            identities.remove(index);
            Ok(())
        } else {
            Err(std::io::Error::other("Failed to remove identity: identity not found").into())
        }
    }
}

#[crate::async_trait]
impl Session for KeyStorage {
    async fn sign(&mut self, sign_request: SignRequest) -> Result<Signature, AgentError> {
        let pubkey: PublicKey = sign_request.pubkey.clone().into();

        if let Some(identity) = self.identity_from_pubkey(&pubkey) {
            match identity.privkey.key_data() {
                KeypairData::Rsa(ref key) => {
                    let algorithm;

                    let private_key: rsa::RsaPrivateKey = rsa::RsaPrivateKey::from_components(
                        rsa::BigUint::try_from(&key.public.n).map_err(AgentError::other)?,
                        rsa::BigUint::try_from(&key.public.e).map_err(AgentError::other)?,
                        rsa::BigUint::try_from(&key.private.d).map_err(AgentError::other)?,
                        vec![
                            rsa::BigUint::try_from(&key.private.p).map_err(AgentError::other)?,
                            // ssh-key 0.6.7 uses "p" here so we need to inline a fixed version:
                            rsa::BigUint::try_from(&key.private.q).map_err(AgentError::other)?,
                        ],
                    )
                    .map_err(AgentError::other)?;
                    let mut rng = rand::thread_rng();
                    let data = &sign_request.data;

                    let signature = if sign_request.flags & signature::RSA_SHA2_512 != 0 {
                        algorithm = "rsa-sha2-512";
                        SigningKey::<Sha512>::new(private_key).sign_with_rng(&mut rng, data)
                    } else if sign_request.flags & signature::RSA_SHA2_256 != 0 {
                        algorithm = "rsa-sha2-256";
                        SigningKey::<Sha256>::new(private_key).sign_with_rng(&mut rng, data)
                    } else {
                        algorithm = "ssh-rsa";
                        SigningKey::<Sha1>::new(private_key).sign_with_rng(&mut rng, data)
                    };
                    Ok(Signature::new(
                        Algorithm::new(algorithm).map_err(AgentError::other)?,
                        signature.to_bytes().to_vec(),
                    )
                    .map_err(AgentError::other)?)
                }
                _ => Err(std::io::Error::other("Signature for key type not implemented").into()),
            }
        } else {
            Err(std::io::Error::other("Failed to create signature: identity not found").into())
        }
    }

    async fn request_identities(&mut self) -> Result<Vec<message::Identity>, AgentError> {
        let mut identities = vec![];
        for identity in self.identities.lock().unwrap().iter() {
            identities.push(message::Identity {
                pubkey: identity.pubkey.key_data().clone(),
                comment: identity.comment.clone(),
            })
        }
        Ok(identities)
    }

    async fn add_identity(&mut self, identity: AddIdentity) -> Result<(), AgentError> {
        if let Credential::Key { privkey, comment } = identity.credential {
            let privkey = PrivateKey::try_from(privkey).map_err(AgentError::other)?;
            self.identity_add(Identity {
                pubkey: PublicKey::from(&privkey),
                privkey,
                comment,
            });
            Ok(())
        } else {
            info!("Unsupported key type: {:#?}", identity.credential);
            Ok(())
        }
    }

    async fn add_identity_constrained(
        &mut self,
        identity: AddIdentityConstrained,
    ) -> Result<(), AgentError> {
        let AddIdentityConstrained {
            identity,
            constraints,
        } = identity;
        info!("Would use these constraints: {constraints:#?}");
        for constraint in constraints {
            if let KeyConstraint::Extension(extension) = constraint {
                if let Some(destination) =
                    extension.parse_key_constraint::<RestrictDestination>()?
                {
                    info!("Destination constraint: {destination:?}");
                }

                if let Credential::Key { privkey, comment } = identity.credential.clone() {
                    let privkey = PrivateKey::try_from(privkey).map_err(AgentError::other)?;
                    self.identity_add(Identity {
                        pubkey: PublicKey::from(&privkey),
                        privkey,
                        comment,
                    });
                }
            }
        }
        self.add_identity(identity).await
    }

    async fn remove_identity(&mut self, identity: RemoveIdentity) -> Result<(), AgentError> {
        let pubkey: PublicKey = identity.pubkey.into();
        self.identity_remove(&pubkey)?;
        Ok(())
    }

    async fn add_smartcard_key(&mut self, key: SmartcardKey) -> Result<(), AgentError> {
        info!("Adding smartcard key: {key:?}");

        Ok(())
    }

    async fn add_smartcard_key_constrained(
        &mut self,
        key: AddSmartcardKeyConstrained,
    ) -> Result<(), AgentError> {
        info!("Adding smartcard key with constraints: {key:?}");
        Ok(())
    }
    async fn lock(&mut self, pwd: String) -> Result<(), AgentError> {
        info!("Locked with password: {pwd:?}");
        Ok(())
    }

    async fn unlock(&mut self, pwd: String) -> Result<(), AgentError> {
        info!("Unlocked with password: {pwd:?}");
        Ok(())
    }

    async fn extension(&mut self, extension: Extension) -> Result<Option<Extension>, AgentError> {
        info!("Extension: {extension:?}");

        match extension.name.as_str() {
            "query" => {
                let response = Extension::new_message(QueryResponse {
                    extensions: vec!["query".into(), "session-bind@openssh.com".into()],
                })?;
                Ok(Some(response))
            }
            "session-bind@openssh.com" => match extension.parse_message::<SessionBind>()? {
                Some(bind) => {
                    bind.verify_signature()
                        .map_err(|_| AgentError::ExtensionFailure)?;

                    info!("Session binding: {bind:?}");
                    Ok(None)
                }
                None => Err(AgentError::Failure),
            },
            _ => Err(AgentError::Failure),
        }
    }
}

#[tokio::main]
async fn main() -> Result<(), AgentError> {
    env_logger::init();

    #[cfg(not(windows))]
    let socket = "ssh-agent.sock";
    #[cfg(windows)]
    let socket = r"\\.\pipe\agent";

    let _ = std::fs::remove_file(socket); // remove the socket if exists

    // This is only used for integration tests on Windows:
    #[cfg(windows)]
    std::fs::File::create("server-started")?;

    listen(Listener::bind(socket)?, KeyStorage::default()).await?;
    Ok(())
}