1use std::collections::HashMap;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::time::Duration;
5
6use serde::de::DeserializeOwned;
7use serde::{Deserialize, Serialize};
8
9use futures::prelude::*;
10use futures::TryStreamExt;
11use services::{RunInfo, RequestInfo, IterationInfo};
12use std::future::Future;
13use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
14use tokio::sync::RwLock;
15use tokio::sync::{
16 mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
17 oneshot,
18};
19use tokio_serde::formats::SymmetricalJson;
20use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
21use uuid::Uuid;
22
23pub use serde;
24pub use tokio;
25pub mod services;
26pub use crows_service;
27
28pub struct Server {
29 listener: TcpListener,
30}
31
32impl Server {
33 pub async fn accept(
34 &self,
35 ) -> Option<(
36 UnboundedSender<Message>,
37 UnboundedReceiver<Message>,
38 oneshot::Receiver<()>,
39 )> {
40 let (socket, _) = self.listener.accept().await.ok()?;
41 let (reader, writer) = socket.into_split();
42
43 let length_delimited = FramedRead::new(reader, LengthDelimitedCodec::new());
45
46 let mut deserialized =
48 tokio_serde::SymmetricallyFramed::new(length_delimited, SymmetricalJson::default());
49
50 let length_delimited = FramedWrite::new(writer, LengthDelimitedCodec::new());
51
52 let mut serialized = tokio_serde::SymmetricallyFramed::new(
53 length_delimited,
54 SymmetricalJson::<Message>::default(),
55 );
56
57 let (serialized_sender, mut serialized_receiver) = unbounded_channel::<Message>();
58 let (deserialized_sender, deserialized_receiver) = unbounded_channel::<Message>();
59 let (close_sender, close_receiver) = oneshot::channel::<()>();
60
61 tokio::spawn(async move {
62 while let Some(message) = serialized_receiver.recv().await {
63 if let Err(err) = serialized.send(message).await {
64 println!("Error while sending message: {err:?}");
65 break;
66 }
67 }
68 });
69
70 tokio::spawn(async move {
71 while let Ok(Some(message)) = deserialized.try_next().await {
73 if let Err(err) = deserialized_sender.send(message) {
74 println!("Error while sending message: {err:?}");
75 break;
76 }
77 }
78
79 if let Err(e) = close_sender.send(()) {
80 println!("Got an error when sending to a close_sender: {e:?}");
81 }
82 });
83
84 Some((serialized_sender, deserialized_receiver, close_receiver))
85 }
86}
87
88pub async fn create_server<A>(addr: A) -> Result<Server, std::io::Error>
89where
90 A: ToSocketAddrs,
91{
92 let listener = TcpListener::bind(addr).await?;
94
95 Ok(Server { listener })
98}
99
100pub async fn create_client<A>(
101 addr: A,
102) -> Result<(UnboundedSender<Message>, UnboundedReceiver<Message>), std::io::Error>
103where
104 A: ToSocketAddrs,
105{
106 let socket = TcpStream::connect(addr).await?;
108
109 let (reader, writer) = socket.into_split();
110
111 let length_delimited = FramedWrite::new(writer, LengthDelimitedCodec::new());
113
114 let mut serialized =
116 tokio_serde::SymmetricallyFramed::new(length_delimited, SymmetricalJson::default());
117
118 let length_delimited = FramedRead::new(reader, LengthDelimitedCodec::new());
120
121 let mut deserialized =
123 tokio_serde::SymmetricallyFramed::new(length_delimited, SymmetricalJson::default());
124
125 let (serialized_sender, mut serialized_receiver) = unbounded_channel::<Message>();
126 let (deserialized_sender, deserialized_receiver) = unbounded_channel::<Message>();
127
128 tokio::spawn(async move {
129 while let Some(message) = serialized_receiver.recv().await {
130 if let Err(err) = serialized.send(message).await {
131 println!("Error while sending message: {err:?}");
132 break;
133 }
134 }
135 });
136
137 tokio::spawn(async move {
138 while let Ok(Some(message)) = deserialized.try_next().await {
140 if let Err(err) = deserialized_sender.send(message) {
141 println!("Error while sending message: {err:?}");
142 break;
143 }
144 }
145 });
146
147 Ok((serialized_sender, deserialized_receiver))
148}
149
150#[derive(Debug)]
151struct RegisterListener {
152 respond_to: oneshot::Sender<String>,
153 message_id: Uuid,
154}
155
156#[derive(Debug)]
157enum InternalMessage {
158 RegisterListener(RegisterListener),
159}
160
161#[derive(Clone)]
162pub struct Client {
163 inner: Arc<RwLock<ClientInner>>,
164 sender: UnboundedSender<Message>,
165 internal_sender: UnboundedSender<InternalMessage>,
166}
167
168struct ClientInner {
169 close_receiver: Option<oneshot::Receiver<()>>,
170}
171
172impl Client {
173 pub async fn request<
174 T: Serialize + std::fmt::Debug + DeserializeOwned + Send + 'static,
175 Y: Serialize + std::fmt::Debug + DeserializeOwned + Send + 'static,
176 >(
177 &self,
178 message: T,
179 ) -> anyhow::Result<Y> {
180 let message = Message {
181 id: Uuid::new_v4(),
182 reply_to: None,
183 message: serde_json::to_string(&message)?,
184 message_type: std::any::type_name::<T>().to_string(),
185 };
186
187 let (tx, rx) = oneshot::channel::<String>();
188 let register_listener = RegisterListener {
189 respond_to: tx,
190 message_id: message.id,
191 };
192 self.send_internal(InternalMessage::RegisterListener(register_listener))
193 .await?;
194 self.send(message).await?;
195
196 match rx.await {
198 Ok(reply) => Ok(serde_json::from_str(&reply)?),
199 Err(e) => Err(e)?,
200 }
201 }
202
203 async fn send(&self, message: Message) -> anyhow::Result<()> {
204 Ok(self.sender.send(message)?)
205 }
206
207 async fn send_internal(&self, message: InternalMessage) -> anyhow::Result<()> {
208 Ok(self.internal_sender.send(message)?)
209 }
210
211 pub fn new<T, DummyType>(
212 sender: UnboundedSender<Message>,
213 mut receiver: UnboundedReceiver<Message>,
214 mut service: T,
215 close_receiver: Option<oneshot::Receiver<()>>,
216 ) -> <T as Service<DummyType>>::Client
217 where
218 T: Service<DummyType> + Send + Sync + 'static + Clone,
219 <T as Service<DummyType>>::Request: Send,
220 <T as Service<DummyType>>::Response: Send,
221 <T as Service<DummyType>>::Client: ClientTrait + Clone + Send + Sync + 'static,
222 {
223 let (internal_sender, mut internal_receiver) = unbounded_channel();
224 let client = T::Client::new(Self {
225 inner: Arc::new(RwLock::new(ClientInner { close_receiver })),
226 sender: sender.clone(),
227 internal_sender,
228 });
229
230 let client_clone = client.clone();
231 tokio::spawn(async move {
232 let mut listeners: HashMap<Uuid, oneshot::Sender<String>> = HashMap::new();
233 loop {
234 tokio::select! {
235 message = receiver.recv() => {
236 match message {
237 Some(message) => {
238 if let Some(reply_to) = message.reply_to {
239 let reply = listeners.remove(&reply_to).unwrap();
240 if reply.send(message.message).is_err() {
241 break;
242 }
243 } else {
244 let service_clone = service.clone();
245 let sender_clone = sender.clone();
246 let client_clone = client_clone.clone();
247 tokio::spawn(async move {
248 let deserialized = serde_json::from_str::<<T as Service<DummyType>>::Request>(&message.message).unwrap();
249 let response = service_clone.handle_request(client_clone, deserialized).await;
250
251 let message = Message {
252 id: Uuid::new_v4(),
253 reply_to: Some(message.id),
254 message: serde_json::to_string(&response).unwrap(),
255 message_type: std::any::type_name::<T>().to_string(),
256 };
257 sender_clone.send(message).unwrap();
258 });
259 }
260 },
261 None => break,
262 }
263 }
264 internal_message = internal_receiver.recv() => {
265 match internal_message {
266 Some(internal_message) => {
267 match internal_message {
268 InternalMessage::RegisterListener(register_listener) => {
269 listeners.insert(register_listener.message_id, register_listener.respond_to);
270 }
271 }
272 },
273 None => break
274 }
275 }
276 }
277 }
278 });
279
280 client
281 }
282
283 pub async fn get_close_receiver(&self) -> Option<oneshot::Receiver<()>> {
284 let mut inner = self.inner.write().await;
285 inner.close_receiver.take()
286 }
287
288 pub async fn wait(&self) {
289 let mut inner = self.inner.write().await;
290 if let Some(receiver) = inner.close_receiver.take() {
291 if let Err(e) = receiver.await {
292 println!("Got an error when waiting for oneshot receiver: {e:?}");
293 }
294 }
295 }
296}
297
298#[derive(Debug, Serialize, Deserialize, Clone)]
299pub struct Message {
300 pub id: Uuid,
301 pub reply_to: Option<Uuid>,
302 pub message: String,
303 pub message_type: String,
304}
305
306pub trait ClientTrait {
307 fn new(client: Client) -> Self;
308}
309
310pub trait Service<DummyType>: Send + Sync {
372 type Response: Send + Serialize;
373 type Request: DeserializeOwned + Send;
374 type Client: ClientTrait + Clone + Send + Sync;
375
376 fn handle_request(
377 &self,
378 client: Self::Client,
379 message: Self::Request,
380 ) -> Pin<Box<dyn Future<Output = Self::Response> + Send + '_>>;
381}
382
383pub async fn process_info_handle(handle: &mut InfoHandle) -> RunInfo {
384 let mut run_info: RunInfo = Default::default();
385 run_info.done = false;
386
387 while let Ok(update) = handle.receiver.try_recv() {
388 match update {
389 InfoMessage::Stderr(buf) => run_info.stderr.push(buf),
390 InfoMessage::Stdout(buf) => run_info.stdout.push(buf),
391 InfoMessage::RequestInfo(info) => run_info.request_stats.push(info),
392 InfoMessage::IterationInfo(info) => run_info.iteration_stats.push(info),
393 InfoMessage::InstanceCheckedOut => run_info.active_instances_delta += 1,
394 InfoMessage::InstanceReserved => run_info.capacity_delta += 1,
395 InfoMessage::InstanceCheckedIn => run_info.active_instances_delta -= 1,
396 InfoMessage::TimingUpdate((elapsed, left)) => {
397 run_info.elapsed = Some(elapsed);
398 run_info.left = Some(left);
399 }
400 InfoMessage::Done => run_info.done = true,
401 }
402 }
403
404 run_info
405}
406
407pub enum InfoMessage {
409 Stderr(Vec<u8>),
410 Stdout(Vec<u8>),
411 RequestInfo(RequestInfo),
412 IterationInfo(IterationInfo),
413 InstanceCheckedOut,
418 InstanceReserved,
419 InstanceCheckedIn,
420 TimingUpdate((Duration, Duration)),
422 Done,
423}
424
425pub struct InfoHandle {
426 pub receiver: UnboundedReceiver<InfoMessage>,
427}