use crate::{
ManagerPeerEntry, Message,
connection::{
CommunicationContext, ManagerContext, NetworkContext, SecurityContext, StateContext,
handle_manager_as_client, handle_manager_as_server, worker::handle_worker_join,
},
peers::AliveTable,
workers::WorkerTable,
};
use ed25519_dalek::SigningKey;
use eyre::{Report, eyre};
use getrandom::fill;
use ipnet::IpNet;
use std::{
collections::HashMap,
net::{IpAddr, Ipv4Addr, SocketAddr},
sync::{
Arc,
atomic::{AtomicU32, AtomicU64, Ordering},
},
};
use tokio::{
sync::{Mutex, RwLock, broadcast},
task::JoinHandle,
};
use volli_core::{
ManagerPeerEntry as CorePeer, Role, WorkerConfig,
token::{encode_token, issue_token},
};
use volli_transport::{MemoryTransport, MessageTransportExt, Transport};
pub struct InMemoryHarness {
tenant: String,
cluster: String,
csk_seed: [u8; 32],
csk: Arc<RwLock<[u8; 32]>>,
csk_ver: Arc<AtomicU32>,
}
#[derive(Default)]
pub struct ManagerOptions {
pub worker_whitelist: Vec<IpNet>,
pub manager_whitelist: Vec<IpNet>,
}
impl InMemoryHarness {
pub fn new(tenant: impl Into<String>, cluster: impl Into<String>) -> Self {
let csk_seed = [7u8; 32];
Self {
tenant: tenant.into(),
cluster: cluster.into(),
csk_seed,
csk: Arc::new(RwLock::new(csk_seed)),
csk_ver: Arc::new(AtomicU32::new(1)),
}
}
pub fn with_csk(mut self, key: [u8; 32], version: u32) -> Self {
self.csk_seed = key;
self.csk = Arc::new(RwLock::new(key));
self.csk_ver = Arc::new(AtomicU32::new(version));
self
}
fn issue_worker_token(&self, worker_id: &str) -> String {
encode_token(
&issue_token(&self.csk_seed, &self.tenant, &self.cluster, worker_id, 3600)
.expect("token issue"),
)
.expect("token encode")
}
pub fn manager(&self, manager_id: impl Into<String>) -> ManagerHandle {
self.manager_with_options(manager_id, ManagerOptions::default())
}
pub fn manager_with_options(
&self,
manager_id: impl Into<String>,
options: ManagerOptions,
) -> ManagerHandle {
ManagerHandle {
inner: Arc::new(ManagerHandleInner::new(manager_id.into(), self, options)),
}
}
pub fn worker(&self, worker_id: impl Into<String>) -> WorkerHandle {
let worker_id = worker_id.into();
let token = self.issue_worker_token(&worker_id);
WorkerHandle {
token,
options: WorkerOptions::default(),
}
}
pub fn worker_token(&self, worker_id: &str) -> String {
self.issue_worker_token(worker_id)
}
}
#[derive(Clone)]
pub struct ManagerHandle {
inner: Arc<ManagerHandleInner>,
}
impl ManagerHandle {
pub fn connect(&self, other: &ManagerHandle) -> ConnectionHandle {
let (client, server) = MemoryTransport::pair();
let dial_ctx = self.inner.client_context();
let accept_ctx = other.inner.server_context();
let server_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0);
let server_task: JoinHandle<Result<(), Report>> = tokio::spawn(async move {
handle_manager_as_server(&accept_ctx, Box::new(server), server_addr).await
});
let peer_id = other.inner.self_meta.manager_id.clone();
let client_task: JoinHandle<Result<(), Report>> = tokio::spawn(async move {
handle_manager_as_client(dial_ctx, Box::new(client), peer_id).await
});
ConnectionHandle {
client_task: Some(client_task),
server_task: Some(server_task),
}
}
pub fn spawn_worker_task<T>(&self, transport: T) -> JoinHandle<Result<(), Report>>
where
T: Transport + Send + 'static,
{
self.spawn_worker_task_with_addr(
transport,
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
)
}
pub fn spawn_worker_task_with_addr<T>(
&self,
transport: T,
addr: SocketAddr,
) -> JoinHandle<Result<(), Report>>
where
T: Transport + Send + 'static,
{
let ctx = self.inner.server_context();
tokio::spawn(async move { handle_worker_join(&ctx, Box::new(transport), addr).await })
}
pub fn meta(&self) -> &ManagerPeerEntry {
&self.inner.self_meta
}
pub fn known_peer_ids(&self) -> Vec<String> {
crate::peers::cached_peers_for_profile(&self.inner.profile)
.unwrap_or_default()
.into_iter()
.map(|p| p.manager_id)
.collect()
}
pub fn peer_version(&self) -> Arc<AtomicU64> {
self.inner.peer_version.clone()
}
pub fn broadcast_version(&self) {
let ver = self.inner.peer_version.load(Ordering::SeqCst);
let _ = self.inner.alive_tx.send(ver);
}
pub async fn set_worker_count(&self, count: u32) {
let collector = self.inner.health.collector.lock().await;
collector.set_worker_count(count);
}
}
struct ManagerHandleInner {
self_meta: CorePeer,
peers: AliveTable,
workers: WorkerTable,
peer_version: Arc<AtomicU64>,
alive_tx: broadcast::Sender<u64>,
dial_tx: crate::connection::mesh::DialTx,
csk: Arc<RwLock<[u8; 32]>>,
csk_ver: Arc<AtomicU32>,
worker_cfg: WorkerConfig,
profile: String,
signing: Arc<SigningKey>,
health: crate::connection::HealthContext,
worker_nets: Arc<Vec<IpNet>>,
manager_nets: Arc<Vec<IpNet>>,
}
impl ManagerHandleInner {
fn new(manager_id: String, harness: &InMemoryHarness, options: ManagerOptions) -> Self {
let peers: AliveTable = Arc::new(Mutex::new(HashMap::new()));
let workers: WorkerTable = Arc::new(Mutex::new(HashMap::new()));
let peer_version = Arc::new(AtomicU64::new(1));
let (alive_tx, _) = broadcast::channel(32);
let (dial_tx, _) = tokio::sync::mpsc::unbounded_channel();
let health = crate::connection::HealthContext::new(crate::health::HealthConfig::default());
let signing = Arc::new({
let mut seed = [0u8; 32];
fill(&mut seed).expect("rng");
SigningKey::from_bytes(&seed)
});
let token = encode_token(
&issue_token(
&harness.csk_seed,
&harness.tenant,
&harness.cluster,
&manager_id,
3600,
)
.expect("token issue"),
)
.expect("token encode");
let worker_cfg = WorkerConfig {
role: Role::Manager,
token,
..Default::default()
};
let self_meta = ManagerPeerEntry {
manager_id: manager_id.clone(),
manager_name: manager_id.clone(),
tenant: harness.tenant.clone(),
cluster: harness.cluster.clone(),
host: "127.0.0.1".into(),
tcp_port: 0,
quic_port: 0,
pub_fp: "0".repeat(64),
csk_ver: harness.csk_ver.load(Ordering::SeqCst),
tls_cert: String::new(),
tls_fp: String::new(),
health: None,
};
Self {
worker_cfg,
csk: harness.csk.clone(),
csk_ver: harness.csk_ver.clone(),
self_meta,
peers,
workers,
peer_version,
alive_tx,
dial_tx,
profile: manager_id,
signing,
health,
worker_nets: Arc::new(options.worker_whitelist),
manager_nets: Arc::new(options.manager_whitelist),
}
}
fn server_context(&self) -> ManagerContext {
ManagerContext::new_server(
SecurityContext {
signing: Some(self.signing.clone()),
csk: self.csk.clone(),
csk_ver: self.csk_ver.clone(),
},
NetworkContext {
worker_nets: self.worker_nets.clone(),
manager_nets: self.manager_nets.clone(),
},
StateContext {
peers: self.peers.clone(),
workers: self.workers.clone(),
self_meta: self.self_meta.clone(),
peer_version: self.peer_version.clone(),
command_distributor: None,
},
CommunicationContext {
alive_tx: self.alive_tx.clone(),
dial_tx: self.dial_tx.clone(),
profile: self.profile.clone(),
},
self.health.clone(),
self.self_meta.manager_id.clone(),
)
}
fn client_context(&self) -> ManagerContext {
ManagerContext::new_peer(
SecurityContext {
signing: Some(self.signing.clone()),
csk: self.csk.clone(),
csk_ver: self.csk_ver.clone(),
},
StateContext {
peers: self.peers.clone(),
workers: self.workers.clone(),
self_meta: self.self_meta.clone(),
peer_version: self.peer_version.clone(),
command_distributor: None,
},
CommunicationContext {
alive_tx: self.alive_tx.clone(),
dial_tx: self.dial_tx.clone(),
profile: self.profile.clone(),
},
self.worker_cfg.clone(),
self.health.clone(),
)
}
}
pub struct ConnectionHandle {
client_task: Option<JoinHandle<Result<(), Report>>>,
server_task: Option<JoinHandle<Result<(), Report>>>,
}
#[derive(Default, Clone)]
pub struct WorkerOptions {
pub command_handler: Option<Arc<dyn WorkerCommandHandler + Send + Sync>>,
}
pub trait WorkerCommandHandler {
fn handle_request(
&self,
request_id: &str,
command: &str,
args: &[String],
) -> WorkerCommandResult;
}
#[derive(Clone)]
pub struct WorkerCommandResult {
pub success: bool,
pub output: String,
pub duration_millis: u64,
}
impl WorkerCommandResult {
pub fn success_text(output: impl Into<String>) -> Self {
Self {
success: true,
output: output.into(),
duration_millis: 0,
}
}
pub fn failure_text(output: impl Into<String>) -> Self {
Self {
success: false,
output: output.into(),
duration_millis: 0,
}
}
}
#[derive(Clone)]
pub struct WorkerHandle {
token: String,
options: WorkerOptions,
}
impl WorkerHandle {
pub fn with_options(mut self, options: WorkerOptions) -> Self {
self.options = options;
self
}
pub fn connect(&self, manager: &ManagerHandle) -> ConnectionHandle {
let (client, server) = MemoryTransport::pair();
let server_task = manager.spawn_worker_task(server);
let token = self.token.clone();
let handler = self.options.command_handler.clone();
let client_task: JoinHandle<Result<(), Report>> =
tokio::spawn(async move { run_worker_client(Box::new(client), token, handler).await });
ConnectionHandle {
client_task: Some(client_task),
server_task: Some(server_task),
}
}
pub fn token(&self) -> &str {
&self.token
}
}
impl ConnectionHandle {
pub async fn shutdown(mut self) {
if let Some(task) = self.client_task.take() {
task.abort();
let _ = task.await;
}
if let Some(task) = self.server_task.take() {
task.abort();
let _ = task.await;
}
}
}
impl Drop for ConnectionHandle {
fn drop(&mut self) {
if let Some(task) = self.client_task.take() {
task.abort();
}
if let Some(task) = self.server_task.take() {
task.abort();
}
}
}
async fn run_worker_client(
mut transport: Box<dyn Transport>,
token: String,
handler: Option<Arc<dyn WorkerCommandHandler + Send + Sync>>,
) -> Result<(), Report> {
transport
.send(&Message::Auth {
token,
worker_id: None,
worker_name: None,
})
.await?;
let mut authed = false;
let mut handshake_done = false;
loop {
match transport.recv().await? {
Some(Message::AuthOk) => {
authed = true;
}
Some(Message::Hello {
manager_id,
nonce,
sig,
}) => {
if !authed {
return Err(eyre!("handshake before auth"));
}
transport
.send(&Message::Welcome {
manager_id,
nonce,
sig,
})
.await?;
handshake_done = true;
}
Some(Message::Announce { .. }) => {
if handshake_done {
break;
}
}
Some(Message::WorkerCommandRequest {
request_id,
command,
args,
..
}) => {
if let Some(handler) = handler.as_ref() {
let mut response = handler.handle_request(&request_id, &command, &args);
if response.duration_millis == 0 {
response.duration_millis = 1;
}
transport
.send(&Message::WorkerCommandResponse {
request_id,
worker_id: "worker-harness".into(),
success: response.success,
duration_millis: response.duration_millis,
output: response.output,
})
.await?;
}
}
Some(Message::AuthErr) => return Err(eyre!("auth rejected")),
None => break,
_ => {}
}
}
Ok(())
}