use crate::tarantool::tlua::{AnyLuaString, CallError, LuaThread};
use crate::tarantool::tuple::Tuple;
use crate::tlua;
use crate::transport::Context;
use serde::Serialize;
use std::fmt::Debug;
use std::ops::Deref;
use std::time::Duration;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum RemoteCallError {
#[error("remote call: {0}")]
RemoteCallError(String),
#[error("prepare request: {0}")]
PrepareRequestError(#[from] rmp_serde::encode::Error),
}
pub struct Builder<'a> {
tlua: &'a LuaThread,
handler_name: String,
}
impl<'a> Builder<'a> {
const DEFAULT_HANDLER: &'static str = "rpc_handler";
pub fn new(tlua: &'a LuaThread) -> Self {
Self {
tlua,
handler_name: Self::DEFAULT_HANDLER.to_string(),
}
}
pub fn with_custom_handler(tlua: &'a LuaThread, handler: &'static str) -> Self {
Self {
tlua,
handler_name: handler.to_string(),
}
}
pub fn shard_endpoint<A: Serialize>(self, path: &'static str) -> ShardEndpoint<'a, A> {
let f = move |ctx: &mut Context, bucket_id, path, opts, args| {
let args = rmp_serde::to_vec_named(&args)?;
let args = AnyLuaString(args);
self.tlua
.get::<tlua::LuaFunction<_>, _>("call_shard")
.ok_or_else(|| RemoteCallError::RemoteCallError("call_shard: not found".into()))?
.call_with_args(&(
&self.handler_name,
bucket_id,
opts,
(path, ctx.clone(), args),
))
.map_err(map_lua_fn_err)
};
ShardEndpoint {
timeout: Duration::from_secs(10),
vshard_group: "default",
route: path,
balance: false,
handler: RemoteLuaCall(Box::new(f)),
}
}
pub fn async_shard_endpoint<A: Serialize>(
self,
path: &'static str,
) -> AsyncShardEndpoint<'a, A> {
let f = move |ctx: &mut Context, bucket_id, path, opts, args| {
let args = rmp_serde::to_vec_named(&args)?;
let args = AnyLuaString(args);
self.tlua
.get::<tlua::LuaFunction<_>, _>("call_shard_async")
.ok_or_else(|| {
RemoteCallError::RemoteCallError("call_shard_async: not found".into())
})?
.call_with_args(&(
&self.handler_name,
bucket_id,
opts,
(path, ctx.clone(), args),
))
.map_err(map_lua_fn_err)
};
AsyncShardEndpoint {
timeout: Duration::from_secs(10),
vshard_group: "default",
route: path,
balance: false,
handler: RemoteLuaCall(Box::new(f)),
}
}
pub fn replicaset_endpoint<A: Serialize>(
self,
path: &'static str,
) -> ReplicasetEndpoint<'a, A> {
let f = move |ctx: &mut Context, rs_uuid, path, opts, args| {
let args = rmp_serde::to_vec_named(&args)?;
let args = AnyLuaString(args);
self.tlua
.get::<tlua::LuaFunction<_>, _>("call_rs")
.ok_or_else(|| RemoteCallError::RemoteCallError("call_rs: not found".into()))?
.call_with_args(&(&self.handler_name, rs_uuid, opts, (path, ctx.clone(), args)))
.map_err(map_lua_fn_err)
};
ReplicasetEndpoint {
timeout: Duration::from_secs(10),
vshard_group: "default",
prefer_replica: false,
route: path,
handler: RemoteLuaCall(Box::new(f)),
}
}
pub fn role_endpoint<A: Serialize>(
self,
role: &'static str,
path: &'static str,
) -> RoleEndpoint<'a, A> {
let f = move |ctx: &mut Context, _, path, opts, args| {
let args = rmp_serde::to_vec_named(&args)?;
let args = AnyLuaString(args);
self.tlua
.get::<tlua::LuaFunction<_>, _>("call_role")
.ok_or_else(|| RemoteCallError::RemoteCallError("call_role: not found".into()))?
.call_with_args(&(Self::DEFAULT_HANDLER, role, opts, (path, ctx.clone(), args)))
.map_err(map_lua_fn_err)
};
RoleEndpoint {
timeout: Duration::from_secs(10),
route: path,
route_mode: RouteMode::RandomMaster,
handler: RemoteLuaCall(Box::new(f)),
}
}
}
#[derive(Clone, tlua::Push, Default)]
pub struct Options {
timeout: f64,
vshard_group: &'static str,
uri: Option<String>,
leader_only: bool,
prefer_replica: bool,
balance: bool,
}
type RemoteLuaCallFn<'a, ID, A> =
dyn Fn(&mut Context, ID, &'static str, Options, A) -> Result<Tuple, RemoteCallError> + 'a;
pub struct RemoteLuaCall<'a, ID, A: Serialize>(pub Box<RemoteLuaCallFn<'a, ID, A>>);
impl<'a, ID, A: Serialize> Deref for RemoteLuaCall<'a, ID, A> {
type Target = Box<RemoteLuaCallFn<'a, ID, A>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
type MiddlewareFn<'a, ID, A> = dyn Fn(RemoteLuaCall<'a, ID, A>) -> RemoteLuaCall<'a, ID, A>;
pub struct Middleware<'a, ID, A: Serialize>(pub Box<MiddlewareFn<'a, ID, A>>);
impl<'a, ID, A: Serialize> Deref for Middleware<'a, ID, A> {
type Target = Box<MiddlewareFn<'a, ID, A>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
pub struct ShardEndpoint<'a, A: Serialize> {
timeout: Duration,
vshard_group: &'static str,
route: &'static str,
balance: bool,
handler: RemoteLuaCall<'a, i64, A>,
}
impl<'a, A: Serialize> ShardEndpoint<'a, A> {
pub fn with_middleware(self, mw: Middleware<'a, i64, A>) -> Self {
Self {
handler: (mw)(self.handler),
..self
}
}
pub fn with_vshard_group(self, group: &'static str) -> Self {
Self {
vshard_group: group,
..self
}
}
pub fn with_balancer(self) -> Self {
Self {
balance: true,
..self
}
}
pub fn with_timeout(self, timeout: Duration) -> Self {
Self { timeout, ..self }
}
pub fn call(
&self,
context: &mut Context,
bucket_id: i64,
args: A,
) -> Result<Tuple, RemoteCallError> {
context.put("path", self.route);
(self.handler)(
context,
bucket_id,
self.route,
Options {
timeout: self.timeout.as_secs_f64(),
vshard_group: self.vshard_group,
balance: self.balance,
..Default::default()
},
args,
)
}
}
pub struct AsyncShardEndpoint<'a, A: Serialize> {
timeout: Duration,
vshard_group: &'static str,
route: &'static str,
balance: bool,
handler: RemoteLuaCall<'a, i64, A>,
}
impl<'a, A: Serialize> AsyncShardEndpoint<'a, A> {
pub fn with_middleware(self, mw: Middleware<'a, i64, A>) -> Self {
Self {
handler: (mw)(self.handler),
..self
}
}
pub fn with_vshard_group(self, group: &'static str) -> Self {
Self {
vshard_group: group,
..self
}
}
pub fn with_balancer(self) -> Self {
Self {
balance: true,
..self
}
}
pub fn with_timeout(self, timeout: Duration) -> Self {
Self { timeout, ..self }
}
pub fn call(
&self,
context: &mut Context,
bucket_id: i64,
args: A,
) -> Result<(), RemoteCallError> {
context.put("path", self.route);
let opts = Options {
timeout: self.timeout.as_secs_f64(),
vshard_group: self.vshard_group,
balance: self.balance,
..Default::default()
};
(self.handler)(context, bucket_id, self.route, opts, args).map(|_| ())
}
}
pub struct ReplicasetEndpoint<'a, A: Serialize> {
timeout: Duration,
vshard_group: &'static str,
prefer_replica: bool,
route: &'static str,
handler: RemoteLuaCall<'a, &'a str, A>,
}
impl<'a, A: Serialize> ReplicasetEndpoint<'a, A> {
pub fn with_middleware(self, mw: Middleware<'a, &'a str, A>) -> Self {
Self {
handler: (mw)(self.handler),
..self
}
}
pub fn with_vshard_group(self, group: &'static str) -> Self {
Self {
vshard_group: group,
..self
}
}
pub fn with_timeout(self, timeout: Duration) -> Self {
Self { timeout, ..self }
}
pub fn prefer_replica(self) -> Self {
Self {
prefer_replica: true,
..self
}
}
pub fn call(
&self,
context: &mut Context,
rs_uuid: &'a str,
args: A,
) -> Result<Tuple, RemoteCallError> {
context.put("path", self.route);
let opts = Options {
timeout: self.timeout.as_secs_f64(),
vshard_group: self.vshard_group,
prefer_replica: self.prefer_replica,
..Default::default()
};
(self.handler)(context, rs_uuid, self.route, opts, args)
}
}
#[derive(Default)]
enum RouteMode<'a> {
#[default]
RandomMaster,
CustomUri(&'a str),
}
pub struct RoleEndpoint<'a, A: Serialize> {
timeout: Duration,
route: &'static str,
route_mode: RouteMode<'a>,
handler: RemoteLuaCall<'a, (), A>,
}
impl<'a, A: Serialize> RoleEndpoint<'a, A> {
pub fn with_middleware(self, mw: Middleware<'a, (), A>) -> Self {
Self {
handler: (mw)(self.handler),
..self
}
}
pub fn with_timeout(self, timeout: Duration) -> Self {
Self { timeout, ..self }
}
pub fn with_uri(self, uri: &'a str) -> Self {
Self {
route_mode: RouteMode::CustomUri(uri),
..self
}
}
pub fn call(&self, context: &mut Context, args: A) -> Result<Tuple, RemoteCallError> {
context.put("path", self.route);
let mut opts = Options {
timeout: self.timeout.as_secs_f64(),
..Default::default()
};
match self.route_mode {
RouteMode::RandomMaster => opts.leader_only = true,
RouteMode::CustomUri(uri) => opts.uri = Some(uri.to_string()),
}
(self.handler)(context, (), self.route, opts, args)
}
}
fn map_lua_fn_err<E>(e: CallError<E>) -> RemoteCallError {
match e {
CallError::LuaError(e) => RemoteCallError::RemoteCallError(format!("{}", e)),
CallError::PushError(_) => RemoteCallError::RemoteCallError("push error".into()),
}
}