use std::sync::{Arc, Mutex};
use std::time::Duration;
use anyhow::Result;
use futures_util::StreamExt;
use serde::{Deserialize, Serialize};
use tokio::io::{self, AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio_util::codec::{Framed, LinesCodec};
pub const NETWORK_TIMEOUT: Duration = Duration::from_secs(3);
pub const MAX_FRAME_LENGTH: usize = 1024;
#[derive(Debug, Serialize, Deserialize)]
struct ProxyResponse {
id: String,
port: u16,
max_conn_count: u8,
url: String,
}
pub async fn open_tunnel(
server: &str,
subdomain: Option<&str>,
local_port: u16,
) -> Result<(), Box<dyn std::error::Error>> {
println!("start connect to: {}, {}", "localhost", "12000");
let assigned_domain = subdomain.unwrap_or("?new");
let uri = format!("{}/{}", server, assigned_domain);
println!("assigned domain: {}", uri);
let resp = reqwest::get(uri).await?.json::<ProxyResponse>().await?;
println!("{:#?}", resp);
let remote_stream = TcpStream::connect(format!("proxy.ad4m.dev:{}", resp.port)).await?;
println!("remote stream connectted");
let codec = LinesCodec::new();
let mut framed_stream = Framed::new(remote_stream, codec);
let counter = Arc::new(Mutex::new(0));
loop {
let _message = framed_stream.next().await;
let mut locked_counter = counter.lock().unwrap();
if *locked_counter < resp.max_conn_count {
println!("spawn new proxy");
*locked_counter += 1;
let counter2 = Arc::clone(&counter);
tokio::spawn(async move { handle_conn(resp.port, local_port, counter2).await });
}
}
}
async fn handle_conn(remote_port: u16, local_port: u16, counter: Arc<Mutex<u8>>) -> Result<()> {
let remote_stream_in = TcpStream::connect(format!("proxy.ad4m.dev:{}", remote_port)).await?;
let local_stream_in = TcpStream::connect(format!("127.0.0.1:{}", local_port)).await?;
proxy(remote_stream_in, local_stream_in, counter).await?;
Ok(())
}
pub async fn proxy<S1, S2>(stream1: S1, stream2: S2, counter: Arc<Mutex<u8>>) -> io::Result<()>
where
S1: AsyncRead + AsyncWrite + Unpin,
S2: AsyncRead + AsyncWrite + Unpin,
{
let (mut s1_read, mut s1_write) = io::split(stream1);
let (mut s2_read, mut s2_write) = io::split(stream2);
tokio::select! {
res = io::copy(&mut s1_read, &mut s2_write) => res,
res = io::copy(&mut s2_read, &mut s1_write) => res,
}?;
let mut locked_counter = counter.lock().unwrap();
*locked_counter -= 1;
Ok(())
}
#[cfg(test)]
mod tests {
#[test]
fn it_works() {
let result = 2 + 2;
assert_eq!(result, 4);
}
}