1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
mod types;
use reqwest::{header, Response};
use serde::Serialize;
use std::time::Duration;
use types::ApiResult;
pub use types::{Error, NotFoundMapping, Result};

macro_rules! execute {
    ($send: expr) => {
        $send
            .send()
            .await
            .and_then(crate::Response::error_for_status)?
            .json::<crate::ApiResult<_>>()
            .await?
            .into()
    };
}

#[cfg(feature = "cache")]
mod cache;

/// KV Client.
#[derive(Debug)]
pub struct Client {
    endpoint: String,
    client: reqwest::Client,

    #[cfg(feature = "cache")]
    cache: cache::Cache,
}

/// Error when creating KV Client.
#[derive(thiserror::Error, Debug)]
pub enum ClientError {
    #[error("token format error {0}")]
    Token(#[from] reqwest::header::InvalidHeaderValue),
    #[error("client build error {0}")]
    Client(#[from] reqwest::Error),
}

impl Client {
    /// Create client with endpoint and token.
    /// If cache is enabled, you may set cache size and ttl.
    pub fn new<T: Into<String>, E: Into<String>>(
        endpoint: E,
        token: T,
        #[cfg(feature = "cache")] cache_size: usize,
        #[cfg(feature = "cache")] expire_ttl: std::time::Duration,
    ) -> std::result::Result<Self, ClientError> {
        // normalize endpoint
        let mut endpoint: String = endpoint.into();
        if !endpoint.ends_with('/') {
            endpoint.push('/');
        }
        let token = token.into();
        let mut headers = header::HeaderMap::new();
        headers.insert("Authorization", header::HeaderValue::from_str(&token)?);
        Ok(Self {
            endpoint,
            client: reqwest::Client::builder()
                .default_headers(headers)
                .build()?,
            #[cfg(feature = "cache")]
            cache: cache::new_cache(cache_size, expire_ttl.into()),
        })
    }

    /// Get value of a key.
    #[cfg(not(feature = "cache"))]
    pub async fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Result<T> {
        execute!(self.client.get(format!("{}{}", self.endpoint, key)))
    }

    /// Set a key value pair.
    pub async fn put<T: Serialize + ?Sized>(&self, key: &str, value: &T) -> Result<()> {
        let r: Result<()> = execute!(self
            .client
            .put(format!("{}{}", self.endpoint, key))
            .json(value));
        #[cfg(feature = "cache")]
        if r.is_ok() {
            self.set_cache(key, value);
        }
        r
    }

    /// Set a key value pair with ttl.
    pub async fn put_with_ttl<T: Serialize + ?Sized>(
        &self,
        key: &str,
        value: &T,
        ttl: Duration,
    ) -> Result<()> {
        let r: Result<()> = execute!(self
            .client
            .put(format!("{}{}", self.endpoint, key))
            .header("ttl", ttl.as_secs())
            .json(value));
        #[cfg(feature = "cache")]
        if r.is_ok() {
            self.set_cache(key, value);
        }
        r
    }

    /// Delete a key value pair.
    pub async fn delete(&self, key: &str) -> Result<()> {
        let r: Result<()> = execute!(self.client.delete(format!("{}{}", self.endpoint, key)));
        #[cfg(feature = "cache")]
        if r.is_ok() {
            self.prune_cached(key);
        }
        r
    }
}