use std::fmt;
use std::io;
use async_trait::async_trait;
use futures::{SinkExt, TryStreamExt};
pub use service_binding;
use ssh_key::Signature;
use tokio::io::{AsyncRead, AsyncWrite};
#[cfg(windows)]
use tokio::net::windows::named_pipe::{NamedPipeServer, ServerOptions};
use tokio::net::{TcpListener, TcpStream};
#[cfg(unix)]
use tokio::net::{UnixListener, UnixStream};
use tokio_util::codec::Framed;
use super::error::AgentError;
use super::proto::message::{Request, Response};
use crate::codec::Codec;
use crate::proto::AddIdentity;
use crate::proto::AddIdentityConstrained;
use crate::proto::AddSmartcardKeyConstrained;
use crate::proto::Extension;
use crate::proto::Identity;
use crate::proto::ProtoError;
use crate::proto::RemoveIdentity;
use crate::proto::SignRequest;
use crate::proto::SmartcardKey;
#[async_trait]
pub trait ListeningSocket {
type Stream: fmt::Debug + AsyncRead + AsyncWrite + Send + Unpin + 'static;
async fn accept(&mut self) -> io::Result<Self::Stream>;
}
#[cfg(unix)]
#[async_trait]
impl ListeningSocket for UnixListener {
type Stream = UnixStream;
async fn accept(&mut self) -> io::Result<Self::Stream> {
UnixListener::accept(self).await.map(|(s, _addr)| s)
}
}
#[async_trait]
impl ListeningSocket for TcpListener {
type Stream = TcpStream;
async fn accept(&mut self) -> io::Result<Self::Stream> {
TcpListener::accept(self).await.map(|(s, _addr)| s)
}
}
#[cfg(windows)]
#[derive(Debug)]
pub struct NamedPipeListener(NamedPipeServer, std::ffi::OsString);
#[cfg(windows)]
impl NamedPipeListener {
pub fn bind(pipe: impl Into<std::ffi::OsString>) -> std::io::Result<Self> {
let pipe = pipe.into();
Ok(NamedPipeListener(
ServerOptions::new()
.first_pipe_instance(true)
.create(&pipe)?,
pipe,
))
}
}
#[cfg(windows)]
#[async_trait]
impl ListeningSocket for NamedPipeListener {
type Stream = NamedPipeServer;
async fn accept(&mut self) -> io::Result<Self::Stream> {
self.0.connect().await?;
Ok(std::mem::replace(
&mut self.0,
ServerOptions::new().create(&self.1)?,
))
}
}
#[async_trait]
pub trait Session: 'static + Sync + Send + Unpin {
async fn request_identities(&mut self) -> Result<Vec<Identity>, AgentError> {
Err(AgentError::from(ProtoError::UnsupportedCommand {
command: 11,
}))
}
async fn sign(&mut self, _request: SignRequest) -> Result<Signature, AgentError> {
Err(AgentError::from(ProtoError::UnsupportedCommand {
command: 13,
}))
}
async fn add_identity(&mut self, _identity: AddIdentity) -> Result<(), AgentError> {
Err(AgentError::from(ProtoError::UnsupportedCommand {
command: 17,
}))
}
async fn add_identity_constrained(
&mut self,
_identity: AddIdentityConstrained,
) -> Result<(), AgentError> {
Err(AgentError::from(ProtoError::UnsupportedCommand {
command: 25,
}))
}
async fn remove_identity(&mut self, _identity: RemoveIdentity) -> Result<(), AgentError> {
Err(AgentError::from(ProtoError::UnsupportedCommand {
command: 18,
}))
}
async fn remove_all_identities(&mut self) -> Result<(), AgentError> {
Err(AgentError::from(ProtoError::UnsupportedCommand {
command: 19,
}))
}
async fn add_smartcard_key(&mut self, _key: SmartcardKey) -> Result<(), AgentError> {
Err(AgentError::from(ProtoError::UnsupportedCommand {
command: 20,
}))
}
async fn add_smartcard_key_constrained(
&mut self,
_key: AddSmartcardKeyConstrained,
) -> Result<(), AgentError> {
Err(AgentError::from(ProtoError::UnsupportedCommand {
command: 26,
}))
}
async fn remove_smartcard_key(&mut self, _key: SmartcardKey) -> Result<(), AgentError> {
Err(AgentError::from(ProtoError::UnsupportedCommand {
command: 21,
}))
}
async fn lock(&mut self, _key: String) -> Result<(), AgentError> {
Err(AgentError::from(ProtoError::UnsupportedCommand {
command: 22,
}))
}
async fn unlock(&mut self, _key: String) -> Result<(), AgentError> {
Err(AgentError::from(ProtoError::UnsupportedCommand {
command: 23,
}))
}
async fn extension(&mut self, _extension: Extension) -> Result<Option<Extension>, AgentError> {
Err(AgentError::from(ProtoError::UnsupportedCommand {
command: 27,
}))
}
async fn handle(&mut self, message: Request) -> Result<Response, AgentError> {
match message {
Request::RequestIdentities => {
return Ok(Response::IdentitiesAnswer(self.request_identities().await?))
}
Request::SignRequest(request) => {
return Ok(Response::SignResponse(self.sign(request).await?))
}
Request::AddIdentity(identity) => self.add_identity(identity).await?,
Request::RemoveIdentity(identity) => self.remove_identity(identity).await?,
Request::RemoveAllIdentities => self.remove_all_identities().await?,
Request::AddSmartcardKey(key) => self.add_smartcard_key(key).await?,
Request::RemoveSmartcardKey(key) => self.remove_smartcard_key(key).await?,
Request::Lock(key) => self.lock(key).await?,
Request::Unlock(key) => self.unlock(key).await?,
Request::AddIdConstrained(identity) => self.add_identity_constrained(identity).await?,
Request::AddSmartcardKeyConstrained(key) => {
self.add_smartcard_key_constrained(key).await?
}
Request::Extension(extension) => {
return match self.extension(extension).await? {
Some(response) => Ok(Response::ExtensionResponse(response)),
None => Ok(Response::Success),
}
}
}
Ok(Response::Success)
}
}
async fn handle_socket<S>(
mut session: impl Session,
mut adapter: Framed<S::Stream, Codec<Request, Response>>,
) -> Result<(), AgentError>
where
S: ListeningSocket + fmt::Debug + Send,
{
loop {
if let Some(incoming_message) = adapter.try_next().await? {
log::debug!("Request: {incoming_message:?}");
let response = match session.handle(incoming_message).await {
Ok(message) => message,
Err(AgentError::ExtensionFailure) => {
log::error!("Extension failure handling message");
Response::ExtensionFailure
}
Err(e) => {
log::error!("Error handling message: {:?}", e);
Response::Failure
}
};
log::debug!("Response: {response:?}");
adapter.send(response).await?;
} else {
return Ok(());
}
}
}
pub trait Agent<S>: 'static + Send + Sync
where
S: ListeningSocket + fmt::Debug + Send,
{
fn new_session(&mut self, socket: &S::Stream) -> impl Session;
}
pub async fn listen<S>(mut socket: S, mut agent: impl Agent<S>) -> Result<(), AgentError>
where
S: ListeningSocket + fmt::Debug + Send,
{
log::info!("Listening; socket = {:?}", socket);
loop {
match socket.accept().await {
Ok(socket) => {
let session = agent.new_session(&socket);
tokio::spawn(async move {
let adapter = Framed::new(socket, Codec::<Request, Response>::default());
if let Err(e) = handle_socket::<S>(session, adapter).await {
log::error!("Agent protocol error: {:?}", e);
}
});
}
Err(e) => {
log::error!("Failed to accept socket: {:?}", e);
return Err(AgentError::IO(e));
}
}
}
}
#[cfg(unix)]
impl<T> Agent<tokio::net::UnixListener> for T
where
T: Clone + Send + Sync + Session,
{
fn new_session(&mut self, _socket: &tokio::net::UnixStream) -> impl Session {
Self::clone(self)
}
}
impl<T> Agent<tokio::net::TcpListener> for T
where
T: Clone + Send + Sync + Session,
{
fn new_session(&mut self, _socket: &tokio::net::TcpStream) -> impl Session {
Self::clone(self)
}
}
#[cfg(windows)]
impl<T> Agent<NamedPipeListener> for T
where
T: Clone + Send + Sync + Session,
{
fn new_session(
&mut self,
_socket: &tokio::net::windows::named_pipe::NamedPipeServer,
) -> impl Session {
Self::clone(self)
}
}
#[cfg(unix)]
type PlatformSpecificListener = tokio::net::UnixListener;
#[cfg(windows)]
type PlatformSpecificListener = NamedPipeListener;
pub async fn bind<A>(listener: service_binding::Listener, agent: A) -> Result<(), AgentError>
where
A: Agent<PlatformSpecificListener> + Agent<tokio::net::TcpListener>,
{
match listener {
#[cfg(unix)]
service_binding::Listener::Unix(listener) => {
listen(UnixListener::from_std(listener)?, agent).await
}
service_binding::Listener::Tcp(listener) => {
listen(TcpListener::from_std(listener)?, agent).await
}
#[cfg(windows)]
service_binding::Listener::NamedPipe(pipe) => {
listen(NamedPipeListener::bind(pipe)?, agent).await
}
#[allow(unreachable_patterns)]
_ => Err(AgentError::IO(std::io::Error::other(
"Unsupported type of a listener.",
))),
}
}