k2v_client/
lib.rs

1use std::collections::BTreeMap;
2use std::convert::TryInto;
3use std::time::{Duration, SystemTime};
4
5use base64::prelude::*;
6use log::{debug, error};
7use percent_encoding::{utf8_percent_encode, AsciiSet, NON_ALPHANUMERIC};
8
9use http::header::{ACCEPT, CONTENT_TYPE};
10use http::status::StatusCode;
11use http::{HeaderName, HeaderValue, Request};
12use hyper::{body::Bytes, Body};
13use hyper::{client::connect::HttpConnector, Client as HttpClient};
14use hyper_rustls::HttpsConnector;
15
16use aws_sigv4::http_request::{sign, SignableRequest, SigningParams, SigningSettings};
17
18use serde::de::Error as DeError;
19use serde::{Deserialize, Deserializer, Serialize, Serializer};
20
21mod error;
22
23pub use error::Error;
24
25const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
26const DEFAULT_POLL_TIMEOUT: Duration = Duration::from_secs(300);
27const SERVICE: &str = "k2v";
28const AMZ_CONTENT_SHA256: HeaderName = HeaderName::from_static("x-amz-content-sha256");
29const GARAGE_CAUSALITY_TOKEN: HeaderName = HeaderName::from_static("x-garage-causality-token");
30
31const STRICT_ENCODE_SET: AsciiSet = NON_ALPHANUMERIC
32	.remove(b'_')
33	.remove(b'-')
34	.remove(b'.')
35	.remove(b'~');
36const PATH_ENCODE_SET: AsciiSet = NON_ALPHANUMERIC
37	.remove(b'/')
38	.remove(b'_')
39	.remove(b'-')
40	.remove(b'.')
41	.remove(b'~');
42
43pub struct K2vClientConfig {
44	pub endpoint: String,
45	pub region: String,
46	pub aws_access_key_id: String,
47	pub aws_secret_access_key: String,
48	pub bucket: String,
49	pub user_agent: Option<String>,
50}
51
52/// Client used to query a K2V server.
53pub struct K2vClient {
54	config: K2vClientConfig,
55	user_agent: HeaderValue,
56	client: HttpClient<HttpsConnector<HttpConnector>>,
57}
58
59impl K2vClient {
60	/// Create a new K2V client.
61	pub fn new(config: K2vClientConfig) -> Result<Self, Error> {
62		let connector = hyper_rustls::HttpsConnectorBuilder::new()
63			.with_native_roots()
64			.https_or_http()
65			.enable_http1()
66			.enable_http2()
67			.build();
68		let client = HttpClient::builder().build(connector);
69		let user_agent: std::borrow::Cow<str> = match &config.user_agent {
70			Some(ua) => ua.into(),
71			None => format!("k2v/{}", env!("CARGO_PKG_VERSION")).into(),
72		};
73		let user_agent = HeaderValue::from_str(&user_agent)
74			.map_err(|_| Error::Message("invalid user agent".into()))?;
75		Ok(K2vClient {
76			config,
77			client,
78			user_agent,
79		})
80	}
81
82	/// Perform a ReadItem request, reading the value(s) stored for a single pk+sk.
83	pub async fn read_item(
84		&self,
85		partition_key: &str,
86		sort_key: &str,
87	) -> Result<CausalValue, Error> {
88		let url = self.build_url(Some(partition_key), &[("sort_key", sort_key)]);
89		let req = Request::get(url)
90			.header(ACCEPT, "application/octet-stream, application/json")
91			.body(Bytes::new())?;
92		let res = self.dispatch(req, None).await?;
93
94		let causality = res
95			.causality_token
96			.ok_or_else(|| Error::InvalidResponse("missing causality token".into()))?;
97
98		if res.status == StatusCode::NO_CONTENT {
99			return Ok(CausalValue {
100				causality,
101				value: vec![K2vValue::Tombstone],
102			});
103		}
104
105		match res.content_type.as_deref() {
106			Some("application/octet-stream") => Ok(CausalValue {
107				causality,
108				value: vec![K2vValue::Value(res.body.to_vec())],
109			}),
110			Some("application/json") => {
111				let value = serde_json::from_slice(&res.body)?;
112				Ok(CausalValue { causality, value })
113			}
114			Some(ct) => Err(Error::InvalidResponse(
115				format!("invalid content type: {}", ct).into(),
116			)),
117			None => Err(Error::InvalidResponse("missing content type".into())),
118		}
119	}
120
121	/// Perform a PollItem request, waiting for the value(s) stored for a single pk+sk to be
122	/// updated.
123	pub async fn poll_item(
124		&self,
125		partition_key: &str,
126		sort_key: &str,
127		causality: CausalityToken,
128		timeout: Option<Duration>,
129	) -> Result<Option<CausalValue>, Error> {
130		let timeout = timeout.unwrap_or(DEFAULT_POLL_TIMEOUT);
131
132		let url = self.build_url(
133			Some(partition_key),
134			&[
135				("sort_key", sort_key),
136				("causality_token", &causality.0),
137				("timeout", &timeout.as_secs().to_string()),
138			],
139		);
140		let req = Request::get(url)
141			.header(ACCEPT, "application/octet-stream, application/json")
142			.body(Bytes::new())?;
143
144		let res = self.dispatch(req, Some(timeout + DEFAULT_TIMEOUT)).await?;
145
146		if res.status == StatusCode::NOT_MODIFIED {
147			return Ok(None);
148		}
149
150		let causality = res
151			.causality_token
152			.ok_or_else(|| Error::InvalidResponse("missing causality token".into()))?;
153
154		if res.status == StatusCode::NO_CONTENT {
155			return Ok(Some(CausalValue {
156				causality,
157				value: vec![K2vValue::Tombstone],
158			}));
159		}
160
161		match res.content_type.as_deref() {
162			Some("application/octet-stream") => Ok(Some(CausalValue {
163				causality,
164				value: vec![K2vValue::Value(res.body.to_vec())],
165			})),
166			Some("application/json") => {
167				let value = serde_json::from_slice(&res.body)?;
168				Ok(Some(CausalValue { causality, value }))
169			}
170			Some(ct) => Err(Error::InvalidResponse(
171				format!("invalid content type: {}", ct).into(),
172			)),
173			None => Err(Error::InvalidResponse("missing content type".into())),
174		}
175	}
176
177	/// Perform a PollRange request, waiting for any change in a given range of keys
178	/// to occur
179	pub async fn poll_range(
180		&self,
181		partition_key: &str,
182		filter: Option<PollRangeFilter<'_>>,
183		seen_marker: Option<&str>,
184		timeout: Option<Duration>,
185	) -> Result<Option<(BTreeMap<String, CausalValue>, String)>, Error> {
186		let timeout = timeout.unwrap_or(DEFAULT_POLL_TIMEOUT);
187
188		let request = PollRangeRequest {
189			filter: filter.unwrap_or_default(),
190			seen_marker,
191			timeout: timeout.as_secs(),
192		};
193
194		let url = self.build_url(Some(partition_key), &[("poll_range", "")]);
195		let payload = serde_json::to_vec(&request)?;
196		let req = Request::post(url).body(Bytes::from(payload))?;
197
198		let res = self.dispatch(req, Some(timeout + DEFAULT_TIMEOUT)).await?;
199
200		if res.status == StatusCode::NOT_MODIFIED {
201			return Ok(None);
202		}
203
204		let resp: PollRangeResponse = serde_json::from_slice(&res.body)?;
205
206		let items = resp
207			.items
208			.into_iter()
209			.map(|BatchReadItem { sk, ct, v }| {
210				(
211					sk,
212					CausalValue {
213						causality: ct,
214						value: v,
215					},
216				)
217			})
218			.collect::<BTreeMap<_, _>>();
219
220		Ok(Some((items, resp.seen_marker)))
221	}
222
223	/// Perform an InsertItem request, inserting a value for a single pk+sk.
224	pub async fn insert_item(
225		&self,
226		partition_key: &str,
227		sort_key: &str,
228		value: Vec<u8>,
229		causality: Option<CausalityToken>,
230	) -> Result<(), Error> {
231		let url = self.build_url(Some(partition_key), &[("sort_key", sort_key)]);
232		let mut req = Request::put(url);
233		if let Some(causality) = causality {
234			req = req.header(GARAGE_CAUSALITY_TOKEN, &causality.0);
235		}
236		let req = req.body(Bytes::from(value))?;
237
238		self.dispatch(req, None).await?;
239		Ok(())
240	}
241
242	/// Perform a DeleteItem request, deleting the value(s) stored for a single pk+sk.
243	pub async fn delete_item(
244		&self,
245		partition_key: &str,
246		sort_key: &str,
247		causality: CausalityToken,
248	) -> Result<(), Error> {
249		let url = self.build_url(Some(partition_key), &[("sort_key", sort_key)]);
250		let req = Request::delete(url)
251			.header(GARAGE_CAUSALITY_TOKEN, &causality.0)
252			.body(Bytes::new())?;
253
254		self.dispatch(req, None).await?;
255		Ok(())
256	}
257
258	/// Perform a ReadIndex request, listing partition key which have at least one associated
259	/// sort key, and which matches the filter.
260	pub async fn read_index(
261		&self,
262		filter: Filter<'_>,
263	) -> Result<PaginatedRange<PartitionInfo>, Error> {
264		let params = filter.query_params();
265		let url = self.build_url(None, &params);
266		let req = Request::get(url).body(Bytes::new())?;
267
268		let res = self.dispatch(req, None).await?;
269
270		let resp: ReadIndexResponse = serde_json::from_slice(&res.body)?;
271
272		let items = resp
273			.partition_keys
274			.into_iter()
275			.map(|ReadIndexItem { pk, info }| (pk, info))
276			.collect();
277
278		Ok(PaginatedRange {
279			items,
280			next_start: resp.next_start,
281		})
282	}
283
284	/// Perform an InsertBatch request, inserting multiple values at once. Note: this operation is
285	/// *not* atomic: it is possible for some sub-operations to fails and others to success. In
286	/// that case, failure is reported.
287	pub async fn insert_batch(&self, operations: &[BatchInsertOp<'_>]) -> Result<(), Error> {
288		let url = self.build_url::<&str>(None, &[]);
289		let payload = serde_json::to_vec(operations)?;
290		let req = Request::post(url).body(payload.into())?;
291
292		self.dispatch(req, None).await?;
293		Ok(())
294	}
295
296	/// Perform a ReadBatch request, reading multiple values or range of values at once.
297	pub async fn read_batch(
298		&self,
299		operations: &[BatchReadOp<'_>],
300	) -> Result<Vec<PaginatedRange<CausalValue>>, Error> {
301		let url = self.build_url(None, &[("search", "")]);
302		let payload = serde_json::to_vec(operations)?;
303		let req = Request::post(url).body(payload.into())?;
304
305		let res = self.dispatch(req, None).await?;
306
307		let resp: Vec<BatchReadResponse> = serde_json::from_slice(&res.body)?;
308
309		Ok(resp
310			.into_iter()
311			.map(|e| PaginatedRange {
312				items: e
313					.items
314					.into_iter()
315					.map(|BatchReadItem { sk, ct, v }| {
316						(
317							sk,
318							CausalValue {
319								causality: ct,
320								value: v,
321							},
322						)
323					})
324					.collect(),
325				next_start: e.next_start,
326			})
327			.collect())
328	}
329
330	/// Perform a DeleteBatch request, deleting mutiple values or range of values at once, without
331	/// providing causality information.
332	pub async fn delete_batch(&self, operations: &[BatchDeleteOp<'_>]) -> Result<Vec<u64>, Error> {
333		let url = self.build_url(None, &[("delete", "")]);
334		let payload = serde_json::to_vec(operations)?;
335		let req = Request::post(url).body(payload.into())?;
336
337		let res = self.dispatch(req, None).await?;
338
339		let resp: Vec<BatchDeleteResponse> = serde_json::from_slice(&res.body)?;
340
341		Ok(resp.into_iter().map(|r| r.deleted_items).collect())
342	}
343
344	async fn dispatch(
345		&self,
346		mut req: Request<Bytes>,
347		timeout: Option<Duration>,
348	) -> Result<Response, Error> {
349		req.headers_mut()
350			.insert(http::header::USER_AGENT, self.user_agent.clone());
351
352		use sha2::{Digest, Sha256};
353		let mut hasher = Sha256::new();
354		hasher.update(req.body());
355		let hash = hex::encode(&hasher.finalize());
356		req.headers_mut()
357			.insert(AMZ_CONTENT_SHA256, hash.try_into().unwrap());
358
359		debug!("request uri: {:?}", req.uri());
360
361		// Sign request
362		let signing_settings = SigningSettings::default();
363		let signing_params = SigningParams::builder()
364			.access_key(&self.config.aws_access_key_id)
365			.secret_key(&self.config.aws_secret_access_key)
366			.region(&self.config.region)
367			.service_name(SERVICE)
368			.time(SystemTime::now())
369			.settings(signing_settings)
370			.build()?;
371		// Convert the HTTP request into a signable request
372		let signable_request = SignableRequest::from(&req);
373
374		// Sign and then apply the signature to the request
375		let (signing_instructions, _signature) =
376			sign(signable_request, &signing_params)?.into_parts();
377		signing_instructions.apply_to_request(&mut req);
378
379		// Send and wait for timeout
380		let res = tokio::select! {
381			res = self.client.request(req.map(Body::from)) => res?,
382			_ = tokio::time::sleep(timeout.unwrap_or(DEFAULT_TIMEOUT)) => {
383				return Err(Error::Timeout);
384			}
385		};
386
387		let (mut res, body) = res.into_parts();
388		let causality_token = match res.headers.remove(GARAGE_CAUSALITY_TOKEN) {
389			Some(v) => Some(CausalityToken(v.to_str()?.to_string())),
390			None => None,
391		};
392		let content_type = match res.headers.remove(CONTENT_TYPE) {
393			Some(v) => Some(v.to_str()?.to_string()),
394			None => None,
395		};
396
397		let body = match res.status {
398			StatusCode::OK => hyper::body::to_bytes(body).await?,
399			StatusCode::NO_CONTENT => Bytes::new(),
400			StatusCode::NOT_FOUND => return Err(Error::NotFound),
401			StatusCode::NOT_MODIFIED => Bytes::new(),
402			s => {
403				let err_body = hyper::body::to_bytes(body).await.unwrap_or_default();
404				let err_body_str = std::str::from_utf8(&err_body)
405					.map(String::from)
406					.unwrap_or_else(|_| BASE64_STANDARD.encode(&err_body));
407
408				if s.is_client_error() || s.is_server_error() {
409					error!("Error response {}: {}", res.status, err_body_str);
410					let err = match serde_json::from_slice::<ErrorResponse>(&err_body) {
411						Ok(err) => Error::Remote(
412							res.status,
413							err.code.into(),
414							err.message.into(),
415							err.path.into(),
416						),
417						Err(_) => Error::Remote(
418							res.status,
419							"unknown".into(),
420							err_body_str.into(),
421							"?".into(),
422						),
423					};
424					return Err(err);
425				} else {
426					let msg = format!(
427						"Unexpected response code {}. Response body: {}",
428						res.status, err_body_str
429					);
430					error!("{}", msg);
431					return Err(Error::InvalidResponse(msg.into()));
432				}
433			}
434		};
435		debug!(
436			"Response body: {}",
437			std::str::from_utf8(&body)
438				.map(String::from)
439				.unwrap_or_else(|_| BASE64_STANDARD.encode(&body))
440		);
441
442		Ok(Response {
443			body,
444			status: res.status,
445			causality_token,
446			content_type,
447		})
448	}
449
450	fn build_url<V: AsRef<str>>(&self, partition_key: Option<&str>, query: &[(&str, V)]) -> String {
451		let mut url = format!("{}/{}", self.config.endpoint, self.config.bucket);
452		if let Some(pk) = partition_key {
453			url.push('/');
454			url.extend(utf8_percent_encode(pk, &PATH_ENCODE_SET));
455		}
456		if !query.is_empty() {
457			url.push('?');
458			for (i, (k, v)) in query.iter().enumerate() {
459				if i > 0 {
460					url.push('&');
461				}
462				url.extend(utf8_percent_encode(k, &STRICT_ENCODE_SET));
463				url.push('=');
464				url.extend(utf8_percent_encode(v.as_ref(), &STRICT_ENCODE_SET));
465			}
466		}
467		url
468	}
469}
470
471/// An opaque token used to convey causality between operations.
472#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
473#[serde(transparent)]
474pub struct CausalityToken(String);
475
476impl From<String> for CausalityToken {
477	fn from(v: String) -> Self {
478		CausalityToken(v)
479	}
480}
481
482impl From<CausalityToken> for String {
483	fn from(v: CausalityToken) -> Self {
484		v.0
485	}
486}
487
488impl AsRef<str> for CausalityToken {
489	fn as_ref(&self) -> &str {
490		&self.0
491	}
492}
493
494/// A value in K2V. can be either a binary value, or a tombstone.
495#[derive(Debug, Clone, PartialEq, Eq)]
496pub enum K2vValue {
497	Tombstone,
498	Value(Vec<u8>),
499}
500
501impl From<Vec<u8>> for K2vValue {
502	fn from(v: Vec<u8>) -> Self {
503		K2vValue::Value(v)
504	}
505}
506
507impl From<Option<Vec<u8>>> for K2vValue {
508	fn from(v: Option<Vec<u8>>) -> Self {
509		match v {
510			Some(v) => K2vValue::Value(v),
511			None => K2vValue::Tombstone,
512		}
513	}
514}
515
516impl<'de> Deserialize<'de> for K2vValue {
517	fn deserialize<D>(d: D) -> Result<Self, D::Error>
518	where
519		D: Deserializer<'de>,
520	{
521		let val: Option<&str> = Option::deserialize(d)?;
522		Ok(match val {
523			Some(s) => K2vValue::Value(
524				BASE64_STANDARD
525					.decode(s)
526					.map_err(|_| DeError::custom("invalid base64"))?,
527			),
528			None => K2vValue::Tombstone,
529		})
530	}
531}
532
533impl Serialize for K2vValue {
534	fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
535	where
536		S: Serializer,
537	{
538		match self {
539			K2vValue::Tombstone => serializer.serialize_none(),
540			K2vValue::Value(v) => {
541				let b64 = BASE64_STANDARD.encode(v);
542				serializer.serialize_str(&b64)
543			}
544		}
545	}
546}
547
548/// A set of K2vValue and associated causality information.
549#[derive(Debug, Clone, Serialize)]
550pub struct CausalValue {
551	pub causality: CausalityToken,
552	pub value: Vec<K2vValue>,
553}
554
555/// Result of paginated requests.
556#[derive(Debug, Clone)]
557pub struct PaginatedRange<V> {
558	pub items: BTreeMap<String, V>,
559	pub next_start: Option<String>,
560}
561
562/// Filter for batch operations.
563#[derive(Debug, Default, Clone, Deserialize, Serialize)]
564pub struct Filter<'a> {
565	pub start: Option<&'a str>,
566	pub end: Option<&'a str>,
567	pub prefix: Option<&'a str>,
568	pub limit: Option<u64>,
569	#[serde(default)]
570	pub reverse: bool,
571}
572
573#[derive(Debug, Default, Clone, Serialize)]
574pub struct PollRangeFilter<'a> {
575	pub start: Option<&'a str>,
576	pub end: Option<&'a str>,
577	pub prefix: Option<&'a str>,
578}
579
580#[derive(Debug, Clone, Serialize)]
581#[serde(rename_all = "camelCase")]
582struct PollRangeRequest<'a> {
583	#[serde(flatten)]
584	filter: PollRangeFilter<'a>,
585	seen_marker: Option<&'a str>,
586	timeout: u64,
587}
588
589#[derive(Debug, Clone, Deserialize)]
590#[serde(rename_all = "camelCase")]
591struct PollRangeResponse {
592	items: Vec<BatchReadItem>,
593	seen_marker: String,
594}
595
596impl<'a> Filter<'a> {
597	fn query_params(&self) -> Vec<(&'static str, std::borrow::Cow<str>)> {
598		let mut res = Vec::<(&'static str, std::borrow::Cow<str>)>::with_capacity(8);
599		if let Some(start) = self.start.as_deref() {
600			res.push(("start", start.into()));
601		}
602		if let Some(end) = self.end.as_deref() {
603			res.push(("end", end.into()));
604		}
605		if let Some(prefix) = self.prefix.as_deref() {
606			res.push(("prefix", prefix.into()));
607		}
608		if let Some(limit) = &self.limit {
609			res.push(("limit", limit.to_string().into()));
610		}
611		if self.reverse {
612			res.push(("reverse", "true".into()));
613		}
614		res
615	}
616}
617
618#[derive(Debug, Clone, Deserialize)]
619#[serde(rename_all = "camelCase")]
620struct ReadIndexResponse<'a> {
621	#[serde(flatten, borrow)]
622	#[allow(dead_code)]
623	filter: Filter<'a>,
624	partition_keys: Vec<ReadIndexItem>,
625	#[allow(dead_code)]
626	more: bool,
627	next_start: Option<String>,
628}
629
630#[derive(Debug, Clone, Deserialize)]
631struct ReadIndexItem {
632	pk: String,
633	#[serde(flatten)]
634	info: PartitionInfo,
635}
636
637/// Information about data stored with a given partition key.
638#[derive(Debug, Clone, Deserialize, Serialize)]
639pub struct PartitionInfo {
640	pub entries: u64,
641	pub conflicts: u64,
642	pub values: u64,
643	pub bytes: u64,
644}
645
646/// Single sub-operation of an InsertBatch.
647#[derive(Debug, Clone, Serialize)]
648pub struct BatchInsertOp<'a> {
649	#[serde(rename = "pk")]
650	pub partition_key: &'a str,
651	#[serde(rename = "sk")]
652	pub sort_key: &'a str,
653	#[serde(rename = "ct")]
654	pub causality: Option<CausalityToken>,
655	#[serde(rename = "v")]
656	pub value: K2vValue,
657}
658
659/// Single sub-operation of a ReadBatch.
660#[derive(Debug, Default, Clone, Deserialize, Serialize)]
661#[serde(rename_all = "camelCase")]
662pub struct BatchReadOp<'a> {
663	pub partition_key: &'a str,
664	#[serde(flatten, borrow)]
665	pub filter: Filter<'a>,
666	#[serde(default)]
667	pub single_item: bool,
668	#[serde(default)]
669	pub conflicts_only: bool,
670	#[serde(default)]
671	pub tombstones: bool,
672}
673
674#[derive(Debug, Clone, Deserialize)]
675#[serde(rename_all = "camelCase")]
676struct BatchReadResponse<'a> {
677	#[serde(flatten, borrow)]
678	#[allow(dead_code)]
679	op: BatchReadOp<'a>,
680	items: Vec<BatchReadItem>,
681	#[allow(dead_code)]
682	more: bool,
683	next_start: Option<String>,
684}
685
686#[derive(Debug, Clone, Deserialize)]
687struct BatchReadItem {
688	sk: String,
689	ct: CausalityToken,
690	v: Vec<K2vValue>,
691}
692
693/// Single sub-operation of a DeleteBatch
694#[derive(Debug, Clone, Deserialize, Serialize)]
695#[serde(rename_all = "camelCase")]
696pub struct BatchDeleteOp<'a> {
697	pub partition_key: &'a str,
698	pub prefix: Option<&'a str>,
699	pub start: Option<&'a str>,
700	pub end: Option<&'a str>,
701	#[serde(default)]
702	pub single_item: bool,
703}
704
705impl<'a> BatchDeleteOp<'a> {
706	pub fn new(partition_key: &'a str) -> Self {
707		BatchDeleteOp {
708			partition_key,
709			prefix: None,
710			start: None,
711			end: None,
712			single_item: false,
713		}
714	}
715}
716
717#[derive(Debug, Clone, Deserialize)]
718#[serde(rename_all = "camelCase")]
719struct BatchDeleteResponse<'a> {
720	#[serde(flatten, borrow)]
721	#[allow(dead_code)]
722	filter: BatchDeleteOp<'a>,
723	deleted_items: u64,
724}
725
726#[derive(Deserialize)]
727struct ErrorResponse {
728	code: String,
729	message: String,
730	#[allow(dead_code)]
731	region: String,
732	path: String,
733}
734
735struct Response {
736	body: Bytes,
737	status: StatusCode,
738	causality_token: Option<CausalityToken>,
739	content_type: Option<String>,
740}