use std::collections::HashMap;
use std::sync::Arc;
use std::sync::LazyLock;
use std::sync::Mutex;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use bytes::Buf;
use tokio::sync::Mutex as TokioMutex;
use tokio::sync::Notify;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use crate::StatusCodeError;
use crate::StatusError;
use crate::attributes::Attributes;
use crate::client::CallOptions;
use crate::client::DynRecvStream as ClientDynRecvStream;
use crate::client::DynSendStream as ClientDynSendStream;
use crate::client::Invoke;
use crate::client::RecvStream as ClientRecvStream;
use crate::client::ResponseStreamItem;
use crate::client::SendOptions as ClientSendOptions;
use crate::client::SendStream as ClientSendStream;
use crate::client::name_resolution::Address;
use crate::client::name_resolution::ChannelController as ResolverChannelController;
use crate::client::name_resolution::Endpoint;
use crate::client::name_resolution::Resolver;
use crate::client::name_resolution::ResolverBuilder;
use crate::client::name_resolution::ResolverOptions;
use crate::client::name_resolution::ResolverUpdate;
use crate::client::name_resolution::Target;
use crate::client::name_resolution::global_registry as global_resolver_registry;
use crate::client::service_config::ServiceConfig;
use crate::client::transport::GLOBAL_TRANSPORT_REGISTRY;
use crate::client::transport::SecurityOpts;
use crate::client::transport::Transport;
use crate::client::transport::TransportOptions;
use crate::core::RecvMessage;
use crate::core::RequestHeaders;
use crate::core::ResponseHeaders;
use crate::core::SendMessage;
use crate::core::Trailers;
use crate::credentials::SecurityLevel;
use crate::credentials::client::ClientConnectionSecurityContext;
use crate::credentials::client::ClientConnectionSecurityInfo;
use crate::credentials::client::DynClientConnectionSecurityInfo;
use crate::credentials::common::Authority;
use crate::rt::GrpcRuntime;
use crate::server::Call as ServerCall;
use crate::server::Listener as ServerListener;
use crate::server::RecvStream as ServerRecvStream;
use crate::server::ResponseStreamItem as ServerResponseStreamItem;
use crate::server::SendOptions as ServerSendOptions;
use crate::server::SendStream as ServerSendStream;
static LISTENERS: LazyLock<Mutex<HashMap<String, mpsc::Sender<InMemoryServerCall>>>> =
LazyLock::new(|| Mutex::new(HashMap::new()));
static NEXT_ID: AtomicU64 = AtomicU64::new(0);
struct InMemoryServerCall {
headers: RequestHeaders,
req_rx: mpsc::UnboundedReceiver<InMemoryRequestStreamItem>,
resp_tx: mpsc::UnboundedSender<InMemoryResponseStreamItem>,
trailer_tx: oneshot::Sender<Trailers>,
}
enum InMemoryRequestStreamItem {
Message(Box<dyn Buf + Send + Sync>),
StreamClosed,
}
enum InMemoryResponseStreamItem {
Headers(ResponseHeaders),
Message(Box<dyn Buf + Send + Sync>),
}
#[derive(Clone)]
pub struct InMemoryListener {
inner: Arc<InMemoryListenerInner>,
}
struct InMemoryListenerInner {
id: String,
r: TokioMutex<mpsc::Receiver<InMemoryServerCall>>,
close_notify: Arc<Notify>,
drop_notify: Arc<Notify>,
}
impl Drop for InMemoryListenerInner {
fn drop(&mut self) {
self.drop_notify.notify_waiters();
}
}
impl Default for InMemoryListener {
fn default() -> Self {
Self::new()
}
}
impl InMemoryListener {
pub fn new() -> Self {
let id = NEXT_ID.fetch_add(1, Ordering::Relaxed).to_string();
let (s, r) = mpsc::channel(1);
let mut listeners = LISTENERS.lock().unwrap();
listeners.insert(id.clone(), s);
Self {
inner: Arc::new(InMemoryListenerInner {
id,
r: TokioMutex::new(r),
close_notify: Arc::new(Notify::new()),
drop_notify: Arc::new(Notify::new()),
}),
}
}
pub fn id(&self) -> String {
self.inner.id.clone()
}
pub async fn close(self) {
let id = self.inner.id.clone();
let drop_notify = self.inner.drop_notify.clone();
let weak = Arc::downgrade(&self.inner);
LISTENERS.lock().unwrap().remove(&id);
self.inner.close_notify.notify_waiters();
drop(self);
loop {
let notified = drop_notify.notified();
if weak.upgrade().is_none() {
return;
}
notified.await;
}
}
pub async fn await_connection(&self) {}
}
impl ServerListener for InMemoryListener {
type SendStream = InMemoryServerSendStream;
type RecvStream = InMemoryServerRecvStream;
async fn accept(&self) -> Option<ServerCall<Self::SendStream, Self::RecvStream>> {
let mut r = self.inner.r.lock().await;
tokio::select! {
call = r.recv() => {
let call = call?;
Some(ServerCall {
headers: call.headers,
send: InMemoryServerSendStream { tx: call.resp_tx },
recv: InMemoryServerRecvStream { rx: call.req_rx },
trailers_tx: call.trailer_tx,
})
}
_ = self.inner.close_notify.notified() => {
None
}
}
}
}
pub struct InMemoryServerSendStream {
tx: mpsc::UnboundedSender<InMemoryResponseStreamItem>,
}
impl ServerSendStream for InMemoryServerSendStream {
async fn send<'a>(
&mut self,
item: ServerResponseStreamItem<'a>,
_options: ServerSendOptions,
) -> Result<(), ()> {
let inmemory_item = match item {
ServerResponseStreamItem::Headers(h) => InMemoryResponseStreamItem::Headers(h),
ServerResponseStreamItem::Message(m) => {
let buf = m.encode().map_err(|_| ())?;
InMemoryResponseStreamItem::Message(buf)
}
};
self.tx.send(inmemory_item).map_err(|_| ())
}
}
pub struct InMemoryServerRecvStream {
rx: mpsc::UnboundedReceiver<InMemoryRequestStreamItem>,
}
impl ServerRecvStream for InMemoryServerRecvStream {
async fn next(&mut self, msg: &mut dyn RecvMessage) -> Option<Result<(), ()>> {
match self.rx.recv().await {
Some(InMemoryRequestStreamItem::Message(mut buf)) => {
if msg.decode(&mut buf).is_err() {
return Some(Err(()));
}
Some(Ok(()))
}
Some(InMemoryRequestStreamItem::StreamClosed) => None,
None => None,
}
}
}
pub struct InMemoryConnection {
s: mpsc::Sender<InMemoryServerCall>,
closed_tx: Option<oneshot::Sender<Result<(), String>>>,
}
impl Invoke for InMemoryConnection {
type SendStream = Box<dyn ClientDynSendStream>;
type RecvStream = Box<dyn ClientDynRecvStream>;
async fn invoke(
&self,
headers: RequestHeaders,
_options: CallOptions,
) -> (Self::SendStream, Self::RecvStream) {
let (req_tx, req_rx) = mpsc::unbounded_channel::<InMemoryRequestStreamItem>();
let (resp_tx, resp_rx) = mpsc::unbounded_channel::<InMemoryResponseStreamItem>();
let (trailer_tx, trailer_rx) = oneshot::channel();
let call = InMemoryServerCall {
headers,
req_rx,
resp_tx,
trailer_tx,
};
let _ = self.s.send(call).await;
(
Box::new(InMemoryClientSendStream { tx: Some(req_tx) }),
Box::new(InMemoryClientRecvStream {
rx: resp_rx,
trailer_rx: Some(trailer_rx),
}),
)
}
}
impl Drop for InMemoryConnection {
fn drop(&mut self) {
let _ = self.closed_tx.take().unwrap().send(Err("".into()));
}
}
pub struct InMemoryClientSendStream {
tx: Option<mpsc::UnboundedSender<InMemoryRequestStreamItem>>,
}
impl ClientSendStream for InMemoryClientSendStream {
async fn send(&mut self, msg: &dyn SendMessage, _options: ClientSendOptions) -> Result<(), ()> {
let buf = msg.encode().unwrap();
if self
.tx
.as_mut()
.unwrap()
.send(InMemoryRequestStreamItem::Message(buf))
.is_err()
{
self.tx = None;
return Err(());
}
Ok(())
}
}
impl Drop for InMemoryClientSendStream {
fn drop(&mut self) {
if let Some(tx) = self.tx.take() {
let _ = tx.send(InMemoryRequestStreamItem::StreamClosed);
}
}
}
pub struct InMemoryClientRecvStream {
rx: mpsc::UnboundedReceiver<InMemoryResponseStreamItem>,
trailer_rx: Option<oneshot::Receiver<Trailers>>,
}
impl ClientRecvStream for InMemoryClientRecvStream {
async fn recv(&mut self, msg: &mut dyn RecvMessage) -> ResponseStreamItem {
match self.rx.recv().await {
Some(InMemoryResponseStreamItem::Headers(h)) => ResponseStreamItem::Headers(h),
Some(InMemoryResponseStreamItem::Message(mut buf)) => {
msg.decode(&mut buf).unwrap();
ResponseStreamItem::Message
}
_ => {
if let Some(trailer_rx) = self.trailer_rx.take() {
match trailer_rx.await {
Ok(trailers) => return ResponseStreamItem::Trailers(trailers),
Err(_) => {
return ResponseStreamItem::Trailers(Trailers::new(Err(
StatusError::new(
StatusCodeError::Internal,
"stream ended without trailers in in-memory transport",
),
)));
}
}
}
ResponseStreamItem::StreamClosed
}
}
}
}
pub struct InMemoryTransport {}
impl Transport for InMemoryTransport {
type Service = InMemoryConnection;
async fn connect(
&self,
target: String,
_runtime: GrpcRuntime,
_security_opts: &SecurityOpts,
_options: &TransportOptions,
) -> Result<
(
Self::Service,
DynClientConnectionSecurityInfo,
oneshot::Receiver<Result<(), String>>,
),
String,
> {
let listeners = LISTENERS.lock().unwrap();
let s = listeners
.get(&target)
.ok_or_else(|| format!("no listener for target: {}", target))?;
let (closed_tx, closed_rx) = oneshot::channel();
let conn = InMemoryConnection {
s: s.clone(),
closed_tx: Some(closed_tx),
};
let sec_info = ClientConnectionSecurityInfo::new(
"inmemory",
SecurityLevel::PrivacyAndIntegrity,
InMemoryChannelecurityContext {},
Attributes::new(),
)
.into_boxed();
Ok((conn, sec_info, closed_rx))
}
}
#[derive(Debug, Clone)]
struct InMemoryChannelecurityContext;
impl ClientConnectionSecurityContext for InMemoryChannelecurityContext {
fn validate_authority(&self, _authority: &Authority) -> bool {
true
}
}
pub struct InMemoryResolverBuilder {}
impl ResolverBuilder for InMemoryResolverBuilder {
fn build(&self, target: &Target, options: ResolverOptions) -> Box<dyn Resolver> {
let path = target.path().strip_prefix('/').unwrap_or(target.path());
let ids: Vec<String> = path.split(',').map(|s| s.to_string()).collect();
options.work_scheduler.schedule_work();
Box::new(InMemoryResolver { ids })
}
fn scheme(&self) -> &str {
"inmemory"
}
fn is_valid_uri(&self, _uri: &Target) -> bool {
true
}
}
struct InMemoryResolver {
ids: Vec<String>,
}
impl Resolver for InMemoryResolver {
fn resolve_now(&mut self) {}
fn work(&mut self, channel_controller: &mut dyn ResolverChannelController) {
let endpoints = self
.ids
.iter()
.map(|id| Endpoint {
addresses: vec![Address {
network_type: "inmemory",
address: crate::byte_str::ByteStr::from(id.clone()),
..Default::default()
}],
..Default::default()
})
.collect();
let _ = channel_controller.update(ResolverUpdate {
endpoints: Ok(endpoints),
service_config: Ok(Some(ServiceConfig {
load_balancing_policy: Some(
crate::client::service_config::LbPolicyType::RoundRobin,
),
})),
..Default::default()
});
}
}
pub fn reg() {
GLOBAL_TRANSPORT_REGISTRY.add_transport("inmemory", InMemoryTransport {});
global_resolver_registry().add_builder(Box::new(InMemoryResolverBuilder {}));
}
#[cfg(test)]
mod tests {
use bytes::Buf;
use super::*;
use crate::core::RecvMessage;
struct NopRecvMessage;
impl RecvMessage for NopRecvMessage {
fn decode(&mut self, _data: &mut dyn Buf) -> Result<(), String> {
Ok(())
}
}
#[tokio::test]
async fn test_in_memory_recv_stream_missing_trailers() {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<InMemoryResponseStreamItem>();
drop(tx);
let (trailer_tx, trailer_rx) = tokio::sync::oneshot::channel::<Trailers>();
drop(trailer_tx);
let mut stream = InMemoryClientRecvStream {
rx,
trailer_rx: Some(trailer_rx),
};
let mut msg = NopRecvMessage;
let item = stream.recv(&mut msg).await;
match item {
ResponseStreamItem::Trailers(t) => {
assert_eq!(
t.status().as_ref().unwrap_err().code(),
crate::StatusCodeError::Internal
);
assert!(
t.status()
.as_ref()
.unwrap_err()
.message()
.contains("stream ended without trailers")
);
}
_ => panic!("expected trailers with error, got {:?}", item),
}
}
}