1use std::collections::HashMap;
4
5use reqwest::Method;
6use serde::{Deserialize, Serialize};
7
8use crate::client::{HeyoClient, HeyoClientOptions, RequestOptions};
9use crate::commands::encode_path;
10use crate::errors::HeyoError;
11
12#[derive(Debug, Clone, Deserialize)]
13pub struct DatabaseInfo {
14 pub id: String,
15 pub name: String,
16 pub user_id: String,
17 #[serde(default)]
18 pub account_id: Option<String>,
19 #[serde(default)]
20 pub backend_server_id: Option<String>,
21 #[serde(default)]
22 pub backend_database_id: Option<String>,
23 #[serde(default)]
24 pub region: Option<String>,
25 pub status: String,
26 #[serde(default = "default_engine")]
30 pub engine: String,
31 #[serde(default)]
32 pub size_class: Option<String>,
33 #[serde(default)]
34 pub s3_key: Option<String>,
35 #[serde(default)]
36 pub wal_s3_prefix: Option<String>,
37 #[serde(default)]
38 pub error_message: Option<String>,
39 pub created_at: String,
40 pub updated_at: String,
41 pub status_changed_at: String,
42}
43
44fn default_engine() -> String {
45 "sqlite".to_string()
46}
47
48#[derive(Debug, Clone, Default, Serialize)]
49pub struct DatabaseCreateOptions {
50 pub name: String,
51 pub region: String,
52 #[serde(skip_serializing_if = "Option::is_none", rename = "size_class")]
53 pub size_class: Option<String>,
54 #[serde(skip_serializing_if = "Option::is_none", rename = "env_vars")]
55 pub env_vars: Option<HashMap<String, String>>,
56 #[serde(skip_serializing_if = "Option::is_none")]
58 pub engine: Option<String>,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63#[serde(untagged)]
64pub enum SqlValue {
65 Null,
66 Bool(bool),
67 Int(i64),
68 Float(f64),
69 Text(String),
70}
71
72impl SqlValue {
73 pub fn as_text(&self) -> Option<&str> {
74 if let SqlValue::Text(s) = self {
75 Some(s)
76 } else {
77 None
78 }
79 }
80 pub fn as_i64(&self) -> Option<i64> {
81 if let SqlValue::Int(n) = self {
82 Some(*n)
83 } else {
84 None
85 }
86 }
87}
88
89impl From<&str> for SqlValue {
90 fn from(s: &str) -> Self {
91 SqlValue::Text(s.to_string())
92 }
93}
94impl From<String> for SqlValue {
95 fn from(s: String) -> Self {
96 SqlValue::Text(s)
97 }
98}
99impl From<i64> for SqlValue {
100 fn from(n: i64) -> Self {
101 SqlValue::Int(n)
102 }
103}
104impl From<i32> for SqlValue {
105 fn from(n: i32) -> Self {
106 SqlValue::Int(n as i64)
107 }
108}
109impl From<bool> for SqlValue {
110 fn from(b: bool) -> Self {
111 SqlValue::Bool(b)
112 }
113}
114impl From<f64> for SqlValue {
115 fn from(f: f64) -> Self {
116 SqlValue::Float(f)
117 }
118}
119
120#[derive(Debug, Clone, Serialize)]
121pub struct SqlStatement {
122 pub sql: String,
123 #[serde(default)]
124 pub args: Vec<SqlValue>,
125}
126
127impl SqlStatement {
128 pub fn new(sql: impl Into<String>) -> Self {
129 Self {
130 sql: sql.into(),
131 args: Vec::new(),
132 }
133 }
134 pub fn with_args(sql: impl Into<String>, args: Vec<SqlValue>) -> Self {
135 Self {
136 sql: sql.into(),
137 args,
138 }
139 }
140}
141
142#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
143#[serde(rename_all = "lowercase")]
144pub enum SqlTransactionMode {
145 Deferred,
146 Immediate,
147 Exclusive,
148}
149
150#[derive(Debug, Clone, Default)]
151pub struct ExecOptions {
152 pub transaction: Option<SqlTransactionMode>,
154 pub max_rows: Option<u32>,
156}
157
158#[derive(Debug, Clone)]
159pub struct ExecResult {
160 pub columns: Vec<String>,
161 pub rows: Vec<Vec<SqlValue>>,
162 pub rows_affected: u64,
163 pub last_insert_row_id: Option<i64>,
164 pub truncated: bool,
165}
166
167#[derive(Debug, Clone)]
168pub struct BatchResult {
169 pub results: Vec<ExecResult>,
170 pub elapsed_ms: u64,
171}
172
173#[derive(Deserialize, Default)]
174struct RawStatementResult {
175 #[serde(default)]
176 columns: Vec<String>,
177 #[serde(default)]
178 rows: Vec<Vec<SqlValue>>,
179 #[serde(default)]
180 rows_affected: Option<u64>,
181 #[serde(default)]
182 last_insert_rowid: Option<i64>,
183 #[serde(default)]
184 truncated: Option<bool>,
185}
186
187#[derive(Deserialize, Default)]
188struct RawExecResponse {
189 #[serde(default)]
190 results: Vec<RawStatementResult>,
191 #[serde(default)]
192 elapsed_ms: u64,
193}
194
195impl From<RawStatementResult> for ExecResult {
196 fn from(r: RawStatementResult) -> Self {
197 ExecResult {
198 columns: r.columns,
199 rows: r.rows,
200 rows_affected: r.rows_affected.unwrap_or(0),
201 last_insert_row_id: r.last_insert_rowid,
202 truncated: r.truncated.unwrap_or(false),
203 }
204 }
205}
206
207#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
208#[serde(rename_all = "lowercase")]
209pub enum ConnectionScope {
210 Read,
211 Write,
212}
213
214#[derive(Debug, Clone, Default)]
215pub struct ConnectionTokenOptions {
216 pub ttl_seconds: Option<u64>,
217 pub scopes: Option<Vec<ConnectionScope>>,
218}
219
220#[derive(Debug, Clone)]
221pub struct ConnectionToken {
222 pub id: String,
223 pub database_id: String,
224 pub url: String,
226 pub auth_token: String,
228 pub scopes: Vec<ConnectionScope>,
229 pub expires_at: String,
230}
231
232#[derive(Deserialize)]
233struct RawConnectionToken {
234 id: String,
235 database_id: String,
236 url: String,
237 auth_token: String,
238 #[serde(default)]
239 scopes: Vec<ConnectionScope>,
240 expires_at: String,
241}
242
243impl From<RawConnectionToken> for ConnectionToken {
244 fn from(r: RawConnectionToken) -> Self {
245 ConnectionToken {
246 id: r.id,
247 database_id: r.database_id,
248 url: r.url,
249 auth_token: r.auth_token,
250 scopes: r.scopes,
251 expires_at: r.expires_at,
252 }
253 }
254}
255
256#[derive(Debug, Clone)]
257pub struct ConnectionTokenInfo {
258 pub id: String,
259 pub database_id: String,
260 pub scopes: Vec<ConnectionScope>,
261 pub revoked: bool,
262 pub expires_at: String,
263 pub created_at: String,
264 pub last_used_at: Option<String>,
265}
266
267#[derive(Deserialize)]
268struct RawConnectionTokenInfo {
269 id: String,
270 database_id: String,
271 #[serde(default)]
272 scopes: Vec<ConnectionScope>,
273 #[serde(default)]
274 revoked: bool,
275 expires_at: String,
276 created_at: String,
277 #[serde(default)]
278 last_used_at: Option<String>,
279}
280
281impl From<RawConnectionTokenInfo> for ConnectionTokenInfo {
282 fn from(r: RawConnectionTokenInfo) -> Self {
283 ConnectionTokenInfo {
284 id: r.id,
285 database_id: r.database_id,
286 scopes: r.scopes,
287 revoked: r.revoked,
288 expires_at: r.expires_at,
289 created_at: r.created_at,
290 last_used_at: r.last_used_at,
291 }
292 }
293}
294
295#[derive(Debug, Clone)]
296pub struct CheckoutResult {
297 pub database_id: String,
298 pub data_version: i64,
299 pub bytes: Vec<u8>,
301}
302
303#[derive(Debug, Clone, Default)]
304pub struct CheckinOptions {
305 pub expected_version: Option<i64>,
306 pub force: bool,
307}
308
309#[derive(Debug, Clone)]
310pub struct CheckinResult {
311 pub database_id: String,
312 pub data_version: i64,
313 pub s3_key: String,
314 pub forced: bool,
315}
316
317#[derive(Deserialize)]
318struct CheckinResponse {
319 database_id: String,
320 data_version: i64,
321 s3_key: String,
322 #[serde(default)]
323 forced: bool,
324}
325
326#[derive(Deserialize)]
327struct DatabasesEnvelope {
328 #[serde(default)]
329 databases: Vec<DatabaseInfo>,
330}
331
332#[derive(Deserialize)]
333struct RegionsEnvelope {
334 #[serde(default)]
335 regions: Vec<String>,
336}
337
338#[derive(Deserialize)]
339struct ConnectionInfoRaw {
340 database_id: String,
341 url: String,
342}
343
344#[derive(Deserialize)]
345struct TokensEnvelope {
346 #[serde(default)]
347 tokens: Vec<RawConnectionTokenInfo>,
348}
349
350#[derive(Clone)]
351pub struct Database {
352 id: String,
353 client: HeyoClient,
354}
355
356impl Database {
357 fn from_raw(client: HeyoClient, info: DatabaseInfo) -> Self {
358 Self {
359 id: info.id,
360 client,
361 }
362 }
363
364 pub fn id(&self) -> &str {
365 &self.id
366 }
367
368 pub fn client(&self) -> &HeyoClient {
369 &self.client
370 }
371
372 pub async fn create(
373 options: DatabaseCreateOptions,
374 client_options: HeyoClientOptions,
375 ) -> Result<Self, HeyoError> {
376 let client = HeyoClient::new(client_options)?;
377 let raw: DatabaseInfo = client
378 .request(
379 Method::POST,
380 "/sqlite-databases",
381 Some(&options),
382 RequestOptions::default(),
383 )
384 .await?;
385 Ok(Database::from_raw(client, raw))
386 }
387
388 pub async fn list(
389 client_options: HeyoClientOptions,
390 ) -> Result<Vec<DatabaseInfo>, HeyoError> {
391 let client = HeyoClient::new(client_options)?;
392 let env: DatabasesEnvelope = client
393 .request(Method::GET, "/sqlite-databases", None::<&()>, RequestOptions::default())
394 .await?;
395 Ok(env.databases)
396 }
397
398 pub async fn get(id: &str, client_options: HeyoClientOptions) -> Result<Self, HeyoError> {
399 let client = HeyoClient::new(client_options)?;
400 let path = format!("/sqlite-databases/{}", encode_path(id));
401 let raw: DatabaseInfo = client
402 .request(Method::GET, &path, None::<&()>, RequestOptions::default())
403 .await?;
404 Ok(Database::from_raw(client, raw))
405 }
406
407 pub async fn regions(client_options: HeyoClientOptions) -> Result<Vec<String>, HeyoError> {
409 let client = HeyoClient::new(client_options)?;
410 let env: RegionsEnvelope = client
411 .request(Method::GET, "/sqlite-regions", None::<&()>, RequestOptions::default())
412 .await?;
413 Ok(env.regions)
414 }
415
416 pub async fn info(&self) -> Result<DatabaseInfo, HeyoError> {
417 let path = format!("/sqlite-databases/{}", encode_path(&self.id));
418 self.client
419 .request(Method::GET, &path, None::<&()>, RequestOptions::default())
420 .await
421 }
422
423 pub async fn delete(&self) -> Result<(), HeyoError> {
424 let path = format!("/sqlite-databases/{}", encode_path(&self.id));
425 self.client
426 .request::<serde_json::Value>(Method::DELETE, &path, None::<&()>, RequestOptions::default())
427 .await?;
428 Ok(())
429 }
430
431 pub async fn exec(
433 &self,
434 sql: &str,
435 args: Vec<SqlValue>,
436 options: ExecOptions,
437 ) -> Result<ExecResult, HeyoError> {
438 let batch = self
439 .batch(vec![SqlStatement::with_args(sql, args)], options)
440 .await?;
441 batch
442 .results
443 .into_iter()
444 .next()
445 .ok_or_else(|| HeyoError::api(0, "empty batch result"))
446 }
447
448 pub async fn batch(
450 &self,
451 statements: Vec<SqlStatement>,
452 options: ExecOptions,
453 ) -> Result<BatchResult, HeyoError> {
454 #[derive(Serialize)]
455 struct Body<'a> {
456 statements: &'a [SqlStatement],
457 #[serde(skip_serializing_if = "Option::is_none")]
458 transaction: Option<SqlTransactionMode>,
459 #[serde(skip_serializing_if = "Option::is_none", rename = "max_rows")]
460 max_rows: Option<u32>,
461 }
462 let body = Body {
463 statements: &statements,
464 transaction: options.transaction,
465 max_rows: options.max_rows,
466 };
467 let path = format!("/sqlite-databases/{}/exec", encode_path(&self.id));
468 let raw: RawExecResponse = self
469 .client
470 .request(Method::POST, &path, Some(&body), RequestOptions::default())
471 .await?;
472 Ok(BatchResult {
473 results: raw.results.into_iter().map(ExecResult::from).collect(),
474 elapsed_ms: raw.elapsed_ms,
475 })
476 }
477
478 pub async fn connect_token(
479 &self,
480 options: ConnectionTokenOptions,
481 ) -> Result<ConnectionToken, HeyoError> {
482 #[derive(Serialize)]
483 struct Body {
484 #[serde(skip_serializing_if = "Option::is_none", rename = "ttl_seconds")]
485 ttl_seconds: Option<u64>,
486 #[serde(skip_serializing_if = "Option::is_none")]
487 scopes: Option<Vec<ConnectionScope>>,
488 }
489 let body = Body {
490 ttl_seconds: options.ttl_seconds,
491 scopes: options.scopes,
492 };
493 let path = format!("/sqlite-databases/{}/connection", encode_path(&self.id));
494 let raw: RawConnectionToken = self
495 .client
496 .request(Method::POST, &path, Some(&body), RequestOptions::default())
497 .await?;
498 Ok(raw.into())
499 }
500
501 pub async fn connection_info(&self) -> Result<(String, String), HeyoError> {
503 let path = format!("/sqlite-databases/{}/connection-info", encode_path(&self.id));
504 let raw: ConnectionInfoRaw = self
505 .client
506 .request(Method::GET, &path, None::<&()>, RequestOptions::default())
507 .await?;
508 Ok((raw.database_id, raw.url))
509 }
510
511 pub async fn list_connections(&self) -> Result<Vec<ConnectionTokenInfo>, HeyoError> {
512 let path = format!("/sqlite-databases/{}/connection-tokens", encode_path(&self.id));
513 let env: TokensEnvelope = self
514 .client
515 .request(Method::GET, &path, None::<&()>, RequestOptions::default())
516 .await?;
517 Ok(env.tokens.into_iter().map(ConnectionTokenInfo::from).collect())
518 }
519
520 pub async fn revoke_connection(&self, token_id: &str) -> Result<(), HeyoError> {
521 let path = format!(
522 "/sqlite-databases/{}/connection-tokens/{}",
523 encode_path(&self.id),
524 encode_path(token_id)
525 );
526 self.client
527 .request::<serde_json::Value>(Method::DELETE, &path, None::<&()>, RequestOptions::default())
528 .await?;
529 Ok(())
530 }
531
532 pub async fn checkout(&self) -> Result<CheckoutResult, HeyoError> {
535 let path = format!("/sqlite-databases/{}/file", encode_path(&self.id));
536 let response = self
537 .client
538 .raw_request(Method::GET, &path, None::<&()>, RequestOptions::default())
539 .await?;
540 if !response.status().is_success() {
541 let status = response.status().as_u16();
542 let body = response.bytes().await.unwrap_or_default();
543 return Err(HeyoError::api(
544 status,
545 format!(
546 "checkout failed for {}: {}",
547 self.id,
548 String::from_utf8_lossy(&body)
549 ),
550 ));
551 }
552 let version = response
553 .headers()
554 .get("x-heyo-data-version")
555 .and_then(|v| v.to_str().ok())
556 .and_then(|s| s.parse::<i64>().ok())
557 .ok_or_else(|| {
558 HeyoError::api(0, "checkout response missing X-Heyo-Data-Version header")
559 })?;
560 let bytes = response
561 .bytes()
562 .await
563 .map_err(|e| HeyoError::api(0, format!("read checkout body: {}", e)))?;
564 Ok(CheckoutResult {
565 database_id: self.id.clone(),
566 data_version: version,
567 bytes: bytes.to_vec(),
568 })
569 }
570
571 pub async fn checkin(
574 &self,
575 bytes: Vec<u8>,
576 options: CheckinOptions,
577 ) -> Result<CheckinResult, HeyoError> {
578 if !options.force && options.expected_version.is_none() {
579 return Err(HeyoError::invalid(
580 "checkin() requires `expected_version` unless `force = true`",
581 ));
582 }
583 let mut req_opts = RequestOptions::default();
584 if let Some(v) = options.expected_version {
585 req_opts
586 .query
587 .push(("expected_version".to_string(), v.to_string()));
588 }
589 if options.force {
590 req_opts.query.push(("force".to_string(), "true".to_string()));
591 }
592 let path = format!("/sqlite-databases/{}/file", encode_path(&self.id));
593 let response = self
594 .client
595 .put_bytes(&path, bytes, "application/gzip", req_opts)
596 .await?;
597 let status = response.status();
598 let body = response
599 .bytes()
600 .await
601 .map_err(|e| HeyoError::api(0, format!("read checkin body: {}", e)))?;
602 if status.as_u16() == 409 {
603 let mut current = -1_i64;
604 if let Ok(v) = serde_json::from_slice::<serde_json::Value>(&body) {
605 if let Some(n) = v.get("current_version").and_then(|x| x.as_i64()) {
606 current = n;
607 }
608 }
609 return Err(HeyoError::CheckinConflict {
610 expected: options.expected_version,
611 current,
612 });
613 }
614 if !status.is_success() {
615 return Err(HeyoError::api(
616 status.as_u16(),
617 format!("checkin failed: {}", String::from_utf8_lossy(&body)),
618 ));
619 }
620 let resp: CheckinResponse = serde_json::from_slice(&body)
621 .map_err(|e| HeyoError::api(0, format!("parse checkin response: {}", e)))?;
622 Ok(CheckinResult {
623 database_id: resp.database_id,
624 data_version: resp.data_version,
625 s3_key: resp.s3_key,
626 forced: resp.forced,
627 })
628 }
629}