atuin_client/
api_client.rs

1use std::collections::HashMap;
2use std::env;
3use std::time::Duration;
4
5use eyre::{Result, bail, eyre};
6use reqwest::{
7    Response, StatusCode, Url,
8    header::{AUTHORIZATION, HeaderMap, USER_AGENT},
9};
10
11use atuin_common::{
12    api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ATUIN_VERSION},
13    record::{EncryptedData, HostId, Record, RecordIdx},
14};
15use atuin_common::{
16    api::{
17        AddHistoryRequest, ChangePasswordRequest, CountResponse, DeleteHistoryRequest,
18        ErrorResponse, LoginRequest, LoginResponse, MeResponse, RegisterResponse,
19        SendVerificationResponse, StatusResponse, SyncHistoryResponse, VerificationTokenRequest,
20        VerificationTokenResponse,
21    },
22    record::RecordStatus,
23};
24
25use semver::Version;
26use time::OffsetDateTime;
27use time::format_description::well_known::Rfc3339;
28
29use crate::{history::History, sync::hash_str, utils::get_host_user};
30
31static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION"),);
32
33pub struct Client<'a> {
34    sync_addr: &'a str,
35    client: reqwest::Client,
36}
37
38fn make_url(address: &str, path: &str) -> Result<String> {
39    // `join()` expects a trailing `/` in order to join paths
40    // e.g. it treats `http://host:port/subdir` as a file called `subdir`
41    let address = if address.ends_with("/") {
42        address
43    } else {
44        &format!("{address}/")
45    };
46
47    // passing a path with a leading `/` will cause `join()` to replace the entire URL path
48    let path = path.strip_prefix("/").unwrap_or(path);
49
50    let url = Url::parse(address)
51        .map(|url| url.join(path))?
52        .map_err(|_| eyre!("invalid address"))?;
53
54    Ok(url.to_string())
55}
56
57pub async fn register(
58    address: &str,
59    username: &str,
60    email: &str,
61    password: &str,
62) -> Result<RegisterResponse> {
63    let mut map = HashMap::new();
64    map.insert("username", username);
65    map.insert("email", email);
66    map.insert("password", password);
67
68    let url = make_url(address, &format!("/user/{username}"))?;
69    let resp = reqwest::get(url).await?;
70
71    if resp.status().is_success() {
72        bail!("username already in use");
73    }
74
75    let url = make_url(address, "/register")?;
76    let client = reqwest::Client::new();
77    let resp = client
78        .post(url)
79        .header(USER_AGENT, APP_USER_AGENT)
80        .header(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION)
81        .json(&map)
82        .send()
83        .await?;
84    let resp = handle_resp_error(resp).await?;
85
86    if !ensure_version(&resp)? {
87        bail!("could not register user due to version mismatch");
88    }
89
90    let session = resp.json::<RegisterResponse>().await?;
91    Ok(session)
92}
93
94pub async fn login(address: &str, req: LoginRequest) -> Result<LoginResponse> {
95    let url = make_url(address, "/login")?;
96    let client = reqwest::Client::new();
97
98    let resp = client
99        .post(url)
100        .header(USER_AGENT, APP_USER_AGENT)
101        .json(&req)
102        .send()
103        .await?;
104    let resp = handle_resp_error(resp).await?;
105
106    if !ensure_version(&resp)? {
107        bail!("Could not login due to version mismatch");
108    }
109
110    let session = resp.json::<LoginResponse>().await?;
111    Ok(session)
112}
113
114#[cfg(feature = "check-update")]
115pub async fn latest_version() -> Result<Version> {
116    use atuin_common::api::IndexResponse;
117
118    let url = "https://api.atuin.sh";
119    let client = reqwest::Client::new();
120
121    let resp = client
122        .get(url)
123        .header(USER_AGENT, APP_USER_AGENT)
124        .send()
125        .await?;
126    let resp = handle_resp_error(resp).await?;
127
128    let index = resp.json::<IndexResponse>().await?;
129    let version = Version::parse(index.version.as_str())?;
130
131    Ok(version)
132}
133
134pub fn ensure_version(response: &Response) -> Result<bool> {
135    let version = response.headers().get(ATUIN_HEADER_VERSION);
136
137    let version = if let Some(version) = version {
138        match version.to_str() {
139            Ok(v) => Version::parse(v),
140            Err(e) => bail!("failed to parse server version: {:?}", e),
141        }
142    } else {
143        bail!("Server not reporting its version: it is either too old or unhealthy");
144    }?;
145
146    // If the client is newer than the server
147    if version.major < ATUIN_VERSION.major {
148        println!(
149            "Atuin version mismatch! In order to successfully sync, the server needs to run a newer version of Atuin"
150        );
151        println!("Client: {ATUIN_CARGO_VERSION}");
152        println!("Server: {version}");
153
154        return Ok(false);
155    }
156
157    Ok(true)
158}
159
160async fn handle_resp_error(resp: Response) -> Result<Response> {
161    let status = resp.status();
162
163    if status == StatusCode::SERVICE_UNAVAILABLE {
164        bail!(
165            "Service unavailable: check https://status.atuin.sh (or get in touch with your host)"
166        );
167    }
168
169    if status == StatusCode::TOO_MANY_REQUESTS {
170        bail!("Rate limited; please wait before doing that again");
171    }
172
173    if !status.is_success() {
174        if let Ok(error) = resp.json::<ErrorResponse>().await {
175            let reason = error.reason;
176
177            if status.is_client_error() {
178                bail!("Invalid request to the service: {status} - {reason}.")
179            }
180
181            bail!(
182                "There was an error with the atuin sync service, server error {status}: {reason}.\nIf the problem persists, contact the host"
183            )
184        }
185
186        bail!(
187            "There was an error with the atuin sync service: Status {status:?}.\nIf the problem persists, contact the host"
188        )
189    }
190
191    Ok(resp)
192}
193
194impl<'a> Client<'a> {
195    pub fn new(
196        sync_addr: &'a str,
197        session_token: &str,
198        connect_timeout: u64,
199        timeout: u64,
200    ) -> Result<Self> {
201        let mut headers = HeaderMap::new();
202        headers.insert(AUTHORIZATION, format!("Token {session_token}").parse()?);
203
204        // used for semver server check
205        headers.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse()?);
206
207        Ok(Client {
208            sync_addr,
209            client: reqwest::Client::builder()
210                .user_agent(APP_USER_AGENT)
211                .default_headers(headers)
212                .connect_timeout(Duration::new(connect_timeout, 0))
213                .timeout(Duration::new(timeout, 0))
214                .build()?,
215        })
216    }
217
218    pub async fn count(&self) -> Result<i64> {
219        let url = make_url(self.sync_addr, "/sync/count")?;
220        let url = Url::parse(url.as_str())?;
221
222        let resp = self.client.get(url).send().await?;
223        let resp = handle_resp_error(resp).await?;
224
225        if !ensure_version(&resp)? {
226            bail!("could not sync due to version mismatch");
227        }
228
229        if resp.status() != StatusCode::OK {
230            bail!("failed to get count (are you logged in?)");
231        }
232
233        let count = resp.json::<CountResponse>().await?;
234
235        Ok(count.count)
236    }
237
238    pub async fn status(&self) -> Result<StatusResponse> {
239        let url = make_url(self.sync_addr, "/sync/status")?;
240        let url = Url::parse(url.as_str())?;
241
242        let resp = self.client.get(url).send().await?;
243        let resp = handle_resp_error(resp).await?;
244
245        if !ensure_version(&resp)? {
246            bail!("could not sync due to version mismatch");
247        }
248
249        let status = resp.json::<StatusResponse>().await?;
250
251        Ok(status)
252    }
253
254    pub async fn me(&self) -> Result<MeResponse> {
255        let url = make_url(self.sync_addr, "/api/v0/me")?;
256        let url = Url::parse(url.as_str())?;
257
258        let resp = self.client.get(url).send().await?;
259        let resp = handle_resp_error(resp).await?;
260
261        let status = resp.json::<MeResponse>().await?;
262
263        Ok(status)
264    }
265
266    pub async fn get_history(
267        &self,
268        sync_ts: OffsetDateTime,
269        history_ts: OffsetDateTime,
270        host: Option<String>,
271    ) -> Result<SyncHistoryResponse> {
272        let host = host.unwrap_or_else(|| hash_str(&get_host_user()));
273
274        let url = make_url(
275            self.sync_addr,
276            &format!(
277                "/sync/history?sync_ts={}&history_ts={}&host={}",
278                urlencoding::encode(sync_ts.format(&Rfc3339)?.as_str()),
279                urlencoding::encode(history_ts.format(&Rfc3339)?.as_str()),
280                host,
281            ),
282        )?;
283
284        let resp = self.client.get(url).send().await?;
285        let resp = handle_resp_error(resp).await?;
286
287        let history = resp.json::<SyncHistoryResponse>().await?;
288        Ok(history)
289    }
290
291    pub async fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> {
292        let url = make_url(self.sync_addr, "/history")?;
293        let url = Url::parse(url.as_str())?;
294
295        let resp = self.client.post(url).json(history).send().await?;
296        handle_resp_error(resp).await?;
297
298        Ok(())
299    }
300
301    pub async fn delete_history(&self, h: History) -> Result<()> {
302        let url = make_url(self.sync_addr, "/history")?;
303        let url = Url::parse(url.as_str())?;
304
305        let resp = self
306            .client
307            .delete(url)
308            .json(&DeleteHistoryRequest {
309                client_id: h.id.to_string(),
310            })
311            .send()
312            .await?;
313
314        handle_resp_error(resp).await?;
315
316        Ok(())
317    }
318
319    pub async fn delete_store(&self) -> Result<()> {
320        let url = make_url(self.sync_addr, "/api/v0/store")?;
321        let url = Url::parse(url.as_str())?;
322
323        let resp = self.client.delete(url).send().await?;
324
325        handle_resp_error(resp).await?;
326
327        Ok(())
328    }
329
330    pub async fn post_records(&self, records: &[Record<EncryptedData>]) -> Result<()> {
331        let url = make_url(self.sync_addr, "/api/v0/record")?;
332        let url = Url::parse(url.as_str())?;
333
334        debug!("uploading {} records to {url}", records.len());
335
336        let resp = self.client.post(url).json(records).send().await?;
337        handle_resp_error(resp).await?;
338
339        Ok(())
340    }
341
342    pub async fn next_records(
343        &self,
344        host: HostId,
345        tag: String,
346        start: RecordIdx,
347        count: u64,
348    ) -> Result<Vec<Record<EncryptedData>>> {
349        debug!("fetching record/s from host {}/{}/{}", host.0, tag, start);
350
351        let url = make_url(
352            self.sync_addr,
353            &format!(
354                "/api/v0/record/next?host={}&tag={}&count={}&start={}",
355                host.0, tag, count, start
356            ),
357        )?;
358
359        let url = Url::parse(url.as_str())?;
360
361        let resp = self.client.get(url).send().await?;
362        let resp = handle_resp_error(resp).await?;
363
364        let records = resp.json::<Vec<Record<EncryptedData>>>().await?;
365
366        Ok(records)
367    }
368
369    pub async fn record_status(&self) -> Result<RecordStatus> {
370        let url = make_url(self.sync_addr, "/api/v0/record")?;
371        let url = Url::parse(url.as_str())?;
372
373        let resp = self.client.get(url).send().await?;
374        let resp = handle_resp_error(resp).await?;
375
376        if !ensure_version(&resp)? {
377            bail!("could not sync records due to version mismatch");
378        }
379
380        let index = resp.json().await?;
381
382        debug!("got remote index {index:?}");
383
384        Ok(index)
385    }
386
387    pub async fn delete(&self) -> Result<()> {
388        let url = make_url(self.sync_addr, "/account")?;
389        let url = Url::parse(url.as_str())?;
390
391        let resp = self.client.delete(url).send().await?;
392
393        if resp.status() == 403 {
394            bail!("invalid login details");
395        } else if resp.status() == 200 {
396            Ok(())
397        } else {
398            bail!("Unknown error");
399        }
400    }
401
402    pub async fn change_password(
403        &self,
404        current_password: String,
405        new_password: String,
406    ) -> Result<()> {
407        let url = make_url(self.sync_addr, "/account/password")?;
408        let url = Url::parse(url.as_str())?;
409
410        let resp = self
411            .client
412            .patch(url)
413            .json(&ChangePasswordRequest {
414                current_password,
415                new_password,
416            })
417            .send()
418            .await?;
419
420        if resp.status() == 401 {
421            bail!("current password is incorrect")
422        } else if resp.status() == 403 {
423            bail!("invalid login details");
424        } else if resp.status() == 200 {
425            Ok(())
426        } else {
427            bail!("Unknown error");
428        }
429    }
430
431    // Either request a verification email if token is null, or validate a token
432    pub async fn verify(&self, token: Option<String>) -> Result<(bool, bool)> {
433        // could dedupe this a bit, but it's simple at the moment
434        let (email_sent, verified) = if let Some(token) = token {
435            let url = make_url(self.sync_addr, "/api/v0/account/verify")?;
436            let url = Url::parse(url.as_str())?;
437
438            let resp = self
439                .client
440                .post(url)
441                .json(&VerificationTokenRequest { token })
442                .send()
443                .await?;
444            let resp = handle_resp_error(resp).await?;
445            let resp = resp.json::<VerificationTokenResponse>().await?;
446
447            (false, resp.verified)
448        } else {
449            let url = make_url(self.sync_addr, "/api/v0/account/send-verification")?;
450            let url = Url::parse(url.as_str())?;
451
452            let resp = self.client.post(url).send().await?;
453            let resp = handle_resp_error(resp).await?;
454            let resp = resp.json::<SendVerificationResponse>().await?;
455
456            (resp.email_sent, resp.verified)
457        };
458
459        Ok((email_sent, verified))
460    }
461}