1pub mod document;
4pub mod filter;
5pub mod query;
6
7use std::sync::Arc;
8
9use anyhow::{Context, anyhow, bail};
10use base64ct::{Base64, Encoding};
11use futures::future::FutureExt;
12use hmac::{Hmac, KeyInit, Mac};
13use omnia_wasi_jsondb::{
14 Document, FilterTree, FutureResult, QueryOpts, QueryResult, WasiJsonDbCtx,
15};
16use reqwest::Client as HttpClient;
17use serde_json::Value;
18use sha2::Sha256;
19
20use crate::Client;
21
22const API_VERSION: &str = "2026-02-06";
23const ACCEPT_HEADER: &str = "application/json;odata=fullmetadata";
24
25impl WasiJsonDbCtx for Client {
27 fn get(&self, collection: String, id: String) -> FutureResult<Option<Document>> {
28 let opts = Arc::clone(&self.options);
29 let http = self.http.clone();
30 let base = Arc::clone(&self.base_url);
31 let hmac_key = Arc::clone(&self.hmac_key);
32 async move {
33 let (table, pk) = parse_collection(&collection)?;
34 let pk = require_pk(&collection, pk.as_ref())?;
35 let uri = format!(
36 "{base}/{table}(PartitionKey='{pk}',RowKey='{rk}')",
37 pk = escape_entity_key(pk),
38 rk = escape_entity_key(&id),
39 );
40
41 let now = now_rfc1123();
42 let auth = sign_request(&opts.name, &hmac_key, &now, &uri)?;
43
44 let response = http
45 .get(&uri)
46 .headers(azure_headers(&now, &auth)?)
47 .send()
48 .await
49 .map_err(|e| anyhow!("HTTP request error: {e}"))?;
50
51 if response.status().as_u16() == 404 {
52 return Ok(None);
53 }
54 if !response.status().is_success() {
55 bail!(
56 "Azure Table get failed ({}): {}",
57 response.status(),
58 response.text().await.unwrap_or_default()
59 );
60 }
61
62 let entity: Value =
63 response.json().await.map_err(|e| anyhow!("failed to parse response JSON: {e}"))?;
64
65 Ok(Some(document::unflatten(&entity)?))
66 }
67 .boxed()
68 }
69
70 fn insert(&self, collection: String, doc: Document) -> FutureResult<()> {
71 let opts = Arc::clone(&self.options);
72 let http = self.http.clone();
73 let base = Arc::clone(&self.base_url);
74 let hmac_key = Arc::clone(&self.hmac_key);
75 async move {
76 let (table, pk) = parse_collection(&collection)?;
77 let pk = require_pk(&collection, pk.as_ref())?;
78 let uri = format!("{base}/{table}");
79 let body = document::flatten(&doc, pk)?;
80
81 let now = now_rfc1123();
82 let auth = sign_request(&opts.name, &hmac_key, &now, &uri)?;
83
84 let response = http
85 .post(&uri)
86 .headers(azure_headers(&now, &auth)?)
87 .json(&body)
88 .send()
89 .await
90 .map_err(|e| anyhow!("HTTP request error: {e}"))?;
91
92 if response.status().as_u16() == 409 {
93 bail!("entity already exists in '{collection}' with id '{}'", doc.id);
94 }
95 if !response.status().is_success() {
96 bail!(
97 "Azure Table insert failed ({}): {}",
98 response.status(),
99 response.text().await.unwrap_or_default()
100 );
101 }
102 Ok(())
103 }
104 .boxed()
105 }
106
107 fn put(&self, collection: String, doc: Document) -> FutureResult<()> {
108 let opts = Arc::clone(&self.options);
109 let http = self.http.clone();
110 let base = Arc::clone(&self.base_url);
111 let hmac_key = Arc::clone(&self.hmac_key);
112 async move {
113 let (table, pk) = parse_collection(&collection)?;
114 let pk = require_pk(&collection, pk.as_ref())?;
115 let uri = format!(
116 "{base}/{table}(PartitionKey='{epk}',RowKey='{rk}')",
117 epk = escape_entity_key(pk),
118 rk = escape_entity_key(&doc.id),
119 );
120 let body = document::flatten(&doc, pk)?;
121
122 let now = now_rfc1123();
123 let auth = sign_request(&opts.name, &hmac_key, &now, &uri)?;
124
125 let response = http
126 .put(&uri)
127 .headers(azure_headers(&now, &auth)?)
128 .json(&body)
129 .send()
130 .await
131 .map_err(|e| anyhow!("HTTP request error: {e}"))?;
132
133 if !response.status().is_success() {
134 bail!(
135 "Azure Table put failed ({}): {}",
136 response.status(),
137 response.text().await.unwrap_or_default()
138 );
139 }
140 Ok(())
141 }
142 .boxed()
143 }
144
145 fn delete(&self, collection: String, id: String) -> FutureResult<bool> {
146 let opts = Arc::clone(&self.options);
147 let http = self.http.clone();
148 let base = Arc::clone(&self.base_url);
149 let hmac_key = Arc::clone(&self.hmac_key);
150 async move {
151 let (table, pk) = parse_collection(&collection)?;
152 let pk = require_pk(&collection, pk.as_ref())?;
153 let uri = format!(
154 "{base}/{table}(PartitionKey='{epk}',RowKey='{rk}')",
155 epk = escape_entity_key(pk),
156 rk = escape_entity_key(&id),
157 );
158
159 let now = now_rfc1123();
160 let auth = sign_request(&opts.name, &hmac_key, &now, &uri)?;
161
162 let mut headers = azure_headers(&now, &auth)?;
163 headers.insert("If-Match", "*".parse().expect("valid header value"));
164
165 let response = http
166 .delete(&uri)
167 .headers(headers)
168 .send()
169 .await
170 .map_err(|e| anyhow!("HTTP request error: {e}"))?;
171
172 if response.status().as_u16() == 404 {
173 return Ok(false);
174 }
175 if !response.status().is_success() {
176 bail!(
177 "Azure Table delete failed ({}): {}",
178 response.status(),
179 response.text().await.unwrap_or_default()
180 );
181 }
182 Ok(true)
183 }
184 .boxed()
185 }
186
187 fn query(
194 &self, collection: String, filter: Option<FilterTree>, options: QueryOpts,
195 ) -> FutureResult<QueryResult> {
196 let opts = Arc::clone(&self.options);
197 let http = self.http.clone();
198 let base = Arc::clone(&self.base_url);
199 let hmac_key = Arc::clone(&self.hmac_key);
200 async move {
201 if options.offset.is_some_and(|o| o > 0) {
202 bail!(
203 "offset is not supported by Azure Table — \
204 use continuation tokens for pagination instead"
205 );
206 }
207
208 let fetch_limit = options.limit.map(|l| l as usize);
209 if fetch_limit == Some(0) {
210 return Ok(QueryResult {
211 documents: Vec::new(),
212 continuation: options.continuation,
213 });
214 }
215
216 let (table, pk) = parse_collection(&collection)?;
217 let user_filter = filter.as_ref().map(filter::to_odata).transpose()?;
218 let odata_filter = build_odata_filter(pk.as_deref(), user_filter.as_deref());
219
220 let mut all_documents: Vec<Document> = Vec::new();
221 let mut next_continuation = options.continuation.clone();
222
223 loop {
224 let remaining = fetch_limit.map(|lim| lim - all_documents.len());
225
226 let (body, continuation) = fetch_page(
227 &http,
228 &opts,
229 &base,
230 &hmac_key,
231 &table,
232 odata_filter.as_deref(),
233 remaining,
234 next_continuation.as_deref(),
235 )
236 .await?;
237
238 if let Some(entries) = body.get("value").and_then(Value::as_array) {
239 for entity in entries {
240 all_documents.push(document::unflatten(entity)?);
241 }
242 }
243
244 let has_more_pages = continuation.is_some();
245 next_continuation = continuation;
246
247 let reached_limit = fetch_limit.is_some_and(|lim| all_documents.len() >= lim);
248
249 if !has_more_pages || reached_limit {
250 break;
251 }
252 }
253
254 Ok(QueryResult {
255 documents: all_documents,
256 continuation: next_continuation,
257 })
258 }
259 .boxed()
260 }
261}
262
263impl Client {
265 pub async fn ensure_table(&self, table: &str) -> anyhow::Result<bool> {
275 let uri = format!("{}/Tables", self.base_url);
276 let now = now_rfc1123();
277 let auth = sign_request(&self.options.name, &self.hmac_key, &now, &uri)?;
278
279 let response = self
280 .http
281 .post(&uri)
282 .headers(azure_headers(&now, &auth)?)
283 .json(&serde_json::json!({"TableName": table}))
284 .send()
285 .await
286 .map_err(|e| anyhow!("create table request: {e}"))?;
287
288 match response.status().as_u16() {
289 201 | 204 => Ok(true),
290 409 => Ok(false),
291 _ => {
292 bail!(
293 "create table failed ({}): {}",
294 response.status(),
295 response.text().await.unwrap_or_default()
296 );
297 }
298 }
299 }
300}
301
302fn escape_entity_key(value: &str) -> String {
308 urlencoding::encode(&value.replace('\'', "''")).into_owned()
309}
310
311fn require_pk<'a>(collection: &str, pk: Option<&'a String>) -> anyhow::Result<&'a str> {
312 pk.map(String::as_str).ok_or_else(|| {
313 anyhow!(
314 "operation requires collection format '{{table}}/{{partitionKey}}', got '{collection}'"
315 )
316 })
317}
318
319fn build_odata_filter(pk: Option<&str>, server_filter: Option<&str>) -> Option<String> {
320 let mut parts: Vec<String> = Vec::new();
321 if let Some(pk) = pk {
322 parts.push(format!("PartitionKey eq '{}'", pk.replace('\'', "''")));
323 }
324 if let Some(sf) = server_filter {
325 parts.push(sf.to_owned());
326 }
327 if parts.is_empty() { None } else { Some(parts.join(" and ")) }
328}
329
330#[allow(clippy::similar_names, clippy::too_many_arguments)]
331async fn fetch_page(
332 http: &HttpClient, opts: &crate::ConnectOptions, base: &str, hmac_key: &[u8], table: &str,
333 odata_filter: Option<&str>, fetch_limit: Option<usize>, continuation: Option<&str>,
334) -> anyhow::Result<(Value, Option<String>)> {
335 let base_uri = format!("{base}/{table}()");
336
337 let mut query_params: Vec<String> = Vec::new();
338 if let Some(f) = odata_filter {
339 query_params.push(format!("$filter={}", urlencoding::encode(f)));
340 }
341 if let Some(limit) = fetch_limit {
342 query_params.push(format!("$top={limit}"));
343 }
344 if let Some(cont) = continuation {
345 let (next_pk, next_rk) = query::decode_continuation(cont);
346 query_params.push(format!("NextPartitionKey={}", urlencoding::encode(&next_pk)));
347 if let Some(rk) = next_rk {
348 query_params.push(format!("NextRowKey={}", urlencoding::encode(&rk)));
349 }
350 }
351
352 let uri = if query_params.is_empty() {
353 base_uri
354 } else {
355 format!("{base_uri}?{}", query_params.join("&"))
356 };
357
358 let now = now_rfc1123();
359 let auth = sign_request(&opts.name, hmac_key, &now, &uri)?;
360
361 let response = http
362 .get(&uri)
363 .headers(azure_headers(&now, &auth)?)
364 .send()
365 .await
366 .map_err(|e| anyhow!("HTTP request error: {e}"))?;
367
368 if !response.status().is_success() {
369 bail!(
370 "Azure Table query failed ({}): {}",
371 response.status(),
372 response.text().await.unwrap_or_default()
373 );
374 }
375
376 let continuation_pk = response
377 .headers()
378 .get("x-ms-continuation-NextPartitionKey")
379 .and_then(|v| v.to_str().ok())
380 .map(str::to_owned);
381 let continuation_rk = response
382 .headers()
383 .get("x-ms-continuation-NextRowKey")
384 .and_then(|v| v.to_str().ok())
385 .map(str::to_owned);
386
387 let body: Value =
388 response.json().await.map_err(|e| anyhow!("failed to parse response JSON: {e}"))?;
389
390 let token = query::encode_continuation(continuation_pk.as_deref(), continuation_rk.as_deref());
391 Ok((body, token))
392}
393
394fn validate_table_name(table: &str) -> anyhow::Result<()> {
398 let len = table.len();
399 if !(3..=63).contains(&len) {
400 bail!("invalid table name '{table}': length must be between 3 and 63 characters");
401 }
402 if !table.starts_with(|c: char| c.is_ascii_alphabetic()) {
403 bail!("invalid table name '{table}': first character must be a letter");
404 }
405 if !table.chars().all(|c| c.is_ascii_alphanumeric()) {
406 bail!("invalid table name '{table}': only ASCII letters and digits are allowed");
407 }
408 if table.eq_ignore_ascii_case("tables") {
409 bail!("invalid table name '{table}': reserved table name");
410 }
411 Ok(())
412}
413
414fn parse_collection(collection: &str) -> anyhow::Result<(String, Option<String>)> {
416 match collection.split_once('/') {
417 Some((table, pk)) if !table.is_empty() && !pk.is_empty() => {
418 validate_table_name(table)?;
419 Ok((table.to_owned(), Some(pk.to_owned())))
420 }
421 Some((table, _)) if !table.is_empty() => {
422 bail!("collection '{collection}' has an empty partition key after '/'")
423 }
424 Some(_) => bail!("collection '{collection}' has an empty table name"),
425 None if !collection.is_empty() => {
426 validate_table_name(collection)?;
427 Ok((collection.to_owned(), None))
428 }
429 _ => bail!("collection must not be empty"),
430 }
431}
432
433fn now_rfc1123() -> String {
434 chrono::Utc::now().format("%a, %d %b %Y %H:%M:%S GMT").to_string()
435}
436
437fn sign_request(
449 account_name: &str, hmac_key: &[u8], date_time: &str, uri: &str,
450) -> anyhow::Result<String> {
451 let uri_path = uri
452 .split("://")
453 .nth(1)
454 .and_then(|after_scheme| after_scheme.find('/').map(|i| &after_scheme[i..]))
455 .unwrap_or("/");
456 let uri_path = uri_path.split('?').next().unwrap_or(uri_path);
457 let resource = format!("/{account_name}{uri_path}");
458 let string_to_sign = format!("{date_time}\n{resource}");
459 let mut hmac = Hmac::<Sha256>::new_from_slice(hmac_key)
460 .map_err(|e| anyhow!("HMAC initialization error: {e}"))?;
461 hmac.update(string_to_sign.as_bytes());
462 let signature = hmac.finalize().into_bytes();
463 let encoded = Base64::encode_string(&signature);
464 Ok(format!("SharedKeyLite {account_name}:{encoded}"))
465}
466
467fn azure_headers(date: &str, auth: &str) -> anyhow::Result<reqwest::header::HeaderMap> {
469 let mut headers = reqwest::header::HeaderMap::new();
470 headers.insert("x-ms-date", date.parse().context("invalid x-ms-date header value")?);
471 headers.insert("x-ms-version", API_VERSION.parse().expect("valid header value"));
472 headers.insert("Authorization", auth.parse().context("invalid Authorization header value")?);
473 headers.insert("Accept", ACCEPT_HEADER.parse().expect("valid header value"));
474 Ok(headers)
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480
481 #[test]
482 fn parse_collection_full() {
483 let (table, pk) = parse_collection("users/tenant-a").unwrap();
484 assert_eq!(table, "users");
485 assert_eq!(pk.as_deref(), Some("tenant-a"));
486 }
487
488 #[test]
489 fn parse_collection_table_only() {
490 let (table, pk) = parse_collection("users").unwrap();
491 assert_eq!(table, "users");
492 assert!(pk.is_none());
493 }
494
495 #[test]
496 fn parse_collection_empty_errors() {
497 parse_collection("").unwrap_err();
498 }
499
500 #[test]
501 fn parse_collection_empty_pk_errors() {
502 parse_collection("users/").unwrap_err();
503 }
504
505 #[test]
506 fn parse_collection_rejects_short_table_name() {
507 parse_collection("ab/pk").unwrap_err();
508 }
509
510 #[test]
511 fn parse_collection_rejects_special_chars_in_table() {
512 parse_collection("my-table/pk").unwrap_err();
513 }
514
515 #[test]
516 fn parse_collection_rejects_reserved_table_name() {
517 parse_collection("Tables/pk").unwrap_err();
518 }
519
520 #[test]
521 fn sign_request_uses_shared_key_lite_format() {
522 let key = Base64::encode_string(b"fake-key-for-unit-test-1234567!");
523 let hmac_key = Base64::decode_vec(&key).unwrap();
524 let auth = sign_request(
525 "myaccount",
526 &hmac_key,
527 "Mon, 01 Jan 2026 00:00:00 GMT",
528 "https://myaccount.table.core.windows.net/Tables",
529 )
530 .unwrap();
531 assert!(auth.starts_with("SharedKeyLite myaccount:"), "{auth}");
532 }
533
534 #[test]
535 fn sign_request_azurite_preserves_account_in_path() {
536 let key = Base64::encode_string(b"fake-key-for-unit-test-1234567!");
537 let hmac_key = Base64::decode_vec(&key).unwrap();
538 let auth_azurite = sign_request(
543 "devstoreaccount1",
544 &hmac_key,
545 "Mon, 01 Jan 2026 00:00:00 GMT",
546 "http://127.0.0.1:10002/devstoreaccount1/myTable",
547 )
548 .unwrap();
549 let auth_cloud = sign_request(
550 "devstoreaccount1",
551 &hmac_key,
552 "Mon, 01 Jan 2026 00:00:00 GMT",
553 "https://devstoreaccount1.table.core.windows.net/myTable",
554 )
555 .unwrap();
556 assert_ne!(
557 auth_azurite, auth_cloud,
558 "Azurite and cloud should differ — Azurite canonicalized resource includes account in URI path"
559 );
560 }
561}