capybara_core/protocol/stream/
mod.rs1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use arc_swap::ArcSwap;
5use async_trait::async_trait;
6use tokio::net::TcpStream;
7use tokio::sync::Notify;
8
9use capybara_util::cachestr::Cachestr;
10
11use crate::error::CapybaraError;
12use crate::pipeline::stream::load;
13use crate::pipeline::stream::StreamPipelineFactoryExt;
14use crate::pipeline::{PipelineConf, StreamContext};
15use crate::proto::{Listener, Signal, Signals};
16use crate::resolver::DEFAULT_RESOLVER;
17use crate::transport::TcpListenerBuilder;
18use crate::upstream::ClientStream;
19
20pub struct StreamListenerBuilder {
21 addr: SocketAddr,
22 id: Option<Cachestr>,
23 pipelines: Vec<(Cachestr, PipelineConf)>,
24}
25
26impl StreamListenerBuilder {
27 pub fn id<A>(mut self, id: A) -> Self
28 where
29 A: AsRef<str>,
30 {
31 self.id.replace(Cachestr::from(id.as_ref()));
32 self
33 }
34
35 pub fn pipeline<N>(mut self, name: N, c: &PipelineConf) -> Self
36 where
37 N: AsRef<str>,
38 {
39 self.pipelines
40 .push((Cachestr::from(name.as_ref()), Clone::clone(c)));
41 self
42 }
43
44 pub fn build(self) -> crate::Result<StreamListener> {
45 let Self {
46 addr,
47 id,
48 pipelines,
49 } = self;
50
51 Ok(StreamListener {
52 id: id.unwrap_or_else(|| Cachestr::from(uuid::Uuid::new_v4().to_string())),
53 addr,
54 pipelines: ArcSwap::from_pointee(pipelines),
55 })
56 }
57}
58
59pub struct StreamListener {
60 id: Cachestr,
61 addr: SocketAddr,
62 pipelines: ArcSwap<Vec<(Cachestr, PipelineConf)>>,
63}
64
65impl StreamListener {
66 pub fn builder(addr: SocketAddr) -> StreamListenerBuilder {
67 StreamListenerBuilder {
68 id: None,
69 addr,
70 pipelines: vec![],
71 }
72 }
73
74 #[inline]
75 fn build_pipeline_factories(&self) -> anyhow::Result<Vec<Box<dyn StreamPipelineFactoryExt>>> {
76 let r = self.pipelines.load();
77
78 let mut factories = Vec::with_capacity(r.len());
79 for (k, v) in r.iter() {
80 let factory = load(k, v)?;
81 factories.push(factory);
82 }
83
84 Ok(factories)
85 }
86
87 fn build_context(
88 client_addr: SocketAddr,
89 factories: &[Box<dyn StreamPipelineFactoryExt>],
90 ) -> anyhow::Result<StreamContext> {
91 let mut b = StreamContext::builder(client_addr);
92
93 for factory in factories {
94 let next = factory.generate_arc()?;
95 b = b.pipeline_arc(next);
96 }
97
98 Ok(b.build())
99 }
100}
101
102#[async_trait]
103impl Listener for StreamListener {
104 fn id(&self) -> &str {
105 self.id.as_ref()
106 }
107
108 async fn listen(&self, signals: &mut Signals) -> crate::Result<()> {
109 let l = TcpListenerBuilder::new(self.addr).build()?;
110
111 info!("'{}' is listening on {}", &self.id, &self.addr);
112
113 let closer = Arc::new(Notify::new());
114
115 let mut pipelines = self.build_pipeline_factories()?;
116
117 loop {
118 tokio::select! {
119 signal = signals.recv() => {
120 match signal {
121 None => {
122 info!("listener '{}' is stopping....", &self.id);
123 return Ok(());
124 }
125 Some(Signal::Shutdown) => {
126 info!("listener '{}' is stopping...", &self.id);
127 return Ok(());
128 }
129 Some(Signal::Reload) => {
130 info!("listener '{}' is reloading...", &self.id);
131 pipelines = self.build_pipeline_factories()?;
133 }
134 }
135 }
136 accept = l.accept() => {
137 let (stream,addr) = accept?;
138
139 let ctx = Self::build_context(addr,&pipelines[..])?;
140 let closer = Clone::clone(&closer);
141
142 let mut h = Handler::new(ctx,stream);
143 tokio::spawn(async move {
144 if let Err(e) = h.handle(closer).await{
145 error!("stream handler occurs an error: {}", e);
146 }
147 });
148 }
149 }
150 }
151 }
152}
153
154struct Handler {
155 ctx: StreamContext,
156 downstream: TcpStream,
157}
158
159impl Handler {
160 const BUFF_SIZE: usize = 8192;
161
162 fn new(ctx: StreamContext, downstream: TcpStream) -> Self {
163 Self { ctx, downstream }
164 }
165
166 async fn resolve(upstream: &str) -> anyhow::Result<SocketAddr> {
167 if let Ok(addr) = upstream.parse::<SocketAddr>() {
168 return Ok(addr);
169 }
170
171 let host_and_port = upstream.split(':').collect::<Vec<&str>>();
172 if host_and_port.len() != 2 {
173 bail!(CapybaraError::NoAddressResolved(
174 upstream.to_string().into()
175 ))
176 }
177 let port: u16 = host_and_port.last().unwrap().parse()?;
178 let ip = DEFAULT_RESOLVER
179 .resolve_one(host_and_port.first().unwrap())
180 .await?;
181
182 Ok(SocketAddr::new(ip, port))
183 }
184
185 async fn handle(&mut self, _closer: Arc<Notify>) -> anyhow::Result<()> {
186 if let Some(p) = self.ctx.pipeline() {
187 p.handle_connect(&mut self.ctx).await?;
188 }
189
190 let mut upstream = match self.ctx.upstream() {
191 None => bail!(CapybaraError::InvalidRoute),
192 Some(upstream) => crate::upstream::establish(&upstream, Self::BUFF_SIZE).await?,
193 };
194
195 use tokio::io::copy_bidirectional_with_sizes as copy;
196
197 let (in_bytes, out_bytes) = match &mut upstream {
198 ClientStream::Tcp(inner) => {
199 copy(
200 &mut self.downstream,
201 inner,
202 Self::BUFF_SIZE,
203 Self::BUFF_SIZE,
204 )
205 .await?
206 }
207 ClientStream::Tls(inner) => {
208 copy(
209 &mut self.downstream,
210 inner,
211 Self::BUFF_SIZE,
212 Self::BUFF_SIZE,
213 )
214 .await?
215 }
216 };
217
218 debug!("copy bidirectional ok: in={}, out={}", in_bytes, out_bytes);
219
220 Ok(())
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use tokio::io::{AsyncReadExt, AsyncWriteExt};
227
228 use super::*;
229
230 async fn init() {
231 pretty_env_logger::try_init_timed().ok();
232 crate::setup().await;
233 }
234
235 #[tokio::test]
236 async fn test_stream_listener() -> anyhow::Result<()> {
237 init().await;
238
239 let (tx, mut rx) = tokio::sync::mpsc::channel(1);
240 let closed = Arc::new(Notify::new());
241
242 let addr = "127.0.0.1:9999";
243
244 {
245 let closed = Clone::clone(&closed);
246 let c: PipelineConf = {
247 let s = r#"
249 upstream: 'httpbin.org:80'
250 "#;
251
252 serde_yaml::from_str(s).unwrap()
253 };
254
255 let l = StreamListener::builder(addr.parse().unwrap())
256 .id("fake-stream-listener")
257 .pipeline("capybara.pipelines.stream.router", &c)
258 .build()?;
259
260 tokio::spawn(async move {
261 let _ = l.listen(&mut rx).await;
262 closed.notify_one();
263 });
264 }
265
266 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
267
268 let mut c = TcpStream::connect(addr).await?;
270
271 {
273 c.write_all(&b"GET /ip HTTP/1.0\r\nHost: httpbin.org\r\nAccept: *\r\n\r\n"[..])
275 .await?;
276 c.flush().await?;
277 }
278
279 {
281 let mut v = vec![];
282 c.read_to_end(&mut v).await?;
283
284 info!("read: {}", String::from_utf8_lossy(&v[..]));
285
286 assert!(!v.is_empty());
287 }
288
289 tx.send(Signal::Shutdown).await?;
291
292 closed.notified().await;
294
295 Ok(())
296 }
297}