capybara_core/protocol/stream/
mod.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use arc_swap::ArcSwap;
5use async_trait::async_trait;
6use tokio::net::TcpStream;
7use tokio::sync::Notify;
8
9use capybara_util::cachestr::Cachestr;
10
11use crate::error::CapybaraError;
12use crate::pipeline::stream::load;
13use crate::pipeline::stream::StreamPipelineFactoryExt;
14use crate::pipeline::{PipelineConf, StreamContext};
15use crate::proto::{Listener, Signal, Signals};
16use crate::resolver::DEFAULT_RESOLVER;
17use crate::transport::TcpListenerBuilder;
18use crate::upstream::ClientStream;
19
20pub struct StreamListenerBuilder {
21    addr: SocketAddr,
22    id: Option<Cachestr>,
23    pipelines: Vec<(Cachestr, PipelineConf)>,
24}
25
26impl StreamListenerBuilder {
27    pub fn id<A>(mut self, id: A) -> Self
28    where
29        A: AsRef<str>,
30    {
31        self.id.replace(Cachestr::from(id.as_ref()));
32        self
33    }
34
35    pub fn pipeline<N>(mut self, name: N, c: &PipelineConf) -> Self
36    where
37        N: AsRef<str>,
38    {
39        self.pipelines
40            .push((Cachestr::from(name.as_ref()), Clone::clone(c)));
41        self
42    }
43
44    pub fn build(self) -> crate::Result<StreamListener> {
45        let Self {
46            addr,
47            id,
48            pipelines,
49        } = self;
50
51        Ok(StreamListener {
52            id: id.unwrap_or_else(|| Cachestr::from(uuid::Uuid::new_v4().to_string())),
53            addr,
54            pipelines: ArcSwap::from_pointee(pipelines),
55        })
56    }
57}
58
59pub struct StreamListener {
60    id: Cachestr,
61    addr: SocketAddr,
62    pipelines: ArcSwap<Vec<(Cachestr, PipelineConf)>>,
63}
64
65impl StreamListener {
66    pub fn builder(addr: SocketAddr) -> StreamListenerBuilder {
67        StreamListenerBuilder {
68            id: None,
69            addr,
70            pipelines: vec![],
71        }
72    }
73
74    #[inline]
75    fn build_pipeline_factories(&self) -> anyhow::Result<Vec<Box<dyn StreamPipelineFactoryExt>>> {
76        let r = self.pipelines.load();
77
78        let mut factories = Vec::with_capacity(r.len());
79        for (k, v) in r.iter() {
80            let factory = load(k, v)?;
81            factories.push(factory);
82        }
83
84        Ok(factories)
85    }
86
87    fn build_context(
88        client_addr: SocketAddr,
89        factories: &[Box<dyn StreamPipelineFactoryExt>],
90    ) -> anyhow::Result<StreamContext> {
91        let mut b = StreamContext::builder(client_addr);
92
93        for factory in factories {
94            let next = factory.generate_arc()?;
95            b = b.pipeline_arc(next);
96        }
97
98        Ok(b.build())
99    }
100}
101
102#[async_trait]
103impl Listener for StreamListener {
104    fn id(&self) -> &str {
105        self.id.as_ref()
106    }
107
108    async fn listen(&self, signals: &mut Signals) -> crate::Result<()> {
109        let l = TcpListenerBuilder::new(self.addr).build()?;
110
111        info!("'{}' is listening on {}", &self.id, &self.addr);
112
113        let closer = Arc::new(Notify::new());
114
115        let mut pipelines = self.build_pipeline_factories()?;
116
117        loop {
118            tokio::select! {
119                signal = signals.recv() => {
120                    match signal {
121                        None => {
122                            info!("listener '{}' is stopping....", &self.id);
123                            return Ok(());
124                        }
125                        Some(Signal::Shutdown) => {
126                            info!("listener '{}' is stopping...", &self.id);
127                            return Ok(());
128                        }
129                        Some(Signal::Reload) => {
130                            info!("listener '{}' is reloading...", &self.id);
131                            // TODO: reload the current listener
132                            pipelines = self.build_pipeline_factories()?;
133                        }
134                    }
135                }
136                accept = l.accept() => {
137                    let (stream,addr) = accept?;
138
139                    let ctx = Self::build_context(addr,&pipelines[..])?;
140                    let closer = Clone::clone(&closer);
141
142                    let mut h = Handler::new(ctx,stream);
143                    tokio::spawn(async move {
144                        if let Err(e) = h.handle(closer).await{
145                            error!("stream handler occurs an error: {}", e);
146                        }
147                    });
148                }
149            }
150        }
151    }
152}
153
154struct Handler {
155    ctx: StreamContext,
156    downstream: TcpStream,
157}
158
159impl Handler {
160    const BUFF_SIZE: usize = 8192;
161
162    fn new(ctx: StreamContext, downstream: TcpStream) -> Self {
163        Self { ctx, downstream }
164    }
165
166    async fn resolve(upstream: &str) -> anyhow::Result<SocketAddr> {
167        if let Ok(addr) = upstream.parse::<SocketAddr>() {
168            return Ok(addr);
169        }
170
171        let host_and_port = upstream.split(':').collect::<Vec<&str>>();
172        if host_and_port.len() != 2 {
173            bail!(CapybaraError::NoAddressResolved(
174                upstream.to_string().into()
175            ))
176        }
177        let port: u16 = host_and_port.last().unwrap().parse()?;
178        let ip = DEFAULT_RESOLVER
179            .resolve_one(host_and_port.first().unwrap())
180            .await?;
181
182        Ok(SocketAddr::new(ip, port))
183    }
184
185    async fn handle(&mut self, _closer: Arc<Notify>) -> anyhow::Result<()> {
186        if let Some(p) = self.ctx.pipeline() {
187            p.handle_connect(&mut self.ctx).await?;
188        }
189
190        let mut upstream = match self.ctx.upstream() {
191            None => bail!(CapybaraError::InvalidRoute),
192            Some(upstream) => crate::upstream::establish(&upstream, Self::BUFF_SIZE).await?,
193        };
194
195        use tokio::io::copy_bidirectional_with_sizes as copy;
196
197        let (in_bytes, out_bytes) = match &mut upstream {
198            ClientStream::Tcp(inner) => {
199                copy(
200                    &mut self.downstream,
201                    inner,
202                    Self::BUFF_SIZE,
203                    Self::BUFF_SIZE,
204                )
205                .await?
206            }
207            ClientStream::Tls(inner) => {
208                copy(
209                    &mut self.downstream,
210                    inner,
211                    Self::BUFF_SIZE,
212                    Self::BUFF_SIZE,
213                )
214                .await?
215            }
216        };
217
218        debug!("copy bidirectional ok: in={}, out={}", in_bytes, out_bytes);
219
220        Ok(())
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use tokio::io::{AsyncReadExt, AsyncWriteExt};
227
228    use super::*;
229
230    async fn init() {
231        pretty_env_logger::try_init_timed().ok();
232        crate::setup().await;
233    }
234
235    #[tokio::test]
236    async fn test_stream_listener() -> anyhow::Result<()> {
237        init().await;
238
239        let (tx, mut rx) = tokio::sync::mpsc::channel(1);
240        let closed = Arc::new(Notify::new());
241
242        let addr = "127.0.0.1:9999";
243
244        {
245            let closed = Clone::clone(&closed);
246            let c: PipelineConf = {
247                // language=yaml
248                let s = r#"
249            upstream: 'httpbin.org:80'
250            "#;
251
252                serde_yaml::from_str(s).unwrap()
253            };
254
255            let l = StreamListener::builder(addr.parse().unwrap())
256                .id("fake-stream-listener")
257                .pipeline("capybara.pipelines.stream.router", &c)
258                .build()?;
259
260            tokio::spawn(async move {
261                let _ = l.listen(&mut rx).await;
262                closed.notify_one();
263            });
264        }
265
266        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
267
268        // establish conn
269        let mut c = TcpStream::connect(addr).await?;
270
271        // write data
272        {
273            // use HTTP/1.0, conn will be closed automatically
274            c.write_all(&b"GET /ip HTTP/1.0\r\nHost: httpbin.org\r\nAccept: *\r\n\r\n"[..])
275                .await?;
276            c.flush().await?;
277        }
278
279        // read data
280        {
281            let mut v = vec![];
282            c.read_to_end(&mut v).await?;
283
284            info!("read: {}", String::from_utf8_lossy(&v[..]));
285
286            assert!(!v.is_empty());
287        }
288
289        // shutdown
290        tx.send(Signal::Shutdown).await?;
291
292        // wait for closed
293        closed.notified().await;
294
295        Ok(())
296    }
297}