Skip to main content

rio_rs/
service.rs

1//! Server services
2
3use futures::future::BoxFuture;
4use futures::sink::SinkExt;
5use futures::{FutureExt, Stream, StreamExt};
6use log::error;
7use std::fmt::Debug;
8use std::panic::AssertUnwindSafe;
9use std::sync::Arc;
10use tracing::{info_span, Instrument};
11
12use tokio::net::TcpStream;
13use tokio::sync::RwLock;
14use tokio_util::codec::{Framed, LengthDelimitedCodec};
15use tower::Service as TowerService;
16
17use crate::app_data::{AppData, AppDataExt};
18use crate::cluster::storage::MembershipStorage;
19use crate::message_router::MessageRouter;
20use crate::object_placement::{ObjectPlacement, ObjectPlacementItem};
21use crate::protocol::pubsub::{SubscriptionRequest, SubscriptionResponse};
22use crate::protocol::{RequestEnvelope, ResponseEnvelope, ResponseError};
23use crate::registry::Registry;
24use crate::{LifecycleMessage, ObjectId};
25
26/// Service to respond to Requests from [crate::client::Client]
27#[derive(Clone, Debug)]
28pub struct Service<S: MembershipStorage, P: ObjectPlacement> {
29    pub(crate) address: String,
30    pub(crate) registry: Arc<RwLock<Registry>>,
31    pub(crate) members_storage: S,
32    pub(crate) object_placement_provider: Arc<RwLock<P>>,
33    pub(crate) app_data: Arc<AppData>,
34}
35
36/// Service implementation to handle [RequestEnvelope] request
37impl<S: MembershipStorage + 'static, P: ObjectPlacement + 'static> TowerService<RequestEnvelope>
38    for Service<S, P>
39{
40    type Response = ResponseEnvelope;
41    type Error = ResponseError;
42    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
43
44    fn poll_ready(
45        &mut self,
46        _cx: &mut std::task::Context<'_>,
47    ) -> std::task::Poll<Result<(), Self::Error>> {
48        std::task::Poll::Ready(Ok(()))
49    }
50
51    /// Call a service locally, or return an error that will
52    /// indicate whether this service is allocated somewhere
53    /// else
54    fn call(&mut self, req: RequestEnvelope) -> Self::Future {
55        let this = self.clone();
56        let result = async move {
57            // Test if this object is in fact allocated in this instance
58            let server_address = this
59                .get_or_create_placement(req.handler_type.clone(), req.handler_id.clone())
60                .await;
61            this.check_address_mismatch(server_address).await?;
62
63            // Ensure the object is started in the registry
64            this.start_service_object(&req.handler_type, &req.handler_id)
65                .await
66                .map_err(|err| {
67                    // Transform some internal error types into better user facing errors
68                    // while retaining other error types
69                    match err {
70                        ResponseError::Unknown(_) => ResponseError::Allocate,
71                        e => e,
72                    }
73                })?;
74
75            // Req + Response to registry
76            let guard = this.registry.read().await;
77            let fut = guard.send(
78                &req.handler_type,
79                &req.handler_id,
80                &req.message_type,
81                &req.payload,
82                this.app_data.clone(),
83            );
84            // TODO review the use of `catch_unwind` and `AssertUnwindSafe`
85            let fut = AssertUnwindSafe(fut);
86            let response = fut.catch_unwind().await;
87
88            // Handle result, 'translating' it to the protocol
89            match response {
90                Ok(Ok(body)) => Ok(ResponseEnvelope::new(body)),
91                Ok(Err(err)) => Err(ResponseError::from(err)),
92                Err(_) => {
93                    // When there is a panic, we will 'remove' the service object
94                    // from both the registry and the ObjectPlacement
95                    this.registry
96                        .read()
97                        .await
98                        .remove(req.handler_type.clone(), req.handler_id.clone())
99                        .await;
100                    this.object_placement_provider
101                        .read()
102                        .await
103                        .remove(&ObjectId(req.handler_type.clone(), req.handler_id.clone()))
104                        .await;
105                    Err(ResponseError::Unknown("Panic".to_string()))
106                }
107            }
108        };
109        Box::pin(result)
110    }
111}
112
113/// This is a iterator to be used on the server to stream
114/// messages back to the client
115#[derive(Debug)]
116pub struct SubscriptionResponseIter {
117    receiver_stream: tokio_stream::wrappers::BroadcastStream<SubscriptionResponse>,
118}
119
120impl SubscriptionResponseIter {
121    pub fn new(channel: tokio::sync::broadcast::Receiver<SubscriptionResponse>) -> Self {
122        let receiver_stream = tokio_stream::wrappers::BroadcastStream::new(channel);
123        Self { receiver_stream }
124    }
125}
126
127impl Stream for SubscriptionResponseIter {
128    type Item = SubscriptionResponse;
129    fn poll_next(
130        self: std::pin::Pin<&mut Self>,
131        _cx: &mut std::task::Context<'_>,
132    ) -> std::task::Poll<Option<Self::Item>> {
133        let this = self.get_mut();
134        this.receiver_stream.poll_next_unpin(_cx).map(|i| {
135            if let Some(result) = i {
136                if result.is_err() {
137                    error!("Error on stream recv {:?}", result);
138                }
139                // TODO error handling
140                // TODO deal with redirect
141                // TODO deal with objects being removed from the current host!
142                result.ok()
143            } else {
144                None
145            }
146        })
147    }
148}
149
150/// Service implementation to handle [SubscriptionRequest] messages
151impl<S, P> TowerService<SubscriptionRequest> for Service<S, P>
152where
153    S: MembershipStorage + 'static,
154    P: ObjectPlacement + 'static,
155{
156    type Response = SubscriptionResponseIter;
157    type Error = ResponseError;
158    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
159
160    fn poll_ready(
161        &mut self,
162        _cx: &mut std::task::Context<'_>,
163    ) -> std::task::Poll<Result<(), Self::Error>> {
164        std::task::Poll::Ready(Ok(()))
165    }
166
167    fn call(&mut self, req: SubscriptionRequest) -> Self::Future {
168        let this = self.clone();
169        let result = async move {
170            let server_address = this
171                .get_or_create_placement(req.handler_type.clone(), req.handler_id.clone())
172                .await;
173
174            this.check_address_mismatch(server_address).await?;
175            // TODO deal with redirect
176            this.start_service_object(&req.handler_type, &req.handler_id)
177                .await
178                .expect("TODO");
179            let receiver = this
180                .app_data
181                .get_or_default::<MessageRouter>()
182                .create_subscription(req.handler_type.clone(), req.handler_id.clone());
183            Ok(SubscriptionResponseIter::new(receiver))
184        };
185        Box::pin(result)
186    }
187}
188
189impl<S: MembershipStorage + 'static, P: ObjectPlacement + 'static> Service<S, P> {
190    /// Returns the ip:port for where this object is placed
191    ///
192    /// If the object is not instantiated anywhere, it will allocate locally
193    #[tracing::instrument]
194    async fn get_or_create_placement(&self, handler_type: String, handler_id: String) -> String {
195        let object_id = ObjectId(handler_type, handler_id);
196        let placement_guard = self.object_placement_provider.read().await;
197        let mut maybe_server_address = placement_guard.lookup(&object_id).await.take();
198        drop(placement_guard);
199
200        // Ensures the placement is on an active server
201        if let Some(server_address) = maybe_server_address.as_ref() {
202            let mut addr_split = server_address.splitn(2, ":");
203            let ip = addr_split.next().unwrap_or_default();
204            let port = addr_split.next().unwrap_or_default();
205
206            // This case should never happen, but writing it here to be handled gracefuly.
207            // It means the placement was stored with bad data, so we remove the record and
208            // take `maybe_server_address` so it picks up a new placement at the end of this
209            // function
210            if ip.is_empty() || port.is_empty() {
211                error!(object_id:? = object_id,
212                       address:%  = server_address,
213                       ip:%  = ip,
214                       port:%  = port;
215                       "The object's placement is in a bad state. This is likely a bug on the object placement code");
216
217                let placement_guard = self.object_placement_provider.read().await;
218                placement_guard.remove(&object_id).await;
219                maybe_server_address.take();
220            }
221            // In case the server in which this object is allocated is inactive/unavailable,
222            // we clean up the server (disassociate all the objects from it), and take
223            // `maybe_server_address` so it picks up a new placement at the end of this function
224            else if !self
225                .members_storage
226                .is_active(ip, port)
227                .await
228                .unwrap_or(false)
229            {
230                let placement_guard = self.object_placement_provider.read().await;
231                placement_guard
232                    .clean_server(server_address.to_string())
233                    .await;
234                maybe_server_address.take();
235            }
236        }
237
238        if let Some(server_address) = maybe_server_address {
239            server_address
240        } else {
241            let new_placement = ObjectPlacementItem::new(object_id, Some(self.address.clone()));
242            {
243                self.object_placement_provider
244                    .write()
245                    .await
246                    .update(new_placement)
247                    .await;
248            };
249            self.address.clone()
250        }
251    }
252
253    /// Checks if the given address is from the local server.
254    /// There are various checks that needs to run.
255    ///
256    /// It returns an Error if it is not
257    #[tracing::instrument]
258    async fn check_address_mismatch(&self, server_address: String) -> Result<(), ResponseError> {
259        if server_address == self.address {
260            return Ok(());
261        }
262
263        let mut split_address = server_address.split(':');
264        let ip = split_address.next().ok_or_else(|| {
265            ResponseError::Unknown(format!(
266                "Malformed address: Missing IP in '{}'",
267                server_address
268            ))
269        })?;
270        let port = split_address.next().ok_or_else(|| {
271            ResponseError::Unknown(format!(
272                "Malformed address: Missing PORT in '{}'",
273                server_address
274            ))
275        })?;
276
277        let is_active = self
278            .members_storage
279            .is_active(ip, port)
280            .await
281            .map_err(|e| ResponseError::Unknown(e.to_string()))?;
282
283        // This object is active somewhere else
284        if is_active {
285            return Err(ResponseError::Redirect(server_address));
286        }
287
288        // This object is not allocated here, and it is not active either
289        self.object_placement_provider
290            .read()
291            .await
292            .clean_server(server_address)
293            .await;
294        Err(ResponseError::DeallocateServiceObject)
295    }
296
297    /// Startup a service object and insert it into registry
298    ///
299    /// If is already running, ignore it
300    #[tracing::instrument]
301    async fn start_service_object(
302        &self,
303        handler_type: &str,
304        handler_id: &str,
305    ) -> Result<(), ResponseError> {
306        // Allocate holding the same read lock as the test, it ensures there is no ongoing write
307        {
308            let registry_guard = self.registry.read().await;
309            if registry_guard.has(handler_type, handler_id).await {
310                return Ok(());
311            }
312
313            let new_object = registry_guard
314                .new_from_type(handler_type, handler_id.to_string())
315                .ok_or(ResponseError::NotSupported(handler_type.to_string()))?;
316
317            registry_guard
318                .insert_boxed_object(handler_type.to_string(), handler_id.to_string(), new_object)
319                .await;
320        };
321
322        let lifecycle_result = {
323            let object_guard = self.registry.read().await;
324            let lifecycle_msg = LifecycleMessage::Load;
325            let lifecycle_ser_msg = bincode::serialize(&lifecycle_msg).expect("TODO");
326            let lifecycle_fut = object_guard.send(
327                handler_type,
328                handler_id,
329                "LifecycleMessage",
330                &lifecycle_ser_msg,
331                self.app_data.clone(),
332            );
333
334            // Catch panics on LifecycleMessage::Load
335            let lifecycle_fut = AssertUnwindSafe(lifecycle_fut);
336            lifecycle_fut.catch_unwind().await
337        };
338
339        // TODO remove duplicated logic (Self::send)
340        if let Err(e) = lifecycle_result {
341            self.registry
342                .read()
343                .await
344                .remove(handler_type.to_string(), handler_id.to_string())
345                .await;
346            self.object_placement_provider
347                .read()
348                .await
349                .remove(&ObjectId(handler_type.to_string(), handler_id.to_string()))
350                .await;
351
352            return Err(ResponseError::Unknown(format!("Task panicked: {:?}", e)));
353        }
354        Ok(())
355    }
356
357    // TODO tune LenghtDelimitedCodec
358    // TODO move this into a transport struct
359    //
360    /// Main service loop
361    ///
362    /// Consumes a stream of frames, each containing a command sent from clients.
363    ///
364    /// The commands might be either a request/response request or a subscription request
365    #[tracing::instrument]
366    pub async fn run(&mut self, stream: TcpStream) {
367        let codec = LengthDelimitedCodec::new();
368        let mut frames = Framed::new(stream, codec);
369
370        while let Some(Ok(frame)) = StreamExt::next(&mut frames)
371            .instrument(info_span!("frame_receive"))
372            .await
373        {
374            let request: Result<RequestEnvelope, _> = bincode::deserialize(&frame);
375            let subscription: Result<SubscriptionRequest, _> = bincode::deserialize(&frame);
376
377            let either_request = match (request, subscription) {
378                (Ok(message), _) => AllRequest::ReqResp(message),
379                (_, Ok(message)) => AllRequest::PubSub(message),
380                _ => {
381                    unreachable!("Got both or neither requests")
382                }
383            };
384            match either_request {
385                AllRequest::ReqResp(message) => {
386                    let response = match self.call(message).await {
387                        Ok(x) => x,
388                        Err(err) => ResponseEnvelope::err(err),
389                    };
390                    let ser_result = bincode::serialize(&response);
391                    let ser_response = match ser_result {
392                        Ok(value) => value,
393                        Err(err) => {
394                            let new_return = ResponseEnvelope::err(
395                                ResponseError::SeralizationError(err.to_string()),
396                            );
397                            bincode::serialize(&new_return)
398                                .expect("Serialization of response error should be infalible")
399                        }
400                    };
401                    frames
402                        .send(ser_response.into())
403                        .instrument(info_span!("response_send"))
404                        .await
405                        .unwrap();
406                }
407                AllRequest::PubSub(message) => {
408                    let stream = self.call(message).await;
409
410                    // If there is an upstream error to establish the subscription,
411                    // wrapi it in a SubscriptionResponse and return earlier
412                    let mut stream = match stream {
413                        Ok(value) => value,
414                        Err(err) => {
415                            let sub_response = SubscriptionResponse::err(err);
416                            let ser_response = bincode::serialize(&sub_response)
417                                .expect("Error serialization should be infalible");
418                            frames
419                                .send(ser_response.into())
420                                .instrument(info_span!("response_send"))
421                                .await
422                                .ok();
423                            return;
424                        }
425                    };
426
427                    while let Some(value) = StreamExt::next(&mut stream).await {
428                        let ser_result = bincode::serialize(&value);
429                        let ser_response = match ser_result {
430                            Ok(value) => value,
431                            Err(err) => {
432                                let new_return = SubscriptionResponse::err(
433                                    ResponseError::SeralizationError(err.to_string()),
434                                );
435                                bincode::serialize(&new_return)
436                                    .expect("Serialization of response error should be infalible")
437                            }
438                        };
439
440                        let send_result = frames
441                            .send(ser_response.into())
442                            .instrument(info_span!("response_send"))
443                            .await;
444
445                        // Stop receiving messages if the sink we redirect messages to is
446                        // closed
447                        if let Err(err) = send_result {
448                            error!("Channel is closed due {}", err);
449                            break;
450                        }
451                    }
452                }
453            }
454        }
455    }
456}
457
458#[derive(Debug)]
459enum AllRequest {
460    ReqResp(RequestEnvelope),
461    PubSub(SubscriptionRequest),
462}
463
464#[cfg(test)]
465mod test {
466    use std::time::Duration;
467
468    use async_trait::async_trait;
469    use rio_macros::{Message, TypeName, WithId};
470    use serde::{Deserialize, Serialize};
471    use tokio::time::timeout;
472    use tower::ServiceExt;
473
474    use super::*;
475    use crate::cluster::storage::local::LocalStorage;
476    use crate::object_placement::local::LocalObjectPlacement;
477
478    use crate::registry::Handler;
479
480    #[derive(Default, WithId, TypeName)]
481    #[rio_path = "crate"]
482    struct MockService {
483        id: String,
484    }
485
486    #[derive(Default, Debug, Message, TypeName, Serialize, Deserialize)]
487    #[rio_path = "crate"]
488    struct MockMessage {
489        text: String,
490    }
491
492    #[derive(Default, Debug, Message, TypeName, Serialize, Deserialize)]
493    #[rio_path = "crate"]
494    struct MockResponse {
495        text: String,
496    }
497
498    #[async_trait]
499    impl Handler<MockMessage> for MockService {
500        type Returns = MockResponse;
501        type Error = ();
502        async fn handle(
503            &mut self,
504            message: MockMessage,
505            _: Arc<AppData>,
506        ) -> Result<Self::Returns, Self::Error> {
507            let resp = MockResponse {
508                text: format!("{} received {}", self.id, message.text),
509            };
510            Ok(resp)
511        }
512    }
513
514    fn svc() -> Service<LocalStorage, LocalObjectPlacement> {
515        let mut registry = Registry::new();
516        registry.add_type::<MockService>();
517        registry.add_handler::<MockService, MockMessage>();
518
519        Service {
520            address: "0.0.0.0:5000".to_string(),
521            registry: Arc::new(RwLock::new(registry)),
522            members_storage: LocalStorage::default(),
523            object_placement_provider: Arc::new(RwLock::new(LocalObjectPlacement::default())),
524            app_data: Arc::new(AppData::new()),
525        }
526    }
527
528    #[tokio::test]
529    async fn test_poll_ready() {
530        let mut svc = svc();
531        ServiceExt::<RequestEnvelope>::ready(&mut svc)
532            .await
533            .expect("service ready");
534    }
535
536    #[tokio::test]
537    async fn test_service_call() {
538        let mut svc = svc();
539        ServiceExt::<RequestEnvelope>::ready(&mut svc)
540            .await
541            .unwrap();
542
543        let req = RequestEnvelope::new(
544            "MockService".into(),
545            "*".into(),
546            "MockMessage".into(),
547            bincode::serialize(&MockMessage { text: "hi".into() }).unwrap(),
548        );
549        let resp = svc.call(req).await.unwrap();
550        let resp: MockResponse = bincode::deserialize(&resp.body.unwrap()).unwrap();
551        assert_eq!(resp.text, "* received hi".to_string());
552    }
553
554    #[tokio::test]
555    async fn test_service_subscription() {
556        let mut svc = svc();
557        ServiceExt::<SubscriptionRequest>::ready(&mut svc)
558            .await
559            .unwrap();
560
561        let req = SubscriptionRequest {
562            handler_type: "MockService".into(),
563            handler_id: "*".into(),
564        };
565        let call_future = svc.call(req);
566        let call_future = timeout(Duration::from_secs(3), call_future);
567        let _stream = call_future.await.unwrap();
568        // TODO assert_eq!(..., stream.next().await);
569    }
570}