Skip to main content

atuin_client/
api_client.rs

1use std::collections::HashMap;
2use std::env;
3use std::time::Duration;
4
5use eyre::{Result, bail, eyre};
6use reqwest::{
7    Response, StatusCode, Url,
8    header::{AUTHORIZATION, HeaderMap, USER_AGENT},
9};
10
11use atuin_common::{
12    api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ATUIN_VERSION},
13    record::{EncryptedData, HostId, Record, RecordIdx},
14    tls::ensure_crypto_provider,
15};
16use atuin_common::{
17    api::{
18        AddHistoryRequest, ChangePasswordRequest, CountResponse, DeleteHistoryRequest,
19        ErrorResponse, LoginRequest, LoginResponse, MeResponse, RegisterResponse, StatusResponse,
20        SyncHistoryResponse,
21    },
22    record::RecordStatus,
23};
24
25use semver::Version;
26use time::OffsetDateTime;
27use time::format_description::well_known::Rfc3339;
28
29use crate::{history::History, sync::hash_str, utils::get_host_user};
30
31static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION"),);
32
33/// Authentication token for sync API requests.
34///
35/// The sync API supports two authentication methods:
36/// - `Bearer`: Hub API tokens (for users authenticated via Atuin Hub)
37/// - `Token`: Legacy CLI session tokens (for users registered via CLI or self-hosted)
38///
39/// When both are available, Hub tokens are preferred as they provide unified
40/// authentication across CLI and Hub features.
41#[derive(Debug, Clone)]
42pub enum AuthToken {
43    /// Hub API token, used with "Bearer {token}" header
44    Bearer(String),
45    /// Legacy CLI session token, used with "Token {token}" header
46    Token(String),
47}
48
49impl AuthToken {
50    /// Format the token as an Authorization header value
51    fn to_header_value(&self) -> String {
52        match self {
53            AuthToken::Bearer(token) => format!("Bearer {token}"),
54            AuthToken::Token(token) => format!("Token {token}"),
55        }
56    }
57}
58
59pub struct Client<'a> {
60    sync_addr: &'a str,
61    client: reqwest::Client,
62}
63
64fn make_url(address: &str, path: &str) -> Result<String> {
65    // `join()` expects a trailing `/` in order to join paths
66    // e.g. it treats `http://host:port/subdir` as a file called `subdir`
67    let address = if address.ends_with("/") {
68        address
69    } else {
70        &format!("{address}/")
71    };
72
73    // passing a path with a leading `/` will cause `join()` to replace the entire URL path
74    let path = path.strip_prefix("/").unwrap_or(path);
75
76    let url = Url::parse(address)
77        .map(|url| url.join(path))?
78        .map_err(|_| eyre!("invalid address"))?;
79
80    Ok(url.to_string())
81}
82
83pub async fn register(
84    address: &str,
85    username: &str,
86    email: &str,
87    password: &str,
88) -> Result<RegisterResponse> {
89    ensure_crypto_provider();
90    let mut map = HashMap::new();
91    map.insert("username", username);
92    map.insert("email", email);
93    map.insert("password", password);
94
95    let url = make_url(address, &format!("/user/{username}"))?;
96    let resp = reqwest::get(url).await?;
97
98    if resp.status().is_success() {
99        bail!("username already in use");
100    }
101
102    let url = make_url(address, "/register")?;
103    let client = reqwest::Client::new();
104    let resp = client
105        .post(url)
106        .header(USER_AGENT, APP_USER_AGENT)
107        .header(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION)
108        .json(&map)
109        .send()
110        .await?;
111    let resp = handle_resp_error(resp).await?;
112
113    if !ensure_version(&resp)? {
114        bail!("could not register user due to version mismatch");
115    }
116
117    let session = resp.json::<RegisterResponse>().await?;
118    Ok(session)
119}
120
121pub async fn login(address: &str, req: LoginRequest) -> Result<LoginResponse> {
122    ensure_crypto_provider();
123    let url = make_url(address, "/login")?;
124    let client = reqwest::Client::new();
125
126    let resp = client
127        .post(url)
128        .header(USER_AGENT, APP_USER_AGENT)
129        .json(&req)
130        .send()
131        .await?;
132    let resp = handle_resp_error(resp).await?;
133
134    if !ensure_version(&resp)? {
135        bail!("Could not login due to version mismatch");
136    }
137
138    let session = resp.json::<LoginResponse>().await?;
139    Ok(session)
140}
141
142#[cfg(feature = "check-update")]
143pub async fn latest_version() -> Result<Version> {
144    use atuin_common::api::IndexResponse;
145
146    ensure_crypto_provider();
147    let url = "https://api.atuin.sh";
148    let client = reqwest::Client::new();
149
150    let resp = client
151        .get(url)
152        .header(USER_AGENT, APP_USER_AGENT)
153        .send()
154        .await?;
155    let resp = handle_resp_error(resp).await?;
156
157    let index = resp.json::<IndexResponse>().await?;
158    let version = Version::parse(index.version.as_str())?;
159
160    Ok(version)
161}
162
163pub fn ensure_version(response: &Response) -> Result<bool> {
164    let version = response.headers().get(ATUIN_HEADER_VERSION);
165
166    let version = if let Some(version) = version {
167        match version.to_str() {
168            Ok(v) => Version::parse(v),
169            Err(e) => bail!("failed to parse server version: {:?}", e),
170        }
171    } else {
172        bail!("Server not reporting its version: it is either too old or unhealthy");
173    }?;
174
175    // If the client is newer than the server
176    if version.major < ATUIN_VERSION.major {
177        println!(
178            "Atuin version mismatch! In order to successfully sync, the server needs to run a newer version of Atuin"
179        );
180        println!("Client: {ATUIN_CARGO_VERSION}");
181        println!("Server: {version}");
182
183        return Ok(false);
184    }
185
186    Ok(true)
187}
188
189async fn handle_resp_error(resp: Response) -> Result<Response> {
190    let status = resp.status();
191    let url = resp.url().to_string();
192
193    if status == StatusCode::SERVICE_UNAVAILABLE {
194        bail!(
195            "Service unavailable: check https://status.atuin.sh (or get in touch with your host)"
196        );
197    }
198
199    if status == StatusCode::TOO_MANY_REQUESTS {
200        bail!("Rate limited; please wait before doing that again");
201    }
202
203    if !status.is_success() {
204        if let Ok(error) = resp.json::<ErrorResponse>().await {
205            let reason = error.reason;
206
207            if status.is_client_error() {
208                bail!("Invalid request to the service at {url}, {status} - {reason}.")
209            }
210
211            bail!(
212                "There was an error with the atuin sync service at {url}, server error {status}: {reason}.\nIf the problem persists, contact the host"
213            )
214        }
215
216        bail!(
217            "There was an error with the atuin sync service at {url}, Status {status:?}.\nIf the problem persists, contact the host"
218        )
219    }
220
221    Ok(resp)
222}
223
224impl<'a> Client<'a> {
225    pub fn new(
226        sync_addr: &'a str,
227        auth: AuthToken,
228        connect_timeout: u64,
229        timeout: u64,
230    ) -> Result<Self> {
231        ensure_crypto_provider();
232        let mut headers = HeaderMap::new();
233        headers.insert(AUTHORIZATION, auth.to_header_value().parse()?);
234
235        // used for semver server check
236        headers.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse()?);
237
238        Ok(Client {
239            sync_addr,
240            client: reqwest::Client::builder()
241                .user_agent(APP_USER_AGENT)
242                .default_headers(headers)
243                .connect_timeout(Duration::new(connect_timeout, 0))
244                .timeout(Duration::new(timeout, 0))
245                .build()?,
246        })
247    }
248
249    pub async fn count(&self) -> Result<i64> {
250        let url = make_url(self.sync_addr, "/sync/count")?;
251        let url = Url::parse(url.as_str())?;
252
253        let resp = self.client.get(url).send().await?;
254        let resp = handle_resp_error(resp).await?;
255
256        if !ensure_version(&resp)? {
257            bail!("could not sync due to version mismatch");
258        }
259
260        if resp.status() != StatusCode::OK {
261            bail!("failed to get count (are you logged in?)");
262        }
263
264        let count = resp.json::<CountResponse>().await?;
265
266        Ok(count.count)
267    }
268
269    pub async fn status(&self) -> Result<StatusResponse> {
270        let url = make_url(self.sync_addr, "/sync/status")?;
271        let url = Url::parse(url.as_str())?;
272
273        let resp = self.client.get(url).send().await?;
274        let resp = handle_resp_error(resp).await?;
275
276        if !ensure_version(&resp)? {
277            bail!("could not sync due to version mismatch");
278        }
279
280        let status = resp.json::<StatusResponse>().await?;
281
282        Ok(status)
283    }
284
285    pub async fn me(&self) -> Result<MeResponse> {
286        let url = make_url(self.sync_addr, "/api/v0/me")?;
287        let url = Url::parse(url.as_str())?;
288
289        let resp = self.client.get(url).send().await?;
290        let resp = handle_resp_error(resp).await?;
291
292        let status = resp.json::<MeResponse>().await?;
293
294        Ok(status)
295    }
296
297    pub async fn get_history(
298        &self,
299        sync_ts: OffsetDateTime,
300        history_ts: OffsetDateTime,
301        host: Option<String>,
302    ) -> Result<SyncHistoryResponse> {
303        let host = host.unwrap_or_else(|| hash_str(&get_host_user()));
304
305        let url = make_url(
306            self.sync_addr,
307            &format!(
308                "/sync/history?sync_ts={}&history_ts={}&host={}",
309                urlencoding::encode(sync_ts.format(&Rfc3339)?.as_str()),
310                urlencoding::encode(history_ts.format(&Rfc3339)?.as_str()),
311                host,
312            ),
313        )?;
314
315        let resp = self.client.get(url).send().await?;
316        let resp = handle_resp_error(resp).await?;
317
318        let history = resp.json::<SyncHistoryResponse>().await?;
319        Ok(history)
320    }
321
322    pub async fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> {
323        let url = make_url(self.sync_addr, "/history")?;
324        let url = Url::parse(url.as_str())?;
325
326        let resp = self.client.post(url).json(history).send().await?;
327        handle_resp_error(resp).await?;
328
329        Ok(())
330    }
331
332    pub async fn delete_history(&self, h: History) -> Result<()> {
333        let url = make_url(self.sync_addr, "/history")?;
334        let url = Url::parse(url.as_str())?;
335
336        let resp = self
337            .client
338            .delete(url)
339            .json(&DeleteHistoryRequest {
340                client_id: h.id.to_string(),
341            })
342            .send()
343            .await?;
344
345        handle_resp_error(resp).await?;
346
347        Ok(())
348    }
349
350    pub async fn delete_store(&self) -> Result<()> {
351        let url = make_url(self.sync_addr, "/api/v0/store")?;
352        let url = Url::parse(url.as_str())?;
353
354        let resp = self.client.delete(url).send().await?;
355
356        handle_resp_error(resp).await?;
357
358        Ok(())
359    }
360
361    pub async fn post_records(&self, records: &[Record<EncryptedData>]) -> Result<()> {
362        let url = make_url(self.sync_addr, "/api/v0/record")?;
363        let url = Url::parse(url.as_str())?;
364
365        debug!("uploading {} records to {url}", records.len());
366
367        let resp = self.client.post(url).json(records).send().await?;
368        handle_resp_error(resp).await?;
369
370        Ok(())
371    }
372
373    pub async fn next_records(
374        &self,
375        host: HostId,
376        tag: String,
377        start: RecordIdx,
378        count: u64,
379    ) -> Result<Vec<Record<EncryptedData>>> {
380        debug!("fetching record/s from host {}/{}/{}", host.0, tag, start);
381
382        let url = make_url(
383            self.sync_addr,
384            &format!(
385                "/api/v0/record/next?host={}&tag={}&count={}&start={}",
386                host.0, tag, count, start
387            ),
388        )?;
389
390        let url = Url::parse(url.as_str())?;
391
392        let resp = self.client.get(url).send().await?;
393        let resp = handle_resp_error(resp).await?;
394
395        let records = resp.json::<Vec<Record<EncryptedData>>>().await?;
396
397        Ok(records)
398    }
399
400    pub async fn record_status(&self) -> Result<RecordStatus> {
401        let url = make_url(self.sync_addr, "/api/v0/record")?;
402        let url = Url::parse(url.as_str())?;
403
404        let resp = self.client.get(url).send().await?;
405        let resp = handle_resp_error(resp).await?;
406
407        if !ensure_version(&resp)? {
408            bail!("could not sync records due to version mismatch");
409        }
410
411        let index = resp.json().await?;
412
413        debug!("got remote index {index:?}");
414
415        Ok(index)
416    }
417
418    pub async fn delete(&self) -> Result<()> {
419        let url = make_url(self.sync_addr, "/account")?;
420        let url = Url::parse(url.as_str())?;
421
422        let resp = self.client.delete(url).send().await?;
423
424        if resp.status() == 403 {
425            bail!("invalid login details");
426        } else if resp.status() == 200 {
427            Ok(())
428        } else {
429            bail!("Unknown error");
430        }
431    }
432
433    pub async fn change_password(
434        &self,
435        current_password: String,
436        new_password: String,
437    ) -> Result<()> {
438        let url = make_url(self.sync_addr, "/account/password")?;
439        let url = Url::parse(url.as_str())?;
440
441        let resp = self
442            .client
443            .patch(url)
444            .json(&ChangePasswordRequest {
445                current_password,
446                new_password,
447            })
448            .send()
449            .await?;
450
451        if resp.status() == 401 {
452            bail!("current password is incorrect")
453        } else if resp.status() == 403 {
454            bail!("invalid login details");
455        } else if resp.status() == 200 {
456            Ok(())
457        } else {
458            bail!("Unknown error");
459        }
460    }
461}