use anyhow::{anyhow, Result};
use reqwest::Client;
use rustrtc::{
transports::sctp::{DataChannel, DataChannelConfig, DataChannelEvent},
PeerConnection, PeerConnectionEvent, SdpType, SessionDescription,
};
use serde_json::Value;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tracing::{error, info};
use crate::webrtc_config::WebRTCConfig;
use crate::{config::IceServerConfig, OfferMessage};
const DRAIN_TIMEOUT: Duration = Duration::from_secs(5);
pub async fn forward_stream_to_webrtc<R, W>(
peer_connection: Arc<PeerConnection>,
data_channel: Arc<DataChannel>,
connect_timeout: Option<u32>,
mut input: R,
mut output: W,
) -> Result<()>
where
R: tokio::io::AsyncRead + Unpin + Send + 'static,
W: tokio::io::AsyncWrite + Unpin + Send + 'static,
{
let (open_tx, open_rx) = tokio::sync::oneshot::channel();
let (msg_tx, mut msg_rx) = tokio::sync::mpsc::unbounded_channel();
let dc_closed = tokio_util::sync::CancellationToken::new();
let dc_clone = data_channel.clone();
let pc_disc = peer_connection.clone();
let dc_closed_tx = dc_closed.clone();
tokio::spawn(async move {
let mut open_tx = Some(open_tx);
while let Some(event) = dc_clone.recv().await {
match event {
DataChannelEvent::Open => {
if let Some(tx) = open_tx.take() {
let _ = tx.send(());
}
}
DataChannelEvent::Message(data) => {
let _ = msg_tx.send(data);
}
DataChannelEvent::Close => {
if let Some(reason) = pc_disc.disconnect_reason() {
tracing::warn!("Data channel closed (reason: {})", reason);
}
dc_closed_tx.cancel();
break;
}
}
}
});
let connect_timeout = connect_timeout.unwrap_or(30);
if let Err(_) = tokio::time::timeout(Duration::from_secs(connect_timeout.into()), open_rx).await
{
return Err(anyhow!("Data channel open timeout"));
}
let pc_monitor = peer_connection.clone();
let webrtc_dead = tokio_util::sync::CancellationToken::new();
let webrtc_dead_tx = webrtc_dead.clone();
tokio::spawn(async move {
let mut state_rx = pc_monitor.subscribe_peer_state();
while let Ok(()) = state_rx.changed().await {
let state = *state_rx.borrow();
match state {
rustrtc::PeerConnectionState::Disconnected
| rustrtc::PeerConnectionState::Failed
| rustrtc::PeerConnectionState::Closed => {
if let Some(reason) = pc_monitor.disconnect_reason() {
tracing::warn!("WebRTC connection lost: {} (state: {:?})", reason, state);
} else {
tracing::warn!("WebRTC connection lost: state {:?}", state);
}
webrtc_dead_tx.cancel();
break;
}
_ => {}
}
}
});
let pc_clone = peer_connection.clone();
let dc_id = data_channel.id;
let input_task = async move {
let mut buffer = [0u8; 1200];
loop {
match input.read(&mut buffer).await {
Ok(0) => {
tracing::debug!("forward_stream_to_webrtc: input EOF");
break;
}
Ok(n) => {
let data = &buffer[..n];
if let Err(e) = pc_clone.send_data(dc_id, data).await {
tracing::error!("Failed to send data through WebRTC: {}", e);
break;
}
}
Err(e) => {
tracing::debug!("forward_stream_to_webrtc: input read failed: {}", e);
break;
}
}
}
};
let mut output_task = tokio::spawn(async move {
while let Some(data) = msg_rx.recv().await {
if output.write_all(&data).await.is_err() {
break;
}
if output.flush().await.is_err() {
break;
}
}
});
tokio::select! {
_ = webrtc_dead.cancelled() => {
tracing::debug!("forward_stream_to_webrtc: exiting due to WebRTC disconnect");
}
_ = dc_closed.cancelled() => {
tracing::debug!("forward_stream_to_webrtc: data channel closed by remote");
}
_ = input_task => {
tracing::debug!("forward_stream_to_webrtc: input closed, waiting for drain");
tokio::select! {
_ = tokio::time::sleep(DRAIN_TIMEOUT) => {
tracing::debug!("forward_stream_to_webrtc: drain timeout, closing");
}
_ = dc_closed.cancelled() => {
tracing::debug!("forward_stream_to_webrtc: data channel closed during drain");
}
_ = &mut output_task => {
tracing::debug!("forward_stream_to_webrtc: output finished during drain");
}
}
}
_ = &mut output_task => {
tracing::debug!("forward_stream_to_webrtc: output closed");
}
}
Ok(())
}
pub struct CliClient {
server_url: String,
token: String,
client: Client,
webrtc_config: WebRTCConfig,
}
impl CliClient {
pub fn new(
server_url: String,
token: String,
ice_servers: Option<Vec<IceServerConfig>>,
) -> Self {
let webrtc_config = WebRTCConfig::new(
server_url.clone(),
token.clone(),
ice_servers.unwrap_or_default(),
);
Self {
server_url,
token,
client: Client::new(),
webrtc_config,
}
}
pub async fn connect_proxy_command(
&self,
connect_timeout: Option<u32>,
agent_id: String,
) -> Result<()> {
let (peer_connection, data_channel) =
self.create_webrtc_connection_silent(&agent_id).await?;
if let Err(e) = forward_stream_to_webrtc(
peer_connection,
data_channel,
connect_timeout,
tokio::io::stdin(),
tokio::io::stdout(),
)
.await
{
tracing::error!("forward_stream_to_webrtc failed: {}", e);
return Err(e);
}
Ok(())
}
pub async fn connect_port_forward(&self, agent_id: String, local_port: u16) -> Result<()> {
info!(
"Starting port forward from localhost:{} to agent {}",
local_port, agent_id
);
let listener = TcpListener::bind(format!("127.0.0.1:{}", local_port)).await?;
info!("Listening on localhost:{}", local_port);
loop {
match listener.accept().await {
Ok((tcp_stream, addr)) => {
info!("New connection from {}", addr);
let agent_id = agent_id.clone();
let client = self.clone();
tokio::spawn(async move {
if let Err(e) = client.handle_tcp_connection(tcp_stream, agent_id).await {
error!("Failed to handle TCP connection: {}", e);
}
});
}
Err(e) => {
error!("Failed to accept connection: {}", e);
}
}
}
}
async fn handle_tcp_connection(
&self,
mut tcp_stream: TcpStream,
agent_id: String,
) -> Result<()> {
let (close_tx, close_rx) = tokio::sync::oneshot::channel();
let max_read_timeout = Duration::from_secs(1800); let setup_result = async {
let (peer_connection, data_channel) = self.create_webrtc_connection(&agent_id).await?;
let pc_clone = peer_connection.clone();
tokio::spawn(async move {
while let Some(event) = pc_clone.recv().await {
match event {
PeerConnectionEvent::DataChannel(dc) => {
tracing::debug!(
"CliClient PC Event: DataChannel: id={}, label={}",
dc.id,
dc.label
);
}
_ => {}
}
}
});
let pc_monitor = peer_connection.clone();
tokio::spawn(async move {
let mut state_rx = pc_monitor.subscribe_peer_state();
while let Ok(()) = state_rx.changed().await {
let state = *state_rx.borrow();
match state {
rustrtc::PeerConnectionState::Disconnected
| rustrtc::PeerConnectionState::Failed
| rustrtc::PeerConnectionState::Closed => {
if let Some(reason) = pc_monitor.disconnect_reason() {
tracing::warn!(
"WebRTC connection ended: {} (state: {:?})",
reason,
state
);
} else {
tracing::warn!("WebRTC connection ended: state {:?}", state);
}
break;
}
_ => {
tracing::debug!("Peer connection state: {:?}", state);
}
}
}
});
if let Err(_) = tokio::time::timeout(
Duration::from_secs(30),
peer_connection.wait_for_connected(),
)
.await
{
return Err(anyhow!("WebRTC connection timeout"));
}
peer_connection.wait_for_connected().await?;
let (open_tx, open_rx) = tokio::sync::oneshot::channel();
let (msg_tx, msg_rx) = tokio::sync::mpsc::unbounded_channel();
let dc_clone = data_channel.clone();
tokio::spawn(async move {
let mut open_tx = Some(open_tx);
while let Some(event) = dc_clone.recv().await {
match event {
DataChannelEvent::Open => {
if let Some(tx) = open_tx.take() {
let _ = tx.send(());
}
}
DataChannelEvent::Message(data) => {
let _ = msg_tx.send(data);
}
DataChannelEvent::Close => {
let _ = close_tx.send(());
break;
}
}
}
});
if let Err(_) = tokio::time::timeout(Duration::from_secs(10), open_rx).await {
return Err(anyhow!("Data channel open timeout"));
}
Ok((peer_connection, data_channel, msg_rx))
}
.await;
let (peer_connection, data_channel, mut msg_rx) = match setup_result {
Ok(res) => res,
Err(e) => {
let msg = format!("RPORT_SETUP_ERROR: {}\n", e);
error!("{}", msg);
let _ = tcp_stream.write_all(msg.as_bytes()).await;
let _ = tcp_stream.flush().await;
tokio::time::sleep(Duration::from_millis(500)).await;
return Err(e);
}
};
let (mut tcp_read, mut tcp_write) = tcp_stream.into_split();
let pc_clone = peer_connection.clone();
let dc_id = data_channel.id;
let tcp_to_webrtc = async move {
let mut buffer = [0u8; 1024];
loop {
let r = tokio::time::timeout(max_read_timeout, tcp_read.read(&mut buffer)).await?;
match r {
Ok(0) => {
info!("TCP connection closed by client");
break;
}
Ok(n) => {
let data = &buffer[..n];
if let Err(e) = pc_clone.send_data(dc_id, data).await {
error!("Failed to send data through WebRTC: {}", e);
break;
}
}
Err(e) => {
error!("Failed to read from TCP: {}", e);
break;
}
}
}
Ok::<(), anyhow::Error>(())
};
let webrtc_to_tcp = async move {
while let Some(data) = msg_rx.recv().await {
if let Err(e) = tcp_write.write_all(&data).await {
error!("Failed to write to TCP: {}", e);
break;
}
if let Err(e) = tcp_write.flush().await {
error!("Failed to flush TCP: {}", e);
break;
}
}
};
tokio::select! {
_ = close_rx => {
let reason_str = peer_connection
.disconnect_reason()
.map(|r| format!("{}", r))
.unwrap_or_else(|| "normal close".to_string());
info!("Data channel closed (reason: {})", reason_str);
}
_ = tcp_to_webrtc => {
info!("TCP to WebRTC forwarding ended");
}
_ = webrtc_to_tcp => {
info!("WebRTC to TCP forwarding ended");
}
}
peer_connection.close();
Ok(())
}
async fn create_webrtc_connection(
&self,
agent_id: &str,
) -> Result<(Arc<PeerConnection>, Arc<DataChannel>)> {
info!("Creating WebRTC peer connection for agent: {}", agent_id);
let peer_connection = self.create_peer_connection().await?;
let data_channel_config = DataChannelConfig {
ordered: true,
..Default::default()
};
let data_channel =
peer_connection.create_data_channel("port-forward", Some(data_channel_config))?;
let offer = peer_connection.create_offer().await?;
peer_connection.set_local_description(offer.clone())?;
if let Err(_) = tokio::time::timeout(
Duration::from_secs(3),
peer_connection.wait_for_gathering_complete(),
)
.await
{
info!("ICE gathering timed out, proceeding with gathered candidates");
}
let offer = peer_connection
.local_description()
.ok_or_else(|| anyhow!("Failed to get local description after ICE gathering"))?;
let sdp = offer.to_sdp_string();
let offer_sdp = sdp
.lines()
.filter(|l| !l.contains("IP6") && !l.contains("::"))
.collect::<Vec<_>>()
.join("\r\n");
let offer_msg = OfferMessage {
id: agent_id.to_string(),
offer: offer_sdp,
};
info!("Sending offer to signaling server...");
let url = format!("{}/rport/offer?token={}", self.server_url, self.token);
let response = self.client.post(&url).json(&offer_msg).send().await?;
if !response.status().is_success() {
return Err(anyhow!("Failed to send offer: {}", response.status()));
}
let response_body: Value = response.json().await?;
let answer_sdp = response_body["answer"]
.as_str()
.ok_or_else(|| anyhow!("Missing answer in response"))?;
let answer = SessionDescription::parse(SdpType::Answer, &answer_sdp)?;
peer_connection.set_remote_description(answer).await?;
info!("WebRTC handshake completed successfully");
Ok((peer_connection, data_channel))
}
async fn create_webrtc_connection_silent(
&self,
agent_id: &str,
) -> Result<(Arc<PeerConnection>, Arc<DataChannel>)> {
let peer_connection = self.create_peer_connection().await?;
let pc_clone = peer_connection.clone();
tokio::spawn(async move {
while let Some(event) = pc_clone.recv().await {
match event {
PeerConnectionEvent::DataChannel(_) => {}
_ => {}
}
}
});
let data_channel_config = DataChannelConfig {
ordered: true,
..Default::default()
};
let data_channel =
peer_connection.create_data_channel("port-forward", Some(data_channel_config))?;
let offer = peer_connection.create_offer().await?;
peer_connection.set_local_description(offer.clone())?;
if let Err(_) = tokio::time::timeout(
Duration::from_secs(3),
peer_connection.wait_for_gathering_complete(),
)
.await
{
info!("ICE gathering timed out, proceeding with gathered candidates");
}
let offer = peer_connection
.local_description()
.ok_or_else(|| anyhow!("Failed to get local description after ICE gathering"))?;
let offer_sdp = offer.to_sdp_string();
let url = format!("{}/rport/offer?token={}", self.server_url, self.token);
tracing::debug!(
"create_webrtc_connection_silent: sending offer to {} \n {}",
url,
offer_sdp
);
let offer_msg = OfferMessage {
id: agent_id.to_string(),
offer: offer_sdp,
};
let response = self
.client
.post(&url)
.timeout(Duration::from_secs(10))
.json(&offer_msg)
.send()
.await?;
if !response.status().is_success() {
tracing::error!("Failed to send offer: {}", response.status());
return Err(anyhow!("Failed to send offer: {}", response.status()));
}
let response_body: Value = response.json().await?;
let answer_sdp = response_body["answer"]
.as_str()
.ok_or_else(|| anyhow!("Missing answer in response"))?;
let answer = SessionDescription::parse(SdpType::Answer, &answer_sdp)?;
peer_connection.set_remote_description(answer).await?;
Ok((peer_connection, data_channel))
}
async fn create_peer_connection(&self) -> Result<Arc<PeerConnection>> {
self.webrtc_config.create_peer_connection().await
}
}
impl Clone for CliClient {
fn clone(&self) -> Self {
Self {
server_url: self.server_url.clone(),
token: self.token.clone(),
client: Client::new(),
webrtc_config: self.webrtc_config.clone(),
}
}
}
#[cfg(test)]
pub mod tests {
use super::*;
use crate::OfferMessage;
use rustrtc::{PeerConnection, RtcConfiguration};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
#[tokio::test]
async fn test_connect_port_forward_integration() -> Result<()> {
let _ = tracing_subscriber::fmt()
.with_env_filter("debug")
.try_init();
let config = RtcConfiguration::default();
let agent_pc = Arc::new(PeerConnection::new(config));
agent_pc.add_transceiver(
rustrtc::MediaKind::Application,
rustrtc::TransceiverDirection::SendRecv,
);
let listener = TcpListener::bind("127.0.0.1:0").await?;
let local_addr = listener.local_addr()?;
let server_url = format!("http://{}", local_addr);
let agent_pc_clone = agent_pc.clone();
tokio::spawn(async move {
loop {
let (mut socket, _) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
let agent_pc = agent_pc_clone.clone();
tokio::spawn(async move {
let mut buf = [0u8; 8192];
let n = match socket.read(&mut buf).await {
Ok(n) if n > 0 => n,
_ => return,
};
let req = String::from_utf8_lossy(&buf[..n]);
if req.contains("GET /rport/iceservers") {
let response = "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: 2\r\n\r\n[]";
socket.write_all(response.as_bytes()).await.unwrap();
return;
}
if req.contains("POST /rport/offer") {
if let Some(idx) = req.find("\r\n\r\n") {
let body = &req[idx + 4..];
if let Ok(offer_msg) = serde_json::from_str::<OfferMessage>(body) {
let offer =
SessionDescription::parse(SdpType::Offer, &offer_msg.offer)
.unwrap();
agent_pc.set_remote_description(offer).await.unwrap();
let answer = agent_pc.create_answer().await.unwrap();
agent_pc.set_local_description(answer.clone()).unwrap();
agent_pc.wait_for_gathering_complete().await;
let answer = agent_pc.local_description().unwrap();
let answer_sdp = answer.to_sdp_string();
let response_json = serde_json::json!({
"uuid": uuid::Uuid::new_v4(),
"offer": offer_msg.offer,
"answer": answer_sdp
});
let response_body = response_json.to_string();
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
response_body.len(),
response_body
);
socket.write_all(response.as_bytes()).await.unwrap();
}
}
}
});
}
});
let client = CliClient::new(server_url, "test-token".to_string(), None);
let client_clone = client.clone();
tokio::spawn(async move {
if let Err(e) = client_clone
.connect_port_forward("gpu03".to_string(), 4023)
.await
{
eprintln!("connect_port_forward failed: {}", e);
}
});
tokio::time::sleep(Duration::from_secs(1)).await;
info!("Connecting to 127.0.0.1:4023");
let _stream = TcpStream::connect("127.0.0.1:4023").await?;
let (dc_tx, dc_rx) = tokio::sync::oneshot::channel();
let agent_pc_clone = agent_pc.clone();
tokio::spawn(async move {
let mut dc_tx = Some(dc_tx);
while let Some(event) = agent_pc_clone.recv().await {
if let PeerConnectionEvent::DataChannel(dc) = event {
if let Some(tx) = dc_tx.take() {
let _ = tx.send(dc);
}
}
}
});
info!("Waiting for Agent PC connection...");
agent_pc.wait_for_connected().await.unwrap();
info!("Agent PC connected!");
let dc = tokio::time::timeout(Duration::from_secs(5), dc_rx).await??;
let (open_tx, open_rx) = tokio::sync::oneshot::channel();
let dc_clone = dc.clone();
tokio::spawn(async move {
let mut open_tx = Some(open_tx);
while let Some(event) = dc_clone.recv().await {
if let DataChannelEvent::Open = event {
if let Some(tx) = open_tx.take() {
let _ = tx.send(());
}
}
}
});
tokio::time::timeout(Duration::from_secs(5), open_rx).await??;
info!("DataChannel Open!");
Ok(())
}
}