use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use futures::prelude::*;
use futures::TryStreamExt;
use services::{RunInfo, RequestInfo, IterationInfo};
use std::future::Future;
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use tokio::sync::RwLock;
use tokio::sync::{
mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
oneshot,
};
use tokio_serde::formats::SymmetricalJson;
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
use uuid::Uuid;
pub use serde;
pub use tokio;
pub mod services;
pub use crows_service;
pub struct Server {
listener: TcpListener,
}
impl Server {
pub async fn accept(
&self,
) -> Option<(
UnboundedSender<Message>,
UnboundedReceiver<Message>,
oneshot::Receiver<()>,
)> {
let (socket, _) = self.listener.accept().await.ok()?;
let (reader, writer) = socket.into_split();
let length_delimited = FramedRead::new(reader, LengthDelimitedCodec::new());
let mut deserialized =
tokio_serde::SymmetricallyFramed::new(length_delimited, SymmetricalJson::default());
let length_delimited = FramedWrite::new(writer, LengthDelimitedCodec::new());
let mut serialized = tokio_serde::SymmetricallyFramed::new(
length_delimited,
SymmetricalJson::<Message>::default(),
);
let (serialized_sender, mut serialized_receiver) = unbounded_channel::<Message>();
let (deserialized_sender, deserialized_receiver) = unbounded_channel::<Message>();
let (close_sender, close_receiver) = oneshot::channel::<()>();
tokio::spawn(async move {
while let Some(message) = serialized_receiver.recv().await {
if let Err(err) = serialized.send(message).await {
println!("Error while sending message: {err:?}");
break;
}
}
});
tokio::spawn(async move {
while let Ok(Some(message)) = deserialized.try_next().await {
if let Err(err) = deserialized_sender.send(message) {
println!("Error while sending message: {err:?}");
break;
}
}
if let Err(e) = close_sender.send(()) {
println!("Got an error when sending to a close_sender: {e:?}");
}
});
Some((serialized_sender, deserialized_receiver, close_receiver))
}
}
pub async fn create_server<A>(addr: A) -> Result<Server, std::io::Error>
where
A: ToSocketAddrs,
{
let listener = TcpListener::bind(addr).await?;
Ok(Server { listener })
}
pub async fn create_client<A>(
addr: A,
) -> Result<(UnboundedSender<Message>, UnboundedReceiver<Message>), std::io::Error>
where
A: ToSocketAddrs,
{
let socket = TcpStream::connect(addr).await?;
let (reader, writer) = socket.into_split();
let length_delimited = FramedWrite::new(writer, LengthDelimitedCodec::new());
let mut serialized =
tokio_serde::SymmetricallyFramed::new(length_delimited, SymmetricalJson::default());
let length_delimited = FramedRead::new(reader, LengthDelimitedCodec::new());
let mut deserialized =
tokio_serde::SymmetricallyFramed::new(length_delimited, SymmetricalJson::default());
let (serialized_sender, mut serialized_receiver) = unbounded_channel::<Message>();
let (deserialized_sender, deserialized_receiver) = unbounded_channel::<Message>();
tokio::spawn(async move {
while let Some(message) = serialized_receiver.recv().await {
if let Err(err) = serialized.send(message).await {
println!("Error while sending message: {err:?}");
break;
}
}
});
tokio::spawn(async move {
while let Ok(Some(message)) = deserialized.try_next().await {
if let Err(err) = deserialized_sender.send(message) {
println!("Error while sending message: {err:?}");
break;
}
}
});
Ok((serialized_sender, deserialized_receiver))
}
#[derive(Debug)]
struct RegisterListener {
respond_to: oneshot::Sender<String>,
message_id: Uuid,
}
#[derive(Debug)]
enum InternalMessage {
RegisterListener(RegisterListener),
}
#[derive(Clone)]
pub struct Client {
inner: Arc<RwLock<ClientInner>>,
sender: UnboundedSender<Message>,
internal_sender: UnboundedSender<InternalMessage>,
}
struct ClientInner {
close_receiver: Option<oneshot::Receiver<()>>,
}
impl Client {
pub async fn request<
T: Serialize + std::fmt::Debug + DeserializeOwned + Send + 'static,
Y: Serialize + std::fmt::Debug + DeserializeOwned + Send + 'static,
>(
&self,
message: T,
) -> anyhow::Result<Y> {
let message = Message {
id: Uuid::new_v4(),
reply_to: None,
message: serde_json::to_string(&message)?,
message_type: std::any::type_name::<T>().to_string(),
};
let (tx, rx) = oneshot::channel::<String>();
let register_listener = RegisterListener {
respond_to: tx,
message_id: message.id,
};
self.send_internal(InternalMessage::RegisterListener(register_listener))
.await?;
self.send(message).await?;
match rx.await {
Ok(reply) => Ok(serde_json::from_str(&reply)?),
Err(e) => Err(e)?,
}
}
async fn send(&self, message: Message) -> anyhow::Result<()> {
Ok(self.sender.send(message)?)
}
async fn send_internal(&self, message: InternalMessage) -> anyhow::Result<()> {
Ok(self.internal_sender.send(message)?)
}
pub fn new<T, DummyType>(
sender: UnboundedSender<Message>,
mut receiver: UnboundedReceiver<Message>,
mut service: T,
close_receiver: Option<oneshot::Receiver<()>>,
) -> <T as Service<DummyType>>::Client
where
T: Service<DummyType> + Send + Sync + 'static + Clone,
<T as Service<DummyType>>::Request: Send,
<T as Service<DummyType>>::Response: Send,
<T as Service<DummyType>>::Client: ClientTrait + Clone + Send + Sync + 'static,
{
let (internal_sender, mut internal_receiver) = unbounded_channel();
let client = T::Client::new(Self {
inner: Arc::new(RwLock::new(ClientInner { close_receiver })),
sender: sender.clone(),
internal_sender,
});
let client_clone = client.clone();
tokio::spawn(async move {
let mut listeners: HashMap<Uuid, oneshot::Sender<String>> = HashMap::new();
loop {
tokio::select! {
message = receiver.recv() => {
match message {
Some(message) => {
if let Some(reply_to) = message.reply_to {
let reply = listeners.remove(&reply_to).unwrap();
if reply.send(message.message).is_err() {
break;
}
} else {
let service_clone = service.clone();
let sender_clone = sender.clone();
let client_clone = client_clone.clone();
tokio::spawn(async move {
let deserialized = serde_json::from_str::<<T as Service<DummyType>>::Request>(&message.message).unwrap();
let response = service_clone.handle_request(client_clone, deserialized).await;
let message = Message {
id: Uuid::new_v4(),
reply_to: Some(message.id),
message: serde_json::to_string(&response).unwrap(),
message_type: std::any::type_name::<T>().to_string(),
};
sender_clone.send(message).unwrap();
});
}
},
None => break,
}
}
internal_message = internal_receiver.recv() => {
match internal_message {
Some(internal_message) => {
match internal_message {
InternalMessage::RegisterListener(register_listener) => {
listeners.insert(register_listener.message_id, register_listener.respond_to);
}
}
},
None => break
}
}
}
}
});
client
}
pub async fn get_close_receiver(&self) -> Option<oneshot::Receiver<()>> {
let mut inner = self.inner.write().await;
inner.close_receiver.take()
}
pub async fn wait(&self) {
let mut inner = self.inner.write().await;
if let Some(receiver) = inner.close_receiver.take() {
if let Err(e) = receiver.await {
println!("Got an error when waiting for oneshot receiver: {e:?}");
}
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Message {
pub id: Uuid,
pub reply_to: Option<Uuid>,
pub message: String,
pub message_type: String,
}
pub trait ClientTrait {
fn new(client: Client) -> Self;
}
pub trait Service<DummyType>: Send + Sync {
type Response: Send + Serialize;
type Request: DeserializeOwned + Send;
type Client: ClientTrait + Clone + Send + Sync;
fn handle_request(
&self,
client: Self::Client,
message: Self::Request,
) -> Pin<Box<dyn Future<Output = Self::Response> + Send + '_>>;
}
pub async fn process_info_handle(handle: &mut InfoHandle) -> RunInfo {
let mut run_info: RunInfo = Default::default();
run_info.done = false;
while let Ok(update) = handle.receiver.try_recv() {
match update {
InfoMessage::Stderr(buf) => run_info.stderr.push(buf),
InfoMessage::Stdout(buf) => run_info.stdout.push(buf),
InfoMessage::RequestInfo(info) => run_info.request_stats.push(info),
InfoMessage::IterationInfo(info) => run_info.iteration_stats.push(info),
InfoMessage::InstanceCheckedOut => run_info.active_instances_delta += 1,
InfoMessage::InstanceReserved => run_info.capacity_delta += 1,
InfoMessage::InstanceCheckedIn => run_info.active_instances_delta -= 1,
InfoMessage::TimingUpdate((elapsed, left)) => {
run_info.elapsed = Some(elapsed);
run_info.left = Some(left);
}
InfoMessage::Done => run_info.done = true,
}
}
run_info
}
pub enum InfoMessage {
Stderr(Vec<u8>),
Stdout(Vec<u8>),
RequestInfo(RequestInfo),
IterationInfo(IterationInfo),
InstanceCheckedOut,
InstanceReserved,
InstanceCheckedIn,
TimingUpdate((Duration, Duration)),
Done,
}
pub struct InfoHandle {
pub receiver: UnboundedReceiver<InfoMessage>,
}