Skip to main content

shuttle_rama/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use rama::{
4    Service,
5    error::OpaqueError,
6    http::{Request, server::HttpServer, service::web::response::IntoResponse},
7    tcp::server::TcpListener,
8};
9use shuttle_runtime::{CustomError, Error, tokio};
10use std::{convert::Infallible, fmt, net::SocketAddr};
11
12/// A wrapper type for [`Service`] so we can implement [`shuttle_runtime::Service`] for it.
13pub struct RamaService<T, State> {
14    svc: T,
15    state: State,
16}
17
18impl<T: Clone, State: Clone> Clone for RamaService<T, State> {
19    fn clone(&self) -> Self {
20        Self {
21            svc: self.svc.clone(),
22            state: self.state.clone(),
23        }
24    }
25}
26
27impl<T: fmt::Debug, State: fmt::Debug> fmt::Debug for RamaService<T, State> {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        f.debug_struct("RamaService")
30            .field("svc", &self.svc)
31            .field("state", &self.state)
32            .finish()
33    }
34}
35
36/// Private type wrapper to indicate [`RamaService`]
37/// is used by the user from the Transport layer (tcp).
38pub struct Transport<S>(S);
39
40/// Private type wrapper to indicate [`RamaService`]
41/// is used by the user from the Application layer (http(s)).
42pub struct Application<S>(S);
43
44macro_rules! impl_wrapper_derive_traits {
45    ($name:ident) => {
46        impl<S: Clone> Clone for $name<S> {
47            fn clone(&self) -> Self {
48                Self(self.0.clone())
49            }
50        }
51
52        impl<S: fmt::Debug> fmt::Debug for $name<S> {
53            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54                f.debug_tuple(stringify!($name)).field(&self.0).finish()
55            }
56        }
57    };
58}
59
60impl_wrapper_derive_traits!(Transport);
61impl_wrapper_derive_traits!(Application);
62
63impl<S> RamaService<Transport<S>, ()> {
64    pub fn transport(svc: S) -> Self {
65        Self {
66            svc: Transport(svc),
67            state: (),
68        }
69    }
70}
71
72impl<S> RamaService<Application<S>, ()> {
73    pub fn application(svc: S) -> Self {
74        Self {
75            svc: Application(svc),
76            state: (),
77        }
78    }
79}
80
81impl<T> RamaService<T, ()> {
82    /// Attach state to this [`RamaService`], such that it will be passed
83    /// as part of each request's [`Context`].
84    ///
85    /// [`Context`]: rama::Context
86    pub fn with_state<State>(self, state: State) -> RamaService<T, State>
87    where
88        State: Clone + Send + Sync + 'static,
89    {
90        RamaService {
91            svc: self.svc,
92            state,
93        }
94    }
95}
96
97#[shuttle_runtime::async_trait]
98impl<S, State> shuttle_runtime::Service for RamaService<Transport<S>, State>
99where
100    S: Service<State, tokio::net::TcpStream>,
101    State: Clone + Send + Sync + 'static,
102{
103    /// Takes the service that is returned by the user in their [shuttle_runtime::main] function
104    /// and binds to an address passed in by shuttle.
105    async fn bind(self, addr: SocketAddr) -> Result<(), Error> {
106        TcpListener::build_with_state(self.state)
107            .bind(addr)
108            .await
109            .map_err(|err| Error::BindPanic(err.to_string()))?
110            .serve(self.svc.0)
111            .await;
112        Ok(())
113    }
114}
115
116#[shuttle_runtime::async_trait]
117impl<S, State, Response> shuttle_runtime::Service for RamaService<Application<S>, State>
118where
119    S: Service<State, Request, Response = Response, Error = Infallible>,
120    Response: IntoResponse + Send + 'static,
121    State: Clone + Send + Sync + 'static,
122{
123    /// Takes the service that is returned by the user in their [shuttle_runtime::main] function
124    /// and binds to an address passed in by shuttle.
125    async fn bind(self, addr: SocketAddr) -> Result<(), Error> {
126        // shuttle only supports h1 between load balancer <=> web service,
127        // h2 is terminated by shuttle's load balancer
128        HttpServer::http1()
129            .listen_with_state(self.state, addr, self.svc.0)
130            .await
131            .map_err(|err| CustomError::new(OpaqueError::from_boxed(err)))?;
132        Ok(())
133    }
134}
135
136#[doc = include_str!("../README.md")]
137pub type ShuttleRamaTransport<S, State = ()> = Result<RamaService<Transport<S>, State>, Error>;
138
139#[doc = include_str!("../README.md")]
140pub type ShuttleRamaApplication<S, State = ()> = Result<RamaService<Application<S>, State>, Error>;
141
142pub use shuttle_runtime::{Error as ShuttleError, Service as ShuttleService};