#![doc = include_str!("../README.md")]
use std::collections::HashMap;
use std::fmt::Debug;
use one_two_eight::{generate_id, generate_id_prototk};
use prototk_derive::Message;
use zerror::{Z, iotoz};
use zerror_core::ErrorCore;
pub const MAX_REQUEST_SIZE: usize = 1usize << 20;
pub const MAX_RESPONSE_SIZE: usize = 1usize << 20;
pub const MAX_BODY_SIZE: usize = 1usize << 20;
generate_id! {TraceID, "trace:"}
generate_id_prototk! {TraceID}
generate_id! {ClientID, "client:"}
generate_id_prototk! {ClientID}
generate_id! {HostID, "host:"}
generate_id_prototk! {HostID}
#[derive(Clone, Default, Eq, PartialEq, prototk_derive::Message)]
pub struct Host {
#[prototk(1, message)]
host_id: HostID,
#[prototk(2, string)]
connect: String,
}
impl Host {
pub fn new(host_id: HostID, connect: String) -> Self {
Self { host_id, connect }
}
pub fn host_id(&self) -> HostID {
self.host_id
}
pub fn connect(&self) -> &str {
&self.connect
}
pub fn hostname_or_ip(&self) -> &str {
let connect = &self.connect;
fn strip_port(connect: &str) -> &str {
if let Some((host, _)) = connect.rsplit_once(':') {
host
} else {
connect
}
}
if connect.starts_with('[') {
let connect = strip_port(connect);
if connect.ends_with(']') {
&connect[1..connect.len() - 1]
} else {
&self.connect
}
} else {
strip_port(connect)
}
}
}
impl std::str::FromStr for Host {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let parts: Vec<String> = s.split('=').map(String::from).collect();
if parts.len() != 2 {
return Err(Error::ResolveFailure {
core: ErrorCore::default(),
what: "could not parse string".to_owned(),
}
.with_info("parts", parts));
}
let host_id: HostID = match parts[0].parse::<HostID>() {
Ok(host_id) => host_id,
Err(err) => {
return Err(Error::ResolveFailure {
core: ErrorCore::default(),
what: "could not parse HostID".to_owned(),
}
.with_info("err", err)
.with_info("host_id", parts[0].to_owned()));
}
};
Ok(Host {
host_id,
connect: parts[1].to_owned(),
})
}
}
impl std::fmt::Debug for Host {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(f, "{}={}", self.host_id().human_readable(), self.connect())
}
}
impl std::fmt::Display for Host {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(f, "{self:?}")
}
}
#[cfg(feature = "indicio")]
impl From<Host> for indicio::Value {
fn from(host: Host) -> Self {
indicio::value!({
host_id: host.host_id().prefix_free_readable(),
connect: host.connect(),
})
}
}
#[derive(Clone, Debug, Default)]
pub struct Context {
clients: Vec<ClientID>,
trace_id: Option<TraceID>,
}
impl Context {
pub fn clients(&self) -> Vec<ClientID> {
self.clients.clone()
}
pub fn with_client(&self, client: ClientID) -> Self {
let mut ctx = self.clone();
ctx.clients.push(client);
ctx
}
pub fn trace_id(&self) -> Option<TraceID> {
self.trace_id
}
pub fn with_trace_id(&self, trace: TraceID) -> Self {
let mut ctx = self.clone();
ctx.trace_id = Some(trace);
ctx
}
}
impl<'a> From<&Request<'a>> for Context {
fn from(req: &Request<'a>) -> Self {
Self {
clients: req.caller.clone(),
trace_id: req.trace,
}
}
}
#[derive(Clone, Message, zerror_derive::Z)]
pub enum Error {
#[prototk(278528, message)]
Success {
#[prototk(1, message)]
core: ErrorCore,
},
#[prototk(278529, message)]
SerializationError {
#[prototk(1, message)]
core: ErrorCore,
#[prototk(2, message)]
err: prototk::Error,
#[prototk(3, string)]
context: String,
},
#[prototk(278530, message)]
UnknownServerName {
#[prototk(1, message)]
core: ErrorCore,
#[prototk(2, string)]
name: String,
},
#[prototk(278531, message)]
UnknownMethodName {
#[prototk(1, message)]
core: ErrorCore,
#[prototk(2, string)]
name: String,
},
#[prototk(278532, message)]
RequestTooLarge {
#[prototk(1, message)]
core: ErrorCore,
#[prototk(2, uint64)]
size: u64,
},
#[prototk(278533, message)]
TransportFailure {
#[prototk(1, message)]
core: ErrorCore,
#[prototk(2, string)]
what: String,
},
#[prototk(278534, message)]
EncryptionMisconfiguration {
#[prototk(1, message)]
core: ErrorCore,
#[prototk(2, string)]
what: String,
},
#[prototk(278535, message)]
UlimitParseError {
#[prototk(1, message)]
core: ErrorCore,
#[prototk(2, string)]
what: String,
},
#[prototk(278536, message)]
OsError {
#[prototk(1, message)]
core: ErrorCore,
#[prototk(2, string)]
what: String,
},
#[prototk(278537, message)]
LogicError {
#[prototk(1, message)]
core: ErrorCore,
#[prototk(2, string)]
what: String,
},
#[prototk(278538, message)]
NotFound {
#[prototk(1, message)]
core: ErrorCore,
#[prototk(2, string)]
what: String,
},
#[prototk(278539, message)]
ResolveFailure {
#[prototk(1, message)]
core: ErrorCore,
#[prototk(2, string)]
what: String,
},
}
impl Error {
pub fn resolve_failure(what: impl Into<String>) -> Self {
Self::ResolveFailure {
core: ErrorCore::default(),
what: what.into(),
}
}
}
impl Default for Error {
fn default() -> Error {
Error::Success {
core: ErrorCore::default(),
}
}
}
impl From<buffertk::Error> for Error {
fn from(err: buffertk::Error) -> Error {
Error::SerializationError {
core: ErrorCore::default(),
err: err.into(),
context: "buffertk unpack error".to_string(),
}
}
}
impl From<prototk::Error> for Error {
fn from(err: prototk::Error) -> Error {
Error::SerializationError {
core: ErrorCore::default(),
err,
context: "prototk unpack error".to_string(),
}
}
}
impl From<std::io::Error> for Error {
fn from(err: std::io::Error) -> Error {
Error::OsError {
core: ErrorCore::default(),
what: format!("{err}"),
}
}
}
iotoz! {Error}
#[cfg(feature = "indicio")]
impl From<Error> for indicio::Value {
fn from(err: Error) -> Self {
match err {
Error::Success { core: _ } => {
indicio::value!({
success: true,
})
}
Error::SerializationError {
core: _,
err,
context,
} => {
indicio::value!({
serialization: {
what: format!("{:?}", err),
context: context,
},
})
}
Error::UnknownServerName { core: _, name } => {
indicio::value!({
unknown_server: name,
})
}
Error::UnknownMethodName { core: _, name } => {
indicio::value!({
unknown_method: name,
})
}
Error::RequestTooLarge { core: _, size } => {
indicio::value!({
request_too_large: {
size: size,
limit: MAX_REQUEST_SIZE,
},
})
}
Error::TransportFailure { core: _, what } => {
indicio::value!({
transport_failure: what,
})
}
Error::EncryptionMisconfiguration { core: _, what } => {
indicio::value!({
encryption_misconfiguration: what,
})
}
Error::UlimitParseError { core: _, what } => {
indicio::value!({
ulimit_parse_error: what,
})
}
Error::OsError { core: _, what } => {
indicio::value!({
os_error: what,
})
}
Error::LogicError { core: _, what } => {
indicio::value!({
logic_error: what,
})
}
Error::NotFound { core: _, what } => {
indicio::value!({
not_found: what,
})
}
Error::ResolveFailure { core: _, what } => {
indicio::value!({
resolve_failure: what,
})
}
}
}
}
pub type Status = Result<Result<Vec<u8>, Vec<u8>>, Error>;
pub trait Server {
fn call(&self, ctx: &Context, method: &str, req: &[u8]) -> Status;
}
pub trait Client {
fn call(&self, ctx: &Context, server: &str, method: &str, req: &[u8]) -> Status;
}
#[derive(Clone, Debug, Default, Message)]
pub struct Frame {
#[prototk(1, uint64)]
pub size: u64,
#[prototk(2, fixed32)]
pub crc32c: u32,
}
impl Frame {
pub fn from_buffer(buf: &[u8]) -> Self {
Self {
size: buf.len() as u64,
crc32c: crc32c::crc32c(buf),
}
}
}
#[derive(Clone, Debug, Default, Message)]
pub struct Request<'a> {
#[prototk(1, string)]
pub service: &'a str,
#[prototk(2, string)]
pub method: &'a str,
#[prototk(3, uint64)]
pub seq_no: u64,
#[prototk(4, bytes)]
pub body: &'a [u8],
#[prototk(5, message)]
pub caller: Vec<ClientID>,
#[prototk(6, message)]
pub trace: Option<TraceID>,
}
#[derive(Clone, Debug, Default, Message)]
pub struct Response<'a> {
#[prototk(3, uint64)]
pub seq_no: u64,
#[prototk(6, message)]
pub trace: Option<TraceID>,
#[prototk(7, bytes)]
pub body: Option<&'a [u8]>,
#[prototk(8, bytes)]
pub service_error: Option<&'a [u8]>,
#[prototk(9, bytes)]
pub rpc_error: Option<&'a [u8]>,
}
#[macro_export]
macro_rules! service {
(name = $service:ident; server = $server:ident; client = $client:ident; error = $error:ty; $(rpc $method:ident ($req:ty) -> $resp:ty;)*) => {
pub trait $service: Send + Sync + 'static {
$(
/// Auto-generated service method generated by service!.
fn $method(&self, ctx: &$crate::Context, req: $req) -> Result<$resp, $error>;
)*
}
pub struct $client {
client: std::sync::Arc<dyn $crate::Client + Send + Sync + 'static>,
}
impl $client where {
pub fn new(client: std::sync::Arc<dyn $crate::Client + Send + Sync + 'static>) -> Self {
Self {
client,
}
}
}
impl $service for $client where
$client: Send + Sync + 'static
{
$(
$crate::client_method! { $service, $method, $req, $resp, $error }
)*
}
pub struct $server<S: $service> {
server: S,
}
impl<S: $service> $server<S> {
pub fn bind(server: S) -> $server<S> {
$server {
server,
}
}
}
impl<S: $service> $crate::Server for $server<S> {
$crate::server_methods! { $service, $error, $($method, $req, $resp),* }
}
};
}
#[macro_export]
macro_rules! client_method {
($service:ident, $method:ident, $req:ty, $resp:ty, $error:ty) => {
fn $method(&self, ctx: &$crate::Context, req: $req) -> Result<$resp, $error> {
let req = ::buffertk::stack_pack(req).to_vec();
let status = self
.client
.call(ctx, stringify!($service), stringify!($method), &req);
match status {
Ok(Ok(msg)) => Ok(<$resp as ::buffertk::Unpackable>::unpack(&msg)?.0),
Ok(Err(msg)) => Err(<$error as ::buffertk::Unpackable>::unpack(&msg)?.0),
Err(err) => Err(err.into()),
}
}
};
}
#[macro_export]
macro_rules! server_methods {
($service:ident, $error:ty, $($method:ident, $req:ty, $resp:ty),*) => {
fn call(&self, ctx: &$crate::Context, method: &str, req: &[u8]) -> $crate::Status {
use buffertk::stack_pack;
match method {
$(
stringify!($method) => {
let req = <$req as ::buffertk::Unpackable>::unpack(req)?.0;
let ans: Result<$resp, $error> = self.server.$method(ctx, req);
match ans {
Ok(resp) => {
Ok(Ok(stack_pack(resp).to_vec()))
}
Err(err) => {
Ok(Err(stack_pack(err).to_vec()))
}
}
}
),*
_ => {
Err($crate::Error::UnknownMethodName {
core: zerror_core::ErrorCore::default(),
name: method.to_string(),
}.into())
},
}
}
};
}
pub struct ServerRegistry {
registry: HashMap<&'static str, Box<dyn Server>>,
}
impl ServerRegistry {
pub fn register<S: Server + 'static>(&mut self, name: &'static str, server: S) {
assert!(!self.registry.contains_key(name));
self.registry.insert(name, Box::new(server));
}
pub fn get_server(&self, name: &str) -> Option<&dyn Server> {
self.registry.get(name).map(|x| x.as_ref())
}
}
pub trait Resolver {
fn resolve(&mut self) -> Result<Host, Error>;
}
#[cfg(test)]
mod tests {
use buffertk::{Unpackable, stack_pack};
use super::*;
fn do_test(s: &str, exp: Error) {
assert_eq!(s, exp.to_string());
let buf = stack_pack(&exp).to_vec();
let got = Error::unpack(&buf).unwrap().0;
assert_eq!(exp, got);
}
#[test]
fn success() {
do_test(
"Success",
Error::Success {
core: ErrorCore::default(),
},
);
}
#[test]
fn serialization_error() {
do_test(
"SerializationError { err: Success, context: \"Some context\" }",
Error::SerializationError {
core: ErrorCore::default(),
context: "Some context".to_owned(),
err: prototk::Error::Success,
},
);
}
#[test]
fn unknown_server_name() {
do_test(
"UnknownServerName { name: \"hostname\" }",
Error::UnknownServerName {
core: ErrorCore::default(),
name: "hostname".to_owned(),
},
);
}
#[test]
fn unknown_method_name() {
do_test(
"UnknownMethodName { name: \"method\" }",
Error::UnknownMethodName {
core: ErrorCore::default(),
name: "method".to_owned(),
},
);
}
#[test]
fn request_too_large() {
do_test(
"RequestTooLarge { size: 10 }",
Error::RequestTooLarge {
core: ErrorCore::default(),
size: 10,
},
);
}
#[test]
fn transport_failure() {
do_test(
"TransportFailure { what: \"socket closed\" }",
Error::TransportFailure {
core: ErrorCore::default(),
what: "socket closed".to_owned(),
},
);
}
#[test]
fn encryption_misconfiguration() {
do_test(
"EncryptionMisconfiguration { what: \"ssl misconfig\" }",
Error::EncryptionMisconfiguration {
core: ErrorCore::default(),
what: "ssl misconfig".to_owned(),
},
);
}
#[test]
fn ulimit_parse_error() {
do_test(
"UlimitParseError { what: \"could not read\" }",
Error::UlimitParseError {
core: ErrorCore::default(),
what: "could not read".to_owned(),
},
);
}
#[test]
fn os_error() {
do_test(
"OsError { what: \"some I/O error\" }",
Error::OsError {
core: ErrorCore::default(),
what: "some I/O error".to_owned(),
},
);
}
#[test]
fn logic_error() {
do_test(
"LogicError { what: \"some logic error\" }",
Error::LogicError {
core: ErrorCore::default(),
what: "some logic error".to_owned(),
},
);
}
#[test]
fn not_found_error() {
do_test(
"NotFound { what: \"deployment\" }",
Error::NotFound {
core: ErrorCore::default(),
what: "deployment".to_owned(),
},
);
}
}