use std::{
collections::BTreeMap,
env,
io::Cursor,
path::{Path, PathBuf},
process::Stdio,
sync::{Arc, RwLock},
time::Instant,
};
use anyhow::{Context, Result, anyhow};
use capnp::message::ReaderOptions;
use capnp::serialize::{OwnedSegments, write_message_to_words};
use capnp_futures::serialize::{read_message, try_read_message};
use futures::AsyncWriteExt;
use serde_json::{self, Value};
use snow::TransportState;
use tokio::{
net::{TcpListener, TcpStream},
process::Command,
task::{JoinHandle, spawn_blocking},
};
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
use tracing::{debug, error, info, warn};
use crate::{
cli_capnp::{cli_request, cli_response},
config::Config,
control_capnp::{
AggregateSortField as CapnpAggregateSortField, control_hello, control_hello_response,
control_request, control_response,
},
error::EventError,
filter::{self, FilterExpr},
observability,
plugin::PluginManager,
replication_noise::{
MAX_NOISE_FRAME_PAYLOAD, perform_server_handshake, read_encrypted_frame,
write_encrypted_frame,
},
schema::AggregateSchema,
service::{
AppendEventInput, CoreContext, CreateAggregateInput, SetAggregateArchiveInput,
normalize_optional_comment,
},
store::{AggregateQueryScope, AggregateSort, AggregateSortField, AggregateState, EventRecord},
};
#[derive(Debug, Clone)]
pub struct CliCommandResult {
pub exit_code: i32,
pub stdout: String,
pub stderr: String,
}
enum ControlReply {
ListAggregates {
aggregates: Vec<AggregateState>,
},
GetAggregate {
found: bool,
aggregate_json: Option<String>,
},
ListEvents {
aggregate_type: String,
aggregate_id: String,
events: Vec<EventRecord>,
},
AppendEvent {
event_json: Option<String>,
},
VerifyAggregate(String),
SelectAggregate {
found: bool,
selection_json: Option<String>,
},
CreateAggregate {
aggregate_json: Option<String>,
},
SetAggregateArchive {
aggregate_json: Option<String>,
},
ListSchemas(String),
ReplaceSchemas(u32),
}
enum ControlCommand {
ListAggregates {
token: String,
skip: usize,
take: Option<usize>,
filter: Option<FilterExpr>,
sort: Option<Vec<AggregateSort>>,
include_archived: bool,
archived_only: bool,
},
GetAggregate {
token: String,
aggregate_type: String,
aggregate_id: String,
},
ListEvents {
token: String,
aggregate_type: String,
aggregate_id: String,
skip: usize,
take: Option<usize>,
},
AppendEvent {
token: String,
aggregate_type: String,
aggregate_id: String,
event_type: String,
payload: Option<Value>,
metadata: Option<Value>,
note: Option<String>,
},
PatchEvent {
token: String,
aggregate_type: String,
aggregate_id: String,
event_type: String,
patch: Value,
metadata: Option<Value>,
note: Option<String>,
},
VerifyAggregate {
aggregate_type: String,
aggregate_id: String,
},
SelectAggregate {
token: String,
aggregate_type: String,
aggregate_id: String,
fields: Vec<String>,
},
CreateAggregate {
token: String,
aggregate_type: String,
aggregate_id: String,
event_type: String,
payload: Value,
metadata: Option<Value>,
note: Option<String>,
},
SetAggregateArchive {
token: String,
aggregate_type: String,
aggregate_id: String,
archived: bool,
comment: Option<String>,
},
ListSchemas {
token: String,
},
ReplaceSchemas {
token: String,
schemas_json: String,
},
}
const CONTROL_PROTOCOL_VERSION: u16 = 1;
const CONTROL_RESPONSE_HEADROOM: usize = 4 * 1024;
enum FirstMessage {
Cli(capnp::message::Reader<OwnedSegments>),
Control {
protocol_version: u16,
token_result: Result<String, String>,
},
}
fn control_error_code(err: &EventError) -> &'static str {
match err {
EventError::InvalidToken => "invalid_token",
EventError::TokenExpired => "token_expired",
EventError::TokenLimitReached => "token_limit_reached",
EventError::Unauthorized => "unauthorized",
EventError::AggregateNotFound => "aggregate_not_found",
EventError::AggregateArchived => "aggregate_archived",
EventError::SchemaExists => "schema_exists",
EventError::SchemaNotFound => "schema_not_found",
EventError::InvalidSchema(_) => "invalid_schema",
EventError::SchemaViolation(_) => "schema_violation",
EventError::Config(_) => "config",
EventError::Storage(_) => "storage",
EventError::Io(_) => "io",
EventError::Serialization(_) => "serialization",
}
}
fn control_command_name(command: &ControlCommand) -> &'static str {
match command {
ControlCommand::ListAggregates { .. } => "list_aggregates",
ControlCommand::GetAggregate { .. } => "get_aggregate",
ControlCommand::ListEvents { .. } => "list_events",
ControlCommand::AppendEvent { .. } => "append_event",
ControlCommand::PatchEvent { .. } => "patch_event",
ControlCommand::VerifyAggregate { .. } => "verify_aggregate",
ControlCommand::SelectAggregate { .. } => "select_aggregate",
ControlCommand::CreateAggregate { .. } => "create_aggregate",
ControlCommand::SetAggregateArchive { .. } => "set_aggregate_archive",
ControlCommand::ListSchemas { .. } => "list_schemas",
ControlCommand::ReplaceSchemas { .. } => "replace_schemas",
}
}
fn read_control_text(
field: capnp::Result<capnp::text::Reader<'_>>,
label: &str,
) -> std::result::Result<String, EventError> {
let reader =
field.map_err(|err| EventError::Serialization(format!("failed to read {label}: {err}")))?;
reader
.to_string()
.map_err(|err| EventError::Serialization(format!("invalid utf-8 in {label}: {err}")))
}
#[cfg_attr(not(test), doc(hidden))]
pub mod test_support {
use super::*;
use tokio::io::{DuplexStream, ReadHalf, WriteHalf, duplex, split};
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
pub fn spawn_control_session(
core: CoreContext,
shared_config: Arc<RwLock<Config>>,
) -> (
Compat<WriteHalf<DuplexStream>>,
Compat<ReadHalf<DuplexStream>>,
tokio::task::JoinHandle<Result<()>>,
) {
let (client, server) = duplex(4096);
let (server_reader_raw, server_writer_raw) = split(server);
let mut server_reader = server_reader_raw.compat();
let mut server_writer = server_writer_raw.compat_write();
let task = tokio::spawn(async move {
let message = read_message(&mut server_reader, ReaderOptions::new())
.await
.context("failed to read control hello")?;
let (protocol_version, token_result) = {
let hello = message
.get_root::<control_hello::Reader>()
.context("failed to decode control hello")?;
let protocol_version = hello.get_protocol_version();
let token_result = hello
.get_token()
.map_err(|err| format!("failed to read control token: {err}"))
.and_then(|reader| {
reader
.to_str()
.map_err(|err| format!("invalid UTF-8 in control token: {err}"))
.map(|value| value.trim().to_string())
});
(protocol_version, token_result)
};
drop(message);
handle_control_handshake(
protocol_version,
token_result,
&mut server_reader,
&mut server_writer,
core,
shared_config,
)
.await
});
let (client_reader_raw, client_writer_raw) = split(client);
let client_writer = client_writer_raw.compat_write();
let client_reader = client_reader_raw.compat();
(client_writer, client_reader, task)
}
}
pub async fn start(
bind_addr: &str,
config_path: Arc<PathBuf>,
core: CoreContext,
shared_config: Arc<RwLock<Config>>,
) -> Result<JoinHandle<()>> {
let listener = TcpListener::bind(bind_addr)
.await
.with_context(|| format!("failed to bind CLI Cap'n Proto listener on {bind_addr}"))?;
let display_addr = listener
.local_addr()
.map(|addr| addr.to_string())
.unwrap_or_else(|_| bind_addr.to_string());
info!("CLI Cap'n Proto server listening on {}", display_addr);
let handle = tokio::spawn({
let core = core.clone();
let shared_config = Arc::clone(&shared_config);
let config_path = Arc::clone(&config_path);
async move {
if let Err(err) = serve(listener, config_path, core, shared_config).await {
warn!("CLI proxy server terminated: {err:?}");
}
}
});
Ok(handle)
}
async fn serve(
listener: TcpListener,
config_path: Arc<PathBuf>,
core: CoreContext,
shared_config: Arc<RwLock<Config>>,
) -> Result<()> {
loop {
let (stream, peer) = listener
.accept()
.await
.context("failed to accept CLI proxy connection")?;
let config_path = Arc::clone(&config_path);
let core = core.clone();
let shared_config = Arc::clone(&shared_config);
tokio::spawn(async move {
if let Err(err) = handle_connection(stream, config_path, core, shared_config).await {
warn!(target: "cli_proxy", peer = %peer, "CLI proxy connection error: {err:?}");
}
});
}
}
async fn handle_connection(
stream: TcpStream,
config_path: Arc<PathBuf>,
core: CoreContext,
shared_config: Arc<RwLock<Config>>,
) -> Result<()> {
let (reader, writer) = stream.into_split();
let mut reader = reader.compat();
let mut writer = writer.compat_write();
let first_message = match try_read_message(&mut reader, ReaderOptions::new()).await {
Ok(Some(message)) => message,
Ok(None) => return Ok(()),
Err(err) => {
return Err(anyhow::Error::new(err).context("failed to read initial message"));
}
};
match classify_first_message(first_message)? {
FirstMessage::Cli(message) => {
handle_cli_loop(
Some(message),
&mut reader,
&mut writer,
Arc::clone(&config_path),
)
.await?;
Ok(())
}
FirstMessage::Control {
protocol_version,
token_result,
} => {
handle_control_handshake(
protocol_version,
token_result,
&mut reader,
&mut writer,
core,
shared_config,
)
.await
}
}
}
async fn handle_cli_loop<R, W>(
mut pending: Option<capnp::message::Reader<OwnedSegments>>,
reader: &mut R,
writer: &mut W,
config_path: Arc<PathBuf>,
) -> Result<()>
where
R: futures::AsyncRead + Unpin,
W: futures::AsyncWrite + Unpin,
{
loop {
let message = if let Some(message) = pending.take() {
message
} else {
match try_read_message(&mut *reader, ReaderOptions::new()).await {
Ok(Some(message)) => message,
Ok(None) => break,
Err(err) => {
return Err(anyhow::Error::new(err).context("failed to read CLI request"));
}
}
};
let response_result = process_request(message, Arc::clone(&config_path)).await;
let response_bytes = {
let mut response_message = capnp::message::Builder::new_default();
{
let mut response = response_message.init_root::<cli_response::Builder>();
match response_result {
Ok(result) => {
response.set_exit_code(result.exit_code);
response.set_stdout(&result.stdout);
response.set_stderr(&result.stderr);
}
Err(err) => {
response.set_exit_code(-1);
response.set_stdout("");
response.set_stderr(&err.to_string());
}
}
}
write_message_to_words(&response_message)
};
writer
.write_all(&response_bytes)
.await
.context("failed to write CLI response")?;
writer
.flush()
.await
.context("failed to flush CLI response")?;
}
Ok(())
}
async fn handle_control_handshake<R, W>(
protocol_version: u16,
token_result: Result<String, String>,
reader: &mut R,
writer: &mut W,
core: CoreContext,
shared_config: Arc<RwLock<Config>>,
) -> Result<()>
where
R: futures::AsyncRead + Unpin,
W: futures::AsyncWrite + Unpin,
{
let handshake_start = Instant::now();
let tokens = core.tokens();
let (accepted, response_text, claims, token) = if protocol_version != CONTROL_PROTOCOL_VERSION {
(
false,
format!("unsupported control protocol version {}", protocol_version),
None,
None,
)
} else {
match token_result {
Err(message) => (false, message, None, None),
Ok(token) if token.is_empty() => {
(false, "missing control token".to_string(), None, None)
}
Ok(token) => match tokens.verify(&token) {
Ok(claims) => (true, "ok".to_string(), Some(claims), Some(token)),
Err(err) => {
warn!("control handshake rejected due to invalid token: {}", err);
(false, "invalid control token".to_string(), None, None)
}
},
}
};
let response_bytes = {
let mut message = capnp::message::Builder::new_default();
{
let mut response = message.init_root::<control_hello_response::Builder>();
response.set_accepted(accepted);
response.set_message(&response_text);
}
write_message_to_words(&message)
};
writer
.write_all(&response_bytes)
.await
.context("failed to write control hello response")?;
writer
.flush()
.await
.context("failed to flush control hello response")?;
let duration = handshake_start.elapsed().as_secs_f64();
observability::record_capnp_control_request(
"hello",
if accepted { "accepted" } else { "rejected" },
duration,
);
if !accepted {
println!("control handshake rejected: {}", response_text);
return Ok(());
}
if let Some(claims) = &claims {
debug!(
subject = %claims.sub,
token_group = %claims.group,
token_user = %claims.user,
"control handshake accepted"
);
}
let token = token.expect("token must be present when handshake accepted");
println!(
"control handshake accepted for token subject {}",
claims.as_ref().map(|c| c.sub.as_str()).unwrap_or("unknown")
);
let noise = perform_server_handshake(reader, writer, token.as_bytes())
.await
.context("failed to establish encrypted control channel")?;
handle_control_session_encrypted(reader, writer, noise, core, shared_config).await
}
fn classify_first_message(message: capnp::message::Reader<OwnedSegments>) -> Result<FirstMessage> {
if message
.get_root::<cli_request::Reader>()
.and_then(|request| request.get_args())
.is_ok()
{
return Ok(FirstMessage::Cli(message));
}
if let Ok(hello) = message.get_root::<control_hello::Reader>() {
let token_result = hello
.get_token()
.map_err(|err| format!("failed to read control token: {err}"))
.and_then(|reader| {
reader
.to_str()
.map_err(|err| format!("invalid UTF-8 in control token: {err}"))
.map(|value| value.trim().to_string())
});
return Ok(FirstMessage::Control {
protocol_version: hello.get_protocol_version(),
token_result,
});
}
Err(anyhow!("unexpected first message"))
}
async fn handle_control_session_encrypted<R, W>(
reader: &mut R,
writer: &mut W,
mut noise: TransportState,
core: CoreContext,
shared_config: Arc<RwLock<Config>>,
) -> Result<()>
where
R: futures::AsyncRead + Unpin,
W: futures::AsyncWrite + Unpin,
{
loop {
let request_bytes = match read_encrypted_frame(reader, &mut noise).await? {
Some(bytes) => bytes,
None => break,
};
let mut cursor = Cursor::new(&request_bytes);
let message = capnp::serialize::read_message(&mut cursor, ReaderOptions::new())
.context("failed to decode control request")?;
let request = message
.get_root::<control_request::Reader>()
.context("failed to decode control request")?;
let request_id = request.get_id();
let command_start = Instant::now();
let (command_label, status_label, response_result) = match parse_control_command(request) {
Ok(command) => {
let label = control_command_name(&command);
let result =
execute_control_command(command, core.clone(), Arc::clone(&shared_config))
.await;
let status = match &result {
Ok(_) => "ok",
Err(err) => control_error_code(err),
};
(label, status, result)
}
Err(err) => {
let status = control_error_code(&err);
("parse", status, Err(err))
}
};
let duration = command_start.elapsed().as_secs_f64();
observability::record_capnp_control_request(command_label, status_label, duration);
let response_bytes = {
let mut response_message = capnp::message::Builder::new_default();
{
let mut response = response_message.init_root::<control_response::Builder>();
response.set_id(request_id);
match response_result {
Ok(reply) => {
let payload = response.reborrow().init_payload();
match reply {
ControlReply::ListAggregates { aggregates } => {
let encoded = encode_json_list_response(
request_id,
&aggregates,
|payload, json| {
let mut builder = payload.init_list_aggregates();
builder.set_aggregates_json(json);
},
)?;
if encoded.included < aggregates.len() {
debug!(
target: "cli_proxy",
requested = aggregates.len(),
included = encoded.included,
"truncated list_aggregates response to satisfy Noise frame limit"
);
}
let mut builder = payload.init_list_aggregates();
builder.set_aggregates_json(&encoded.json);
}
ControlReply::GetAggregate {
found,
aggregate_json,
} => {
let mut builder = payload.init_get_aggregate();
builder.set_found(found);
if let Some(json) = aggregate_json {
builder.set_aggregate_json(&json);
} else {
builder.set_aggregate_json("");
}
}
ControlReply::ListEvents {
aggregate_type,
aggregate_id,
events,
} => {
let encoded = encode_json_list_response(
request_id,
&events,
|payload, json| {
let mut builder = payload.init_list_events();
builder.set_events_json(json);
},
)?;
if encoded.included < events.len() {
debug!(
target: "cli_proxy",
aggregate_type = %aggregate_type,
aggregate_id = %aggregate_id,
requested = events.len(),
included = encoded.included,
"truncated list_events response to satisfy Noise frame limit"
);
}
let mut builder = payload.init_list_events();
builder.set_events_json(&encoded.json);
}
ControlReply::AppendEvent { event_json } => {
let mut builder = payload.init_append_event();
if let Some(json) = event_json {
builder.set_event_json(&json);
} else {
builder.set_event_json("");
}
}
ControlReply::VerifyAggregate(merkle_root) => {
let mut builder = payload.init_verify_aggregate();
builder.set_merkle_root(&merkle_root);
}
ControlReply::SelectAggregate {
found,
selection_json,
} => {
let mut builder = payload.init_select_aggregate();
builder.set_found(found);
if let Some(json) = selection_json {
builder.set_selection_json(&json);
} else {
builder.set_selection_json("");
}
}
ControlReply::CreateAggregate { aggregate_json } => {
let mut builder = payload.init_create_aggregate();
if let Some(json) = aggregate_json {
builder.set_aggregate_json(&json);
} else {
builder.set_aggregate_json("");
}
}
ControlReply::SetAggregateArchive { aggregate_json } => {
let mut builder = payload.init_set_aggregate_archive();
if let Some(json) = aggregate_json {
builder.set_aggregate_json(&json);
} else {
builder.set_aggregate_json("");
}
}
ControlReply::ListSchemas(json) => {
let mut builder = payload.init_list_schemas();
builder.set_schemas_json(&json);
}
ControlReply::ReplaceSchemas(replaced) => {
let mut builder = payload.init_replace_schemas();
builder.set_replaced(replaced);
}
}
}
Err(err) => {
let payload = response.reborrow().init_payload();
let mut builder = payload.init_error();
builder.set_code(control_error_code(&err));
builder.set_message(&err.to_string());
}
}
}
write_message_to_words(&response_message)
};
write_encrypted_frame(writer, &mut noise, &response_bytes)
.await
.context("failed to send control response")?;
}
Ok(())
}
fn parse_control_command(
request: control_request::Reader<'_>,
) -> std::result::Result<ControlCommand, EventError> {
use control_request::payload;
match request
.get_payload()
.which()
.map_err(|err| EventError::Serialization(err.to_string()))?
{
payload::ListAggregates(req) => {
let req = req.map_err(|err| EventError::Serialization(err.to_string()))?;
let skip = usize::try_from(req.get_skip())
.map_err(|_| EventError::InvalidSchema("skip exceeds platform limits".into()))?;
let take = if req.get_has_take() {
Some(usize::try_from(req.get_take()).map_err(|_| {
EventError::InvalidSchema("take exceeds platform limits".into())
})?)
} else {
None
};
let filter = if req.get_has_filter() {
let raw = read_control_text(req.get_filter(), "filter")?;
let trimmed = raw.trim();
if trimmed.is_empty() {
None
} else {
Some(filter::parse_shorthand(trimmed).map_err(|err| {
EventError::InvalidSchema(format!("invalid filter expression: {err}"))
})?)
}
} else {
None
};
let sort = if req.get_has_sort() {
let list = req
.get_sort()
.map_err(|err| EventError::Serialization(err.to_string()))?;
let mut directives = Vec::with_capacity(list.len() as usize);
for entry in list.iter() {
let field = match entry.get_field().map_err(|err| {
EventError::InvalidSchema(format!(
"unknown aggregate sort field value: {err}"
))
})? {
CapnpAggregateSortField::AggregateType => AggregateSortField::AggregateType,
CapnpAggregateSortField::AggregateId => AggregateSortField::AggregateId,
CapnpAggregateSortField::Version => AggregateSortField::Version,
CapnpAggregateSortField::MerkleRoot => AggregateSortField::MerkleRoot,
CapnpAggregateSortField::Archived => AggregateSortField::Archived,
};
directives.push(AggregateSort {
field,
descending: entry.get_descending(),
});
}
if directives.is_empty() {
None
} else {
Some(directives)
}
} else {
None
};
let include_archived = req.get_include_archived();
let archived_only = req.get_archived_only();
let token = read_control_text(req.get_token(), "token")?
.trim()
.to_string();
Ok(ControlCommand::ListAggregates {
token,
skip,
take,
filter,
sort,
include_archived,
archived_only,
})
}
payload::GetAggregate(req) => {
let req = req.map_err(|err| EventError::Serialization(err.to_string()))?;
let aggregate_type = read_control_text(req.get_aggregate_type(), "aggregate_type")?;
let aggregate_id = read_control_text(req.get_aggregate_id(), "aggregate_id")?;
let token = read_control_text(req.get_token(), "token")?
.trim()
.to_string();
Ok(ControlCommand::GetAggregate {
token,
aggregate_type,
aggregate_id,
})
}
payload::ListEvents(req) => {
let req = req.map_err(|err| EventError::Serialization(err.to_string()))?;
let aggregate_type = read_control_text(req.get_aggregate_type(), "aggregate_type")?;
let aggregate_id = read_control_text(req.get_aggregate_id(), "aggregate_id")?;
let skip = usize::try_from(req.get_skip())
.map_err(|_| EventError::InvalidSchema("skip exceeds platform limits".into()))?;
let take = if req.get_has_take() {
Some(usize::try_from(req.get_take()).map_err(|_| {
EventError::InvalidSchema("take exceeds platform limits".into())
})?)
} else {
None
};
Ok(ControlCommand::ListEvents {
token: read_control_text(req.get_token(), "token")?
.trim()
.to_string(),
aggregate_type,
aggregate_id,
skip,
take,
})
}
payload::AppendEvent(req) => {
let req = req.map_err(|err| EventError::Serialization(err.to_string()))?;
let token = read_control_text(req.get_token(), "token")?
.trim()
.to_string();
let aggregate_type = read_control_text(req.get_aggregate_type(), "aggregate_type")?;
let aggregate_id = read_control_text(req.get_aggregate_id(), "aggregate_id")?;
let event_type = read_control_text(req.get_event_type(), "event_type")?;
let payload_raw = read_control_text(req.get_payload_json(), "payload_json")?;
let payload_trimmed = payload_raw.trim();
let payload = if payload_trimmed.is_empty() {
None
} else {
Some(
serde_json::from_str::<Value>(payload_trimmed).map_err(|err| {
EventError::InvalidSchema(format!("invalid payload_json: {err}"))
})?,
)
};
let metadata = if req.get_has_metadata() {
let metadata_raw = read_control_text(req.get_metadata_json(), "metadata_json")?;
let metadata_trimmed = metadata_raw.trim();
if metadata_trimmed.is_empty() {
None
} else {
Some(
serde_json::from_str::<Value>(metadata_trimmed).map_err(|err| {
EventError::InvalidSchema(format!("invalid metadata_json: {err}"))
})?,
)
}
} else {
None
};
let note = if req.get_has_note() {
Some(read_control_text(req.get_note(), "note")?)
} else {
None
};
Ok(ControlCommand::AppendEvent {
token,
aggregate_type,
aggregate_id,
event_type,
payload,
metadata,
note,
})
}
payload::PatchEvent(req) => {
let req = req.map_err(|err| EventError::Serialization(err.to_string()))?;
let token = read_control_text(req.get_token(), "token")?
.trim()
.to_string();
let aggregate_type = read_control_text(req.get_aggregate_type(), "aggregate_type")?;
let aggregate_id = read_control_text(req.get_aggregate_id(), "aggregate_id")?;
let event_type = read_control_text(req.get_event_type(), "event_type")?;
let patch_raw = read_control_text(req.get_patch_json(), "patch_json")?;
let patch_trimmed = patch_raw.trim();
let patch = if patch_trimmed.is_empty() {
return Err(EventError::InvalidSchema(
"patch_json must be provided for patch events".into(),
));
} else {
serde_json::from_str::<Value>(patch_trimmed).map_err(|err| {
EventError::InvalidSchema(format!("invalid patch_json: {err}"))
})?
};
let metadata = if req.get_has_metadata() {
let metadata_raw = read_control_text(req.get_metadata_json(), "metadata_json")?;
let metadata_trimmed = metadata_raw.trim();
if metadata_trimmed.is_empty() {
None
} else {
Some(
serde_json::from_str::<Value>(metadata_trimmed).map_err(|err| {
EventError::InvalidSchema(format!("invalid metadata_json: {err}"))
})?,
)
}
} else {
None
};
let note = if req.get_has_note() {
Some(read_control_text(req.get_note(), "note")?)
} else {
None
};
Ok(ControlCommand::PatchEvent {
token,
aggregate_type,
aggregate_id,
event_type,
patch,
metadata,
note,
})
}
payload::CreateAggregate(req) => {
let req = req.map_err(|err| EventError::Serialization(err.to_string()))?;
let token = read_control_text(req.get_token(), "token")?
.trim()
.to_string();
let aggregate_type = read_control_text(req.get_aggregate_type(), "aggregate_type")?;
let aggregate_id = read_control_text(req.get_aggregate_id(), "aggregate_id")?;
let event_type = read_control_text(req.get_event_type(), "event_type")?;
let payload_raw = read_control_text(req.get_payload_json(), "payload_json")?;
let payload_trimmed = payload_raw.trim();
if payload_trimmed.is_empty() {
return Err(EventError::InvalidSchema(
"payload_json must be provided when creating an aggregate".into(),
));
}
let payload = serde_json::from_str::<Value>(payload_trimmed)
.map_err(|err| EventError::InvalidSchema(format!("invalid payload_json: {err}")))?;
let metadata = if req.get_has_metadata() {
let metadata_raw = read_control_text(req.get_metadata_json(), "metadata_json")?;
let metadata_trimmed = metadata_raw.trim();
if metadata_trimmed.is_empty() {
None
} else {
Some(
serde_json::from_str::<Value>(metadata_trimmed).map_err(|err| {
EventError::InvalidSchema(format!("invalid metadata_json: {err}"))
})?,
)
}
} else {
None
};
let note = if req.get_has_note() {
Some(read_control_text(req.get_note(), "note")?)
} else {
None
};
Ok(ControlCommand::CreateAggregate {
token,
aggregate_type,
aggregate_id,
event_type,
payload,
metadata,
note,
})
}
payload::VerifyAggregate(req) => {
let req = req.map_err(|err| EventError::Serialization(err.to_string()))?;
let aggregate_type = read_control_text(req.get_aggregate_type(), "aggregate_type")?;
let aggregate_id = read_control_text(req.get_aggregate_id(), "aggregate_id")?;
Ok(ControlCommand::VerifyAggregate {
aggregate_type,
aggregate_id,
})
}
payload::SelectAggregate(req) => {
let req = req.map_err(|err| EventError::Serialization(err.to_string()))?;
let aggregate_type = read_control_text(req.get_aggregate_type(), "aggregate_type")?;
let aggregate_id = read_control_text(req.get_aggregate_id(), "aggregate_id")?;
let fields_reader = req.get_fields().map_err(|err| {
EventError::Serialization(format!("failed to read fields: {err}"))
})?;
let mut fields = Vec::new();
for field in fields_reader.iter() {
let value = field.map_err(|err| {
EventError::Serialization(format!("invalid utf-8 in fields entry: {err}"))
})?;
let field_str = value.to_str().map_err(|err| {
EventError::Serialization(format!("invalid utf-8 in fields entry: {err}"))
})?;
fields.push(field_str.to_string());
}
let token = read_control_text(req.get_token(), "token")?
.trim()
.to_string();
Ok(ControlCommand::SelectAggregate {
token,
aggregate_type,
aggregate_id,
fields,
})
}
payload::SetAggregateArchive(req) => {
let req = req.map_err(|err| EventError::Serialization(err.to_string()))?;
let token = read_control_text(req.get_token(), "token")?
.trim()
.to_string();
let aggregate_type = read_control_text(req.get_aggregate_type(), "aggregate_type")?;
let aggregate_id = read_control_text(req.get_aggregate_id(), "aggregate_id")?;
let archived = req.get_archived();
let comment = normalize_optional_comment(if req.get_has_comment() {
Some(read_control_text(req.get_comment(), "comment")?)
} else {
None
});
Ok(ControlCommand::SetAggregateArchive {
token,
aggregate_type,
aggregate_id,
archived,
comment,
})
}
payload::ListSchemas(req) => {
let req = req.map_err(|err| EventError::Serialization(err.to_string()))?;
let token = read_control_text(req.get_token(), "token")?
.trim()
.to_string();
Ok(ControlCommand::ListSchemas { token })
}
payload::ReplaceSchemas(req) => {
let req = req.map_err(|err| EventError::Serialization(err.to_string()))?;
let token = read_control_text(req.get_token(), "token")?
.trim()
.to_string();
let schemas_json = read_control_text(req.get_schemas_json(), "schemas_json")?;
Ok(ControlCommand::ReplaceSchemas {
token,
schemas_json,
})
}
}
}
async fn execute_control_command(
command: ControlCommand,
core: CoreContext,
shared_config: Arc<RwLock<Config>>,
) -> std::result::Result<ControlReply, EventError> {
match command {
ControlCommand::ListAggregates {
token,
skip,
take,
filter,
sort,
include_archived,
archived_only,
} => {
let mut scope = if archived_only {
AggregateQueryScope::ArchivedOnly
} else if include_archived {
AggregateQueryScope::IncludeArchived
} else {
AggregateQueryScope::ActiveOnly
};
if matches!(scope, AggregateQueryScope::ActiveOnly) {
if let Some(expr) = filter.as_ref() {
if expr.references_field("archived") {
scope = AggregateQueryScope::IncludeArchived;
}
}
}
let aggregates = spawn_blocking({
let core = core.clone();
let filter = filter;
let sort = sort;
let token = token.clone();
move || {
let sort_ref = sort.as_ref().map(|keys| keys.as_slice());
core.list_aggregates(&token, skip, take, filter, sort_ref, scope)
}
})
.await
.map_err(|err| EventError::Storage(format!("list aggregates task failed: {err}")))?;
let aggregates = aggregates?;
Ok(ControlReply::ListAggregates { aggregates })
}
ControlCommand::GetAggregate {
token,
aggregate_type,
aggregate_id,
} => {
let aggregate = spawn_blocking({
let core = core.clone();
let aggregate_type = aggregate_type.clone();
let aggregate_id = aggregate_id.clone();
let token = token.clone();
move || core.get_aggregate(&token, &aggregate_type, &aggregate_id)
})
.await
.map_err(|err| EventError::Storage(format!("get aggregate task failed: {err}")))??;
let json = aggregate
.as_ref()
.map(|aggregate| serde_json::to_string(aggregate))
.transpose()?;
Ok(ControlReply::GetAggregate {
found: aggregate.is_some(),
aggregate_json: json,
})
}
ControlCommand::ListEvents {
token,
aggregate_type,
aggregate_id,
skip,
take,
} => {
let events = spawn_blocking({
let core = core.clone();
let aggregate_type = aggregate_type.clone();
let aggregate_id = aggregate_id.clone();
let token = token.clone();
move || core.list_events(&token, &aggregate_type, &aggregate_id, skip, take)
})
.await
.map_err(|err| EventError::Storage(format!("list events task failed: {err}")))??;
Ok(ControlReply::ListEvents {
aggregate_type,
aggregate_id,
events,
})
}
ControlCommand::AppendEvent {
token,
aggregate_type,
aggregate_id,
event_type,
payload,
metadata,
note,
} => {
let input = AppendEventInput {
token,
aggregate_type,
aggregate_id,
event_type,
payload,
patch: None,
metadata,
note,
};
let verbose = config_verbose(&shared_config)?;
handle_append_event_command(
core.clone(),
Arc::clone(&shared_config),
input,
"append event",
verbose,
)
.await
}
ControlCommand::PatchEvent {
token,
aggregate_type,
aggregate_id,
event_type,
patch,
metadata,
note,
} => {
let input = AppendEventInput {
token,
aggregate_type,
aggregate_id,
event_type,
payload: None,
patch: Some(patch),
metadata,
note,
};
let verbose = config_verbose(&shared_config)?;
handle_append_event_command(
core.clone(),
Arc::clone(&shared_config),
input,
"append patch event",
verbose,
)
.await
}
ControlCommand::VerifyAggregate {
aggregate_type,
aggregate_id,
} => {
let merkle = spawn_blocking({
let core = core.clone();
let aggregate_type = aggregate_type.clone();
let aggregate_id = aggregate_id.clone();
move || core.verify_aggregate(&aggregate_type, &aggregate_id)
})
.await
.map_err(|err| EventError::Storage(format!("verify aggregate task failed: {err}")))??;
Ok(ControlReply::VerifyAggregate(merkle))
}
ControlCommand::SelectAggregate {
token,
aggregate_type,
aggregate_id,
fields,
} => {
let selection = spawn_blocking({
let core = core.clone();
let aggregate_type = aggregate_type.clone();
let aggregate_id = aggregate_id.clone();
let fields = fields.clone();
let token = token.clone();
move || {
core.select_aggregate_fields(&token, &aggregate_type, &aggregate_id, &fields)
}
})
.await
.map_err(|err| EventError::Storage(format!("select aggregate task failed: {err}")))??;
match selection {
Some(map) => {
let json = serde_json::to_string(&map)?;
Ok(ControlReply::SelectAggregate {
found: true,
selection_json: Some(json),
})
}
None => Ok(ControlReply::SelectAggregate {
found: false,
selection_json: None,
}),
}
}
ControlCommand::CreateAggregate {
token,
aggregate_type,
aggregate_id,
event_type,
payload,
metadata,
note,
} => {
let verbose = config_verbose(&shared_config)?;
let aggregate = spawn_blocking({
let core = core.clone();
let aggregate_type = aggregate_type.clone();
let aggregate_id = aggregate_id.clone();
let token = token.clone();
let event_type = event_type.clone();
let payload = payload.clone();
let metadata = metadata.clone();
let note = note.clone();
move || {
core.create_aggregate(CreateAggregateInput {
token,
aggregate_type,
aggregate_id,
event_type,
payload,
metadata,
note,
})
}
})
.await
.map_err(|err| EventError::Storage(format!("create aggregate task failed: {err}")))??;
let json = if verbose {
Some(serde_json::to_string(&aggregate)?)
} else {
None
};
Ok(ControlReply::CreateAggregate {
aggregate_json: json,
})
}
ControlCommand::SetAggregateArchive {
token,
aggregate_type,
aggregate_id,
archived,
comment,
} => {
let verbose = config_verbose(&shared_config)?;
let aggregate = spawn_blocking({
let core = core.clone();
let aggregate_type = aggregate_type.clone();
let aggregate_id = aggregate_id.clone();
let token = token.clone();
let comment = comment.clone();
move || {
core.set_aggregate_archive(SetAggregateArchiveInput {
token,
aggregate_type,
aggregate_id,
archived,
comment,
})
}
})
.await
.map_err(|err| {
EventError::Storage(format!("set aggregate archive task failed: {err}"))
})??;
let json = if verbose {
Some(serde_json::to_string(&aggregate)?)
} else {
None
};
Ok(ControlReply::SetAggregateArchive {
aggregate_json: json,
})
}
ControlCommand::ListSchemas { token } => {
let tokens = core.tokens();
let schemas = core.schemas();
let token_clone = token.clone();
let json = spawn_blocking(move || {
tokens
.authorize_action(&token_clone, "schema.read", None)
.and_then(|_| {
let snapshot = schemas.snapshot();
serde_json::to_string(&snapshot)
.map_err(|err| EventError::Serialization(err.to_string()))
})
})
.await
.map_err(|err| EventError::Storage(format!("list schemas task failed: {err}")))??;
Ok(ControlReply::ListSchemas(json))
}
ControlCommand::ReplaceSchemas {
token,
schemas_json,
} => {
let tokens = core.tokens();
let schemas = core.schemas();
let token_clone = token.clone();
let json = schemas_json.clone();
let replaced = spawn_blocking(move || -> std::result::Result<u32, EventError> {
tokens.authorize_action(&token_clone, "schema.write", None)?;
let map: BTreeMap<String, AggregateSchema> = serde_json::from_str(&json)
.map_err(|err| EventError::Serialization(err.to_string()))?;
let count = map.len() as u32;
schemas.replace_all(map)?;
Ok(count)
})
.await
.map_err(|err| EventError::Storage(format!("replace schemas task failed: {err}")))??;
Ok(ControlReply::ReplaceSchemas(replaced))
}
}
}
struct EncodedJsonList {
json: String,
included: usize,
}
fn encode_json_list_response<T, F>(
request_id: u64,
items: &[T],
mut set_payload: F,
) -> std::result::Result<EncodedJsonList, EventError>
where
T: serde::Serialize,
F: FnMut(control_response::payload::Builder<'_>, &str),
{
let max_payload_len = MAX_NOISE_FRAME_PAYLOAD.saturating_sub(CONTROL_RESPONSE_HEADROOM);
if max_payload_len == 0 {
return Err(EventError::Serialization(
"control frame limit too small to encode responses".to_string(),
));
}
if items.is_empty() {
let empty_json = "[]".to_string();
ensure_response_size_with_payload(
request_id,
&empty_json,
max_payload_len,
&mut set_payload,
)?;
return Ok(EncodedJsonList {
json: empty_json,
included: 0,
});
}
let mut low = 1usize;
let mut high = items.len();
let mut best: Option<(usize, String)> = None;
while low <= high {
let mid = (low + high) / 2;
let json = serde_json::to_string(&items[..mid])?;
match ensure_response_size_with_payload(
request_id,
&json,
max_payload_len,
&mut set_payload,
) {
Ok(()) => {
best = Some((mid, json));
low = mid + 1;
}
Err(err) => {
if mid == 1 {
return Err(err);
}
high = mid - 1;
}
}
}
match best {
Some((included, json)) => Ok(EncodedJsonList { json, included }),
None => Ok(EncodedJsonList {
json: "[]".to_string(),
included: 0,
}),
}
}
fn ensure_response_size_with_payload<F>(
request_id: u64,
json: &str,
max_payload_len: usize,
set_payload: &mut F,
) -> std::result::Result<(), EventError>
where
F: FnMut(control_response::payload::Builder<'_>, &str),
{
let mut message = capnp::message::Builder::new_default();
{
let mut response = message.init_root::<control_response::Builder>();
response.set_id(request_id);
let payload = response.reborrow().init_payload();
set_payload(payload, json);
}
let words = write_message_to_words(&message);
let byte_len = words.len() * 8;
if byte_len > max_payload_len {
return Err(EventError::Serialization(format!(
"control response exceeds frame limit ({} bytes)",
MAX_NOISE_FRAME_PAYLOAD
)));
}
Ok(())
}
async fn handle_append_event_command(
core: CoreContext,
shared_config: Arc<RwLock<Config>>,
input: AppendEventInput,
task_label: &'static str,
verbose: bool,
) -> std::result::Result<ControlReply, EventError> {
let record = spawn_blocking({
let core = core.clone();
move || core.append_event(input)
})
.await
.map_err(|err| EventError::Storage(format!("{task_label} task failed: {err}")))??;
let record_json = if verbose {
Some(serde_json::to_string(&record)?)
} else {
None
};
let schemas = core.schemas();
if schemas.should_snapshot(&record.aggregate_type, record.version) {
let aggregate_type = record.aggregate_type.clone();
let aggregate_id = record.aggregate_id.clone();
let version = record.version;
let snapshot_result = spawn_blocking({
let store = core.store();
let aggregate_type_key = aggregate_type.clone();
let aggregate_id_key = aggregate_id.clone();
move || {
store.create_snapshot(
&aggregate_type_key,
&aggregate_id_key,
Some(format!("auto snapshot v{}", version)),
)
}
})
.await
.map_err(|err| EventError::Storage(format!("failed to create auto snapshot: {err}")))?;
if let Err(err) = snapshot_result {
warn!(
target: "server",
"failed to create auto snapshot for {}::{} v{}: {}",
aggregate_type, aggregate_id, version, err
);
}
}
let plugins = {
let guard = shared_config
.read()
.map_err(|_| EventError::Storage("failed to acquire config lock".into()))?;
PluginManager::from_config(&*guard)?
};
if !plugins.is_empty() {
let schemas = core.schemas();
let schema = schemas.get(&record.aggregate_type).ok();
let aggregate_type = record.aggregate_type.clone();
let aggregate_id = record.aggregate_id.clone();
let state_result = spawn_blocking({
let store = core.store();
let aggregate_type_key = aggregate_type.clone();
let aggregate_id_key = aggregate_id.clone();
move || store.get_aggregate_state(&aggregate_type_key, &aggregate_id_key)
})
.await
.map_err(|err| EventError::Storage(format!("failed to load aggregate state: {err}")))?;
match state_result {
Ok(current_state) => {
if let Err(err) = plugins.notify_event(&record, ¤t_state, schema.as_ref()) {
error!("plugin notification failed: {}", err);
}
}
Err(err) => {
warn!(
"plugin notification skipped (failed to load state for {}::{}): {}",
aggregate_type, aggregate_id, err
);
}
}
}
Ok(ControlReply::AppendEvent {
event_json: record_json,
})
}
fn config_verbose(shared_config: &Arc<RwLock<Config>>) -> std::result::Result<bool, EventError> {
let guard = shared_config
.read()
.map_err(|_| EventError::Storage("failed to acquire config lock".into()))?;
Ok(guard.verbose_responses())
}
async fn process_request(
message: capnp::message::Reader<OwnedSegments>,
config_path: Arc<PathBuf>,
) -> Result<CliCommandResult> {
let request = message
.get_root::<cli_request::Reader>()
.context("failed to decode CLI request")?;
let args = {
let args_reader = request
.get_args()
.context("failed to read CLI request arguments")?;
let mut collected = Vec::with_capacity(args_reader.len() as usize);
for arg in args_reader.iter() {
let value = arg.context("failed to read CLI argument")?;
collected.push(value.to_string()?);
}
collected
};
execute_cli_command(args, &config_path).await
}
async fn execute_cli_command(args: Vec<String>, config_path: &PathBuf) -> Result<CliCommandResult> {
let exe = resolve_cli_executable().context("failed to resolve CLI executable")?;
let command_label = args.first().cloned().unwrap_or_else(|| "help".to_string());
let final_args = augment_args_with_config(args, config_path);
let mut command = Command::new(exe);
command.args(&final_args);
command.stdout(Stdio::piped());
command.stderr(Stdio::piped());
let start = Instant::now();
let output = match command.output().await {
Ok(output) => output,
Err(err) => {
let duration = start.elapsed().as_secs_f64();
observability::record_cli_proxy_command(&command_label, "spawn_error", None, duration);
return Err(anyhow!("failed to execute CLI command: {err}"));
}
};
let exit_code = output
.status
.code()
.unwrap_or_else(|| if output.status.success() { 0 } else { -1 });
let status_label = if output.status.success() {
"ok"
} else {
"exit"
};
let duration = start.elapsed().as_secs_f64();
observability::record_cli_proxy_command(
&command_label,
status_label,
Some(exit_code),
duration,
);
let stdout = String::from_utf8_lossy(&output.stdout).into_owned();
let stderr = String::from_utf8_lossy(&output.stderr).into_owned();
debug!(
target: "cli_proxy",
args = ?final_args,
exit_code,
"CLI command executed"
);
Ok(CliCommandResult {
exit_code,
stdout,
stderr,
})
}
fn resolve_cli_executable() -> Result<PathBuf> {
if let Ok(path) = env::var("EVENTDBX_CLI") {
return Ok(PathBuf::from(path));
}
if let Ok(path) = env::var("CARGO_BIN_EXE_eventdbx") {
return Ok(PathBuf::from(path));
}
if let Ok(path) = env::var("CARGO_BIN_EXE_dbx") {
return Ok(PathBuf::from(path));
}
let current = std::env::current_exe().context("failed to resolve current executable")?;
if let Some(dir) = current.parent() {
if let Some(candidate) = probe_dir_for_cli(dir) {
return Ok(candidate);
}
if let Some(parent) = dir.parent() {
if let Some(candidate) = probe_dir_for_cli(parent) {
return Ok(candidate);
}
}
}
Ok(current)
}
fn probe_dir_for_cli(dir: &Path) -> Option<PathBuf> {
let candidates = ["eventdbx", "dbx"];
for candidate in candidates {
let unix_path = dir.join(candidate);
if unix_path.exists() {
return Some(unix_path);
}
}
#[cfg(windows)]
{
let candidates = ["eventdbx.exe", "dbx.exe"];
for candidate in candidates {
let windows_path = dir.join(candidate);
if windows_path.exists() {
return Some(windows_path);
}
}
}
None
}
fn augment_args_with_config(mut args: Vec<String>, config_path: &PathBuf) -> Vec<String> {
if has_config_arg(&args) {
return args;
}
let mut final_args = Vec::with_capacity(args.len() + 2);
final_args.push("--config".to_string());
final_args.push(config_path.to_string_lossy().into_owned());
final_args.extend(args.drain(..));
final_args
}
fn has_config_arg(args: &[String]) -> bool {
args.iter()
.any(|arg| arg == "--config" || arg.starts_with("--config="))
}
pub async fn invoke(args: &[String], addr: &str) -> Result<CliCommandResult> {
let stream = TcpStream::connect(addr)
.await
.with_context(|| format!("failed to connect to CLI proxy at {addr}"))?;
let (reader, writer) = stream.into_split();
let mut writer = writer.compat_write();
let message_bytes = {
let mut message = capnp::message::Builder::new_default();
{
let mut request = message.init_root::<cli_request::Builder>();
let mut list = request.init_args(args.len() as u32);
for (idx, arg) in args.iter().enumerate() {
list.set(idx as u32, arg);
}
}
write_message_to_words(&message)
};
writer
.write_all(&message_bytes)
.await
.context("failed to send CLI request")?;
writer
.flush()
.await
.context("failed to flush CLI request")?;
let mut reader = reader.compat();
let response_message = read_message(&mut reader, ReaderOptions::new())
.await
.context("failed to read CLI response")?;
let response = response_message
.get_root::<cli_response::Reader>()
.context("failed to decode CLI response")?;
let stdout = response
.get_stdout()
.context("missing stdout field in CLI response")?
.to_string()
.map_err(|err| anyhow::Error::new(err).context("invalid UTF-8 in CLI stdout"))?;
let stderr = response
.get_stderr()
.context("missing stderr field in CLI response")?
.to_string()
.map_err(|err| anyhow::Error::new(err).context("invalid UTF-8 in CLI stderr"))?;
Ok(CliCommandResult {
exit_code: response.get_exit_code(),
stdout,
stderr,
})
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn probe_dir_prefers_eventdbx() {
let dir = TempDir::new().unwrap();
let eventdbx = dir.path().join("eventdbx");
let dbx = dir.path().join("dbx");
std::fs::write(&eventdbx, b"").unwrap();
std::fs::write(&dbx, b"").unwrap();
let resolved = probe_dir_for_cli(dir.path()).expect("expected CLI path");
assert_eq!(resolved, eventdbx);
}
#[test]
fn probe_dir_falls_back_to_dbx() {
let dir = TempDir::new().unwrap();
let dbx = dir.path().join("dbx");
std::fs::write(&dbx, b"").unwrap();
let resolved = probe_dir_for_cli(dir.path()).expect("expected CLI path");
assert_eq!(resolved, dbx);
}
}