use crate::auth::{
certificates::VerifiableCertificate,
session_manager::SessionManager,
transports::Transport,
types::{
AuthMessage, MessageType, PeerSession, RequestedCertificateSet, AUTH_PROTOCOL_ID,
AUTH_VERSION,
},
utils::{create_nonce, get_verifiable_certificates, validate_certificates, verify_nonce},
};
use crate::primitives::to_base64;
use crate::primitives::PublicKey;
use crate::wallet::{
Counterparty, CreateSignatureArgs, GetPublicKeyArgs, Protocol, SecurityLevel,
VerifySignatureArgs, WalletInterface,
};
use crate::{Error, Result};
use rand::RngCore;
use std::collections::HashMap;
use std::future::Future;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{oneshot, RwLock};
#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
async fn wait_with_timeout<F>(future: F, timeout_ms: u64) -> Result<F::Output>
where
F: Future,
{
tokio::time::timeout(Duration::from_millis(timeout_ms), future)
.await
.map_err(|_| Error::AuthError("Handshake timeout".into()))
}
#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
async fn wait_with_timeout<F>(future: F, timeout_ms: u64) -> Result<F::Output>
where
F: Future,
{
use futures::future::{select, Either};
use futures::pin_mut;
let delay = futures_timer::Delay::new(Duration::from_millis(timeout_ms));
pin_mut!(future);
pin_mut!(delay);
match select(future, delay).await {
Either::Left((output, _)) => Ok(output),
Either::Right(((), _)) => Err(Error::AuthError("Handshake timeout".into())),
}
}
pub type GeneralMessageCallback = Box<
dyn Fn(
PublicKey,
Vec<u8>,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
+ Send
+ Sync,
>;
pub type CertificateCallback = Box<
dyn Fn(
PublicKey,
Vec<VerifiableCertificate>,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
+ Send
+ Sync,
>;
pub type CertificateRequestCallback = Box<
dyn Fn(
PublicKey,
RequestedCertificateSet,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
+ Send
+ Sync,
>;
pub struct PeerOptions<W: WalletInterface, T: Transport> {
pub wallet: W,
pub transport: T,
pub certificates_to_request: Option<RequestedCertificateSet>,
pub session_manager: Option<SessionManager>,
pub auto_persist_last_session: bool,
pub originator: Option<String>,
}
pub struct Peer<W: WalletInterface, T: Transport> {
wallet: W,
transport: Arc<T>,
session_manager: Arc<RwLock<SessionManager>>,
certificates_to_request: Option<RequestedCertificateSet>,
general_message_callbacks: Arc<RwLock<HashMap<u32, GeneralMessageCallback>>>,
certificate_callbacks: Arc<RwLock<HashMap<u32, CertificateCallback>>>,
certificate_request_callbacks: Arc<RwLock<HashMap<u32, CertificateRequestCallback>>>,
next_callback_id: AtomicU32,
pending_handshakes: Arc<RwLock<HashMap<String, oneshot::Sender<Result<PeerSession>>>>>,
#[allow(dead_code)]
auto_persist_last_session: bool,
originator: String,
identity_key: Arc<RwLock<Option<PublicKey>>>,
}
impl<W: WalletInterface + 'static, T: Transport + 'static> Peer<W, T> {
pub fn new(options: PeerOptions<W, T>) -> Self {
let originator = options.originator.unwrap_or_else(|| "unknown".to_string());
Self {
wallet: options.wallet,
transport: Arc::new(options.transport),
session_manager: Arc::new(RwLock::new(options.session_manager.unwrap_or_default())),
certificates_to_request: options.certificates_to_request,
general_message_callbacks: Arc::new(RwLock::new(HashMap::new())),
certificate_callbacks: Arc::new(RwLock::new(HashMap::new())),
certificate_request_callbacks: Arc::new(RwLock::new(HashMap::new())),
next_callback_id: AtomicU32::new(1),
pending_handshakes: Arc::new(RwLock::new(HashMap::new())),
auto_persist_last_session: options.auto_persist_last_session,
originator,
identity_key: Arc::new(RwLock::new(None)),
}
}
pub fn start(&self) {
let session_manager = self.session_manager.clone();
let pending_handshakes = self.pending_handshakes.clone();
let general_message_callbacks = self.general_message_callbacks.clone();
let certificate_callbacks = self.certificate_callbacks.clone();
let certificate_request_callbacks = self.certificate_request_callbacks.clone();
self.transport.set_callback(Box::new(move |message| {
let session_manager = session_manager.clone();
let pending_handshakes = pending_handshakes.clone();
let general_message_callbacks = general_message_callbacks.clone();
let certificate_callbacks = certificate_callbacks.clone();
let certificate_request_callbacks = certificate_request_callbacks.clone();
Box::pin(async move {
match message.message_type {
MessageType::InitialResponse => {
let result: Result<(String, String)> = (|| {
let client_nonce = message.your_nonce.as_ref().ok_or_else(|| {
Error::AuthError("InitialResponse missing your_nonce".into())
})?;
let server_nonce = message
.nonce
.as_ref()
.or(message.initial_nonce.as_ref())
.ok_or_else(|| {
Error::AuthError(
"InitialResponse missing nonce and initial_nonce".into(),
)
})?;
Ok((client_nonce.clone(), server_nonce.clone()))
})();
match result {
Ok((client_nonce, server_nonce)) => {
let session_result = {
let mut mgr = session_manager.write().await;
if let Some(existing) = mgr.get_session(&client_nonce).cloned()
{
let mut updated = existing;
updated.peer_identity_key =
Some(message.identity_key.clone());
updated.peer_nonce = Some(server_nonce.clone());
updated.is_authenticated = true;
updated.touch();
let session_clone = updated.clone();
mgr.update_session(updated);
Ok(session_clone)
} else {
let mut session =
PeerSession::with_nonce(client_nonce.clone());
session.peer_identity_key =
Some(message.identity_key.clone());
session.peer_nonce = Some(server_nonce.clone());
session.is_authenticated = true;
session.touch();
mgr.add_session(session.clone()).map(|_| session)
}
};
let mut pending = pending_handshakes.write().await;
if let Some(tx) = pending.remove(&client_nonce) {
let _ = tx.send(session_result);
}
}
Err(e) => {
if let Some(ref client_nonce) = message.your_nonce {
let mut pending = pending_handshakes.write().await;
if let Some(tx) = pending.remove(client_nonce) {
let _ = tx.send(Err(e.clone()));
}
}
return Err(e);
}
}
}
MessageType::General => {
let payload = message.payload.clone().unwrap_or_default();
let sender = message.identity_key.clone();
let cbs = general_message_callbacks.read().await;
for (_, callback) in cbs.iter() {
callback(sender.clone(), payload.clone()).await?;
}
}
MessageType::CertificateRequest => {
let sender = message.identity_key.clone();
let requested = message.requested_certificates.clone().unwrap_or_default();
let cbs = certificate_request_callbacks.read().await;
for (_, callback) in cbs.iter() {
callback(sender.clone(), requested.clone()).await?;
}
}
MessageType::CertificateResponse => {
let sender = message.identity_key.clone();
let certs = message.certificates.clone().unwrap_or_default();
let cbs = certificate_callbacks.read().await;
for (_, callback) in cbs.iter() {
callback(sender.clone(), certs.clone()).await?;
}
}
MessageType::InitialRequest => {
}
}
Ok(())
})
}));
}
pub fn start_server(self: &Arc<Self>) {
let peer = Arc::clone(self);
self.transport.set_callback(Box::new(move |message| {
let peer = Arc::clone(&peer);
Box::pin(async move { peer.handle_incoming_message(message).await })
}));
}
pub async fn to_peer(
&self,
message: &[u8],
identity_key: Option<&str>,
max_wait_time: Option<u64>,
) -> Result<()> {
let session = self
.get_authenticated_session(identity_key, max_wait_time)
.await?;
let my_identity = self.get_identity_key().await?;
let mut msg = AuthMessage::new(MessageType::General, my_identity);
let mut random_bytes = [0u8; 32];
rand::thread_rng().fill_bytes(&mut random_bytes);
let msg_nonce = to_base64(&random_bytes);
msg.nonce = Some(msg_nonce);
msg.your_nonce = session.peer_nonce.clone();
msg.payload = Some(message.to_vec());
self.sign_message(&mut msg, &session).await?;
self.transport.send(&msg).await?;
Ok(())
}
pub async fn get_authenticated_session(
&self,
identity_key: Option<&str>,
max_wait_time: Option<u64>,
) -> Result<PeerSession> {
if let Some(key) = identity_key {
let mgr = self.session_manager.read().await;
if let Some(session) = mgr.get_session(key) {
if session.is_authenticated {
return Ok(session.clone());
}
}
}
self.initiate_handshake(identity_key, max_wait_time).await
}
pub async fn request_certificates(
&self,
requested: RequestedCertificateSet,
identity_key: Option<&str>,
max_wait_time: Option<u64>,
) -> Result<()> {
let session = self
.get_authenticated_session(identity_key, max_wait_time)
.await?;
let my_identity = self.get_identity_key().await?;
let mut msg = AuthMessage::new(MessageType::CertificateRequest, my_identity);
let mut random_bytes = [0u8; 32];
rand::thread_rng().fill_bytes(&mut random_bytes);
msg.nonce = Some(to_base64(&random_bytes));
msg.your_nonce = session.peer_nonce.clone();
msg.requested_certificates = Some(requested);
self.sign_message(&mut msg, &session).await?;
self.transport.send(&msg).await
}
pub async fn send_certificate_response(
&self,
verifier_identity_key: &str,
certificates: Vec<VerifiableCertificate>,
) -> Result<()> {
let mgr = self.session_manager.read().await;
let session = mgr
.get_session(verifier_identity_key)
.ok_or_else(|| Error::AuthError("No session with peer".into()))?
.clone();
drop(mgr);
let my_identity = self.get_identity_key().await?;
let mut msg = AuthMessage::new(MessageType::CertificateResponse, my_identity);
let mut random_bytes = [0u8; 32];
rand::thread_rng().fill_bytes(&mut random_bytes);
msg.nonce = Some(to_base64(&random_bytes));
msg.your_nonce = session.peer_nonce.clone();
msg.certificates = Some(certificates);
self.sign_message(&mut msg, &session).await?;
self.transport.send(&msg).await
}
pub async fn listen_for_general_messages<F>(&self, callback: F) -> u32
where
F: Fn(
PublicKey,
Vec<u8>,
)
-> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
+ Send
+ Sync
+ 'static,
{
let id = self.next_callback_id.fetch_add(1, Ordering::SeqCst);
let mut cbs = self.general_message_callbacks.write().await;
cbs.insert(id, Box::new(callback));
id
}
pub async fn stop_listening_for_general_messages(&self, callback_id: u32) {
let mut cbs = self.general_message_callbacks.write().await;
cbs.remove(&callback_id);
}
pub async fn listen_for_certificates_received<F>(&self, callback: F) -> u32
where
F: Fn(
PublicKey,
Vec<VerifiableCertificate>,
)
-> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
+ Send
+ Sync
+ 'static,
{
let id = self.next_callback_id.fetch_add(1, Ordering::SeqCst);
let mut cbs = self.certificate_callbacks.write().await;
cbs.insert(id, Box::new(callback));
id
}
pub async fn stop_listening_for_certificates_received(&self, callback_id: u32) {
let mut cbs = self.certificate_callbacks.write().await;
cbs.remove(&callback_id);
}
pub async fn listen_for_certificates_requested<F>(&self, callback: F) -> u32
where
F: Fn(
PublicKey,
RequestedCertificateSet,
)
-> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
+ Send
+ Sync
+ 'static,
{
let id = self.next_callback_id.fetch_add(1, Ordering::SeqCst);
let mut cbs = self.certificate_request_callbacks.write().await;
cbs.insert(id, Box::new(callback));
id
}
pub async fn stop_listening_for_certificates_requested(&self, callback_id: u32) {
let mut cbs = self.certificate_request_callbacks.write().await;
cbs.remove(&callback_id);
}
pub fn session_manager(&self) -> &Arc<RwLock<SessionManager>> {
&self.session_manager
}
pub async fn get_identity_key(&self) -> Result<PublicKey> {
{
let cached = self.identity_key.read().await;
if let Some(ref key) = *cached {
return Ok(key.clone());
}
}
let result = self
.wallet
.get_public_key(
GetPublicKeyArgs {
identity_key: true,
protocol_id: None,
key_id: None,
counterparty: None,
for_self: None,
},
&self.originator,
)
.await?;
let key = PublicKey::from_hex(&result.public_key)?;
let mut cached = self.identity_key.write().await;
*cached = Some(key.clone());
Ok(key)
}
pub async fn handle_incoming_message(&self, message: AuthMessage) -> Result<()> {
if message.version != AUTH_VERSION {
return Err(Error::AuthError(format!(
"Invalid auth version: expected {}, got {}",
AUTH_VERSION, message.version
)));
}
match message.message_type {
MessageType::InitialRequest => self.process_initial_request(message).await,
MessageType::InitialResponse => self.process_initial_response(message).await,
MessageType::CertificateRequest => self.process_certificate_request(message).await,
MessageType::CertificateResponse => self.process_certificate_response(message).await,
MessageType::General => self.process_general_message(message).await,
}
}
async fn sign_message(&self, message: &mut AuthMessage, session: &PeerSession) -> Result<()> {
let data = message.signing_data();
let key_id = message.get_key_id(session.peer_nonce.as_deref());
let protocol = Protocol::new(SecurityLevel::Counterparty, AUTH_PROTOCOL_ID);
let counterparty = session
.peer_identity_key
.as_ref()
.map(|k| Counterparty::Other(k.clone()));
let result = self
.wallet
.create_signature(
CreateSignatureArgs {
data: Some(data.clone()),
hash_to_directly_sign: None,
protocol_id: protocol,
key_id: key_id.clone(),
counterparty,
},
&self.originator,
)
.await?;
message.signature = Some(result.signature);
Ok(())
}
async fn verify_message_signature(
&self,
message: &AuthMessage,
session: &PeerSession,
) -> Result<bool> {
let data = message.signing_data();
let key_id = message.get_key_id(session.session_nonce.as_deref());
let signature = message
.signature
.as_ref()
.ok_or_else(|| Error::AuthError("Message not signed".into()))?;
let protocol = Protocol::new(SecurityLevel::Counterparty, AUTH_PROTOCOL_ID);
let result = self
.wallet
.verify_signature(
VerifySignatureArgs {
data: Some(data),
hash_to_directly_verify: None,
signature: signature.clone(),
protocol_id: protocol,
key_id,
counterparty: Some(Counterparty::Other(message.identity_key.clone())),
for_self: None,
},
&self.originator,
)
.await?;
Ok(result.valid)
}
async fn initiate_handshake(
&self,
_identity_key: Option<&str>,
max_wait_time: Option<u64>,
) -> Result<PeerSession> {
let my_identity = self.get_identity_key().await?;
let session_nonce = create_nonce(&self.wallet, None, &self.originator).await?;
let session = PeerSession::with_nonce(session_nonce.clone());
{
let mut mgr = self.session_manager.write().await;
mgr.add_session(session.clone())?;
}
let mut msg = AuthMessage::new(MessageType::InitialRequest, my_identity);
msg.initial_nonce = Some(session_nonce.clone());
if let Some(ref req) = self.certificates_to_request {
msg.requested_certificates = Some(req.clone());
}
let (tx, rx) = oneshot::channel();
{
let mut pending = self.pending_handshakes.write().await;
pending.insert(session_nonce.clone(), tx);
}
self.transport.send(&msg).await?;
let timeout = max_wait_time.unwrap_or(30000);
let result = wait_with_timeout(rx, timeout)
.await?
.map_err(|_| Error::AuthError("Handshake cancelled".into()))??;
Ok(result)
}
async fn process_initial_request(&self, message: AuthMessage) -> Result<()> {
let my_identity = self.get_identity_key().await?;
let session_nonce = create_nonce(&self.wallet, None, &self.originator).await?;
let mut session = PeerSession::with_nonce(session_nonce.clone());
session.peer_identity_key = Some(message.identity_key.clone());
session.peer_nonce = message.initial_nonce.clone();
session.is_authenticated = true;
session.touch();
if let Some(ref req) = message.requested_certificates {
if !req.is_empty() {
session.certificates_required = true;
}
}
{
let mut mgr = self.session_manager.write().await;
mgr.add_session(session.clone())?;
}
let mut response = AuthMessage::new(MessageType::InitialResponse, my_identity);
response.nonce = Some(session_nonce.clone());
response.initial_nonce = Some(session_nonce);
response.your_nonce = message.initial_nonce.clone();
self.sign_message(&mut response, &session).await?;
self.transport.send(&response).await?;
if let Some(ref req) = message.requested_certificates {
if !req.is_empty() {
let certs = get_verifiable_certificates(
&self.wallet,
req,
&message.identity_key,
&self.originator,
)
.await
.unwrap_or_default();
if !certs.is_empty() {
self.send_certificate_response(&message.identity_key.to_hex(), certs)
.await?;
}
}
}
Ok(())
}
async fn process_initial_response(&self, message: AuthMessage) -> Result<()> {
let our_nonce = message
.your_nonce
.as_ref()
.ok_or_else(|| Error::AuthError("InitialResponse missing your_nonce".into()))?
.clone();
let result = self
.process_initial_response_inner(&message, &our_nonce)
.await;
if let Err(ref e) = result {
let mut pending = self.pending_handshakes.write().await;
if let Some(tx) = pending.remove(&our_nonce) {
let _ = tx.send(Err(e.clone()));
}
}
result
}
async fn process_initial_response_inner(
&self,
message: &AuthMessage,
our_nonce: &str,
) -> Result<()> {
let temp_session = PeerSession {
session_nonce: Some(our_nonce.to_string()),
peer_identity_key: Some(message.identity_key.clone()),
peer_nonce: message.initial_nonce.clone(),
..Default::default()
};
if !self
.verify_message_signature(message, &temp_session)
.await?
{
return Err(Error::AuthError("InitialResponse signature invalid".into()));
}
let nonce_to_verify = message
.nonce
.as_deref()
.or(message.initial_nonce.as_deref())
.unwrap_or("");
if !verify_nonce(
nonce_to_verify,
&self.wallet,
Some(&message.identity_key),
&self.originator,
)
.await
.unwrap_or(false)
{
}
{
let mut mgr = self.session_manager.write().await;
if let Some(existing) = mgr.get_session(our_nonce).cloned() {
let mut updated = existing;
updated.peer_identity_key = Some(message.identity_key.clone());
updated.peer_nonce = message.initial_nonce.clone();
updated.is_authenticated = true;
updated.touch();
if let Some(ref req) = self.certificates_to_request {
if !req.is_empty() {
updated.certificates_required = true;
}
}
let session_clone = updated.clone();
mgr.update_session(updated);
let mut pending = self.pending_handshakes.write().await;
if let Some(tx) = pending.remove(our_nonce) {
let _ = tx.send(Ok(session_clone));
}
}
}
Ok(())
}
async fn process_certificate_request(&self, message: AuthMessage) -> Result<()> {
let sender_hex = message.identity_key.to_hex();
let mgr = self.session_manager.read().await;
let session = mgr
.get_session(&sender_hex)
.ok_or_else(|| Error::AuthError("No session with sender".into()))?;
if !self.verify_message_signature(&message, session).await? {
return Err(Error::AuthError(
"CertificateRequest signature invalid".into(),
));
}
drop(mgr);
if let Some(ref requested) = message.requested_certificates {
let cbs = self.certificate_request_callbacks.read().await;
for (_, callback) in cbs.iter() {
let _ = callback(message.identity_key.clone(), requested.clone()).await;
}
}
Ok(())
}
async fn process_certificate_response(&self, message: AuthMessage) -> Result<()> {
let sender_hex = message.identity_key.to_hex();
let mgr = self.session_manager.read().await;
let session = mgr
.get_session(&sender_hex)
.ok_or_else(|| Error::AuthError("No session with sender".into()))?
.clone();
drop(mgr);
if !self.verify_message_signature(&message, &session).await? {
return Err(Error::AuthError(
"CertificateResponse signature invalid".into(),
));
}
validate_certificates(
&self.wallet,
&message,
self.certificates_to_request.as_ref(),
&self.originator,
)
.await?;
{
let mut mgr = self.session_manager.write().await;
if let Some(session) = mgr.get_session_mut(&sender_hex) {
session.certificates_validated = true;
session.touch();
}
}
if let Some(ref certs) = message.certificates {
let cbs = self.certificate_callbacks.read().await;
for (_, callback) in cbs.iter() {
let _ = callback(message.identity_key.clone(), certs.clone()).await;
}
}
Ok(())
}
async fn process_general_message(&self, message: AuthMessage) -> Result<()> {
let sender_hex = message.identity_key.to_hex();
let mgr = self.session_manager.read().await;
let session = mgr
.get_session(&sender_hex)
.ok_or_else(|| Error::AuthError("No session with sender".into()))?
.clone();
drop(mgr);
if !session.is_authenticated {
return Err(Error::AuthError("Session not authenticated".into()));
}
if !self.verify_message_signature(&message, &session).await? {
return Err(Error::AuthError("General message signature invalid".into()));
}
{
let mut mgr = self.session_manager.write().await;
if let Some(s) = mgr.get_session_mut(&sender_hex) {
s.peer_nonce = message.nonce.clone();
s.touch();
}
}
if let Some(ref payload) = message.payload {
let cbs = self.general_message_callbacks.read().await;
for (_, callback) in cbs.iter() {
let _ = callback(message.identity_key.clone(), payload.clone()).await;
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::transports::MockTransport;
use crate::primitives::PrivateKey;
use crate::wallet::ProtoWallet;
fn make_peer() -> Peer<ProtoWallet, MockTransport> {
let wallet = ProtoWallet::new(Some(PrivateKey::random()));
let transport = MockTransport::new();
Peer::new(PeerOptions {
wallet,
transport,
certificates_to_request: None,
session_manager: None,
auto_persist_last_session: false,
originator: Some("test".into()),
})
}
#[tokio::test]
async fn test_peer_creation() {
let peer = make_peer();
let identity = peer.get_identity_key().await.unwrap();
assert_eq!(identity.to_compressed().len(), 33);
}
#[tokio::test]
async fn test_listener_registration() {
let peer = make_peer();
let id = peer
.listen_for_general_messages(|_sender, _payload| Box::pin(async { Ok(()) }))
.await;
assert!(id > 0);
peer.stop_listening_for_general_messages(id).await;
}
#[tokio::test]
async fn test_session_manager_access() {
let peer = make_peer();
let mgr = peer.session_manager.read().await;
assert!(mgr.is_empty());
}
}