1mod responses;
2mod types;
3
4use anyhow::{bail, Result};
5use futures_util::{
6 stream::{SplitSink, SplitStream},
7 SinkExt, StreamExt,
8};
9pub use responses::*;
10use serde::{Deserialize, Serialize};
11use std::{
12 collections::HashMap,
13 sync::{
14 atomic::{AtomicU64, Ordering},
15 Arc, Mutex,
16 },
17};
18use std::{sync::mpsc, time::Duration};
19use tokio::net::TcpStream;
20use tokio::task::JoinHandle;
21use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};
22
23type ChatWebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
24
25type CorrId = String;
26type MessageQueue = mpsc::Receiver<ChatSrvResponse>;
27
28#[derive(Debug)]
29pub struct ChatClient {
30 uri: String,
31 command_counter: AtomicU64,
32 timeout: Duration,
33 write_stream: SplitSink<ChatWebSocket, Message>,
34 listener_handle: JoinHandle<()>,
35 command_waiters: Arc<Mutex<HashMap<CorrId, mpsc::Sender<ChatResponse>>>>,
36 message_queue: MessageQueue, }
38
39#[derive(Serialize, Debug)]
40#[serde(rename_all = "camelCase")]
41struct ChatSrvRequest {
42 corr_id: CorrId,
43 cmd: String,
44}
45
46#[derive(Serialize, Deserialize, Debug, Clone)]
47#[serde(rename_all = "camelCase")]
48pub struct ChatSrvResponse {
49 pub corr_id: Option<CorrId>,
50 pub resp: ChatResponse,
51}
52
53impl ChatClient {
54 pub async fn start(uri: &str) -> Result<ChatClient> {
55 log::debug!("Connecting to SimpleX chat client at URI: {}", uri);
56 let (ws_stream, resp) = connect_async(uri).await?;
57
58 let (write_stream, read_stream) = ws_stream.split();
64
65 log::debug!(
66 "Successfully connected to SimpleX chat client with response: {:?}",
67 resp
68 );
69
70 let command_waiters = Arc::new(Mutex::new(HashMap::new()));
71 let command_waiters_copy = command_waiters.clone();
72 let uri_copy = uri.to_owned();
73 let (tx, rx) = mpsc::channel::<ChatSrvResponse>();
74 let listener_handle = tokio::spawn(async {
75 Self::message_listener(read_stream, uri_copy, command_waiters_copy, tx).await
76 });
77
78 let client = ChatClient {
79 uri: uri.to_owned(),
80 command_counter: AtomicU64::new(0),
81 write_stream,
82 listener_handle,
83 command_waiters,
84 message_queue: rx,
85 timeout: Duration::from_millis(3000),
86 };
87
88 Ok(client)
89 }
90
91 pub async fn message_listener(
92 read_stream: SplitStream<ChatWebSocket>,
93 uri: String,
94 command_waiters: Arc<Mutex<HashMap<CorrId, mpsc::Sender<ChatResponse>>>>,
95 message_queue: mpsc::Sender<ChatSrvResponse>,
96 ) {
97 read_stream
98 .for_each(|message| async {
99 let message = message.unwrap().into_text().unwrap();
100 log::debug!("New message for client '{}': {:?}", uri, message);
101
102 let srv_resp = serde_json::from_str::<ChatSrvResponse>(&message).unwrap();
103
104 log::trace!("Deserialized server resposne: {:?}", srv_resp);
105
106 match srv_resp.corr_id {
107 Some(ref corr_id) => {
108 let command_waiters = command_waiters.lock().unwrap();
111 match command_waiters.get(corr_id) {
112 Some(chan) => {
113 chan.send(srv_resp.resp).unwrap();
114 }
115 None => message_queue.send(srv_resp).unwrap(),
116 }
117 }
118 None => {
119 message_queue.send(srv_resp).unwrap()
122 }
123 };
124 })
125 .await;
126 }
127
128 pub async fn send_command(&mut self, command: &str) -> Result<ChatResponse> {
129 let corr_id = (self.command_counter.fetch_add(1, Ordering::Relaxed) + 1).to_string();
130
131 let (tx, rx) = mpsc::channel::<ChatResponse>();
133
134 {
135 let mut command_waiters = self.command_waiters.lock().unwrap();
136 command_waiters.insert(corr_id.clone(), tx);
137 log::trace!(
138 "Inserted '{}' to command waiters of client '{}': {:?}",
139 corr_id,
140 self.uri,
141 command_waiters
142 );
143 }
144
145 log::debug!(
146 "Sending command `{}` ({}) to '{}'",
147 command,
148 corr_id,
149 self.uri
150 );
151
152 let srv_req = ChatSrvRequest {
153 corr_id: corr_id.to_string(),
154 cmd: command.to_owned(),
155 };
156 let cmd_json = serde_json::to_string(&srv_req)?;
157 log::trace!("Serialized command: {}", cmd_json);
158
159 self.write_stream.send(Message::Text(cmd_json)).await?;
160
161 log::debug!("Command '{}' send successfully to '{}'", corr_id, self.uri);
162
163 log::debug!(
164 "Waiting for response to command '{}' on client '{}'... (timeout = {:?})",
165 corr_id,
166 self.uri,
167 self.timeout
168 );
169
170 let resp = rx.recv_timeout(self.timeout);
171
172 {
173 let mut command_waiters = self.command_waiters.lock().unwrap();
174 command_waiters.remove(&corr_id);
175 log::trace!(
176 "Removed '{}' from command waiters of client '{}': {:?}",
177 corr_id,
178 self.uri,
179 command_waiters
180 );
181 }
182
183 let resp = resp?;
184
185 Ok(resp)
186 }
187
188 pub async fn next_message(&mut self) -> Result<ChatSrvResponse> {
189 Ok(self.message_queue.recv()?)
190 }
191
192 pub async fn api_get_active_user(&mut self) -> Result<User> {
194 let resp = self.send_command("/u").await?;
195 let ChatResponse::ActiveUser { user, .. } = resp else {
196 bail!("The command response does not match the expected type");
197 };
198
199 Ok(user)
200 }
201
202 pub async fn api_chats(&mut self) -> Result<Vec<Chat>> {
203 let resp = self.send_command("/chats").await?;
204 let ChatResponse::Chats { chats, .. } = resp else {
205 bail!("The command response does not match the expected type");
206 };
207
208 Ok(chats)
209 }
210
211 pub async fn api_get_user_address(&mut self) -> Result<Option<String>> {
212 let resp = self.send_command("/show_address").await?;
213 match resp {
214 ChatResponse::UserContactLink { contact_link, .. } => {
215 Ok(Some(contact_link.conn_req_contact))
216 }
217 ChatResponse::ChatCmdError { .. } => Ok(None),
218 _ => {
219 bail!("The command response does not match the expected type");
220 }
221 }
222 }
223
224 pub async fn api_create_user_address(&mut self) -> Result<String> {
225 let resp = self.send_command("/address").await?;
226 let ChatResponse::UserContactLinkCreated {
227 conn_req_contact, ..
228 } = resp
229 else {
230 bail!("The command response does not match the expected type");
231 };
232
233 Ok(conn_req_contact)
234 }
235}
236
237impl Drop for ChatClient {
238 fn drop(&mut self) {
239 self.listener_handle.abort();
240 }
241}