iterm2_client/
connection.rs1use crate::auth::{self, AppleScriptRunner, Credentials, OsascriptRunner};
8use crate::error::{self, Error, Result};
9use crate::proto;
10use crate::transport;
11use futures_util::{SinkExt, StreamExt};
12use prost::Message;
13use std::collections::HashMap;
14use std::sync::atomic::{AtomicI64, Ordering};
15use std::sync::Arc;
16use std::time::Duration;
17use tokio::io::{AsyncRead, AsyncWrite};
18use tokio::sync::{broadcast, oneshot, Mutex};
19use tokio_tungstenite::tungstenite;
20
21const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
22const NOTIFICATION_CHANNEL_SIZE: usize = 1024;
23const MAX_PENDING_REQUESTS: usize = 4096;
24
25type PendingMap = HashMap<i64, oneshot::Sender<proto::ServerOriginatedMessage>>;
26
27pub struct Connection<S> {
32 inner: Arc<Inner<S>>,
33 shared: Arc<Shared>,
34}
35
36struct Inner<S> {
37 sink: Mutex<transport::WsSink<S>>,
38 _dispatch_handle: tokio::task::JoinHandle<()>,
39}
40
41struct Shared {
42 pending: Mutex<PendingMap>,
43 notification_tx: broadcast::Sender<proto::Notification>,
44 next_id: AtomicI64,
45}
46
47impl<S> Clone for Connection<S> {
48 fn clone(&self) -> Self {
49 Connection {
50 inner: Arc::clone(&self.inner),
51 shared: Arc::clone(&self.shared),
52 }
53 }
54}
55
56impl Connection<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>> {
57 pub async fn connect_tcp(app_name: &str) -> Result<Self> {
61 let credentials = auth::resolve_credentials(app_name, &OsascriptRunner)?;
62 let (sink, source) = transport::connect_tcp(&credentials, app_name).await?;
63 Ok(Self::from_split(sink, source))
64 }
65
66 pub async fn connect_tcp_with_credentials(
68 app_name: &str,
69 credentials: &Credentials,
70 ) -> Result<Self> {
71 let (sink, source) = transport::connect_tcp(credentials, app_name).await?;
72 Ok(Self::from_split(sink, source))
73 }
74}
75
76impl Connection<tokio::net::UnixStream> {
77 pub async fn connect(app_name: &str) -> Result<Self> {
84 let credentials = auth::resolve_credentials(app_name, &OsascriptRunner)?;
85 let (sink, source) = transport::connect_unix(&credentials, app_name).await?;
86 Ok(Self::from_split(sink, source))
87 }
88
89 pub async fn connect_unix(app_name: &str) -> Result<Self> {
91 let credentials = auth::resolve_credentials(app_name, &OsascriptRunner)?;
92 let (sink, source) = transport::connect_unix(&credentials, app_name).await?;
93 Ok(Self::from_split(sink, source))
94 }
95
96 pub async fn connect_with_runner(
98 app_name: &str,
99 runner: &dyn AppleScriptRunner,
100 ) -> Result<Self> {
101 let credentials = auth::resolve_credentials(app_name, runner)?;
102 let (sink, source) = transport::connect_unix(&credentials, app_name).await?;
103 Ok(Self::from_split(sink, source))
104 }
105
106 pub async fn connect_with_credentials(
108 app_name: &str,
109 credentials: &Credentials,
110 ) -> Result<Self> {
111 let (sink, source) = transport::connect_unix(credentials, app_name).await?;
112 Ok(Self::from_split(sink, source))
113 }
114}
115
116impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Connection<S> {
117 pub fn from_split(sink: transport::WsSink<S>, source: transport::WsSource<S>) -> Self {
121 let (notification_tx, _) = broadcast::channel(NOTIFICATION_CHANNEL_SIZE);
122 let shared = Arc::new(Shared {
123 pending: Mutex::new(HashMap::new()),
124 notification_tx: notification_tx.clone(),
125 next_id: AtomicI64::new(1),
126 });
127
128 let shared_for_dispatch = Arc::clone(&shared);
129 let dispatch_handle = tokio::spawn(dispatch_loop(source, shared_for_dispatch));
130
131 let inner = Arc::new(Inner {
132 sink: Mutex::new(sink),
133 _dispatch_handle: dispatch_handle,
134 });
135
136 Connection { inner, shared }
137 }
138
139 pub async fn call(
141 &self,
142 request: proto::ClientOriginatedMessage,
143 ) -> Result<proto::ServerOriginatedMessage> {
144 self.call_with_timeout(request, DEFAULT_TIMEOUT).await
145 }
146
147 pub async fn call_with_timeout(
149 &self,
150 mut request: proto::ClientOriginatedMessage,
151 timeout: Duration,
152 ) -> Result<proto::ServerOriginatedMessage> {
153 let id = self.shared.next_id.fetch_add(1, Ordering::SeqCst);
154
155 let (tx, rx) = oneshot::channel();
156 {
157 let mut pending = self.shared.pending.lock().await;
158 if pending.len() >= MAX_PENDING_REQUESTS {
160 return Err(Error::Api(
161 "Too many pending requests (max 4096)".to_string(),
162 ));
163 }
164 pending.insert(id, tx);
165 }
166
167 request.id = Some(id);
168
169 let mut buf = Vec::new();
171 request
172 .encode(&mut buf)
173 .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::Other, e)))?;
174
175 let send_result = {
176 let mut sink = self.inner.sink.lock().await;
177 SinkExt::<tungstenite::Message>::send(
178 &mut *sink,
179 tungstenite::Message::Binary(buf.into()),
180 )
181 .await
182 };
183
184 if let Err(e) = send_result {
185 let mut pending = self.shared.pending.lock().await;
187 pending.remove(&id);
188 return Err(Error::WebSocket(e));
189 }
190
191 match tokio::time::timeout(timeout, rx).await {
198 Ok(Ok(response)) => {
199 if let Some(proto::server_originated_message::Submessage::Error(err_str)) =
201 &response.submessage
202 {
203 return Err(error::api_error(err_str));
204 }
205 Ok(response)
206 }
207 Ok(Err(_)) => {
208 let mut pending = self.shared.pending.lock().await;
210 pending.remove(&id);
211 Err(Error::ConnectionClosed)
212 }
213 Err(_) => {
214 let mut pending = self.shared.pending.lock().await;
219 pending.remove(&id);
220 Err(Error::Timeout(timeout))
221 }
222 }
223 }
224
225 pub fn subscribe_notifications(&self) -> broadcast::Receiver<proto::Notification> {
230 self.shared.notification_tx.subscribe()
231 }
232}
233
234async fn dispatch_loop<S: AsyncRead + AsyncWrite + Unpin>(
235 mut source: transport::WsSource<S>,
236 shared: Arc<Shared>,
237) {
238 let mut decode_errors: u32 = 0;
239 const MAX_CONSECUTIVE_DECODE_ERRORS: u32 = 100;
240
241 while let Some(msg_result) = source.next().await {
242 let msg = match msg_result {
243 Ok(tungstenite::Message::Binary(data)) => {
244 match proto::ServerOriginatedMessage::decode(data.as_ref()) {
245 Ok(m) => {
246 decode_errors = 0;
247 m
248 }
249 Err(_) => {
250 decode_errors += 1;
251 if decode_errors >= MAX_CONSECUTIVE_DECODE_ERRORS {
252 break;
255 }
256 continue;
257 }
258 }
259 }
260 Ok(tungstenite::Message::Close(_)) => break,
261 Ok(_) => continue,
262 Err(_) => break,
263 };
264
265 if msg.id.is_none() {
267 if let Some(proto::server_originated_message::Submessage::Notification(notif)) =
268 msg.submessage
269 {
270 let _ = shared.notification_tx.send(notif);
271 }
272 continue;
273 }
274
275 if let Some(id) = msg.id {
276 let mut pending = shared.pending.lock().await;
277 if let Some(sender) = pending.remove(&id) {
278 let _ = sender.send(msg);
279 }
280 }
281 }
282}