1use std::{
2 fmt::{Debug, Display},
3 path::Path,
4 sync::Arc,
5};
6
7use tokio::{
8 io::{AsyncRead, AsyncWrite, BufStream},
9 net::{TcpStream, ToSocketAddrs},
10};
11pub use tokio_rustls::rustls::ServerName;
12use tokio_rustls::{client::TlsStream, TlsConnector};
13use webrtc::{
14 api::{
15 interceptor_registry::register_default_interceptors, media_engine::MediaEngine, APIBuilder,
16 API,
17 },
18 data_channel::{data_channel_init::RTCDataChannelInit, RTCDataChannel},
19 ice_transport::{ice_candidate::RTCIceCandidate, ice_server::RTCIceServer},
20 interceptor::registry::Registry,
21 peer_connection::{configuration::RTCConfiguration, RTCPeerConnection},
22};
23
24use crate::{
25 tls::{new_tls_connector, TlsInitError},
26 transport::{IDUpgradeTransport, RecvError, StreamTransport, UpgradeTransport},
27 RTCMessage, STUN_SERVERS,
28};
29
30pub struct UpgradeWebRTCClient<C: UpgradeTransport> {
31 client: C,
32 api: API,
33 config: RTCConfiguration,
34}
35
36impl<C: Debug + UpgradeTransport> Debug for UpgradeWebRTCClient<C> {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 f.debug_struct("UpgradeWebRTCClient")
39 .field("client", &self.client)
40 .finish()
41 }
42}
43
44pub struct PeerAndChannels<'a> {
45 pub peer: RTCPeerConnection,
46 pub channels: Vec<(&'a str, Arc<RTCDataChannel>)>,
47}
48
49#[derive(Debug)]
50pub enum ClientError<DE> {
51 WebRTCError(webrtc::Error),
52 IOError(std::io::Error),
53 DeserializeError(DE),
54 UnexpectedMessage,
55}
56
57impl<DE> From<webrtc::Error> for ClientError<DE> {
58 fn from(value: webrtc::Error) -> Self {
59 ClientError::WebRTCError(value)
60 }
61}
62
63impl<DE> From<std::io::Error> for ClientError<DE> {
64 fn from(value: std::io::Error) -> Self {
65 ClientError::IOError(value)
66 }
67}
68
69impl<DE> From<RecvError<DE>> for ClientError<DE> {
70 fn from(value: RecvError<DE>) -> Self {
71 match value {
72 RecvError::DeserializeError(e) => Self::DeserializeError(e),
73 RecvError::IOError(e) => Self::IOError(e),
74 }
75 }
76}
77
78impl<DE: Display> Display for ClientError<DE> {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 match self {
81 ClientError::WebRTCError(e) => write!(f, "{e}"),
82 ClientError::IOError(e) => write!(f, "{e}"),
83 ClientError::DeserializeError(e) => write!(f, "{e}"),
84 ClientError::UnexpectedMessage => write!(f, "Unexpected WebRTC SDP type"),
85 }
86 }
87}
88
89impl<DE: Display + Debug> std::error::Error for ClientError<DE> {}
90
91impl<C: UpgradeTransport> UpgradeWebRTCClient<C> {
92 pub fn new(client: C) -> Self {
93 let mut m = MediaEngine::default();
94 m.register_default_codecs()
95 .expect("Default codecs should have registered safely");
96
97 let mut registry = Registry::new();
98
99 registry = register_default_interceptors(registry, &mut m)
101 .expect("Default interceptors should have registered safely");
102
103 Self {
104 client,
105 api: APIBuilder::new()
106 .with_media_engine(m)
107 .with_interceptor_registry(registry)
108 .build(),
109 config: RTCConfiguration {
110 ice_servers: vec![RTCIceServer {
111 urls: STUN_SERVERS.map(Into::into).to_vec(),
112 ..Default::default()
113 }],
114 ..Default::default()
115 },
116 }
117 }
118
119 pub async fn upgrade<'a>(
120 &mut self,
121 channel_configs: impl IntoIterator<Item = (&'a str, RTCDataChannelInit)>,
122 ) -> Result<PeerAndChannels<'a>, ClientError<C::DeserializationError>> {
123 let peer = self.api.new_peer_connection(self.config.clone()).await?;
124 let mut channels = vec![];
125
126 for (label, option) in channel_configs {
127 channels.push((label, peer.create_data_channel(label, Some(option)).await?));
128 }
129
130 let offer = peer.create_offer(None).await?;
131 self.client.send_obj(&offer).await?;
132 peer.set_local_description(offer).await?;
133 let mut ices = vec![];
134 let answer = loop {
135 let msg: RTCMessage = self.client.recv_obj().await?;
136 match msg {
137 RTCMessage::SDPAnswer(x) => break x,
138 RTCMessage::ICE(x) => ices.push(x),
139 }
140 };
141 peer.set_remote_description(answer).await?;
142 for ice in ices {
143 peer.add_ice_candidate(ice).await?;
144 }
145
146 let (ice_sender, mut ice_receiver) = tokio::sync::mpsc::channel(3);
147
148 peer.on_ice_candidate(Box::new(move |c: Option<RTCIceCandidate>| {
149 let ice_sender = ice_sender.clone();
150 Box::pin(async move {
151 let _ = ice_sender.send(c).await;
152 })
153 }));
154
155 let mut done_sending_ice = false;
156 let mut done_receiving_ice = false;
157
158 loop {
159 tokio::select! {
160 ice_to_send = ice_receiver.recv() => {
161 let ice_to_send = ice_to_send.unwrap();
162 self.client.send_obj(&ice_to_send).await?;
163 if ice_to_send.is_none() {
164 done_sending_ice = true;
165 if done_receiving_ice {
166 break
167 }
168 };
169 }
170 received_msg = self.client.recv_obj::<Option<RTCMessage>>() => {
171 let received_msg = received_msg?;
172 let received_ice = match received_msg {
173 Some(RTCMessage::ICE(x)) => Some(x),
174 None => None,
175 _ => return Err(ClientError::UnexpectedMessage)
176 };
177 let Some(received_ice) = received_ice else {
178 done_receiving_ice = true;
179 if done_sending_ice {
180 break
181 }
182 continue
183 };
184 peer.add_ice_candidate(received_ice).await?;
185 }
186 }
187 }
188
189 Ok(PeerAndChannels { peer, channels })
190 }
191}
192
193impl<C> UpgradeWebRTCClient<StreamTransport<C>>
194where
195 C: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static + IDUpgradeTransport,
196{
197 pub async fn add_tls(
198 self,
199 domain: ServerName,
200 connector: &TlsConnector,
201 ) -> std::io::Result<UpgradeWebRTCClient<StreamTransport<TlsStream<C>>>> {
202 let stream = connector.connect(domain, self.client.stream).await?;
203 Ok(UpgradeWebRTCClient {
204 client: StreamTransport::from(stream),
205 api: self.api,
206 config: self.config,
207 })
208 }
209
210 pub async fn add_tls_from_config(
211 self,
212 domain: ServerName,
213 root_cert_path: Option<impl AsRef<Path>>,
214 ) -> Result<UpgradeWebRTCClient<StreamTransport<TlsStream<C>>>, TlsInitError> {
215 let connector = new_tls_connector(root_cert_path)?;
216 self.add_tls(domain, &connector).await.map_err(Into::into)
217 }
218}
219
220pub async fn client_new_tcp(
221 addr: impl ToSocketAddrs,
222) -> std::io::Result<UpgradeWebRTCClient<StreamTransport<BufStream<TcpStream>>>> {
223 Ok(UpgradeWebRTCClient::new(
224 BufStream::new(TcpStream::connect(addr).await?).into(),
225 ))
226}
227
228#[cfg(feature = "local_sockets")]
229pub async fn client_new_local_socket<'a>(
230 addr: impl interprocess::local_socket::ToLocalSocketName<'a>,
231) -> std::io::Result<
232 UpgradeWebRTCClient<
233 StreamTransport<
234 BufStream<
235 tokio_util::compat::Compat<interprocess::local_socket::tokio::LocalSocketStream>,
236 >,
237 >,
238 >,
239> {
240 use interprocess::local_socket::tokio::LocalSocketStream;
241 use tokio_util::compat::FuturesAsyncWriteCompatExt;
242
243 Ok(UpgradeWebRTCClient::new(
244 BufStream::new(LocalSocketStream::connect(addr).await?.compat_write()).into(),
245 ))
246}