1use std::collections::BTreeMap;
2use std::convert::TryInto;
3use std::time::{Duration, SystemTime};
4
5use base64::prelude::*;
6use log::{debug, error};
7use percent_encoding::{utf8_percent_encode, AsciiSet, NON_ALPHANUMERIC};
8
9use http::header::{ACCEPT, CONTENT_TYPE};
10use http::status::StatusCode;
11use http::{HeaderName, HeaderValue, Request};
12use hyper::{body::Bytes, Body};
13use hyper::{client::connect::HttpConnector, Client as HttpClient};
14use hyper_rustls::HttpsConnector;
15
16use aws_sigv4::http_request::{sign, SignableRequest, SigningParams, SigningSettings};
17
18use serde::de::Error as DeError;
19use serde::{Deserialize, Deserializer, Serialize, Serializer};
20
21mod error;
22
23pub use error::Error;
24
25const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
26const DEFAULT_POLL_TIMEOUT: Duration = Duration::from_secs(300);
27const SERVICE: &str = "k2v";
28const AMZ_CONTENT_SHA256: HeaderName = HeaderName::from_static("x-amz-content-sha256");
29const GARAGE_CAUSALITY_TOKEN: HeaderName = HeaderName::from_static("x-garage-causality-token");
30
31const STRICT_ENCODE_SET: AsciiSet = NON_ALPHANUMERIC
32 .remove(b'_')
33 .remove(b'-')
34 .remove(b'.')
35 .remove(b'~');
36const PATH_ENCODE_SET: AsciiSet = NON_ALPHANUMERIC
37 .remove(b'/')
38 .remove(b'_')
39 .remove(b'-')
40 .remove(b'.')
41 .remove(b'~');
42
43pub struct K2vClientConfig {
44 pub endpoint: String,
45 pub region: String,
46 pub aws_access_key_id: String,
47 pub aws_secret_access_key: String,
48 pub bucket: String,
49 pub user_agent: Option<String>,
50}
51
52pub struct K2vClient {
54 config: K2vClientConfig,
55 user_agent: HeaderValue,
56 client: HttpClient<HttpsConnector<HttpConnector>>,
57}
58
59impl K2vClient {
60 pub fn new(config: K2vClientConfig) -> Result<Self, Error> {
62 let connector = hyper_rustls::HttpsConnectorBuilder::new()
63 .with_native_roots()
64 .https_or_http()
65 .enable_http1()
66 .enable_http2()
67 .build();
68 let client = HttpClient::builder().build(connector);
69 let user_agent: std::borrow::Cow<str> = match &config.user_agent {
70 Some(ua) => ua.into(),
71 None => format!("k2v/{}", env!("CARGO_PKG_VERSION")).into(),
72 };
73 let user_agent = HeaderValue::from_str(&user_agent)
74 .map_err(|_| Error::Message("invalid user agent".into()))?;
75 Ok(K2vClient {
76 config,
77 client,
78 user_agent,
79 })
80 }
81
82 pub async fn read_item(
84 &self,
85 partition_key: &str,
86 sort_key: &str,
87 ) -> Result<CausalValue, Error> {
88 let url = self.build_url(Some(partition_key), &[("sort_key", sort_key)]);
89 let req = Request::get(url)
90 .header(ACCEPT, "application/octet-stream, application/json")
91 .body(Bytes::new())?;
92 let res = self.dispatch(req, None).await?;
93
94 let causality = res
95 .causality_token
96 .ok_or_else(|| Error::InvalidResponse("missing causality token".into()))?;
97
98 if res.status == StatusCode::NO_CONTENT {
99 return Ok(CausalValue {
100 causality,
101 value: vec![K2vValue::Tombstone],
102 });
103 }
104
105 match res.content_type.as_deref() {
106 Some("application/octet-stream") => Ok(CausalValue {
107 causality,
108 value: vec![K2vValue::Value(res.body.to_vec())],
109 }),
110 Some("application/json") => {
111 let value = serde_json::from_slice(&res.body)?;
112 Ok(CausalValue { causality, value })
113 }
114 Some(ct) => Err(Error::InvalidResponse(
115 format!("invalid content type: {}", ct).into(),
116 )),
117 None => Err(Error::InvalidResponse("missing content type".into())),
118 }
119 }
120
121 pub async fn poll_item(
124 &self,
125 partition_key: &str,
126 sort_key: &str,
127 causality: CausalityToken,
128 timeout: Option<Duration>,
129 ) -> Result<Option<CausalValue>, Error> {
130 let timeout = timeout.unwrap_or(DEFAULT_POLL_TIMEOUT);
131
132 let url = self.build_url(
133 Some(partition_key),
134 &[
135 ("sort_key", sort_key),
136 ("causality_token", &causality.0),
137 ("timeout", &timeout.as_secs().to_string()),
138 ],
139 );
140 let req = Request::get(url)
141 .header(ACCEPT, "application/octet-stream, application/json")
142 .body(Bytes::new())?;
143
144 let res = self.dispatch(req, Some(timeout + DEFAULT_TIMEOUT)).await?;
145
146 if res.status == StatusCode::NOT_MODIFIED {
147 return Ok(None);
148 }
149
150 let causality = res
151 .causality_token
152 .ok_or_else(|| Error::InvalidResponse("missing causality token".into()))?;
153
154 if res.status == StatusCode::NO_CONTENT {
155 return Ok(Some(CausalValue {
156 causality,
157 value: vec![K2vValue::Tombstone],
158 }));
159 }
160
161 match res.content_type.as_deref() {
162 Some("application/octet-stream") => Ok(Some(CausalValue {
163 causality,
164 value: vec![K2vValue::Value(res.body.to_vec())],
165 })),
166 Some("application/json") => {
167 let value = serde_json::from_slice(&res.body)?;
168 Ok(Some(CausalValue { causality, value }))
169 }
170 Some(ct) => Err(Error::InvalidResponse(
171 format!("invalid content type: {}", ct).into(),
172 )),
173 None => Err(Error::InvalidResponse("missing content type".into())),
174 }
175 }
176
177 pub async fn poll_range(
180 &self,
181 partition_key: &str,
182 filter: Option<PollRangeFilter<'_>>,
183 seen_marker: Option<&str>,
184 timeout: Option<Duration>,
185 ) -> Result<Option<(BTreeMap<String, CausalValue>, String)>, Error> {
186 let timeout = timeout.unwrap_or(DEFAULT_POLL_TIMEOUT);
187
188 let request = PollRangeRequest {
189 filter: filter.unwrap_or_default(),
190 seen_marker,
191 timeout: timeout.as_secs(),
192 };
193
194 let url = self.build_url(Some(partition_key), &[("poll_range", "")]);
195 let payload = serde_json::to_vec(&request)?;
196 let req = Request::post(url).body(Bytes::from(payload))?;
197
198 let res = self.dispatch(req, Some(timeout + DEFAULT_TIMEOUT)).await?;
199
200 if res.status == StatusCode::NOT_MODIFIED {
201 return Ok(None);
202 }
203
204 let resp: PollRangeResponse = serde_json::from_slice(&res.body)?;
205
206 let items = resp
207 .items
208 .into_iter()
209 .map(|BatchReadItem { sk, ct, v }| {
210 (
211 sk,
212 CausalValue {
213 causality: ct,
214 value: v,
215 },
216 )
217 })
218 .collect::<BTreeMap<_, _>>();
219
220 Ok(Some((items, resp.seen_marker)))
221 }
222
223 pub async fn insert_item(
225 &self,
226 partition_key: &str,
227 sort_key: &str,
228 value: Vec<u8>,
229 causality: Option<CausalityToken>,
230 ) -> Result<(), Error> {
231 let url = self.build_url(Some(partition_key), &[("sort_key", sort_key)]);
232 let mut req = Request::put(url);
233 if let Some(causality) = causality {
234 req = req.header(GARAGE_CAUSALITY_TOKEN, &causality.0);
235 }
236 let req = req.body(Bytes::from(value))?;
237
238 self.dispatch(req, None).await?;
239 Ok(())
240 }
241
242 pub async fn delete_item(
244 &self,
245 partition_key: &str,
246 sort_key: &str,
247 causality: CausalityToken,
248 ) -> Result<(), Error> {
249 let url = self.build_url(Some(partition_key), &[("sort_key", sort_key)]);
250 let req = Request::delete(url)
251 .header(GARAGE_CAUSALITY_TOKEN, &causality.0)
252 .body(Bytes::new())?;
253
254 self.dispatch(req, None).await?;
255 Ok(())
256 }
257
258 pub async fn read_index(
261 &self,
262 filter: Filter<'_>,
263 ) -> Result<PaginatedRange<PartitionInfo>, Error> {
264 let params = filter.query_params();
265 let url = self.build_url(None, ¶ms);
266 let req = Request::get(url).body(Bytes::new())?;
267
268 let res = self.dispatch(req, None).await?;
269
270 let resp: ReadIndexResponse = serde_json::from_slice(&res.body)?;
271
272 let items = resp
273 .partition_keys
274 .into_iter()
275 .map(|ReadIndexItem { pk, info }| (pk, info))
276 .collect();
277
278 Ok(PaginatedRange {
279 items,
280 next_start: resp.next_start,
281 })
282 }
283
284 pub async fn insert_batch(&self, operations: &[BatchInsertOp<'_>]) -> Result<(), Error> {
288 let url = self.build_url::<&str>(None, &[]);
289 let payload = serde_json::to_vec(operations)?;
290 let req = Request::post(url).body(payload.into())?;
291
292 self.dispatch(req, None).await?;
293 Ok(())
294 }
295
296 pub async fn read_batch(
298 &self,
299 operations: &[BatchReadOp<'_>],
300 ) -> Result<Vec<PaginatedRange<CausalValue>>, Error> {
301 let url = self.build_url(None, &[("search", "")]);
302 let payload = serde_json::to_vec(operations)?;
303 let req = Request::post(url).body(payload.into())?;
304
305 let res = self.dispatch(req, None).await?;
306
307 let resp: Vec<BatchReadResponse> = serde_json::from_slice(&res.body)?;
308
309 Ok(resp
310 .into_iter()
311 .map(|e| PaginatedRange {
312 items: e
313 .items
314 .into_iter()
315 .map(|BatchReadItem { sk, ct, v }| {
316 (
317 sk,
318 CausalValue {
319 causality: ct,
320 value: v,
321 },
322 )
323 })
324 .collect(),
325 next_start: e.next_start,
326 })
327 .collect())
328 }
329
330 pub async fn delete_batch(&self, operations: &[BatchDeleteOp<'_>]) -> Result<Vec<u64>, Error> {
333 let url = self.build_url(None, &[("delete", "")]);
334 let payload = serde_json::to_vec(operations)?;
335 let req = Request::post(url).body(payload.into())?;
336
337 let res = self.dispatch(req, None).await?;
338
339 let resp: Vec<BatchDeleteResponse> = serde_json::from_slice(&res.body)?;
340
341 Ok(resp.into_iter().map(|r| r.deleted_items).collect())
342 }
343
344 async fn dispatch(
345 &self,
346 mut req: Request<Bytes>,
347 timeout: Option<Duration>,
348 ) -> Result<Response, Error> {
349 req.headers_mut()
350 .insert(http::header::USER_AGENT, self.user_agent.clone());
351
352 use sha2::{Digest, Sha256};
353 let mut hasher = Sha256::new();
354 hasher.update(req.body());
355 let hash = hex::encode(&hasher.finalize());
356 req.headers_mut()
357 .insert(AMZ_CONTENT_SHA256, hash.try_into().unwrap());
358
359 debug!("request uri: {:?}", req.uri());
360
361 let signing_settings = SigningSettings::default();
363 let signing_params = SigningParams::builder()
364 .access_key(&self.config.aws_access_key_id)
365 .secret_key(&self.config.aws_secret_access_key)
366 .region(&self.config.region)
367 .service_name(SERVICE)
368 .time(SystemTime::now())
369 .settings(signing_settings)
370 .build()?;
371 let signable_request = SignableRequest::from(&req);
373
374 let (signing_instructions, _signature) =
376 sign(signable_request, &signing_params)?.into_parts();
377 signing_instructions.apply_to_request(&mut req);
378
379 let res = tokio::select! {
381 res = self.client.request(req.map(Body::from)) => res?,
382 _ = tokio::time::sleep(timeout.unwrap_or(DEFAULT_TIMEOUT)) => {
383 return Err(Error::Timeout);
384 }
385 };
386
387 let (mut res, body) = res.into_parts();
388 let causality_token = match res.headers.remove(GARAGE_CAUSALITY_TOKEN) {
389 Some(v) => Some(CausalityToken(v.to_str()?.to_string())),
390 None => None,
391 };
392 let content_type = match res.headers.remove(CONTENT_TYPE) {
393 Some(v) => Some(v.to_str()?.to_string()),
394 None => None,
395 };
396
397 let body = match res.status {
398 StatusCode::OK => hyper::body::to_bytes(body).await?,
399 StatusCode::NO_CONTENT => Bytes::new(),
400 StatusCode::NOT_FOUND => return Err(Error::NotFound),
401 StatusCode::NOT_MODIFIED => Bytes::new(),
402 s => {
403 let err_body = hyper::body::to_bytes(body).await.unwrap_or_default();
404 let err_body_str = std::str::from_utf8(&err_body)
405 .map(String::from)
406 .unwrap_or_else(|_| BASE64_STANDARD.encode(&err_body));
407
408 if s.is_client_error() || s.is_server_error() {
409 error!("Error response {}: {}", res.status, err_body_str);
410 let err = match serde_json::from_slice::<ErrorResponse>(&err_body) {
411 Ok(err) => Error::Remote(
412 res.status,
413 err.code.into(),
414 err.message.into(),
415 err.path.into(),
416 ),
417 Err(_) => Error::Remote(
418 res.status,
419 "unknown".into(),
420 err_body_str.into(),
421 "?".into(),
422 ),
423 };
424 return Err(err);
425 } else {
426 let msg = format!(
427 "Unexpected response code {}. Response body: {}",
428 res.status, err_body_str
429 );
430 error!("{}", msg);
431 return Err(Error::InvalidResponse(msg.into()));
432 }
433 }
434 };
435 debug!(
436 "Response body: {}",
437 std::str::from_utf8(&body)
438 .map(String::from)
439 .unwrap_or_else(|_| BASE64_STANDARD.encode(&body))
440 );
441
442 Ok(Response {
443 body,
444 status: res.status,
445 causality_token,
446 content_type,
447 })
448 }
449
450 fn build_url<V: AsRef<str>>(&self, partition_key: Option<&str>, query: &[(&str, V)]) -> String {
451 let mut url = format!("{}/{}", self.config.endpoint, self.config.bucket);
452 if let Some(pk) = partition_key {
453 url.push('/');
454 url.extend(utf8_percent_encode(pk, &PATH_ENCODE_SET));
455 }
456 if !query.is_empty() {
457 url.push('?');
458 for (i, (k, v)) in query.iter().enumerate() {
459 if i > 0 {
460 url.push('&');
461 }
462 url.extend(utf8_percent_encode(k, &STRICT_ENCODE_SET));
463 url.push('=');
464 url.extend(utf8_percent_encode(v.as_ref(), &STRICT_ENCODE_SET));
465 }
466 }
467 url
468 }
469}
470
471#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
473#[serde(transparent)]
474pub struct CausalityToken(String);
475
476impl From<String> for CausalityToken {
477 fn from(v: String) -> Self {
478 CausalityToken(v)
479 }
480}
481
482impl From<CausalityToken> for String {
483 fn from(v: CausalityToken) -> Self {
484 v.0
485 }
486}
487
488impl AsRef<str> for CausalityToken {
489 fn as_ref(&self) -> &str {
490 &self.0
491 }
492}
493
494#[derive(Debug, Clone, PartialEq, Eq)]
496pub enum K2vValue {
497 Tombstone,
498 Value(Vec<u8>),
499}
500
501impl From<Vec<u8>> for K2vValue {
502 fn from(v: Vec<u8>) -> Self {
503 K2vValue::Value(v)
504 }
505}
506
507impl From<Option<Vec<u8>>> for K2vValue {
508 fn from(v: Option<Vec<u8>>) -> Self {
509 match v {
510 Some(v) => K2vValue::Value(v),
511 None => K2vValue::Tombstone,
512 }
513 }
514}
515
516impl<'de> Deserialize<'de> for K2vValue {
517 fn deserialize<D>(d: D) -> Result<Self, D::Error>
518 where
519 D: Deserializer<'de>,
520 {
521 let val: Option<&str> = Option::deserialize(d)?;
522 Ok(match val {
523 Some(s) => K2vValue::Value(
524 BASE64_STANDARD
525 .decode(s)
526 .map_err(|_| DeError::custom("invalid base64"))?,
527 ),
528 None => K2vValue::Tombstone,
529 })
530 }
531}
532
533impl Serialize for K2vValue {
534 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
535 where
536 S: Serializer,
537 {
538 match self {
539 K2vValue::Tombstone => serializer.serialize_none(),
540 K2vValue::Value(v) => {
541 let b64 = BASE64_STANDARD.encode(v);
542 serializer.serialize_str(&b64)
543 }
544 }
545 }
546}
547
548#[derive(Debug, Clone, Serialize)]
550pub struct CausalValue {
551 pub causality: CausalityToken,
552 pub value: Vec<K2vValue>,
553}
554
555#[derive(Debug, Clone)]
557pub struct PaginatedRange<V> {
558 pub items: BTreeMap<String, V>,
559 pub next_start: Option<String>,
560}
561
562#[derive(Debug, Default, Clone, Deserialize, Serialize)]
564pub struct Filter<'a> {
565 pub start: Option<&'a str>,
566 pub end: Option<&'a str>,
567 pub prefix: Option<&'a str>,
568 pub limit: Option<u64>,
569 #[serde(default)]
570 pub reverse: bool,
571}
572
573#[derive(Debug, Default, Clone, Serialize)]
574pub struct PollRangeFilter<'a> {
575 pub start: Option<&'a str>,
576 pub end: Option<&'a str>,
577 pub prefix: Option<&'a str>,
578}
579
580#[derive(Debug, Clone, Serialize)]
581#[serde(rename_all = "camelCase")]
582struct PollRangeRequest<'a> {
583 #[serde(flatten)]
584 filter: PollRangeFilter<'a>,
585 seen_marker: Option<&'a str>,
586 timeout: u64,
587}
588
589#[derive(Debug, Clone, Deserialize)]
590#[serde(rename_all = "camelCase")]
591struct PollRangeResponse {
592 items: Vec<BatchReadItem>,
593 seen_marker: String,
594}
595
596impl<'a> Filter<'a> {
597 fn query_params(&self) -> Vec<(&'static str, std::borrow::Cow<str>)> {
598 let mut res = Vec::<(&'static str, std::borrow::Cow<str>)>::with_capacity(8);
599 if let Some(start) = self.start.as_deref() {
600 res.push(("start", start.into()));
601 }
602 if let Some(end) = self.end.as_deref() {
603 res.push(("end", end.into()));
604 }
605 if let Some(prefix) = self.prefix.as_deref() {
606 res.push(("prefix", prefix.into()));
607 }
608 if let Some(limit) = &self.limit {
609 res.push(("limit", limit.to_string().into()));
610 }
611 if self.reverse {
612 res.push(("reverse", "true".into()));
613 }
614 res
615 }
616}
617
618#[derive(Debug, Clone, Deserialize)]
619#[serde(rename_all = "camelCase")]
620struct ReadIndexResponse<'a> {
621 #[serde(flatten, borrow)]
622 #[allow(dead_code)]
623 filter: Filter<'a>,
624 partition_keys: Vec<ReadIndexItem>,
625 #[allow(dead_code)]
626 more: bool,
627 next_start: Option<String>,
628}
629
630#[derive(Debug, Clone, Deserialize)]
631struct ReadIndexItem {
632 pk: String,
633 #[serde(flatten)]
634 info: PartitionInfo,
635}
636
637#[derive(Debug, Clone, Deserialize, Serialize)]
639pub struct PartitionInfo {
640 pub entries: u64,
641 pub conflicts: u64,
642 pub values: u64,
643 pub bytes: u64,
644}
645
646#[derive(Debug, Clone, Serialize)]
648pub struct BatchInsertOp<'a> {
649 #[serde(rename = "pk")]
650 pub partition_key: &'a str,
651 #[serde(rename = "sk")]
652 pub sort_key: &'a str,
653 #[serde(rename = "ct")]
654 pub causality: Option<CausalityToken>,
655 #[serde(rename = "v")]
656 pub value: K2vValue,
657}
658
659#[derive(Debug, Default, Clone, Deserialize, Serialize)]
661#[serde(rename_all = "camelCase")]
662pub struct BatchReadOp<'a> {
663 pub partition_key: &'a str,
664 #[serde(flatten, borrow)]
665 pub filter: Filter<'a>,
666 #[serde(default)]
667 pub single_item: bool,
668 #[serde(default)]
669 pub conflicts_only: bool,
670 #[serde(default)]
671 pub tombstones: bool,
672}
673
674#[derive(Debug, Clone, Deserialize)]
675#[serde(rename_all = "camelCase")]
676struct BatchReadResponse<'a> {
677 #[serde(flatten, borrow)]
678 #[allow(dead_code)]
679 op: BatchReadOp<'a>,
680 items: Vec<BatchReadItem>,
681 #[allow(dead_code)]
682 more: bool,
683 next_start: Option<String>,
684}
685
686#[derive(Debug, Clone, Deserialize)]
687struct BatchReadItem {
688 sk: String,
689 ct: CausalityToken,
690 v: Vec<K2vValue>,
691}
692
693#[derive(Debug, Clone, Deserialize, Serialize)]
695#[serde(rename_all = "camelCase")]
696pub struct BatchDeleteOp<'a> {
697 pub partition_key: &'a str,
698 pub prefix: Option<&'a str>,
699 pub start: Option<&'a str>,
700 pub end: Option<&'a str>,
701 #[serde(default)]
702 pub single_item: bool,
703}
704
705impl<'a> BatchDeleteOp<'a> {
706 pub fn new(partition_key: &'a str) -> Self {
707 BatchDeleteOp {
708 partition_key,
709 prefix: None,
710 start: None,
711 end: None,
712 single_item: false,
713 }
714 }
715}
716
717#[derive(Debug, Clone, Deserialize)]
718#[serde(rename_all = "camelCase")]
719struct BatchDeleteResponse<'a> {
720 #[serde(flatten, borrow)]
721 #[allow(dead_code)]
722 filter: BatchDeleteOp<'a>,
723 deleted_items: u64,
724}
725
726#[derive(Deserialize)]
727struct ErrorResponse {
728 code: String,
729 message: String,
730 #[allow(dead_code)]
731 region: String,
732 path: String,
733}
734
735struct Response {
736 body: Bytes,
737 status: StatusCode,
738 causality_token: Option<CausalityToken>,
739 content_type: Option<String>,
740}