1#![allow(clippy::needless_return)]
7
8use eventsource_stream::Eventsource;
9pub use futures::Stream;
10use futures::StreamExt;
11use parking_lot::RwLock;
12use reqwest::header::{HeaderMap, HeaderValue};
13use reqwest::Method;
14use std::borrow::Cow;
15use std::sync::Arc;
16use thiserror::Error;
17
18use serde::de::DeserializeOwned;
19use serde::{Deserialize, Serialize};
20
21#[derive(Debug, Error)]
22pub enum Error {
23 #[error("Reqwest: {0}")]
24 Reqwest(#[from] reqwest::Error),
25 #[error("JSON: {0}")]
26 Json(#[from] serde_json::Error),
27 #[error("JWT: {0}")]
28 Jwt(#[from] jsonwebtoken::errors::Error),
29 #[error("Url: {0}")]
30 Url(#[from] url::ParseError),
31 #[error("Precondition: {0}")]
32 Precondition(&'static str),
33}
34
35#[derive(Clone, Debug)]
37pub struct User {
38 pub sub: String,
39 pub email: String,
40}
41
42#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
46pub struct Tokens {
47 pub auth_token: String,
48 pub refresh_token: Option<String>,
49 pub csrf_token: Option<String>,
50}
51
52#[derive(Clone, Debug, Default)]
53pub struct Pagination {
54 pub cursor: Option<String>,
55 pub limit: Option<usize>,
56}
57
58impl Pagination {
59 pub fn with(limit: impl Into<Option<usize>>, cursor: impl Into<Option<String>>) -> Pagination {
60 return Pagination {
61 limit: limit.into(),
62 cursor: cursor.into(),
63 };
64 }
65
66 pub fn with_limit(limit: impl Into<Option<usize>>) -> Pagination {
67 return Pagination::with(limit, None);
68 }
69
70 pub fn with_cursor(cursor: impl Into<Option<String>>) -> Pagination {
71 return Pagination::with(None, cursor);
72 }
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
76pub enum DbEvent {
77 Update(Option<serde_json::Value>),
78 Insert(Option<serde_json::Value>),
79 Delete(Option<serde_json::Value>),
80 Error(String),
81}
82
83#[derive(Clone, Debug, Deserialize)]
84pub struct ListResponse<T> {
85 pub cursor: Option<String>,
86 pub total_count: Option<usize>,
87 pub records: Vec<T>,
88}
89
90pub trait RecordId<'a> {
91 fn serialized_id(self) -> Cow<'a, str>;
92}
93
94impl RecordId<'_> for String {
95 fn serialized_id(self) -> Cow<'static, str> {
96 return Cow::Owned(self);
97 }
98}
99
100impl<'a> RecordId<'a> for &'a String {
101 fn serialized_id(self) -> Cow<'a, str> {
102 return Cow::Borrowed(self);
103 }
104}
105
106impl<'a> RecordId<'a> for &'a str {
107 fn serialized_id(self) -> Cow<'a, str> {
108 return Cow::Borrowed(self);
109 }
110}
111
112impl RecordId<'_> for i64 {
113 fn serialized_id(self) -> Cow<'static, str> {
114 return Cow::Owned(self.to_string());
115 }
116}
117
118pub trait ReadArgumentsTrait<'a> {
119 fn serialized_id(self) -> Cow<'a, str>;
120 fn expand(&self) -> Option<&'a [&'a str]>;
121}
122
123impl<'a, T: RecordId<'a>> ReadArgumentsTrait<'a> for T {
124 fn serialized_id(self) -> Cow<'a, str> {
125 return self.serialized_id();
126 }
127
128 fn expand(&self) -> Option<&'a [&'a str]> {
129 return None;
130 }
131}
132
133#[derive(Debug, Default)]
134pub struct ReadArguments<'a, T: RecordId<'a>> {
135 pub id: T,
136 pub expand: Option<&'a [&'a str]>,
137}
138
139impl<'a, T: RecordId<'a>> ReadArgumentsTrait<'a> for ReadArguments<'a, T> {
140 fn serialized_id(self) -> Cow<'a, str> {
141 return self.id.serialized_id();
142 }
143
144 fn expand(&self) -> Option<&'a [&'a str]> {
145 return self.expand;
146 }
147}
148
149struct ThinClient {
150 client: reqwest::Client,
151 url: url::Url,
152}
153
154impl ThinClient {
155 async fn fetch<T: Serialize>(
156 &self,
157 path: &str,
158 headers: HeaderMap,
159 method: Method,
160 body: Option<&T>,
161 query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
162 ) -> Result<reqwest::Response, Error> {
163 assert!(path.starts_with("/"));
164
165 let mut url = self.url.clone();
166 url.set_path(path);
167
168 if let Some(query_params) = query_params {
169 let mut params = url.query_pairs_mut();
170 for (key, value) in query_params {
171 params.append_pair(key, value);
172 }
173 }
174
175 let request = {
176 let mut builder = self.client.request(method, url).headers(headers);
177 if let Some(ref body) = body {
178 builder = builder.body(serde_json::to_string(body)?);
179 }
180 builder.build()?
181 };
182
183 return Ok(self.client.execute(request).await?);
184 }
185}
186
187#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
188struct JwtTokenClaims {
189 sub: String,
190 iat: i64,
191 exp: i64,
192 email: String,
193 csrf_token: String,
194}
195
196fn decode_auth_token<T: DeserializeOwned>(token: &str) -> Result<T, Error> {
197 let decoding_key = jsonwebtoken::DecodingKey::from_secret(&[]);
198
199 let mut validation = jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::EdDSA);
201 validation.insecure_disable_signature_validation();
202
203 return Ok(jsonwebtoken::decode::<T>(token, &decoding_key, &validation).map(|data| data.claims)?);
204}
205
206#[derive(Clone)]
207pub struct RecordApi {
208 client: Arc<ClientState>,
209 name: String,
210}
211
212#[derive(Default)]
213pub struct ListArguments<'a> {
214 pub pagination: Pagination,
215 pub order: Option<&'a [&'a str]>,
216 pub filters: Option<&'a [&'a str]>,
217 pub expand: Option<&'a [&'a str]>,
218 pub count: bool,
219}
220
221impl RecordApi {
222 pub async fn list<T: DeserializeOwned>(
223 &self,
224 args: ListArguments<'_>,
225 ) -> Result<ListResponse<T>, Error> {
226 let mut params: Vec<(Cow<'static, str>, Cow<'static, str>)> = vec![];
227 if let Some(cursor) = args.pagination.cursor {
228 params.push((Cow::Borrowed("cursor"), Cow::Owned(cursor)));
229 }
230
231 if let Some(limit) = args.pagination.limit {
232 params.push((Cow::Borrowed("limit"), Cow::Owned(limit.to_string())));
233 }
234
235 #[inline]
236 fn to_list(slice: &[&str]) -> String {
237 return slice.join(",");
238 }
239
240 if let Some(order) = args.order {
241 if !order.is_empty() {
242 params.push((Cow::Borrowed("order"), Cow::Owned(to_list(order))));
243 }
244 }
245
246 if let Some(expand) = args.expand {
247 if !expand.is_empty() {
248 params.push((Cow::Borrowed("expand"), Cow::Owned(to_list(expand))));
249 }
250 }
251
252 if args.count {
253 params.push((Cow::Borrowed("count"), Cow::Borrowed("true")));
254 }
255
256 if let Some(filters) = args.filters {
257 for filter in filters {
258 let Some((name_op, value)) = filter.split_once("=") else {
259 panic!("Filter '{filter}' does not match: 'name[op]=value'");
260 };
261
262 params.push((
263 Cow::Owned(name_op.to_string()),
264 Cow::Owned(value.to_string()),
265 ));
266 }
267 }
268
269 let response = self
270 .client
271 .fetch(
272 &format!("/{RECORD_API}/{}", self.name),
273 Method::GET,
274 None::<&()>,
275 Some(¶ms),
276 )
277 .await?;
278
279 return Ok(response.json().await?);
280 }
281
282 pub async fn read<'a, T: DeserializeOwned>(
283 &self,
284 args: impl ReadArgumentsTrait<'a>,
285 ) -> Result<T, Error> {
286 let expand = args
287 .expand()
288 .map(|e| vec![(Cow::Borrowed("expand"), Cow::Owned(e.join(",")))]);
289
290 let response = self
291 .client
292 .fetch(
293 &format!(
294 "/{RECORD_API}/{name}/{id}",
295 name = self.name,
296 id = args.serialized_id()
297 ),
298 Method::GET,
299 None::<&()>,
300 expand.as_deref(),
301 )
302 .await?;
303
304 return Ok(response.json().await?);
305 }
306
307 pub async fn create<T: Serialize>(&self, record: T) -> Result<String, Error> {
308 return Ok(self.create_impl(record).await?.swap_remove(0));
309 }
310
311 pub async fn create_bulk<T: Serialize>(&self, record: &[T]) -> Result<Vec<String>, Error> {
312 return self.create_impl(record).await;
313 }
314
315 async fn create_impl<T: Serialize>(&self, record: T) -> Result<Vec<String>, Error> {
316 let response = self
317 .client
318 .fetch(
319 &format!("/{RECORD_API}/{name}", name = self.name),
320 Method::POST,
321 Some(&record),
322 None,
323 )
324 .await?;
325
326 #[derive(Deserialize)]
327 pub struct RecordIdResponse {
328 pub ids: Vec<String>,
329 }
330
331 return Ok(response.json::<RecordIdResponse>().await?.ids);
332 }
333
334 pub async fn update<'a, T: Serialize>(
335 &self,
336 id: impl RecordId<'a>,
337 record: T,
338 ) -> Result<(), Error> {
339 self
340 .client
341 .fetch(
342 &format!(
343 "/{RECORD_API}/{name}/{id}",
344 name = self.name,
345 id = id.serialized_id()
346 ),
347 Method::PATCH,
348 Some(&record),
349 None,
350 )
351 .await?;
352
353 return Ok(());
354 }
355
356 pub async fn delete<'a>(&self, id: impl RecordId<'a>) -> Result<(), Error> {
357 self
358 .client
359 .fetch(
360 &format!(
361 "/{RECORD_API}/{name}/{id}",
362 name = self.name,
363 id = id.serialized_id()
364 ),
365 Method::DELETE,
366 None::<&()>,
367 None,
368 )
369 .await?;
370
371 return Ok(());
372 }
373
374 pub async fn subscribe<'a>(
375 &self,
376 id: impl RecordId<'a>,
377 ) -> Result<impl Stream<Item = DbEvent>, Error> {
378 let response = self
380 .client
381 .fetch(
382 &format!(
383 "/{RECORD_API}/{name}/subscribe/{id}",
384 name = self.name,
385 id = id.serialized_id()
386 ),
387 Method::GET,
388 None::<&()>,
389 None,
390 )
391 .await?;
392
393 return Ok(
394 response
395 .bytes_stream()
396 .eventsource()
397 .filter_map(|event_or| async {
398 if let Ok(event) = event_or {
399 if let Ok(db_event) = serde_json::from_str::<DbEvent>(&event.data) {
400 return Some(db_event);
401 }
402 }
403 return None;
404 }),
405 );
406 }
407}
408
409#[derive(Clone, Debug)]
410struct TokenState {
411 state: Option<(Tokens, JwtTokenClaims)>,
412 headers: HeaderMap,
413}
414
415impl TokenState {
416 fn build(tokens: Option<&Tokens>) -> TokenState {
417 let headers = build_headers(tokens);
418 return TokenState {
419 state: tokens.and_then(|tokens| {
420 let Ok(jwt_token) = decode_auth_token::<JwtTokenClaims>(&tokens.auth_token) else {
421 log::error!("Failed to decode auth token.");
422 return None;
423 };
424 return Some((tokens.clone(), jwt_token));
425 }),
426 headers,
427 };
428 }
429}
430
431struct ClientState {
432 client: ThinClient,
433 site: String,
434 tokens: RwLock<TokenState>,
435}
436
437impl ClientState {
438 #[inline]
439 async fn fetch<T: Serialize>(
440 &self,
441 path: &str,
442 method: Method,
443 body: Option<&T>,
444 query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
445 ) -> Result<reqwest::Response, Error> {
446 let (mut headers, refresh_token) = self.extract_headers_and_refresh_token_if_exp();
447 if let Some(refresh_token) = refresh_token {
448 let new_tokens = ClientState::refresh_tokens(&self.client, headers, refresh_token).await?;
449
450 headers = new_tokens.headers.clone();
451 *self.tokens.write() = new_tokens;
452 }
453
454 return Ok(
455 self
456 .client
457 .fetch(path, headers, method, body, query_params)
458 .await?
459 .error_for_status()?,
460 );
461 }
462
463 #[inline]
464 fn extract_headers_and_refresh_token_if_exp(&self) -> (HeaderMap, Option<String>) {
465 #[inline]
466 fn should_refresh(jwt: &JwtTokenClaims) -> bool {
467 return jwt.exp - 60 < now() as i64;
468 }
469
470 let tokens = self.tokens.read();
471 let headers = tokens.headers.clone();
472 return match tokens.state {
473 Some(ref state) if should_refresh(&state.1) => (headers, state.0.refresh_token.clone()),
474 _ => (headers, None),
475 };
476 }
477
478 fn extract_headers_refresh_token(&self) -> Result<(HeaderMap, String), Error> {
479 let tokens = self.tokens.read();
480 let Some(ref state) = tokens.state else {
481 return Err(Error::Precondition("Not logged int?"));
482 };
483
484 let Some(ref refresh_token) = state.0.refresh_token else {
485 return Err(Error::Precondition("Missing refresh token"));
486 };
487
488 return Ok((tokens.headers.clone(), refresh_token.clone()));
489 }
490
491 async fn refresh_tokens(
492 client: &ThinClient,
493 headers: HeaderMap,
494 refresh_token: String,
495 ) -> Result<TokenState, Error> {
496 #[derive(Serialize)]
497 struct RefreshRequest<'a> {
498 refresh_token: &'a str,
499 }
500
501 let response = client
502 .fetch(
503 &format!("/{AUTH_API}/refresh"),
504 headers,
505 Method::POST,
506 Some(&RefreshRequest {
507 refresh_token: &refresh_token,
508 }),
509 None,
510 )
511 .await?;
512
513 #[derive(Deserialize)]
514 struct RefreshResponse {
515 auth_token: String,
516 csrf_token: Option<String>,
517 }
518
519 let refresh_response: RefreshResponse = response.json().await?;
520 return Ok(TokenState::build(Some(&Tokens {
521 auth_token: refresh_response.auth_token,
522 refresh_token: Some(refresh_token),
523 csrf_token: refresh_response.csrf_token,
524 })));
525 }
526}
527
528#[derive(Clone)]
529pub struct Client {
530 state: Arc<ClientState>,
531}
532
533impl Client {
534 pub fn new(site: &str, tokens: Option<Tokens>) -> Result<Client, Error> {
535 return Ok(Client {
536 state: Arc::new(ClientState {
537 client: ThinClient {
538 client: reqwest::Client::new(),
539 url: url::Url::parse(site)?,
540 },
541 site: site.to_string(),
542 tokens: RwLock::new(TokenState::build(tokens.as_ref())),
543 }),
544 });
545 }
546
547 pub fn site(&self) -> String {
548 return self.state.site.clone();
549 }
550
551 pub fn tokens(&self) -> Option<Tokens> {
552 return self.state.tokens.read().state.as_ref().map(|x| x.0.clone());
553 }
554
555 pub fn user(&self) -> Option<User> {
556 if let Some(state) = &self.state.tokens.read().state {
557 return Some(User {
558 sub: state.1.sub.clone(),
559 email: state.1.email.clone(),
560 });
561 }
562 return None;
563 }
564
565 pub fn records(&self, api_name: &str) -> RecordApi {
566 return RecordApi {
567 client: self.state.clone(),
568 name: api_name.to_string(),
569 };
570 }
571
572 pub async fn refresh(&self) -> Result<(), Error> {
573 let (headers, refresh_token) = self.state.extract_headers_refresh_token()?;
574 let new_tokens =
575 ClientState::refresh_tokens(&self.state.client, headers, refresh_token).await?;
576
577 *self.state.tokens.write() = new_tokens;
578 return Ok(());
579 }
580
581 pub async fn login(&self, email: &str, password: &str) -> Result<Tokens, Error> {
582 #[derive(Serialize)]
583 struct Credentials<'a> {
584 email: &'a str,
585 password: &'a str,
586 }
587
588 let response = self
589 .state
590 .fetch(
591 &format!("/{AUTH_API}/login"),
592 Method::POST,
593 Some(&Credentials { email, password }),
594 None,
595 )
596 .await?;
597
598 let tokens: Tokens = response.json().await?;
599 self.update_tokens(Some(&tokens));
600 return Ok(tokens);
601 }
602
603 pub async fn logout(&self) -> Result<(), Error> {
604 #[derive(Serialize)]
605 struct LogoutRequest {
606 refresh_token: String,
607 }
608
609 let response_or = match self.state.extract_headers_refresh_token() {
610 Ok((_headers, refresh_token)) => {
611 self
612 .state
613 .fetch(
614 &format!("/{AUTH_API}/logout"),
615 Method::POST,
616 Some(&LogoutRequest { refresh_token }),
617 None,
618 )
619 .await
620 }
621 _ => {
622 self
623 .state
624 .fetch(
625 &format!("/{AUTH_API}/logout"),
626 Method::GET,
627 None::<&()>,
628 None,
629 )
630 .await
631 }
632 };
633
634 self.update_tokens(None);
635
636 return response_or.map(|_| ());
637 }
638
639 fn update_tokens(&self, tokens: Option<&Tokens>) -> TokenState {
640 let state = TokenState::build(tokens);
641
642 *self.state.tokens.write() = state.clone();
643 if let Some(ref s) = state.state {
646 let now = now();
647 if s.1.exp < now as i64 {
648 log::warn!("Token expired");
649 }
650 }
651
652 return state;
653 }
654}
655
656fn build_headers(tokens: Option<&Tokens>) -> HeaderMap {
657 let mut base = HeaderMap::new();
658 base.insert("Content-Type", HeaderValue::from_static("application/json"));
659
660 if let Some(tokens) = tokens {
661 if let Ok(value) = HeaderValue::from_str(&format!("Bearer {}", tokens.auth_token)) {
662 base.insert("Authorization", value);
663 } else {
664 log::error!("Failed to build bearer token.");
665 }
666
667 if let Some(ref refresh) = tokens.refresh_token {
668 if let Ok(value) = HeaderValue::from_str(refresh) {
669 base.insert("Refresh-Token", value);
670 } else {
671 log::error!("Failed to build refresh token header.");
672 }
673 }
674
675 if let Some(ref csrf) = tokens.csrf_token {
676 if let Ok(value) = HeaderValue::from_str(csrf) {
677 base.insert("CSRF-Token", value);
678 } else {
679 log::error!("Failed to build refresh token header.");
680 }
681 }
682 }
683
684 return base;
685}
686
687fn now() -> u64 {
688 return std::time::SystemTime::now()
689 .duration_since(std::time::UNIX_EPOCH)
690 .expect("Duration since epoch")
691 .as_secs();
692}
693
694const AUTH_API: &str = "api/auth/v1";
695const RECORD_API: &str = "api/records/v1";
696
697#[cfg(test)]
698mod tests {
699 use super::*;
700
701 #[tokio::test]
702 async fn is_send_test() {
703 let client = Client::new("http://127.0.0.1:4000", None).unwrap();
704
705 let api = client.records("simple_strict_table");
706
707 for _ in 0..2 {
708 let api = api.clone();
709 tokio::spawn(async move {
710 let response = api.read::<serde_json::Value>(0).await;
712 assert!(response.is_err());
713 })
714 .await
715 .unwrap();
716 }
717 }
718}