use async_trait::async_trait;
use backon::{BlockingRetryable, ExponentialBuilder, Retryable};
use log::*;
use lru::LruCache;
use std::num::NonZeroUsize;
use std::process;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use std::time::SystemTime;
use tokio::sync::{mpsc, oneshot};
use tokio::task::spawn_blocking;
use triggered::{Listener, Trigger};
use lightning_signer::bitcoin::hashes::sha256::Hash as Sha256Hash;
use lightning_signer::bitcoin::hashes::Hash;
use super::adapter::{ChannelReply, ChannelRequest, ClientId};
use vls_common::*;
use vls_protocol::{
msgs, msgs::DeBolt as _, msgs::Message, msgs::SerBolt as _, Error as ProtocolError,
};
use vls_protocol_client::{ClientResult as Result, Error, SignerPort};
use vls_protocol_signer::vls_protocol;
use crate::client::Client;
use crate::*;
const BACKOFF_RETRY_TIMES: usize = 8;
const PREAPPROVE_CACHE_TTL: Duration = Duration::from_secs(60);
const PREAPPROVE_CACHE_SIZE: usize = 6;
struct PreapprovalCacheEntry {
tstamp: SystemTime,
reply_bytes: Vec<u8>,
}
pub struct GrpcSignerPort {
sender: mpsc::Sender<ChannelRequest>,
is_ready: Arc<AtomicBool>,
}
fn backoff() -> ExponentialBuilder {
ExponentialBuilder::new().with_max_times(BACKOFF_RETRY_TIMES)
}
#[async_trait]
impl SignerPort for GrpcSignerPort {
async fn handle_message(&self, message: Vec<u8>) -> Result<Vec<u8>> {
let result = (|| async {
let reply_rx = self.send_request(message.clone()).await?;
let reply = reply_rx.await.map_err(|_| Error::Transport)?;
if reply.is_temporary_failure {
info!("temporary error, retrying");
return Err(Error::TransportTransient);
}
return Ok(reply.reply);
})
.retry(backoff())
.when(|e| *e == Error::TransportTransient)
.await
.map_err(|e| {
error!("signer retry failed: {:?}", e);
e
})?;
Ok(result)
}
fn is_ready(&self) -> bool {
self.is_ready.load(Ordering::Relaxed)
}
}
impl GrpcSignerPort {
pub fn new(sender: mpsc::Sender<ChannelRequest>) -> Self {
GrpcSignerPort { sender, is_ready: Arc::new(AtomicBool::new(false)) }
}
pub(crate) fn set_ready(&self) {
info!("setting is_ready true");
self.is_ready.store(true, Ordering::Relaxed);
}
async fn send_request(&self, message: Vec<u8>) -> Result<oneshot::Receiver<ChannelReply>> {
let (reply_rx, request) = Self::prepare_request(message, None);
self.sender.send(request).await.map_err(|_| ProtocolError::Eof)?;
Ok(reply_rx)
}
fn send_request_blocking(
&self,
message: Vec<u8>,
client_id: Option<ClientId>,
) -> Result<oneshot::Receiver<ChannelReply>> {
let (reply_rx, request) = Self::prepare_request(message, client_id);
self.sender.blocking_send(request).map_err(|_| ProtocolError::Eof)?;
Ok(reply_rx)
}
fn prepare_request(
message: Vec<u8>,
client_id: Option<ClientId>,
) -> (oneshot::Receiver<ChannelReply>, ChannelRequest) {
let (reply_tx, reply_rx) = oneshot::channel();
let request = ChannelRequest { client_id, message, reply_tx };
(reply_rx, request)
}
}
#[derive(Clone)]
pub struct InitMessageCache {
pub init_message: Option<Vec<u8>>,
}
impl InitMessageCache {
pub fn new() -> Self {
Self { init_message: None }
}
}
pub struct SignerLoop<C: 'static + Client> {
client: C,
log_prefix: String,
signer_port: Arc<GrpcSignerPort>,
client_id: Option<ClientId>,
shutdown_trigger: Option<Trigger>,
shutdown_signal: Option<Listener>,
preapproval_cache: LruCache<Sha256Hash, PreapprovalCacheEntry>,
init_message_cache: Arc<Mutex<InitMessageCache>>,
}
impl<C: 'static + Client> SignerLoop<C> {
pub fn new(
client: C,
signer_port: Arc<GrpcSignerPort>,
shutdown_trigger: Trigger,
shutdown_signal: Listener,
init_message_cache: Arc<Mutex<InitMessageCache>>,
) -> Self {
let log_prefix = format!("{}/{}/{}", std::process::id(), client.id(), 0);
let preapproval_cache = LruCache::new(NonZeroUsize::new(PREAPPROVE_CACHE_SIZE).unwrap());
Self {
client,
log_prefix,
signer_port,
client_id: None,
shutdown_trigger: Some(shutdown_trigger),
shutdown_signal: Some(shutdown_signal),
preapproval_cache,
init_message_cache,
}
}
fn new_for_client(client: C, signer_port: Arc<GrpcSignerPort>, client_id: ClientId) -> Self {
let log_prefix = format!("{}/{}/{}", std::process::id(), client.id(), client_id.dbid);
let preapproval_cache = LruCache::new(NonZeroUsize::new(PREAPPROVE_CACHE_SIZE).unwrap());
Self {
client,
log_prefix,
signer_port,
client_id: Some(client_id),
shutdown_trigger: None,
shutdown_signal: None,
preapproval_cache,
init_message_cache: Arc::new(Mutex::new(InitMessageCache::new())),
}
}
fn is_root(&self) -> bool {
self.client_id.is_none()
}
pub fn init_message_cache(&self) -> Arc<Mutex<InitMessageCache>> {
self.init_message_cache.clone()
}
pub fn start(&mut self) {
info!("read loop {}: start", self.log_prefix);
if let Some(shutdown_signal) = self.shutdown_signal.as_ref() {
let shutdown_signal_clone = shutdown_signal.clone();
let log_prefix_clone = self.log_prefix.clone();
tokio::spawn(async move {
info!("read loop {} waiting for shutdown", log_prefix_clone);
tokio::select! {
_ = shutdown_signal_clone => {
info!("read loop {} saw shutdown, calling exit", log_prefix_clone);
process::exit(0);
}
}
});
}
match self.do_loop() {
Ok(()) => info!("read loop {} done", self.log_prefix),
Err(Error::Protocol(ProtocolError::Eof)) => {
info!("read loop {} saw EOF; ending", self.log_prefix)
}
Err(e) => error!("read loop {} saw error {:?}; ending", self.log_prefix, e),
}
if let Some(trigger) = self.shutdown_trigger.as_ref() {
warn!("read loop {} terminated; triggering shutdown", self.log_prefix);
trigger.trigger();
}
}
fn do_loop(&mut self) -> Result<()> {
loop {
let raw_msg = self.client.read_raw()?;
let msg = msgs::from_vec(raw_msg.clone())?;
log_request!(msg);
match msg {
Message::ClientHsmFd(m) => {
self.client.write(msgs::ClientHsmFdReply {}).unwrap();
let new_client = self.client.new_client();
info!("new client {} -> {}", self.log_prefix, new_client.id());
let peer_id = m.peer_id.0;
let client_id = ClientId { peer_id, dbid: m.dbid };
let mut new_loop =
SignerLoop::new_for_client(new_client, self.signer_port.clone(), client_id);
spawn_blocking(move || new_loop.start());
}
Message::PreapproveInvoice(_) | Message::PreapproveKeysend(_) => {
let now = SystemTime::now();
let req_hash = Sha256Hash::hash(&raw_msg);
if let Some(entry) = self.preapproval_cache.get(&req_hash) {
let age = now.duration_since(entry.tstamp).expect("age");
if age < PREAPPROVE_CACHE_TTL {
debug!("{} found in preapproval cache", self.log_prefix);
let reply = entry.reply_bytes.clone();
log_reply!(reply, self);
self.client.write_vec(reply)?;
continue;
}
}
let reply_bytes = self.do_proxy_msg(raw_msg, false)?;
let reply = msgs::from_vec(reply_bytes.clone()).expect("parse reply failed");
match reply {
Message::PreapproveKeysendReply(pkr) =>
if pkr.result == true {
debug!("{} adding keysend to preapproval cache", self.log_prefix);
self.preapproval_cache.put(
req_hash,
PreapprovalCacheEntry { tstamp: now, reply_bytes },
);
},
Message::PreapproveInvoiceReply(pir) =>
if pir.result == true {
debug!("{} adding invoice to preapproval cache", self.log_prefix);
self.preapproval_cache.put(
req_hash,
PreapprovalCacheEntry { tstamp: now, reply_bytes },
);
},
_ => {} }
}
#[cfg(feature = "developer")]
Message::HsmdDevPreinit2(_) => {
if !self.is_root() {
error!(
"read loop {}: unexpected HsmdDevPreinit2 on non-root connection",
self.log_prefix
);
return Err(Error::Protocol(ProtocolError::UnexpectedType(
msgs::HsmdInit::TYPE,
)));
}
_ = self.do_proxy_msg(raw_msg, true)?;
}
Message::HsmdInit(mut m) => {
if !self.is_root() {
error!(
"read loop {}: unexpected HsmdInit on non-root connection",
self.log_prefix
);
return Err(Error::Protocol(ProtocolError::UnexpectedType(
msgs::HsmdInit::TYPE,
)));
}
let raw_reply = self.do_proxy_msg(raw_msg, false)?;
let reply = msgs::from_vec(raw_reply)?;
let init_reply = match reply {
Message::HsmdInitReplyV4(m) => m,
x => {
error!(
"read loop {}: unexpected reply to HsmdInit {:?}",
self.log_prefix, x
);
return Err(Error::Protocol(ProtocolError::UnexpectedType(0)));
}
};
m.hsm_wire_max_version = init_reply.hsm_version;
m.hsm_wire_min_version = init_reply.hsm_version;
let mut init_message_cache = self.init_message_cache.lock().unwrap();
if init_message_cache.init_message.is_some() {
error!("read loop {}: unexpected duplicate HsmdInit", self.log_prefix);
return Err(Error::Protocol(ProtocolError::UnexpectedType(
msgs::HsmdInit::TYPE,
)));
}
init_message_cache.init_message = Some(m.as_vec());
self.signer_port.set_ready();
}
Message::HsmdInit2(m) => {
if !self.is_root() {
error!(
"read loop {}: unexpected HsmdInit on non-root connection",
self.log_prefix
);
return Err(Error::Protocol(ProtocolError::UnexpectedType(
msgs::HsmdInit2::TYPE,
)));
}
self.do_proxy_msg(raw_msg, false)?;
let mut init_message_cache = self.init_message_cache.lock().unwrap();
if init_message_cache.init_message.is_some() {
error!("read loop {}: unexpected duplicate HsmdInit", self.log_prefix);
return Err(Error::Protocol(ProtocolError::UnexpectedType(
msgs::HsmdInit2::TYPE,
)));
}
init_message_cache.init_message = Some(m.as_vec());
}
_ => {
self.do_proxy_msg(raw_msg, false)?;
}
}
}
}
fn do_proxy_msg(&mut self, raw_msg: Vec<u8>, is_oneway: bool) -> Result<Vec<u8>> {
let result = self.handle_message(raw_msg, is_oneway);
if let Err(ref err) = result {
log_error!(err, self);
}
let reply = result?;
if is_oneway {
debug!("oneway sent {}", self.log_prefix);
} else {
log_reply!(reply, self);
self.client.write_vec(reply.clone())?;
debug!("replied {}", self.log_prefix);
}
Ok(reply)
}
fn handle_message(&mut self, message: Vec<u8>, is_oneway: bool) -> Result<Vec<u8>> {
let result = (|| {
info!(
"read loop {}: request {}{}",
self.log_prefix,
msgs::message_name_from_vec(&message),
if is_oneway { " (oneway)" } else { "" }
);
let reply_rx = self.send_request(message.clone())?;
if is_oneway {
Ok(vec![])
} else {
let reply = reply_rx.blocking_recv().map_err(|_| Error::Transport)?;
if reply.is_temporary_failure {
info!("read loop {}: temporary error, retrying", self.log_prefix);
return Err(Error::TransportTransient);
};
info!(
"read loop {}: reply {}",
self.log_prefix,
msgs::message_name_from_vec(&reply.reply)
);
Ok(reply.reply)
}
})
.retry(backoff())
.when(|e| *e == Error::TransportTransient)
.call()
.map_err(|e| {
error!("read loop {}: signer retry failed: {:?}", self.log_prefix, e);
e
})?;
Ok(result)
}
fn send_request(&mut self, message: Vec<u8>) -> Result<oneshot::Receiver<ChannelReply>> {
self.signer_port.send_request_blocking(message, self.client_id.clone())
}
}