1use std::marker::PhantomData;
2use std::sync::Arc;
3
4use arc_swap::ArcSwapOption;
5use async_trait::async_trait;
6
7use crate::error::Error;
8use crate::message::*;
9use crate::netapp::*;
10
11#[async_trait]
18pub trait StreamingEndpointHandler<M>: Send + Sync
19where
20 M: Message,
21{
22 async fn handle(self: &Arc<Self>, m: Req<M>, from: NodeID) -> Resp<M>;
23}
24
25#[async_trait]
30impl<M: Message> EndpointHandler<M> for () {
31 async fn handle(self: &Arc<()>, _m: &M, _from: NodeID) -> M::Response {
32 panic!("This endpoint should not have a local handler.");
33 }
34}
35
36#[async_trait]
42pub trait EndpointHandler<M>: Send + Sync
43where
44 M: Message,
45{
46 async fn handle(self: &Arc<Self>, m: &M, from: NodeID) -> M::Response;
47}
48
49#[async_trait]
50impl<T, M> StreamingEndpointHandler<M> for T
51where
52 T: EndpointHandler<M>,
53 M: Message,
54{
55 async fn handle(self: &Arc<Self>, mut m: Req<M>, from: NodeID) -> Resp<M> {
56 drop(m.take_stream());
59 Resp::new(EndpointHandler::handle(self, m.msg(), from).await)
60 }
61}
62
63pub struct Endpoint<M, H>
76where
77 M: Message,
78 H: StreamingEndpointHandler<M>,
79{
80 _phantom: PhantomData<M>,
81 netapp: Arc<NetApp>,
82 path: String,
83 handler: ArcSwapOption<H>,
84}
85
86impl<M, H> Endpoint<M, H>
87where
88 M: Message,
89 H: StreamingEndpointHandler<M>,
90{
91 pub(crate) fn new(netapp: Arc<NetApp>, path: String) -> Self {
92 Self {
93 _phantom: PhantomData::default(),
94 netapp,
95 path,
96 handler: ArcSwapOption::from(None),
97 }
98 }
99
100 pub fn path(&self) -> &str {
102 &self.path
103 }
104
105 pub fn set_handler(&self, h: Arc<H>) {
108 self.handler.swap(Some(h));
109 }
110
111 pub async fn call_streaming<T>(
116 &self,
117 target: &NodeID,
118 req: T,
119 prio: RequestPriority,
120 ) -> Result<Resp<M>, Error>
121 where
122 T: IntoReq<M>,
123 {
124 if *target == self.netapp.id {
125 match self.handler.load_full() {
126 None => Err(Error::NoHandler),
127 Some(h) => Ok(h.handle(req.into_req_local(), self.netapp.id).await),
128 }
129 } else {
130 let conn = self
131 .netapp
132 .client_conns
133 .read()
134 .unwrap()
135 .get(target)
136 .cloned();
137 match conn {
138 None => Err(Error::Message(format!(
139 "Not connected: {}",
140 hex::encode(&target[..8])
141 ))),
142 Some(c) => c.call(req.into_req()?, self.path.as_str(), prio).await,
143 }
144 }
145 }
146
147 pub async fn call(
151 &self,
152 target: &NodeID,
153 req: M,
154 prio: RequestPriority,
155 ) -> Result<<M as Message>::Response, Error> {
156 Ok(self.call_streaming(target, req, prio).await?.into_msg())
157 }
158}
159
160pub(crate) type DynEndpoint = Box<dyn GenericEndpoint + Send + Sync>;
163
164#[async_trait]
165pub(crate) trait GenericEndpoint {
166 async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result<RespEnc, Error>;
167 fn drop_handler(&self);
168 fn clone_endpoint(&self) -> DynEndpoint;
169}
170
171#[derive(Clone)]
172pub(crate) struct EndpointArc<M, H>(pub(crate) Arc<Endpoint<M, H>>)
173where
174 M: Message,
175 H: StreamingEndpointHandler<M>;
176
177#[async_trait]
178impl<M, H> GenericEndpoint for EndpointArc<M, H>
179where
180 M: Message,
181 H: StreamingEndpointHandler<M> + 'static,
182{
183 async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result<RespEnc, Error> {
184 match self.0.handler.load_full() {
185 None => Err(Error::NoHandler),
186 Some(h) => {
187 let req = Req::from_enc(req_enc)?;
188 let res = h.handle(req, from).await;
189 Ok(res.into_enc()?)
190 }
191 }
192 }
193
194 fn drop_handler(&self) {
195 self.0.handler.swap(None);
196 }
197
198 fn clone_endpoint(&self) -> DynEndpoint {
199 Box::new(Self(self.0.clone()))
200 }
201}