1use futures::future::BoxFuture;
4use futures::sink::SinkExt;
5use futures::{FutureExt, Stream, StreamExt};
6use log::error;
7use std::fmt::Debug;
8use std::panic::AssertUnwindSafe;
9use std::sync::Arc;
10use tracing::{info_span, Instrument};
11
12use tokio::net::TcpStream;
13use tokio::sync::RwLock;
14use tokio_util::codec::{Framed, LengthDelimitedCodec};
15use tower::Service as TowerService;
16
17use crate::app_data::{AppData, AppDataExt};
18use crate::cluster::storage::MembershipStorage;
19use crate::message_router::MessageRouter;
20use crate::object_placement::{ObjectPlacement, ObjectPlacementItem};
21use crate::protocol::pubsub::{SubscriptionRequest, SubscriptionResponse};
22use crate::protocol::{RequestEnvelope, ResponseEnvelope, ResponseError};
23use crate::registry::Registry;
24use crate::{LifecycleMessage, ObjectId};
25
26#[derive(Clone, Debug)]
28pub struct Service<S: MembershipStorage, P: ObjectPlacement> {
29 pub(crate) address: String,
30 pub(crate) registry: Arc<RwLock<Registry>>,
31 pub(crate) members_storage: S,
32 pub(crate) object_placement_provider: Arc<RwLock<P>>,
33 pub(crate) app_data: Arc<AppData>,
34}
35
36impl<S: MembershipStorage + 'static, P: ObjectPlacement + 'static> TowerService<RequestEnvelope>
38 for Service<S, P>
39{
40 type Response = ResponseEnvelope;
41 type Error = ResponseError;
42 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
43
44 fn poll_ready(
45 &mut self,
46 _cx: &mut std::task::Context<'_>,
47 ) -> std::task::Poll<Result<(), Self::Error>> {
48 std::task::Poll::Ready(Ok(()))
49 }
50
51 fn call(&mut self, req: RequestEnvelope) -> Self::Future {
55 let this = self.clone();
56 let result = async move {
57 let server_address = this
59 .get_or_create_placement(req.handler_type.clone(), req.handler_id.clone())
60 .await;
61 this.check_address_mismatch(server_address).await?;
62
63 this.start_service_object(&req.handler_type, &req.handler_id)
65 .await
66 .map_err(|err| {
67 match err {
70 ResponseError::Unknown(_) => ResponseError::Allocate,
71 e => e,
72 }
73 })?;
74
75 let guard = this.registry.read().await;
77 let fut = guard.send(
78 &req.handler_type,
79 &req.handler_id,
80 &req.message_type,
81 &req.payload,
82 this.app_data.clone(),
83 );
84 let fut = AssertUnwindSafe(fut);
86 let response = fut.catch_unwind().await;
87
88 match response {
90 Ok(Ok(body)) => Ok(ResponseEnvelope::new(body)),
91 Ok(Err(err)) => Err(ResponseError::from(err)),
92 Err(_) => {
93 this.registry
96 .read()
97 .await
98 .remove(req.handler_type.clone(), req.handler_id.clone())
99 .await;
100 this.object_placement_provider
101 .read()
102 .await
103 .remove(&ObjectId(req.handler_type.clone(), req.handler_id.clone()))
104 .await;
105 Err(ResponseError::Unknown("Panic".to_string()))
106 }
107 }
108 };
109 Box::pin(result)
110 }
111}
112
113#[derive(Debug)]
116pub struct SubscriptionResponseIter {
117 receiver_stream: tokio_stream::wrappers::BroadcastStream<SubscriptionResponse>,
118}
119
120impl SubscriptionResponseIter {
121 pub fn new(channel: tokio::sync::broadcast::Receiver<SubscriptionResponse>) -> Self {
122 let receiver_stream = tokio_stream::wrappers::BroadcastStream::new(channel);
123 Self { receiver_stream }
124 }
125}
126
127impl Stream for SubscriptionResponseIter {
128 type Item = SubscriptionResponse;
129 fn poll_next(
130 self: std::pin::Pin<&mut Self>,
131 _cx: &mut std::task::Context<'_>,
132 ) -> std::task::Poll<Option<Self::Item>> {
133 let this = self.get_mut();
134 this.receiver_stream.poll_next_unpin(_cx).map(|i| {
135 if let Some(result) = i {
136 if result.is_err() {
137 error!("Error on stream recv {:?}", result);
138 }
139 result.ok()
143 } else {
144 None
145 }
146 })
147 }
148}
149
150impl<S, P> TowerService<SubscriptionRequest> for Service<S, P>
152where
153 S: MembershipStorage + 'static,
154 P: ObjectPlacement + 'static,
155{
156 type Response = SubscriptionResponseIter;
157 type Error = ResponseError;
158 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
159
160 fn poll_ready(
161 &mut self,
162 _cx: &mut std::task::Context<'_>,
163 ) -> std::task::Poll<Result<(), Self::Error>> {
164 std::task::Poll::Ready(Ok(()))
165 }
166
167 fn call(&mut self, req: SubscriptionRequest) -> Self::Future {
168 let this = self.clone();
169 let result = async move {
170 let server_address = this
171 .get_or_create_placement(req.handler_type.clone(), req.handler_id.clone())
172 .await;
173
174 this.check_address_mismatch(server_address).await?;
175 this.start_service_object(&req.handler_type, &req.handler_id)
177 .await
178 .expect("TODO");
179 let receiver = this
180 .app_data
181 .get_or_default::<MessageRouter>()
182 .create_subscription(req.handler_type.clone(), req.handler_id.clone());
183 Ok(SubscriptionResponseIter::new(receiver))
184 };
185 Box::pin(result)
186 }
187}
188
189impl<S: MembershipStorage + 'static, P: ObjectPlacement + 'static> Service<S, P> {
190 #[tracing::instrument]
194 async fn get_or_create_placement(&self, handler_type: String, handler_id: String) -> String {
195 let object_id = ObjectId(handler_type, handler_id);
196 let placement_guard = self.object_placement_provider.read().await;
197 let mut maybe_server_address = placement_guard.lookup(&object_id).await.take();
198 drop(placement_guard);
199
200 if let Some(server_address) = maybe_server_address.as_ref() {
202 let mut addr_split = server_address.splitn(2, ":");
203 let ip = addr_split.next().unwrap_or_default();
204 let port = addr_split.next().unwrap_or_default();
205
206 if ip.is_empty() || port.is_empty() {
211 error!(object_id:? = object_id,
212 address:% = server_address,
213 ip:% = ip,
214 port:% = port;
215 "The object's placement is in a bad state. This is likely a bug on the object placement code");
216
217 let placement_guard = self.object_placement_provider.read().await;
218 placement_guard.remove(&object_id).await;
219 maybe_server_address.take();
220 }
221 else if !self
225 .members_storage
226 .is_active(ip, port)
227 .await
228 .unwrap_or(false)
229 {
230 let placement_guard = self.object_placement_provider.read().await;
231 placement_guard
232 .clean_server(server_address.to_string())
233 .await;
234 maybe_server_address.take();
235 }
236 }
237
238 if let Some(server_address) = maybe_server_address {
239 server_address
240 } else {
241 let new_placement = ObjectPlacementItem::new(object_id, Some(self.address.clone()));
242 {
243 self.object_placement_provider
244 .write()
245 .await
246 .update(new_placement)
247 .await;
248 };
249 self.address.clone()
250 }
251 }
252
253 #[tracing::instrument]
258 async fn check_address_mismatch(&self, server_address: String) -> Result<(), ResponseError> {
259 if server_address == self.address {
260 return Ok(());
261 }
262
263 let mut split_address = server_address.split(':');
264 let ip = split_address.next().ok_or_else(|| {
265 ResponseError::Unknown(format!(
266 "Malformed address: Missing IP in '{}'",
267 server_address
268 ))
269 })?;
270 let port = split_address.next().ok_or_else(|| {
271 ResponseError::Unknown(format!(
272 "Malformed address: Missing PORT in '{}'",
273 server_address
274 ))
275 })?;
276
277 let is_active = self
278 .members_storage
279 .is_active(ip, port)
280 .await
281 .map_err(|e| ResponseError::Unknown(e.to_string()))?;
282
283 if is_active {
285 return Err(ResponseError::Redirect(server_address));
286 }
287
288 self.object_placement_provider
290 .read()
291 .await
292 .clean_server(server_address)
293 .await;
294 Err(ResponseError::DeallocateServiceObject)
295 }
296
297 #[tracing::instrument]
301 async fn start_service_object(
302 &self,
303 handler_type: &str,
304 handler_id: &str,
305 ) -> Result<(), ResponseError> {
306 {
308 let registry_guard = self.registry.read().await;
309 if registry_guard.has(handler_type, handler_id).await {
310 return Ok(());
311 }
312
313 let new_object = registry_guard
314 .new_from_type(handler_type, handler_id.to_string())
315 .ok_or(ResponseError::NotSupported(handler_type.to_string()))?;
316
317 registry_guard
318 .insert_boxed_object(handler_type.to_string(), handler_id.to_string(), new_object)
319 .await;
320 };
321
322 let lifecycle_result = {
323 let object_guard = self.registry.read().await;
324 let lifecycle_msg = LifecycleMessage::Load;
325 let lifecycle_ser_msg = bincode::serialize(&lifecycle_msg).expect("TODO");
326 let lifecycle_fut = object_guard.send(
327 handler_type,
328 handler_id,
329 "LifecycleMessage",
330 &lifecycle_ser_msg,
331 self.app_data.clone(),
332 );
333
334 let lifecycle_fut = AssertUnwindSafe(lifecycle_fut);
336 lifecycle_fut.catch_unwind().await
337 };
338
339 if let Err(e) = lifecycle_result {
341 self.registry
342 .read()
343 .await
344 .remove(handler_type.to_string(), handler_id.to_string())
345 .await;
346 self.object_placement_provider
347 .read()
348 .await
349 .remove(&ObjectId(handler_type.to_string(), handler_id.to_string()))
350 .await;
351
352 return Err(ResponseError::Unknown(format!("Task panicked: {:?}", e)));
353 }
354 Ok(())
355 }
356
357 #[tracing::instrument]
366 pub async fn run(&mut self, stream: TcpStream) {
367 let codec = LengthDelimitedCodec::new();
368 let mut frames = Framed::new(stream, codec);
369
370 while let Some(Ok(frame)) = StreamExt::next(&mut frames)
371 .instrument(info_span!("frame_receive"))
372 .await
373 {
374 let request: Result<RequestEnvelope, _> = bincode::deserialize(&frame);
375 let subscription: Result<SubscriptionRequest, _> = bincode::deserialize(&frame);
376
377 let either_request = match (request, subscription) {
378 (Ok(message), _) => AllRequest::ReqResp(message),
379 (_, Ok(message)) => AllRequest::PubSub(message),
380 _ => {
381 unreachable!("Got both or neither requests")
382 }
383 };
384 match either_request {
385 AllRequest::ReqResp(message) => {
386 let response = match self.call(message).await {
387 Ok(x) => x,
388 Err(err) => ResponseEnvelope::err(err),
389 };
390 let ser_result = bincode::serialize(&response);
391 let ser_response = match ser_result {
392 Ok(value) => value,
393 Err(err) => {
394 let new_return = ResponseEnvelope::err(
395 ResponseError::SeralizationError(err.to_string()),
396 );
397 bincode::serialize(&new_return)
398 .expect("Serialization of response error should be infalible")
399 }
400 };
401 frames
402 .send(ser_response.into())
403 .instrument(info_span!("response_send"))
404 .await
405 .unwrap();
406 }
407 AllRequest::PubSub(message) => {
408 let stream = self.call(message).await;
409
410 let mut stream = match stream {
413 Ok(value) => value,
414 Err(err) => {
415 let sub_response = SubscriptionResponse::err(err);
416 let ser_response = bincode::serialize(&sub_response)
417 .expect("Error serialization should be infalible");
418 frames
419 .send(ser_response.into())
420 .instrument(info_span!("response_send"))
421 .await
422 .ok();
423 return;
424 }
425 };
426
427 while let Some(value) = StreamExt::next(&mut stream).await {
428 let ser_result = bincode::serialize(&value);
429 let ser_response = match ser_result {
430 Ok(value) => value,
431 Err(err) => {
432 let new_return = SubscriptionResponse::err(
433 ResponseError::SeralizationError(err.to_string()),
434 );
435 bincode::serialize(&new_return)
436 .expect("Serialization of response error should be infalible")
437 }
438 };
439
440 let send_result = frames
441 .send(ser_response.into())
442 .instrument(info_span!("response_send"))
443 .await;
444
445 if let Err(err) = send_result {
448 error!("Channel is closed due {}", err);
449 break;
450 }
451 }
452 }
453 }
454 }
455 }
456}
457
458#[derive(Debug)]
459enum AllRequest {
460 ReqResp(RequestEnvelope),
461 PubSub(SubscriptionRequest),
462}
463
464#[cfg(test)]
465mod test {
466 use std::time::Duration;
467
468 use async_trait::async_trait;
469 use rio_macros::{Message, TypeName, WithId};
470 use serde::{Deserialize, Serialize};
471 use tokio::time::timeout;
472 use tower::ServiceExt;
473
474 use super::*;
475 use crate::cluster::storage::local::LocalStorage;
476 use crate::object_placement::local::LocalObjectPlacement;
477
478 use crate::registry::Handler;
479
480 #[derive(Default, WithId, TypeName)]
481 #[rio_path = "crate"]
482 struct MockService {
483 id: String,
484 }
485
486 #[derive(Default, Debug, Message, TypeName, Serialize, Deserialize)]
487 #[rio_path = "crate"]
488 struct MockMessage {
489 text: String,
490 }
491
492 #[derive(Default, Debug, Message, TypeName, Serialize, Deserialize)]
493 #[rio_path = "crate"]
494 struct MockResponse {
495 text: String,
496 }
497
498 #[async_trait]
499 impl Handler<MockMessage> for MockService {
500 type Returns = MockResponse;
501 type Error = ();
502 async fn handle(
503 &mut self,
504 message: MockMessage,
505 _: Arc<AppData>,
506 ) -> Result<Self::Returns, Self::Error> {
507 let resp = MockResponse {
508 text: format!("{} received {}", self.id, message.text),
509 };
510 Ok(resp)
511 }
512 }
513
514 fn svc() -> Service<LocalStorage, LocalObjectPlacement> {
515 let mut registry = Registry::new();
516 registry.add_type::<MockService>();
517 registry.add_handler::<MockService, MockMessage>();
518
519 Service {
520 address: "0.0.0.0:5000".to_string(),
521 registry: Arc::new(RwLock::new(registry)),
522 members_storage: LocalStorage::default(),
523 object_placement_provider: Arc::new(RwLock::new(LocalObjectPlacement::default())),
524 app_data: Arc::new(AppData::new()),
525 }
526 }
527
528 #[tokio::test]
529 async fn test_poll_ready() {
530 let mut svc = svc();
531 ServiceExt::<RequestEnvelope>::ready(&mut svc)
532 .await
533 .expect("service ready");
534 }
535
536 #[tokio::test]
537 async fn test_service_call() {
538 let mut svc = svc();
539 ServiceExt::<RequestEnvelope>::ready(&mut svc)
540 .await
541 .unwrap();
542
543 let req = RequestEnvelope::new(
544 "MockService".into(),
545 "*".into(),
546 "MockMessage".into(),
547 bincode::serialize(&MockMessage { text: "hi".into() }).unwrap(),
548 );
549 let resp = svc.call(req).await.unwrap();
550 let resp: MockResponse = bincode::deserialize(&resp.body.unwrap()).unwrap();
551 assert_eq!(resp.text, "* received hi".to_string());
552 }
553
554 #[tokio::test]
555 async fn test_service_subscription() {
556 let mut svc = svc();
557 ServiceExt::<SubscriptionRequest>::ready(&mut svc)
558 .await
559 .unwrap();
560
561 let req = SubscriptionRequest {
562 handler_type: "MockService".into(),
563 handler_id: "*".into(),
564 };
565 let call_future = svc.call(req);
566 let call_future = timeout(Duration::from_secs(3), call_future);
567 let _stream = call_future.await.unwrap();
568 }
570}