1use std::collections::HashMap;
2use std::env;
3use std::time::Duration;
4
5use eyre::{Result, bail, eyre};
6use reqwest::{
7 Response, StatusCode, Url,
8 header::{AUTHORIZATION, HeaderMap, USER_AGENT},
9};
10
11use atuin_common::{
12 api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ATUIN_VERSION},
13 record::{EncryptedData, HostId, Record, RecordIdx},
14};
15use atuin_common::{
16 api::{
17 AddHistoryRequest, ChangePasswordRequest, CountResponse, DeleteHistoryRequest,
18 ErrorResponse, LoginRequest, LoginResponse, MeResponse, RegisterResponse,
19 SendVerificationResponse, StatusResponse, SyncHistoryResponse, VerificationTokenRequest,
20 VerificationTokenResponse,
21 },
22 record::RecordStatus,
23};
24
25use semver::Version;
26use time::OffsetDateTime;
27use time::format_description::well_known::Rfc3339;
28
29use crate::{history::History, sync::hash_str, utils::get_host_user};
30
31static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION"),);
32
33pub struct Client<'a> {
34 sync_addr: &'a str,
35 client: reqwest::Client,
36}
37
38fn make_url(address: &str, path: &str) -> Result<String> {
39 let address = if address.ends_with("/") {
42 address
43 } else {
44 &format!("{address}/")
45 };
46
47 let path = path.strip_prefix("/").unwrap_or(path);
49
50 let url = Url::parse(address)
51 .map(|url| url.join(path))?
52 .map_err(|_| eyre!("invalid address"))?;
53
54 Ok(url.to_string())
55}
56
57pub async fn register(
58 address: &str,
59 username: &str,
60 email: &str,
61 password: &str,
62) -> Result<RegisterResponse> {
63 let mut map = HashMap::new();
64 map.insert("username", username);
65 map.insert("email", email);
66 map.insert("password", password);
67
68 let url = make_url(address, &format!("/user/{username}"))?;
69 let resp = reqwest::get(url).await?;
70
71 if resp.status().is_success() {
72 bail!("username already in use");
73 }
74
75 let url = make_url(address, "/register")?;
76 let client = reqwest::Client::new();
77 let resp = client
78 .post(url)
79 .header(USER_AGENT, APP_USER_AGENT)
80 .header(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION)
81 .json(&map)
82 .send()
83 .await?;
84 let resp = handle_resp_error(resp).await?;
85
86 if !ensure_version(&resp)? {
87 bail!("could not register user due to version mismatch");
88 }
89
90 let session = resp.json::<RegisterResponse>().await?;
91 Ok(session)
92}
93
94pub async fn login(address: &str, req: LoginRequest) -> Result<LoginResponse> {
95 let url = make_url(address, "/login")?;
96 let client = reqwest::Client::new();
97
98 let resp = client
99 .post(url)
100 .header(USER_AGENT, APP_USER_AGENT)
101 .json(&req)
102 .send()
103 .await?;
104 let resp = handle_resp_error(resp).await?;
105
106 if !ensure_version(&resp)? {
107 bail!("Could not login due to version mismatch");
108 }
109
110 let session = resp.json::<LoginResponse>().await?;
111 Ok(session)
112}
113
114#[cfg(feature = "check-update")]
115pub async fn latest_version() -> Result<Version> {
116 use atuin_common::api::IndexResponse;
117
118 let url = "https://api.atuin.sh";
119 let client = reqwest::Client::new();
120
121 let resp = client
122 .get(url)
123 .header(USER_AGENT, APP_USER_AGENT)
124 .send()
125 .await?;
126 let resp = handle_resp_error(resp).await?;
127
128 let index = resp.json::<IndexResponse>().await?;
129 let version = Version::parse(index.version.as_str())?;
130
131 Ok(version)
132}
133
134pub fn ensure_version(response: &Response) -> Result<bool> {
135 let version = response.headers().get(ATUIN_HEADER_VERSION);
136
137 let version = if let Some(version) = version {
138 match version.to_str() {
139 Ok(v) => Version::parse(v),
140 Err(e) => bail!("failed to parse server version: {:?}", e),
141 }
142 } else {
143 bail!("Server not reporting its version: it is either too old or unhealthy");
144 }?;
145
146 if version.major < ATUIN_VERSION.major {
148 println!(
149 "Atuin version mismatch! In order to successfully sync, the server needs to run a newer version of Atuin"
150 );
151 println!("Client: {ATUIN_CARGO_VERSION}");
152 println!("Server: {version}");
153
154 return Ok(false);
155 }
156
157 Ok(true)
158}
159
160async fn handle_resp_error(resp: Response) -> Result<Response> {
161 let status = resp.status();
162
163 if status == StatusCode::SERVICE_UNAVAILABLE {
164 bail!(
165 "Service unavailable: check https://status.atuin.sh (or get in touch with your host)"
166 );
167 }
168
169 if status == StatusCode::TOO_MANY_REQUESTS {
170 bail!("Rate limited; please wait before doing that again");
171 }
172
173 if !status.is_success() {
174 if let Ok(error) = resp.json::<ErrorResponse>().await {
175 let reason = error.reason;
176
177 if status.is_client_error() {
178 bail!("Invalid request to the service: {status} - {reason}.")
179 }
180
181 bail!(
182 "There was an error with the atuin sync service, server error {status}: {reason}.\nIf the problem persists, contact the host"
183 )
184 }
185
186 bail!(
187 "There was an error with the atuin sync service: Status {status:?}.\nIf the problem persists, contact the host"
188 )
189 }
190
191 Ok(resp)
192}
193
194impl<'a> Client<'a> {
195 pub fn new(
196 sync_addr: &'a str,
197 session_token: &str,
198 connect_timeout: u64,
199 timeout: u64,
200 ) -> Result<Self> {
201 let mut headers = HeaderMap::new();
202 headers.insert(AUTHORIZATION, format!("Token {session_token}").parse()?);
203
204 headers.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse()?);
206
207 Ok(Client {
208 sync_addr,
209 client: reqwest::Client::builder()
210 .user_agent(APP_USER_AGENT)
211 .default_headers(headers)
212 .connect_timeout(Duration::new(connect_timeout, 0))
213 .timeout(Duration::new(timeout, 0))
214 .build()?,
215 })
216 }
217
218 pub async fn count(&self) -> Result<i64> {
219 let url = make_url(self.sync_addr, "/sync/count")?;
220 let url = Url::parse(url.as_str())?;
221
222 let resp = self.client.get(url).send().await?;
223 let resp = handle_resp_error(resp).await?;
224
225 if !ensure_version(&resp)? {
226 bail!("could not sync due to version mismatch");
227 }
228
229 if resp.status() != StatusCode::OK {
230 bail!("failed to get count (are you logged in?)");
231 }
232
233 let count = resp.json::<CountResponse>().await?;
234
235 Ok(count.count)
236 }
237
238 pub async fn status(&self) -> Result<StatusResponse> {
239 let url = make_url(self.sync_addr, "/sync/status")?;
240 let url = Url::parse(url.as_str())?;
241
242 let resp = self.client.get(url).send().await?;
243 let resp = handle_resp_error(resp).await?;
244
245 if !ensure_version(&resp)? {
246 bail!("could not sync due to version mismatch");
247 }
248
249 let status = resp.json::<StatusResponse>().await?;
250
251 Ok(status)
252 }
253
254 pub async fn me(&self) -> Result<MeResponse> {
255 let url = make_url(self.sync_addr, "/api/v0/me")?;
256 let url = Url::parse(url.as_str())?;
257
258 let resp = self.client.get(url).send().await?;
259 let resp = handle_resp_error(resp).await?;
260
261 let status = resp.json::<MeResponse>().await?;
262
263 Ok(status)
264 }
265
266 pub async fn get_history(
267 &self,
268 sync_ts: OffsetDateTime,
269 history_ts: OffsetDateTime,
270 host: Option<String>,
271 ) -> Result<SyncHistoryResponse> {
272 let host = host.unwrap_or_else(|| hash_str(&get_host_user()));
273
274 let url = make_url(
275 self.sync_addr,
276 &format!(
277 "/sync/history?sync_ts={}&history_ts={}&host={}",
278 urlencoding::encode(sync_ts.format(&Rfc3339)?.as_str()),
279 urlencoding::encode(history_ts.format(&Rfc3339)?.as_str()),
280 host,
281 ),
282 )?;
283
284 let resp = self.client.get(url).send().await?;
285 let resp = handle_resp_error(resp).await?;
286
287 let history = resp.json::<SyncHistoryResponse>().await?;
288 Ok(history)
289 }
290
291 pub async fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> {
292 let url = make_url(self.sync_addr, "/history")?;
293 let url = Url::parse(url.as_str())?;
294
295 let resp = self.client.post(url).json(history).send().await?;
296 handle_resp_error(resp).await?;
297
298 Ok(())
299 }
300
301 pub async fn delete_history(&self, h: History) -> Result<()> {
302 let url = make_url(self.sync_addr, "/history")?;
303 let url = Url::parse(url.as_str())?;
304
305 let resp = self
306 .client
307 .delete(url)
308 .json(&DeleteHistoryRequest {
309 client_id: h.id.to_string(),
310 })
311 .send()
312 .await?;
313
314 handle_resp_error(resp).await?;
315
316 Ok(())
317 }
318
319 pub async fn delete_store(&self) -> Result<()> {
320 let url = make_url(self.sync_addr, "/api/v0/store")?;
321 let url = Url::parse(url.as_str())?;
322
323 let resp = self.client.delete(url).send().await?;
324
325 handle_resp_error(resp).await?;
326
327 Ok(())
328 }
329
330 pub async fn post_records(&self, records: &[Record<EncryptedData>]) -> Result<()> {
331 let url = make_url(self.sync_addr, "/api/v0/record")?;
332 let url = Url::parse(url.as_str())?;
333
334 debug!("uploading {} records to {url}", records.len());
335
336 let resp = self.client.post(url).json(records).send().await?;
337 handle_resp_error(resp).await?;
338
339 Ok(())
340 }
341
342 pub async fn next_records(
343 &self,
344 host: HostId,
345 tag: String,
346 start: RecordIdx,
347 count: u64,
348 ) -> Result<Vec<Record<EncryptedData>>> {
349 debug!("fetching record/s from host {}/{}/{}", host.0, tag, start);
350
351 let url = make_url(
352 self.sync_addr,
353 &format!(
354 "/api/v0/record/next?host={}&tag={}&count={}&start={}",
355 host.0, tag, count, start
356 ),
357 )?;
358
359 let url = Url::parse(url.as_str())?;
360
361 let resp = self.client.get(url).send().await?;
362 let resp = handle_resp_error(resp).await?;
363
364 let records = resp.json::<Vec<Record<EncryptedData>>>().await?;
365
366 Ok(records)
367 }
368
369 pub async fn record_status(&self) -> Result<RecordStatus> {
370 let url = make_url(self.sync_addr, "/api/v0/record")?;
371 let url = Url::parse(url.as_str())?;
372
373 let resp = self.client.get(url).send().await?;
374 let resp = handle_resp_error(resp).await?;
375
376 if !ensure_version(&resp)? {
377 bail!("could not sync records due to version mismatch");
378 }
379
380 let index = resp.json().await?;
381
382 debug!("got remote index {index:?}");
383
384 Ok(index)
385 }
386
387 pub async fn delete(&self) -> Result<()> {
388 let url = make_url(self.sync_addr, "/account")?;
389 let url = Url::parse(url.as_str())?;
390
391 let resp = self.client.delete(url).send().await?;
392
393 if resp.status() == 403 {
394 bail!("invalid login details");
395 } else if resp.status() == 200 {
396 Ok(())
397 } else {
398 bail!("Unknown error");
399 }
400 }
401
402 pub async fn change_password(
403 &self,
404 current_password: String,
405 new_password: String,
406 ) -> Result<()> {
407 let url = make_url(self.sync_addr, "/account/password")?;
408 let url = Url::parse(url.as_str())?;
409
410 let resp = self
411 .client
412 .patch(url)
413 .json(&ChangePasswordRequest {
414 current_password,
415 new_password,
416 })
417 .send()
418 .await?;
419
420 if resp.status() == 401 {
421 bail!("current password is incorrect")
422 } else if resp.status() == 403 {
423 bail!("invalid login details");
424 } else if resp.status() == 200 {
425 Ok(())
426 } else {
427 bail!("Unknown error");
428 }
429 }
430
431 pub async fn verify(&self, token: Option<String>) -> Result<(bool, bool)> {
433 let (email_sent, verified) = if let Some(token) = token {
435 let url = make_url(self.sync_addr, "/api/v0/account/verify")?;
436 let url = Url::parse(url.as_str())?;
437
438 let resp = self
439 .client
440 .post(url)
441 .json(&VerificationTokenRequest { token })
442 .send()
443 .await?;
444 let resp = handle_resp_error(resp).await?;
445 let resp = resp.json::<VerificationTokenResponse>().await?;
446
447 (false, resp.verified)
448 } else {
449 let url = make_url(self.sync_addr, "/api/v0/account/send-verification")?;
450 let url = Url::parse(url.as_str())?;
451
452 let resp = self.client.post(url).send().await?;
453 let resp = handle_resp_error(resp).await?;
454 let resp = resp.json::<SendVerificationResponse>().await?;
455
456 (resp.email_sent, resp.verified)
457 };
458
459 Ok((email_sent, verified))
460 }
461}