crows_utils/
lib.rs

1use std::collections::HashMap;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::time::Duration;
5
6use serde::de::DeserializeOwned;
7use serde::{Deserialize, Serialize};
8
9use futures::prelude::*;
10use futures::TryStreamExt;
11use services::{RunInfo, RequestInfo, IterationInfo};
12use std::future::Future;
13use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
14use tokio::sync::RwLock;
15use tokio::sync::{
16    mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
17    oneshot,
18};
19use tokio_serde::formats::SymmetricalJson;
20use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
21use uuid::Uuid;
22
23pub use serde;
24pub use tokio;
25pub mod services;
26pub use crows_service;
27
28pub struct Server {
29    listener: TcpListener,
30}
31
32impl Server {
33    pub async fn accept(
34        &self,
35    ) -> Option<(
36        UnboundedSender<Message>,
37        UnboundedReceiver<Message>,
38        oneshot::Receiver<()>,
39    )> {
40        let (socket, _) = self.listener.accept().await.ok()?;
41        let (reader, writer) = socket.into_split();
42
43        // Delimit frames using a length header
44        let length_delimited = FramedRead::new(reader, LengthDelimitedCodec::new());
45
46        // Deserialize frames
47        let mut deserialized =
48            tokio_serde::SymmetricallyFramed::new(length_delimited, SymmetricalJson::default());
49
50        let length_delimited = FramedWrite::new(writer, LengthDelimitedCodec::new());
51
52        let mut serialized = tokio_serde::SymmetricallyFramed::new(
53            length_delimited,
54            SymmetricalJson::<Message>::default(),
55        );
56
57        let (serialized_sender, mut serialized_receiver) = unbounded_channel::<Message>();
58        let (deserialized_sender, deserialized_receiver) = unbounded_channel::<Message>();
59        let (close_sender, close_receiver) = oneshot::channel::<()>();
60
61        tokio::spawn(async move {
62            while let Some(message) = serialized_receiver.recv().await {
63                if let Err(err) = serialized.send(message).await {
64                    println!("Error while sending message: {err:?}");
65                    break;
66                }
67            }
68        });
69
70        tokio::spawn(async move {
71            // TODO: handle Err
72            while let Ok(Some(message)) = deserialized.try_next().await {
73                if let Err(err) = deserialized_sender.send(message) {
74                    println!("Error while sending message: {err:?}");
75                    break;
76                }
77            }
78
79            if let Err(e) = close_sender.send(()) {
80                println!("Got an error when sending to a close_sender: {e:?}");
81            }
82        });
83
84        Some((serialized_sender, deserialized_receiver, close_receiver))
85    }
86}
87
88pub async fn create_server<A>(addr: A) -> Result<Server, std::io::Error>
89where
90    A: ToSocketAddrs,
91{
92    // Bind a server socket
93    let listener = TcpListener::bind(addr).await?;
94
95    // println!("listening on {:?}", listener.local_addr());
96
97    Ok(Server { listener })
98}
99
100pub async fn create_client<A>(
101    addr: A,
102) -> Result<(UnboundedSender<Message>, UnboundedReceiver<Message>), std::io::Error>
103where
104    A: ToSocketAddrs,
105{
106    // Bind a server socket
107    let socket = TcpStream::connect(addr).await?;
108
109    let (reader, writer) = socket.into_split();
110
111    // Delimit frames using a length header
112    let length_delimited = FramedWrite::new(writer, LengthDelimitedCodec::new());
113
114    // Serialize frames with JSON
115    let mut serialized =
116        tokio_serde::SymmetricallyFramed::new(length_delimited, SymmetricalJson::default());
117
118    // Delimit frames using a length header
119    let length_delimited = FramedRead::new(reader, LengthDelimitedCodec::new());
120
121    // Deserialize frames
122    let mut deserialized =
123        tokio_serde::SymmetricallyFramed::new(length_delimited, SymmetricalJson::default());
124
125    let (serialized_sender, mut serialized_receiver) = unbounded_channel::<Message>();
126    let (deserialized_sender, deserialized_receiver) = unbounded_channel::<Message>();
127
128    tokio::spawn(async move {
129        while let Some(message) = serialized_receiver.recv().await {
130            if let Err(err) = serialized.send(message).await {
131                println!("Error while sending message: {err:?}");
132                break;
133            }
134        }
135    });
136
137    tokio::spawn(async move {
138        // TODO: handle Err
139        while let Ok(Some(message)) = deserialized.try_next().await {
140            if let Err(err) = deserialized_sender.send(message) {
141                println!("Error while sending message: {err:?}");
142                break;
143            }
144        }
145    });
146
147    Ok((serialized_sender, deserialized_receiver))
148}
149
150#[derive(Debug)]
151struct RegisterListener {
152    respond_to: oneshot::Sender<String>,
153    message_id: Uuid,
154}
155
156#[derive(Debug)]
157enum InternalMessage {
158    RegisterListener(RegisterListener),
159}
160
161#[derive(Clone)]
162pub struct Client {
163    inner: Arc<RwLock<ClientInner>>,
164    sender: UnboundedSender<Message>,
165    internal_sender: UnboundedSender<InternalMessage>,
166}
167
168struct ClientInner {
169    close_receiver: Option<oneshot::Receiver<()>>,
170}
171
172impl Client {
173    pub async fn request<
174        T: Serialize + std::fmt::Debug + DeserializeOwned + Send + 'static,
175        Y: Serialize + std::fmt::Debug + DeserializeOwned + Send + 'static,
176    >(
177        &self,
178        message: T,
179    ) -> anyhow::Result<Y> {
180        let message = Message {
181            id: Uuid::new_v4(),
182            reply_to: None,
183            message: serde_json::to_string(&message)?,
184            message_type: std::any::type_name::<T>().to_string(),
185        };
186
187        let (tx, rx) = oneshot::channel::<String>();
188        let register_listener = RegisterListener {
189            respond_to: tx,
190            message_id: message.id,
191        };
192        self.send_internal(InternalMessage::RegisterListener(register_listener))
193            .await?;
194        self.send(message).await?;
195
196        // TODO: rewrite to map
197        match rx.await {
198            Ok(reply) => Ok(serde_json::from_str(&reply)?),
199            Err(e) => Err(e)?,
200        }
201    }
202
203    async fn send(&self, message: Message) -> anyhow::Result<()> {
204        Ok(self.sender.send(message)?)
205    }
206
207    async fn send_internal(&self, message: InternalMessage) -> anyhow::Result<()> {
208        Ok(self.internal_sender.send(message)?)
209    }
210
211    pub fn new<T, DummyType>(
212        sender: UnboundedSender<Message>,
213        mut receiver: UnboundedReceiver<Message>,
214        mut service: T,
215        close_receiver: Option<oneshot::Receiver<()>>,
216    ) -> <T as Service<DummyType>>::Client
217    where
218        T: Service<DummyType> + Send + Sync + 'static + Clone,
219        <T as Service<DummyType>>::Request: Send,
220        <T as Service<DummyType>>::Response: Send,
221        <T as Service<DummyType>>::Client: ClientTrait + Clone + Send + Sync + 'static,
222    {
223        let (internal_sender, mut internal_receiver) = unbounded_channel();
224        let client = T::Client::new(Self {
225            inner: Arc::new(RwLock::new(ClientInner { close_receiver })),
226            sender: sender.clone(),
227            internal_sender,
228        });
229
230        let client_clone = client.clone();
231        tokio::spawn(async move {
232            let mut listeners: HashMap<Uuid, oneshot::Sender<String>> = HashMap::new();
233            loop {
234                tokio::select! {
235                    message = receiver.recv() => {
236                        match message {
237                            Some(message) => {
238                                if let Some(reply_to) = message.reply_to {
239                                    let reply = listeners.remove(&reply_to).unwrap();
240                                    if reply.send(message.message).is_err() {
241                                        break;
242                                    }
243                                } else {
244                                    let service_clone = service.clone();
245                                    let sender_clone = sender.clone();
246                                    let client_clone = client_clone.clone();
247                                    tokio::spawn(async move {
248                                        let deserialized = serde_json::from_str::<<T as Service<DummyType>>::Request>(&message.message).unwrap();
249                                        let response = service_clone.handle_request(client_clone, deserialized).await;
250
251                                        let message = Message {
252                                            id: Uuid::new_v4(),
253                                            reply_to: Some(message.id),
254                                            message: serde_json::to_string(&response).unwrap(),
255                                            message_type: std::any::type_name::<T>().to_string(),
256                                        };
257                                        sender_clone.send(message).unwrap();
258                                    });
259                                }
260                            },
261                            None => break,
262                        }
263                    }
264                    internal_message = internal_receiver.recv() => {
265                        match internal_message {
266                            Some(internal_message) => {
267                                match internal_message {
268                                    InternalMessage::RegisterListener(register_listener) => {
269                                        listeners.insert(register_listener.message_id, register_listener.respond_to);
270                                    }
271                                }
272                            },
273                            None => break
274                        }
275                    }
276                }
277            }
278        });
279
280        client
281    }
282
283    pub async fn get_close_receiver(&self) -> Option<oneshot::Receiver<()>> {
284        let mut inner = self.inner.write().await;
285        inner.close_receiver.take()
286    }
287
288    pub async fn wait(&self) {
289        let mut inner = self.inner.write().await;
290        if let Some(receiver) = inner.close_receiver.take() {
291            if let Err(e) = receiver.await {
292                println!("Got an error when waiting for oneshot receiver: {e:?}");
293            }
294        }
295    }
296}
297
298#[derive(Debug, Serialize, Deserialize, Clone)]
299pub struct Message {
300    pub id: Uuid,
301    pub reply_to: Option<Uuid>,
302    pub message: String,
303    pub message_type: String,
304}
305
306pub trait ClientTrait {
307    fn new(client: Client) -> Self;
308}
309
310/// The DummyType here is needed, because in the `service` macro we implement
311/// service on a generic type. This, in turn, is needed because I wanted to
312/// allow for a service definition after specifying the impl.
313/// For example we define a Worker RPC service in a shared crate/file. In there
314/// we want to only define the interface, but in order for the service to work properly
315/// the Service trait has to be also implemented. It's best to do it in the macro
316/// itself, cause it requires a lot of boilerplate, but when the macro runs, we don't
317/// have the actual service defined yet.
318///
319/// So if we could define all of it in one file it would be something like:
320///
321///     trait Worker {
322///         async fn ping(&self) -> String;
323///     }
324///
325///     struct WorkerService {}
326///
327///     impl Worker for WorkerService {
328///         async fn ping(&self) -> String { todo!() }
329///     }
330///
331///     impl Service for WorkrService {
332///         type Request = WorkerRequest;
333///         type Response = WorkerResponse;
334///
335///         fn handle_request(...) { .... }
336///     }
337///
338/// The problem is, we don't want to require implementation of the service to live
339/// in the same place where the definition lives. That's why it's better to only
340/// implement Service for a generic type and thus allow for it to be applied
341/// only when the type is actually created, for example:
342///
343///     impl<T> Service for T
344///     where T: Worker + Send + Sync { }
345///
346/// The issue here is that this results in a "conflicting implementation" error if
347/// there is more than one `impl` of this type present. The reason is future proofing.
348/// For example consider the previous impl and another one for another service
349///
350///     impl<T> Service for T
351///     where T: Coordinator + Send + Sync { }
352///
353/// While we know that we don't want to implement both `Coordinator` and `Worker`
354/// traits on the same type, Rust doesn't. The solution is to add a "dummy type"
355/// to the service implementation and thus narrow down the impl to a specific generic
356/// type, for example:
357///
358///     struct DummyWorkerService {}
359///
360///     impl<T> Service<DummyWorkerService> for T
361///     where T: Worker + Send + Sync { }
362///
363/// Now the impl is only considered for a specific Service type and the only
364/// additional requirement is that now we have to include the dummy type when
365/// specifycing the service, for example if we accept the Worker service as an
366/// argument we say:
367///
368///     fn foo<T>(service: T)
369///         where T: Service<DummyWorkerService> { }
370///
371pub trait Service<DummyType>: Send + Sync {
372    type Response: Send + Serialize;
373    type Request: DeserializeOwned + Send;
374    type Client: ClientTrait + Clone + Send + Sync;
375
376    fn handle_request(
377        &self,
378        client: Self::Client,
379        message: Self::Request,
380    ) -> Pin<Box<dyn Future<Output = Self::Response> + Send + '_>>;
381}
382
383pub async fn process_info_handle(handle: &mut InfoHandle) -> RunInfo {
384    let mut run_info: RunInfo = Default::default();
385    run_info.done = false;
386
387    while let Ok(update) = handle.receiver.try_recv() {
388        match update {
389            InfoMessage::Stderr(buf) => run_info.stderr.push(buf),
390            InfoMessage::Stdout(buf) => run_info.stdout.push(buf),
391            InfoMessage::RequestInfo(info) => run_info.request_stats.push(info),
392            InfoMessage::IterationInfo(info) => run_info.iteration_stats.push(info),
393            InfoMessage::InstanceCheckedOut => run_info.active_instances_delta += 1,
394            InfoMessage::InstanceReserved => run_info.capacity_delta += 1,
395            InfoMessage::InstanceCheckedIn => run_info.active_instances_delta -= 1,
396            InfoMessage::TimingUpdate((elapsed, left)) => {
397                run_info.elapsed = Some(elapsed);
398                run_info.left = Some(left);
399            }
400            InfoMessage::Done => run_info.done = true,
401        }
402    }
403
404    run_info
405}
406
407// TODO: I don't like that name, I think it should be changed
408pub enum InfoMessage {
409    Stderr(Vec<u8>),
410    Stdout(Vec<u8>),
411    RequestInfo(RequestInfo),
412    IterationInfo(IterationInfo),
413    // TODO: I'm not sure if shoving any kind of update here is a good idea,
414    // but at the moment it's the easiest way to pass data back to the client,
415    // so I'm going with it. I'd like to revisit it in the future, though and
416    // consider alternatives
417    InstanceCheckedOut,
418    InstanceReserved,
419    InstanceCheckedIn,
420    // elapsed, left
421    TimingUpdate((Duration, Duration)),
422    Done,
423}
424
425pub struct InfoHandle {
426    pub receiver: UnboundedReceiver<InfoMessage>,
427}