pub(crate) mod message_rpc {
tonic::include_proto!("message");
}
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use message_rpc::message_service_client::MessageServiceClient;
use message_rpc::message_service_server::{MessageService, MessageServiceServer};
use message_rpc::{
FetchMessagesRequest, FetchMessagesResponse, PushMessageRequest, PushMessageResponse,
};
use tonic::transport::{Certificate, Identity, Server, ServerTlsConfig, Uri};
use tonic::transport::{Channel, ClientTlsConfig};
use tonic::{Request, Response, Result as RpcResult};
use super::Transport;
use crate::error::{Result, TransportError};
pub struct RpcTransport {
rpc_client: MessageServiceClient<Channel>,
last_sync_id: Arc<AtomicU64>,
push_target: Vec<u8>,
fetch_target: Vec<u8>,
}
impl RpcTransport {
pub async fn connect(
msg_server_addr: impl TryInto<Uri>,
my_identity_key: &[u8],
peer_identity_key: &[u8],
ca: Option<Certificate>,
) -> Result<Self> {
let uri: Uri = msg_server_addr
.try_into()
.unwrap_or_else(|_| panic!("Invalid message server address."));
let mut endpoint = Channel::builder(uri.clone());
if uri.scheme_str() == Some("https") {
endpoint = endpoint
.tls_config({
let mut config = ClientTlsConfig::new().with_native_roots();
if let Some(ca) = ca {
config = config.ca_certificate(ca);
}
config
})
.unwrap();
}
let channel = endpoint
.connect()
.await
.map_err(|_| TransportError::Connect)?;
let rpc_client = MessageServiceClient::new(channel);
Ok(Self {
rpc_client,
last_sync_id: Arc::new(AtomicU64::default()),
push_target: [my_identity_key, peer_identity_key].concat().to_vec(),
fetch_target: [peer_identity_key, my_identity_key].concat().to_vec(),
})
}
}
#[allow(clippy::manual_async_fn)]
impl Transport for RpcTransport {
fn push_bytes(&mut self, bytes: Vec<u8>) -> impl Future<Output = Result<()>> + Send + 'static {
let req = PushMessageRequest {
target: self.push_target.clone(),
enc_message: bytes,
};
let mut client = self.rpc_client.clone();
async move {
let _resp = client
.push_message(req)
.await
.map_err(|_| TransportError::Push)?;
Ok(())
}
}
fn fetch_bytes(
&mut self,
limit: Option<usize>,
) -> impl Future<Output = Result<Vec<Vec<u8>>>> + Send + 'static {
let req = FetchMessagesRequest {
target: self.fetch_target.clone(),
last_sync_id: self.last_sync_id.load(Ordering::Relaxed),
limit: limit.map(|limit| limit as u64),
};
let mut client = self.rpc_client.clone();
let last_sync_id = self.last_sync_id.clone();
async move {
let resp = client
.fetch_messages(req)
.await
.map_err(|_| TransportError::Fetch)?;
last_sync_id.fetch_add(resp.get_ref().enc_messages.len() as u64, Ordering::Relaxed);
Ok(resp.into_inner().enc_messages)
}
}
}
pub struct RpcMessageServer {}
impl RpcMessageServer {
pub async fn run(addr: impl AsRef<str>, identity: Option<Identity>) -> Result<()> {
let addr = addr.as_ref().parse().unwrap();
let mut server = Server::builder();
if let Some(identity) = identity {
server = server
.tls_config(ServerTlsConfig::new().identity(identity))
.unwrap()
}
server
.add_service(MessageServiceServer::new(RpcMessageServerInner::default()))
.serve(addr)
.await
.map_err(|_| TransportError::Server)?;
Ok(())
}
}
#[allow(clippy::type_complexity)]
#[derive(Debug, Default)]
pub(crate) struct RpcMessageServerInner {
db: RwLock<HashMap<Vec<u8>, Arc<RwLock<Vec<Vec<u8>>>>>>,
}
#[tonic::async_trait]
impl MessageService for RpcMessageServerInner {
async fn push_message(
&self,
request: Request<PushMessageRequest>,
) -> RpcResult<Response<PushMessageResponse>> {
let req = request.into_inner();
let q = self.db.write().entry(req.target).or_default().clone();
q.write().push(req.enc_message);
Ok(Response::new(PushMessageResponse {}))
}
async fn fetch_messages(
&self,
request: Request<FetchMessagesRequest>,
) -> RpcResult<Response<FetchMessagesResponse>> {
let req = request.into_inner();
let q = self.db.write().entry(req.target).or_default().clone();
let q = q.read();
let enc_messages = q
.get(
req.last_sync_id as usize
..req
.limit
.map(|limit| ((req.last_sync_id + limit) as usize).max(q.len()))
.unwrap_or(q.len()),
)
.map(|x| x.to_vec())
.unwrap_or_default();
Ok(Response::new(FetchMessagesResponse { enc_messages }))
}
}
#[cfg(test)]
mod test {
use crate::transport::EncryptedMessage;
use super::*;
#[tokio::test]
async fn grpc_transport() {
tokio::spawn(async {
RpcMessageServer::run("[::1]:3000", None).await.unwrap();
});
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let mut alice = RpcTransport::connect("http://[::1]:3000", b"Alice", b"Bob", None)
.await
.unwrap();
let mut bob = RpcTransport::connect("http://[::1]:3000", b"Bob", b"Alice", None)
.await
.unwrap();
let msg = EncryptedMessage {
enc_header: vec![1, 2, 3],
enc_content: vec![4, 5, 6],
};
alice.push(msg.clone()).await.unwrap();
assert_eq!(bob.fetch(None).await.unwrap()[0], msg);
alice.push(msg.clone()).await.unwrap();
assert_eq!(bob.fetch(None).await.unwrap()[0], msg);
let msg = EncryptedMessage {
enc_header: vec![4, 5, 6],
enc_content: vec![1, 2, 3],
};
alice.push(msg.clone()).await.unwrap();
assert_eq!(bob.fetch(None).await.unwrap()[0], msg);
}
}