1use std::collections::HashMap;
2use std::env;
3use std::time::Duration;
4
5use eyre::{Result, bail, eyre};
6use reqwest::{
7 Response, StatusCode, Url,
8 header::{AUTHORIZATION, HeaderMap, USER_AGENT},
9};
10
11use atuin_common::{
12 api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ATUIN_VERSION},
13 record::{EncryptedData, HostId, Record, RecordIdx},
14 tls::ensure_crypto_provider,
15};
16use atuin_common::{
17 api::{
18 AddHistoryRequest, ChangePasswordRequest, CountResponse, DeleteHistoryRequest,
19 ErrorResponse, LoginRequest, LoginResponse, MeResponse, RegisterResponse, StatusResponse,
20 SyncHistoryResponse,
21 },
22 record::RecordStatus,
23};
24
25use semver::Version;
26use time::OffsetDateTime;
27use time::format_description::well_known::Rfc3339;
28
29use crate::{history::History, sync::hash_str, utils::get_host_user};
30
31static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION"),);
32
33#[derive(Debug, Clone)]
42pub enum AuthToken {
43 Bearer(String),
45 Token(String),
47}
48
49impl AuthToken {
50 fn to_header_value(&self) -> String {
52 match self {
53 AuthToken::Bearer(token) => format!("Bearer {token}"),
54 AuthToken::Token(token) => format!("Token {token}"),
55 }
56 }
57}
58
59pub struct Client<'a> {
60 sync_addr: &'a str,
61 client: reqwest::Client,
62}
63
64fn make_url(address: &str, path: &str) -> Result<String> {
65 let address = if address.ends_with("/") {
68 address
69 } else {
70 &format!("{address}/")
71 };
72
73 let path = path.strip_prefix("/").unwrap_or(path);
75
76 let url = Url::parse(address)
77 .map(|url| url.join(path))?
78 .map_err(|_| eyre!("invalid address"))?;
79
80 Ok(url.to_string())
81}
82
83pub async fn register(
84 address: &str,
85 username: &str,
86 email: &str,
87 password: &str,
88) -> Result<RegisterResponse> {
89 ensure_crypto_provider();
90 let mut map = HashMap::new();
91 map.insert("username", username);
92 map.insert("email", email);
93 map.insert("password", password);
94
95 let url = make_url(address, &format!("/user/{username}"))?;
96 let resp = reqwest::get(url).await?;
97
98 if resp.status().is_success() {
99 bail!("username already in use");
100 }
101
102 let url = make_url(address, "/register")?;
103 let client = reqwest::Client::new();
104 let resp = client
105 .post(url)
106 .header(USER_AGENT, APP_USER_AGENT)
107 .header(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION)
108 .json(&map)
109 .send()
110 .await?;
111 let resp = handle_resp_error(resp).await?;
112
113 if !ensure_version(&resp)? {
114 bail!("could not register user due to version mismatch");
115 }
116
117 let session = resp.json::<RegisterResponse>().await?;
118 Ok(session)
119}
120
121pub async fn login(address: &str, req: LoginRequest) -> Result<LoginResponse> {
122 ensure_crypto_provider();
123 let url = make_url(address, "/login")?;
124 let client = reqwest::Client::new();
125
126 let resp = client
127 .post(url)
128 .header(USER_AGENT, APP_USER_AGENT)
129 .json(&req)
130 .send()
131 .await?;
132 let resp = handle_resp_error(resp).await?;
133
134 if !ensure_version(&resp)? {
135 bail!("Could not login due to version mismatch");
136 }
137
138 let session = resp.json::<LoginResponse>().await?;
139 Ok(session)
140}
141
142#[cfg(feature = "check-update")]
143pub async fn latest_version() -> Result<Version> {
144 use atuin_common::api::IndexResponse;
145
146 ensure_crypto_provider();
147 let url = "https://api.atuin.sh";
148 let client = reqwest::Client::new();
149
150 let resp = client
151 .get(url)
152 .header(USER_AGENT, APP_USER_AGENT)
153 .send()
154 .await?;
155 let resp = handle_resp_error(resp).await?;
156
157 let index = resp.json::<IndexResponse>().await?;
158 let version = Version::parse(index.version.as_str())?;
159
160 Ok(version)
161}
162
163pub fn ensure_version(response: &Response) -> Result<bool> {
164 let version = response.headers().get(ATUIN_HEADER_VERSION);
165
166 let version = if let Some(version) = version {
167 match version.to_str() {
168 Ok(v) => Version::parse(v),
169 Err(e) => bail!("failed to parse server version: {:?}", e),
170 }
171 } else {
172 bail!("Server not reporting its version: it is either too old or unhealthy");
173 }?;
174
175 if version.major < ATUIN_VERSION.major {
177 println!(
178 "Atuin version mismatch! In order to successfully sync, the server needs to run a newer version of Atuin"
179 );
180 println!("Client: {ATUIN_CARGO_VERSION}");
181 println!("Server: {version}");
182
183 return Ok(false);
184 }
185
186 Ok(true)
187}
188
189async fn handle_resp_error(resp: Response) -> Result<Response> {
190 let status = resp.status();
191 let url = resp.url().to_string();
192
193 if status == StatusCode::SERVICE_UNAVAILABLE {
194 bail!(
195 "Service unavailable: check https://status.atuin.sh (or get in touch with your host)"
196 );
197 }
198
199 if status == StatusCode::TOO_MANY_REQUESTS {
200 bail!("Rate limited; please wait before doing that again");
201 }
202
203 if !status.is_success() {
204 if let Ok(error) = resp.json::<ErrorResponse>().await {
205 let reason = error.reason;
206
207 if status.is_client_error() {
208 bail!("Invalid request to the service at {url}, {status} - {reason}.")
209 }
210
211 bail!(
212 "There was an error with the atuin sync service at {url}, server error {status}: {reason}.\nIf the problem persists, contact the host"
213 )
214 }
215
216 bail!(
217 "There was an error with the atuin sync service at {url}, Status {status:?}.\nIf the problem persists, contact the host"
218 )
219 }
220
221 Ok(resp)
222}
223
224impl<'a> Client<'a> {
225 pub fn new(
226 sync_addr: &'a str,
227 auth: AuthToken,
228 connect_timeout: u64,
229 timeout: u64,
230 ) -> Result<Self> {
231 ensure_crypto_provider();
232 let mut headers = HeaderMap::new();
233 headers.insert(AUTHORIZATION, auth.to_header_value().parse()?);
234
235 headers.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse()?);
237
238 Ok(Client {
239 sync_addr,
240 client: reqwest::Client::builder()
241 .user_agent(APP_USER_AGENT)
242 .default_headers(headers)
243 .connect_timeout(Duration::new(connect_timeout, 0))
244 .timeout(Duration::new(timeout, 0))
245 .build()?,
246 })
247 }
248
249 pub async fn count(&self) -> Result<i64> {
250 let url = make_url(self.sync_addr, "/sync/count")?;
251 let url = Url::parse(url.as_str())?;
252
253 let resp = self.client.get(url).send().await?;
254 let resp = handle_resp_error(resp).await?;
255
256 if !ensure_version(&resp)? {
257 bail!("could not sync due to version mismatch");
258 }
259
260 if resp.status() != StatusCode::OK {
261 bail!("failed to get count (are you logged in?)");
262 }
263
264 let count = resp.json::<CountResponse>().await?;
265
266 Ok(count.count)
267 }
268
269 pub async fn status(&self) -> Result<StatusResponse> {
270 let url = make_url(self.sync_addr, "/sync/status")?;
271 let url = Url::parse(url.as_str())?;
272
273 let resp = self.client.get(url).send().await?;
274 let resp = handle_resp_error(resp).await?;
275
276 if !ensure_version(&resp)? {
277 bail!("could not sync due to version mismatch");
278 }
279
280 let status = resp.json::<StatusResponse>().await?;
281
282 Ok(status)
283 }
284
285 pub async fn me(&self) -> Result<MeResponse> {
286 let url = make_url(self.sync_addr, "/api/v0/me")?;
287 let url = Url::parse(url.as_str())?;
288
289 let resp = self.client.get(url).send().await?;
290 let resp = handle_resp_error(resp).await?;
291
292 let status = resp.json::<MeResponse>().await?;
293
294 Ok(status)
295 }
296
297 pub async fn get_history(
298 &self,
299 sync_ts: OffsetDateTime,
300 history_ts: OffsetDateTime,
301 host: Option<String>,
302 ) -> Result<SyncHistoryResponse> {
303 let host = host.unwrap_or_else(|| hash_str(&get_host_user()));
304
305 let url = make_url(
306 self.sync_addr,
307 &format!(
308 "/sync/history?sync_ts={}&history_ts={}&host={}",
309 urlencoding::encode(sync_ts.format(&Rfc3339)?.as_str()),
310 urlencoding::encode(history_ts.format(&Rfc3339)?.as_str()),
311 host,
312 ),
313 )?;
314
315 let resp = self.client.get(url).send().await?;
316 let resp = handle_resp_error(resp).await?;
317
318 let history = resp.json::<SyncHistoryResponse>().await?;
319 Ok(history)
320 }
321
322 pub async fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> {
323 let url = make_url(self.sync_addr, "/history")?;
324 let url = Url::parse(url.as_str())?;
325
326 let resp = self.client.post(url).json(history).send().await?;
327 handle_resp_error(resp).await?;
328
329 Ok(())
330 }
331
332 pub async fn delete_history(&self, h: History) -> Result<()> {
333 let url = make_url(self.sync_addr, "/history")?;
334 let url = Url::parse(url.as_str())?;
335
336 let resp = self
337 .client
338 .delete(url)
339 .json(&DeleteHistoryRequest {
340 client_id: h.id.to_string(),
341 })
342 .send()
343 .await?;
344
345 handle_resp_error(resp).await?;
346
347 Ok(())
348 }
349
350 pub async fn delete_store(&self) -> Result<()> {
351 let url = make_url(self.sync_addr, "/api/v0/store")?;
352 let url = Url::parse(url.as_str())?;
353
354 let resp = self.client.delete(url).send().await?;
355
356 handle_resp_error(resp).await?;
357
358 Ok(())
359 }
360
361 pub async fn post_records(&self, records: &[Record<EncryptedData>]) -> Result<()> {
362 let url = make_url(self.sync_addr, "/api/v0/record")?;
363 let url = Url::parse(url.as_str())?;
364
365 debug!("uploading {} records to {url}", records.len());
366
367 let resp = self.client.post(url).json(records).send().await?;
368 handle_resp_error(resp).await?;
369
370 Ok(())
371 }
372
373 pub async fn next_records(
374 &self,
375 host: HostId,
376 tag: String,
377 start: RecordIdx,
378 count: u64,
379 ) -> Result<Vec<Record<EncryptedData>>> {
380 debug!("fetching record/s from host {}/{}/{}", host.0, tag, start);
381
382 let url = make_url(
383 self.sync_addr,
384 &format!(
385 "/api/v0/record/next?host={}&tag={}&count={}&start={}",
386 host.0, tag, count, start
387 ),
388 )?;
389
390 let url = Url::parse(url.as_str())?;
391
392 let resp = self.client.get(url).send().await?;
393 let resp = handle_resp_error(resp).await?;
394
395 let records = resp.json::<Vec<Record<EncryptedData>>>().await?;
396
397 Ok(records)
398 }
399
400 pub async fn record_status(&self) -> Result<RecordStatus> {
401 let url = make_url(self.sync_addr, "/api/v0/record")?;
402 let url = Url::parse(url.as_str())?;
403
404 let resp = self.client.get(url).send().await?;
405 let resp = handle_resp_error(resp).await?;
406
407 if !ensure_version(&resp)? {
408 bail!("could not sync records due to version mismatch");
409 }
410
411 let index = resp.json().await?;
412
413 debug!("got remote index {index:?}");
414
415 Ok(index)
416 }
417
418 pub async fn delete(&self) -> Result<()> {
419 let url = make_url(self.sync_addr, "/account")?;
420 let url = Url::parse(url.as_str())?;
421
422 let resp = self.client.delete(url).send().await?;
423
424 if resp.status() == 403 {
425 bail!("invalid login details");
426 } else if resp.status() == 200 {
427 Ok(())
428 } else {
429 bail!("Unknown error");
430 }
431 }
432
433 pub async fn change_password(
434 &self,
435 current_password: String,
436 new_password: String,
437 ) -> Result<()> {
438 let url = make_url(self.sync_addr, "/account/password")?;
439 let url = Url::parse(url.as_str())?;
440
441 let resp = self
442 .client
443 .patch(url)
444 .json(&ChangePasswordRequest {
445 current_password,
446 new_password,
447 })
448 .send()
449 .await?;
450
451 if resp.status() == 401 {
452 bail!("current password is incorrect")
453 } else if resp.status() == 403 {
454 bail!("invalid login details");
455 } else if resp.status() == 200 {
456 Ok(())
457 } else {
458 bail!("Unknown error");
459 }
460 }
461}