use cotton::prelude::*;
use redis::{Client, Connection, Commands};
use serde::{Serialize, Deserialize};
use maybe_string::MaybeString;
use multistream_batch::channel::buf_batch::{BufBatchChannel, Command, CommandSender};
use multistream_batch::channel::EndOfStreamError;
use std::thread;
use std::collections::{VecDeque, BTreeSet};
use cotton::directories::ProjectDirs;
pub mod crypto;
use crypto::{
PublicKey, SecretKey, SignPublicKey, SignKeypair,
generate_encryption_key, generate_challenge, Challenge,
load_secret_key, load_public_key,
load_signing_secret_key, load_signing_public_key,
default_base_staion_public_key_file, default_base_staion_secret_key_file,
default_network_public_key_file, default_network_secret_key_file,
make_keypair, make_box, open_box, make_signed_box, open_signed_box,
};
const REDIS_TIMEOUT_CONNECTION: Duration = Duration::from_secs(4);
const REDIS_TIMEOUT_RXTX: Duration = Duration::from_secs(10);
const REPLY_BATCH_MAX_SIZE: usize = 1000;
const REPLY_BATCH_MAX_WAIT: Duration = Duration::from_millis(200);
const REPLY_BATCH_QUEUE_SIZE: usize = 1000;
const REPLY_REDIS_LIST_MAX_LEN: u64 = 2000;
const REPLY_REDIS_LIST_MAX_LEN_WAIT: Duration = Duration::from_millis(200);
const REDIS_BLPOP_MAX: usize = 10;
pub fn project_dir() -> ProjectDirs {
ProjectDirs::from("cbradio.crates.io", env!("CARGO_PKG_AUTHORS"), env!("CARGO_PKG_NAME")).or_failed_to("getting project directories")
}
pub fn with_create_dir(path: PathBuf) -> PResult<PathBuf> {
in_context_of("creating directory", || {
let parent = path.parent().ok_or_problem("no parent directory to crate")?;
create_dir_all(parent)?;
Ok(path)
})
}
pub fn timestamp() -> i64 {
Utc::now().timestamp_millis()
}
pub struct Event<'s> {
station: Option<&'s str>,
channel: &'s str,
path: &'s str,
}
impl Display for Event<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(station) = &self.station {
write!(f, "event://{}.{}.cbradio.crates.io/v0/{}", station, self.channel, self.path)
} else {
write!(f, "event://{}.cbradio.crates.io/v0/{}", self.channel, self.path)
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
pub struct Tags(BTreeSet<String>);
impl From<Vec<String>> for Tags {
fn from(value: Vec<String>) -> Tags {
Tags(value.into_iter().collect())
}
}
impl Tags {
pub fn agent_tags_match(&self, agent_tags: &Tags) -> bool {
self.0.is_empty() || self.0.iter().all(|tag| agent_tags.0.contains(tag))
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RequestMessage {
pub from: String,
pub tags: Tags,
pub request: Request,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum Request {
Ping,
Run(String),
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ReplyMessage {
pub from: String,
pub tags: Tags,
pub timestamp: i64,
pub reply: Reply,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum Reply {
Pong,
Error(String),
Run {
path: String,
},
Stdout(MaybeString),
Stderr(MaybeString),
Status {
code: Option<i32>,
signal: Option<i32>,
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct DataFrame {
challenge: Challenge,
sequence_number: u64,
payload: Payload,
}
impl DataFrame {
fn into_payload(self) -> Payload {
self.payload
}
fn from_bytes(value: &[u8], private_key: &SecretKey) -> PResult<(PublicKey, DataFrame)> {
let (public_key, plaintext) = open_box(value, private_key)?;
Ok((public_key, serde_bare::from_slice(&plaintext).problem_while("deserializing reply frame")?))
}
fn to_bytes(&self, public_key: &PublicKey, private_key: &SecretKey) -> PResult<Vec<u8>> {
Ok(make_box(serde_bare::to_vec(&self).problem_while("serializing reply frame")?, public_key, private_key)?)
}
fn from_signed_bytes(value: &[u8], signing_public_key: &SignPublicKey, private_key: &SecretKey) -> PResult<(PublicKey, DataFrame)> {
let (public_key, plaintext) = open_signed_box(value, signing_public_key, private_key)?;
Ok((public_key, serde_bare::from_slice(&plaintext).problem_while("deserializing reply frame")?))
}
fn to_signed_bytes(&self, signing_keypair: &SignKeypair, public_key: &PublicKey, private_key: &SecretKey) -> PResult<Vec<u8>> {
Ok(make_signed_box(serde_bare::to_vec(&self).problem_while("serializing reply frame")?, signing_keypair, public_key, private_key)?)
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Payload(Vec<u8>);
impl TryFrom<RequestMessage> for Payload {
type Error = Problem;
fn try_from(value: RequestMessage) -> Result<Payload, Problem> {
let payload = serde_bare::to_vec(&value).problem_while("serializing request messages")?;
Ok(Payload::from_bytes(payload)?.into())
}
}
impl TryFrom<Payload> for RequestMessage {
type Error = Problem;
fn try_from(value: Payload) -> Result<RequestMessage, Problem> {
Ok(serde_bare::from_slice(&value.into_bytes()?).problem_while("deserializing request messages")?)
}
}
impl TryFrom<Vec<ReplyMessage>> for Payload {
type Error = Problem;
fn try_from(value: Vec<ReplyMessage>) -> Result<Payload, Problem> {
let payload = serde_bare::to_vec(&value).problem_while("serializing reply messages")?;
Ok(Payload::from_bytes(payload)?.into())
}
}
impl TryFrom<Payload> for Vec<ReplyMessage> {
type Error = Problem;
fn try_from(value: Payload) -> Result<Vec<ReplyMessage>, Problem> {
Ok(serde_bare::from_slice(&value.into_bytes()?).problem_while("deserializing reply messages")?)
}
}
impl Payload {
fn from_bytes(payload: Vec<u8>) -> PResult<Payload> {
let mut compressed = Vec::new();
let mut lz4 = lz4::EncoderBuilder::new().level(9).build(&mut compressed).problem_while("creating LZ4 encoder")?;
lz4.write_all(&payload).problem_while("compressing payload with LZ4")?;
lz4.finish().1.problem_while("compressing payload with LZ4")?;
Ok(Payload(compressed))
}
fn into_bytes(self) -> PResult<Vec<u8>> {
let mut payload = Vec::new();
let mut compressed = Cursor::new(self.0);
let mut lz4 = lz4::Decoder::new(&mut compressed).problem_while("creating LZ4 decoder")?;
lz4.read_to_end(&mut payload).problem_while("decompressing payload with LZ4")?;
lz4.finish().1.problem_while("decompressing payload with LZ4")?;
Ok(payload)
}
fn into_data_frame(self, sequence_number: u64, challenge: Challenge) -> DataFrame {
DataFrame {
sequence_number,
challenge,
payload: self,
}
}
}
#[derive(Debug, Clone)]
pub struct ResponseHandler {
sender: CommandSender<(i64, Reply)>,
}
impl ResponseHandler {
pub fn reply(&mut self, reply: Reply) -> PResult<()> {
self.sender.send(Command::Append((timestamp(), reply))).problem_while("sending reply")?;
Ok(())
}
}
pub struct Agent {
station_key: SignPublicKey,
network_key: SecretKey,
identity: String,
tags: Tags,
channel: String,
redis: Client,
}
impl Agent {
pub fn new(connection_string: &str, identity: String, tags: Tags, channel: String, base_station_public_key_file: Option<&Path>, network_secret_key_file: Option<&Path>) -> PResult<Agent> {
let redis = redis::Client::open(connection_string).problem_while("setting up agent Redis client")?;
Ok(Agent {
station_key: load_signing_public_key(base_station_public_key_file.unwrap_or(default_base_staion_public_key_file().as_ref()))?,
network_key: load_secret_key(network_secret_key_file.unwrap_or(default_network_secret_key_file().as_ref()))?,
identity,
tags,
channel,
redis,
})
}
pub fn rx<M>(&mut self, mut on_request: M) -> PResult<()> where M: FnMut(Request, ResponseHandler) -> PResult<()> {
let mut rx = self.redis.get_connection_with_timeout(REDIS_TIMEOUT_CONNECTION).problem_while("making RX Agent connection to Redis")?;
rx.set_read_timeout(Some(REDIS_TIMEOUT_RXTX)).problem_while("setting RX timeout")?;
rx.set_write_timeout(Some(REDIS_TIMEOUT_RXTX)).problem_while("setting RX timeout")?;
let mut rx = rx.as_pubsub();
rx.psubscribe(&Event {
station: Some(&self.identity),
channel: &self.channel,
path: "request"
}.to_string())?; rx.psubscribe(&Event {
station: None,
channel: &self.channel,
path: "request"
}.to_string())?; rx.set_read_timeout(Some(Duration::from_secs(10))).problem_while("setting RX timeout")?;
let mut previous_sequence_number = timestamp() as u64;
loop {
let message = match rx.get_message() {
Ok(message) => message,
Err(err) if err.is_timeout() => {
rx.unsubscribe("test-3424235sdfklf").problem_while("checking RX connection")?;
continue;
}
Err(err) => Err(err).problem_while("getting message")?,
};
let _source: String = message.get_channel().problem_while("getting message channel name")?;
let message: Vec<u8> = message.get_payload().problem_while("getting message payload")?;
let (session_public_key, message_frame) = DataFrame::from_signed_bytes(&message, &self.station_key, &self.network_key)?;
debug!("Got message frame: size: {} B sequence: {}", message.len(), message_frame.sequence_number);
if message_frame.sequence_number <= previous_sequence_number {
return problem!("Bad base station sequence number: got: {} expected greater than: {}", message_frame.sequence_number, previous_sequence_number)
}
previous_sequence_number = message_frame.sequence_number;
let mut reply_sequence_number = message_frame.sequence_number;
let challenge = message_frame.challenge;
let request_message: RequestMessage = message_frame.into_payload().try_into()?;
if !request_message.tags.agent_tags_match(&self.tags) {
debug!("Request with non-matching tags: request tags: {:?}, agent tags: {:?}", request_message.tags.0, self.tags.0);
return Ok(())
}
info!("Got message: from: {}, tags: {:?} request: {:?}", request_message.from, request_message.tags.0, request_message.request);
let mut tx = self.redis.get_connection_with_timeout(REDIS_TIMEOUT_CONNECTION).problem_while("making TX Agent connection to Redis")?;
tx.set_read_timeout(Some(REDIS_TIMEOUT_RXTX)).problem_while("setting TX timeout")?;
tx.set_write_timeout(Some(REDIS_TIMEOUT_RXTX)).problem_while("setting TX timeout")?;
let event = Event {
station: Some(&request_message.from),
channel: &self.channel,
path: "reply"
}.to_string();
let (sender, mut batch) = BufBatchChannel::new(REPLY_BATCH_MAX_SIZE, REPLY_BATCH_MAX_WAIT, REPLY_BATCH_QUEUE_SIZE);
let identity = self.identity.clone();
let tags = self.tags.clone();
let session_reply_key = generate_encryption_key();
let tx_thread = thread::spawn(move || {
in_context_of("TX thread", || loop {
match batch.next() {
Ok(replies) => {
let payload: Payload = replies
.map(|(timestamp, reply)| ReplyMessage {
from: identity.clone(),
tags: tags.clone(),
timestamp,
reply,
})
.collect_vec()
.try_into()?;
reply_sequence_number += 1;
let value = payload.into_data_frame(reply_sequence_number, challenge).to_bytes(&session_public_key, &session_reply_key)?;
let len: u64 = tx.rpush(event.as_str(), &value).problem_while("pushing reply to Redis")?;
if len > REPLY_REDIS_LIST_MAX_LEN {
warn!("Reply queue full: {}", len);
loop {
sleep(REPLY_REDIS_LIST_MAX_LEN_WAIT);
let len: u64 = tx.llen(event.as_str()).problem_while("getting reply queue len")?;
if len <= REPLY_REDIS_LIST_MAX_LEN {
warn!("Resuming sending replies; queue len: {}", len);
break
}
}
}
}
Err(EndOfStreamError) => break Ok(()),
}
}).ok_or_log_error();
});
on_request(request_message.request, ResponseHandler { sender }).ok_or_log_error();
tx_thread.join().ok().ok_or_problem("Joining TX thread")?;
}
}
}
pub struct BaseStation {
network_key: PublicKey,
station_key: SignKeypair,
identity: String,
channel: String,
redis: Client,
}
struct QuickBlopIter {
event: String,
redis: Connection,
buffer: VecDeque<Vec<u8>>,
timeout: usize,
}
impl QuickBlopIter {
fn new(event: String, redis: Connection, timeout: usize) -> PResult<QuickBlopIter> {
let mut ret = QuickBlopIter {
event,
redis,
buffer: VecDeque::with_capacity(REDIS_BLPOP_MAX),
timeout: 0,
};
ret.set_timeout(timeout)?;
Ok(ret)
}
fn set_timeout(&mut self, timeout: usize) -> PResult<()> {
debug!("Setting receive timeout to: {} seconds", timeout);
self.timeout = timeout;
if timeout > 0 {
self.redis.set_read_timeout(Some(Duration::from_secs(timeout as u64) + REDIS_TIMEOUT_RXTX)).problem_while("setting TX timeout")
} else {
self.redis.set_read_timeout(None).problem_while("setting TX timeout")
}
}
}
impl Iterator for QuickBlopIter {
type Item = PResult<Vec<u8>>;
fn next(&mut self) -> Option<PResult<Vec<u8>>> {
if let Some(item) = self.buffer.pop_front() {
return Some(Ok(item))
}
in_context_of("popping items from queue", || {
match self.redis.blpop::<_, Option<(String, Vec<u8>)>>(&self.event, self.timeout).problem_while("awaiting replies")? {
Some((_channel, value)) => {
if REDIS_BLPOP_MAX > 1 {
match self.redis.lrange::<_, Option<Vec<Vec<u8>>>>(&self.event, 0, REDIS_BLPOP_MAX as isize - 1).problem_while("bulk-getting replies")? {
Some(values) => {
self.redis.ltrim(&self.event, values.len() as isize, -1).problem_while("trimming consumed value")?;
self.buffer.extend(values);
}
None => (),
}
}
Ok(Some(value))
}
None => Ok(None),
}
}).transpose()
}
}
pub struct ReplyIter {
queue_iter: QuickBlopIter,
agent_sequences: HashMap<[u8; 32], u64>,
session_key: SecretKey,
sequence_number: u64,
challenge: [u8; 24],
}
impl ReplyIter {
fn new(event: String, redis: Connection, timeout: usize, session_key: SecretKey, sequence_number: u64, challenge: [u8; 24]) -> PResult<ReplyIter> {
Ok(ReplyIter {
queue_iter: QuickBlopIter::new(event, redis, timeout)?,
agent_sequences: Default::default(),
session_key,
sequence_number,
challenge,
})
}
pub fn set_timeout(&mut self, timeout: usize) -> PResult<()> {
self.queue_iter.set_timeout(timeout)
}
}
impl Iterator for ReplyIter {
type Item = PResult<Vec<ReplyMessage>>;
fn next(&mut self) -> Option<PResult<Vec<ReplyMessage>>> {
self.queue_iter.next().map(|value| -> PResult<Vec<ReplyMessage>> {
let value = value?;
let (agent_ephemeral_public_key, reply_frame) = DataFrame::from_bytes(&value, &self.session_key)?;
let agent_id = agent_ephemeral_public_key.as_bytes().to_owned();
debug!("[{}] Got reply frame: size: {} B sequence: {}", hex::encode(agent_id), value.len(), reply_frame.sequence_number);
let agent_previous_sequesnce_number = self.agent_sequences.entry(agent_id).or_insert(self.sequence_number);
if reply_frame.sequence_number <= *agent_previous_sequesnce_number {
return problem!("[{}] Bad agent sequence number: got: {} expected greater than: {}", hex::encode(agent_id), reply_frame.sequence_number, agent_previous_sequesnce_number)
}
(*agent_previous_sequesnce_number) = reply_frame.sequence_number;
if reply_frame.challenge != self.challenge {
return problem!("Bad reply challenge")
}
let reply_messages: Vec<ReplyMessage> = reply_frame.into_payload().try_into()?;
debug!("Got reply messages: {}", reply_messages.len());
Ok(reply_messages)
})
}
}
impl BaseStation {
pub fn new(connection_string: &str, identity: String, channel: String, network_public_key_file: Option<&Path>, base_station_secret_key_file: Option<&Path>) -> PResult<BaseStation> {
let redis = redis::Client::open(connection_string).problem_while("setting up base station Redis connection")?;
Ok(BaseStation {
station_key: make_keypair(load_signing_secret_key(base_station_secret_key_file.unwrap_or(default_base_staion_secret_key_file().as_ref()))?),
network_key: load_public_key(network_public_key_file.unwrap_or(default_network_public_key_file().as_ref()))?,
identity,
channel,
redis,
})
}
pub fn request(&mut self, tags: Tags, quiet_timeout: usize, request: Request) -> PResult<ReplyIter> {
let channel = &self.channel;
let identity = &self.identity;
let mut tx = self.redis.get_connection_with_timeout(REDIS_TIMEOUT_CONNECTION).problem_while("making TX Station connection to Redis")?;
tx.set_read_timeout(Some(REDIS_TIMEOUT_RXTX)).problem_while("setting TX timeout")?;
tx.set_write_timeout(Some(REDIS_TIMEOUT_RXTX)).problem_while("setting TX timeout")?;
let message = RequestMessage {
from: self.identity.clone(),
tags,
request,
};
let event = Event {
station: Some(&identity),
channel: &channel,
path: "reply"
}.to_string();
tx.del(&event).problem_while("flushing messages")?;
let session_key = generate_encryption_key();
let sequence_number = timestamp() as u64;
let challenge = generate_challenge();
info!("Publishing message: from: {}, tags: {:?}, request: {:?}", message.from, message.tags.0, message.request);
let payload: Payload = message.try_into()?;
let message_frame: DataFrame = payload.into_data_frame(sequence_number, challenge);
let message = message_frame.to_signed_bytes(&self.station_key, &self.network_key, &session_key)?;
tx.publish(&Event {
station: None,
channel,
path: "request",
}.to_string(), message)?;
ReplyIter::new(event, tx, quiet_timeout, session_key, sequence_number, challenge)
}
}