tchannel_rs/
subchannel.rs

1use crate::channel::TResult;
2use crate::connection::pool::ConnectionPools;
3use crate::connection::{ConnectionResult, FrameInput, FrameOutput};
4use crate::defragmentation::ResponseDefragmenter;
5use crate::errors::{CodecError, ConnectionError, HandlerError, TChannelError};
6use crate::fragmentation::RequestFragmenter;
7use crate::frames::TFrameStream;
8use crate::handler::{
9    HandlerResult, MessageArgsHandler, RequestHandler, RequestHandlerAdapter, RequestHandlerAsync,
10    RequestHandlerAsyncAdapter,
11};
12use crate::messages::args::{MessageArgs, MessageArgsResponse, ResponseCode};
13use crate::messages::Message;
14use futures::StreamExt;
15use futures::{future, TryStreamExt};
16use log::{debug, error};
17use std::collections::HashMap;
18use std::net::{SocketAddr, ToSocketAddrs};
19use std::sync::Arc;
20use tokio::sync::{Mutex, RwLock};
21
22type HandlerRef = Arc<Mutex<Box<dyn MessageArgsHandler>>>;
23
24/// TChannel protocol subchannel.
25///
26/// Allows to send [`Message`](crate::messages::Message) and [`register`](Self::register)/[`unregister`](Self::unregister) [`RequestHandler`](crate::handler::RequestHandler) (or [`RequestHandlerAsync`](crate::handler::RequestHandlerAsync)).
27#[derive(Debug, new)]
28pub struct SubChannel {
29    service_name: String,
30    connection_pools: Arc<ConnectionPools>,
31    #[new(default)]
32    handlers: RwLock<HashMap<String, HandlerRef>>,
33}
34
35impl SubChannel {
36    pub(super) async fn send<REQ: Message, RES: Message, ADDR: ToSocketAddrs>(
37        &self,
38        request: REQ,
39        host: ADDR,
40    ) -> HandlerResult<RES> {
41        let (frames_in, frames_out) = self.create_frame_io(host).await?;
42        let response_res = self.send_internal(request, frames_in, &frames_out).await;
43        frames_out.close().await; //TODO still ugly
44        match response_res {
45            Ok((code, response)) => match code {
46                ResponseCode::Ok => Ok(response),
47                ResponseCode::Error => Err(HandlerError::MessageError(response)),
48            },
49            Err(err) => Err(HandlerError::InternalError(err)),
50        }
51    }
52
53    async fn create_frame_io<ADDR: ToSocketAddrs>(
54        &self,
55        host: ADDR,
56    ) -> TResult<(FrameInput, FrameOutput)> {
57        let host = first_addr(host)?;
58        self.connect(host).await
59    }
60
61    pub(super) async fn send_internal<REQ: Message, RES: Message>(
62        &self,
63        request: REQ,
64        frames_in: FrameInput,
65        frames_out: &FrameOutput,
66    ) -> TResult<(ResponseCode, RES)> {
67        let frames = self.create_frames(request).await?;
68        send_frames(frames, frames_out).await?;
69        let response = ResponseDefragmenter::new(frames_in)
70            .read_response_msg()
71            .await;
72        frames_out.close().await; //TODO ugly and broken
73        response
74    }
75
76    /// Registers request handler.
77    pub async fn register<REQ, RES, HANDLER>(
78        &self,
79        endpoint: impl AsRef<str>,
80        request_handler: HANDLER,
81    ) -> TResult<()>
82    where
83        REQ: Message + 'static,
84        RES: Message + 'static,
85        HANDLER: RequestHandler<REQ = REQ, RES = RES> + 'static,
86    {
87        let handler_adapter = RequestHandlerAdapter::new(request_handler);
88        self.register_handler(endpoint, Arc::new(Mutex::new(Box::new(handler_adapter))))
89            .await
90    }
91
92    /// Registers async request handler.
93    pub async fn register_async<REQ, RES, HANDLER>(
94        &self,
95        endpoint: impl AsRef<str>,
96        request_handler: HANDLER,
97    ) -> TResult<()>
98    where
99        REQ: Message + 'static,
100        RES: Message + 'static,
101        HANDLER: RequestHandlerAsync<REQ = REQ, RES = RES> + 'static,
102    {
103        let handler_adapter = RequestHandlerAsyncAdapter::new(request_handler);
104        self.register_handler(endpoint, Arc::new(Mutex::new(Box::new(handler_adapter))))
105            .await
106    }
107
108    /// Unregisters request handler. Found handler will be dropped.
109    pub async fn unregister(&mut self, endpoint: impl AsRef<str>) -> TResult<()> {
110        let mut handlers = self.handlers.write().await;
111        match handlers.remove(endpoint.as_ref()) {
112            Some(_) => Ok(()),
113            None => Err(TChannelError::Error(format!(
114                "Handler '{}' is missing.",
115                endpoint.as_ref()
116            ))),
117        }
118    }
119
120    async fn register_handler(
121        &self,
122        endpoint: impl AsRef<str>,
123        request_handler: HandlerRef,
124    ) -> TResult<()> {
125        let mut handlers = self.handlers.write().await;
126        if handlers.contains_key(endpoint.as_ref()) {
127            return Err(TChannelError::Error(format!(
128                "Handler already registered for '{}'",
129                endpoint.as_ref()
130            )));
131        }
132        handlers.insert(endpoint.as_ref().to_string(), request_handler);
133        Ok(()) //TODO return &mut of nested handler?
134    }
135
136    async fn connect(&self, host: SocketAddr) -> TResult<(FrameInput, FrameOutput)> {
137        let pool = self.connection_pools.get(host).await?;
138        let connection = pool.get().await?;
139        Ok(connection.new_frames_io().await?)
140    }
141
142    async fn create_frames<REQ: Message>(&self, request: REQ) -> TResult<TFrameStream> {
143        let message_args = request.try_into()?;
144        RequestFragmenter::new(self.service_name.clone(), message_args).create_frames()
145    }
146
147    pub(crate) async fn handle(&self, request: MessageArgs) -> MessageArgsResponse {
148        let endpoint = Self::read_endpoint_name(&request)?;
149        let handler_locked = self.get_handler(endpoint).await?;
150        let mut handler = handler_locked.lock().await; //TODO do I really want Mutex? maybe handle(&self,..) instead of handle(&mut self,..) ?
151        handler.handle(request).await
152    }
153
154    async fn get_handler(&self, endpoint: String) -> TResult<HandlerRef> {
155        let handlers = self.handlers.read().await;
156        match handlers.get(&endpoint) {
157            Some(handler) => Ok(handler.clone()),
158            None => Err(TChannelError::Error(format!(
159                "No handler with name '{}'.",
160                endpoint
161            ))),
162        }
163    }
164
165    fn read_endpoint_name(request: &MessageArgs) -> Result<String, CodecError> {
166        match request.args.get(0) {
167            Some(arg) => Ok(String::from_utf8(arg.to_vec())?),
168            None => Err(CodecError::Error("Missing arg1/endpoint name".to_string())),
169        }
170    }
171}
172
173fn first_addr<ADDR: ToSocketAddrs>(addr: ADDR) -> ConnectionResult<SocketAddr> {
174    let mut addrs = addr.to_socket_addrs()?;
175    if let Some(addr) = addrs.next() {
176        return Ok(addr);
177    }
178    Err(ConnectionError::Error(
179        "Unable to get host addr".to_string(),
180    ))
181}
182
183async fn send_frames(frames: TFrameStream, frames_out: &FrameOutput) -> ConnectionResult<()> {
184    debug!("Sending frames");
185    frames
186        .then(|frame| frames_out.send(frame))
187        .inspect_err(|err| error!("Failed to send frame {:?}", err))
188        .try_for_each(|_res| future::ready(Ok(())))
189        .await
190}