use std::collections::HashMap;
use reqwest::Method;
use serde::{Deserialize, Serialize};
use crate::client::{HeyoClient, HeyoClientOptions, RequestOptions};
use crate::commands::encode_path;
use crate::errors::HeyoError;
#[derive(Debug, Clone, Deserialize)]
pub struct DatabaseInfo {
pub id: String,
pub name: String,
pub user_id: String,
#[serde(default)]
pub account_id: Option<String>,
#[serde(default)]
pub backend_server_id: Option<String>,
#[serde(default)]
pub backend_database_id: Option<String>,
#[serde(default)]
pub region: Option<String>,
pub status: String,
#[serde(default = "default_engine")]
pub engine: String,
#[serde(default)]
pub size_class: Option<String>,
#[serde(default)]
pub s3_key: Option<String>,
#[serde(default)]
pub wal_s3_prefix: Option<String>,
#[serde(default)]
pub error_message: Option<String>,
pub created_at: String,
pub updated_at: String,
pub status_changed_at: String,
}
fn default_engine() -> String {
"sqlite".to_string()
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct DatabaseCreateOptions {
pub name: String,
pub region: String,
#[serde(skip_serializing_if = "Option::is_none", rename = "size_class")]
pub size_class: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", rename = "env_vars")]
pub env_vars: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub engine: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum SqlValue {
Null,
Bool(bool),
Int(i64),
Float(f64),
Text(String),
}
impl SqlValue {
pub fn as_text(&self) -> Option<&str> {
if let SqlValue::Text(s) = self {
Some(s)
} else {
None
}
}
pub fn as_i64(&self) -> Option<i64> {
if let SqlValue::Int(n) = self {
Some(*n)
} else {
None
}
}
}
impl From<&str> for SqlValue {
fn from(s: &str) -> Self {
SqlValue::Text(s.to_string())
}
}
impl From<String> for SqlValue {
fn from(s: String) -> Self {
SqlValue::Text(s)
}
}
impl From<i64> for SqlValue {
fn from(n: i64) -> Self {
SqlValue::Int(n)
}
}
impl From<i32> for SqlValue {
fn from(n: i32) -> Self {
SqlValue::Int(n as i64)
}
}
impl From<bool> for SqlValue {
fn from(b: bool) -> Self {
SqlValue::Bool(b)
}
}
impl From<f64> for SqlValue {
fn from(f: f64) -> Self {
SqlValue::Float(f)
}
}
#[derive(Debug, Clone, Serialize)]
pub struct SqlStatement {
pub sql: String,
#[serde(default)]
pub args: Vec<SqlValue>,
}
impl SqlStatement {
pub fn new(sql: impl Into<String>) -> Self {
Self {
sql: sql.into(),
args: Vec::new(),
}
}
pub fn with_args(sql: impl Into<String>, args: Vec<SqlValue>) -> Self {
Self {
sql: sql.into(),
args,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum SqlTransactionMode {
Deferred,
Immediate,
Exclusive,
}
#[derive(Debug, Clone, Default)]
pub struct ExecOptions {
pub transaction: Option<SqlTransactionMode>,
pub max_rows: Option<u32>,
}
#[derive(Debug, Clone)]
pub struct ExecResult {
pub columns: Vec<String>,
pub rows: Vec<Vec<SqlValue>>,
pub rows_affected: u64,
pub last_insert_row_id: Option<i64>,
pub truncated: bool,
}
#[derive(Debug, Clone)]
pub struct BatchResult {
pub results: Vec<ExecResult>,
pub elapsed_ms: u64,
}
#[derive(Deserialize, Default)]
struct RawStatementResult {
#[serde(default)]
columns: Vec<String>,
#[serde(default)]
rows: Vec<Vec<SqlValue>>,
#[serde(default)]
rows_affected: Option<u64>,
#[serde(default)]
last_insert_rowid: Option<i64>,
#[serde(default)]
truncated: Option<bool>,
}
#[derive(Deserialize, Default)]
struct RawExecResponse {
#[serde(default)]
results: Vec<RawStatementResult>,
#[serde(default)]
elapsed_ms: u64,
}
impl From<RawStatementResult> for ExecResult {
fn from(r: RawStatementResult) -> Self {
ExecResult {
columns: r.columns,
rows: r.rows,
rows_affected: r.rows_affected.unwrap_or(0),
last_insert_row_id: r.last_insert_rowid,
truncated: r.truncated.unwrap_or(false),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ConnectionScope {
Read,
Write,
}
#[derive(Debug, Clone, Default)]
pub struct ConnectionTokenOptions {
pub ttl_seconds: Option<u64>,
pub scopes: Option<Vec<ConnectionScope>>,
}
#[derive(Debug, Clone)]
pub struct ConnectionToken {
pub id: String,
pub database_id: String,
pub url: String,
pub auth_token: String,
pub scopes: Vec<ConnectionScope>,
pub expires_at: String,
}
#[derive(Deserialize)]
struct RawConnectionToken {
id: String,
database_id: String,
url: String,
auth_token: String,
#[serde(default)]
scopes: Vec<ConnectionScope>,
expires_at: String,
}
impl From<RawConnectionToken> for ConnectionToken {
fn from(r: RawConnectionToken) -> Self {
ConnectionToken {
id: r.id,
database_id: r.database_id,
url: r.url,
auth_token: r.auth_token,
scopes: r.scopes,
expires_at: r.expires_at,
}
}
}
#[derive(Debug, Clone)]
pub struct ConnectionTokenInfo {
pub id: String,
pub database_id: String,
pub scopes: Vec<ConnectionScope>,
pub revoked: bool,
pub expires_at: String,
pub created_at: String,
pub last_used_at: Option<String>,
}
#[derive(Deserialize)]
struct RawConnectionTokenInfo {
id: String,
database_id: String,
#[serde(default)]
scopes: Vec<ConnectionScope>,
#[serde(default)]
revoked: bool,
expires_at: String,
created_at: String,
#[serde(default)]
last_used_at: Option<String>,
}
impl From<RawConnectionTokenInfo> for ConnectionTokenInfo {
fn from(r: RawConnectionTokenInfo) -> Self {
ConnectionTokenInfo {
id: r.id,
database_id: r.database_id,
scopes: r.scopes,
revoked: r.revoked,
expires_at: r.expires_at,
created_at: r.created_at,
last_used_at: r.last_used_at,
}
}
}
#[derive(Debug, Clone)]
pub struct CheckoutResult {
pub database_id: String,
pub data_version: i64,
pub bytes: Vec<u8>,
}
#[derive(Debug, Clone, Default)]
pub struct CheckinOptions {
pub expected_version: Option<i64>,
pub force: bool,
}
#[derive(Debug, Clone)]
pub struct CheckinResult {
pub database_id: String,
pub data_version: i64,
pub s3_key: String,
pub forced: bool,
}
#[derive(Deserialize)]
struct CheckinResponse {
database_id: String,
data_version: i64,
s3_key: String,
#[serde(default)]
forced: bool,
}
#[derive(Deserialize)]
struct DatabasesEnvelope {
#[serde(default)]
databases: Vec<DatabaseInfo>,
}
#[derive(Deserialize)]
struct RegionsEnvelope {
#[serde(default)]
regions: Vec<String>,
}
#[derive(Deserialize)]
struct ConnectionInfoRaw {
database_id: String,
url: String,
}
#[derive(Deserialize)]
struct TokensEnvelope {
#[serde(default)]
tokens: Vec<RawConnectionTokenInfo>,
}
#[derive(Clone)]
pub struct Database {
id: String,
client: HeyoClient,
}
impl Database {
fn from_raw(client: HeyoClient, info: DatabaseInfo) -> Self {
Self {
id: info.id,
client,
}
}
pub fn id(&self) -> &str {
&self.id
}
pub fn client(&self) -> &HeyoClient {
&self.client
}
pub async fn create(
options: DatabaseCreateOptions,
client_options: HeyoClientOptions,
) -> Result<Self, HeyoError> {
let client = HeyoClient::new(client_options)?;
let raw: DatabaseInfo = client
.request(
Method::POST,
"/sqlite-databases",
Some(&options),
RequestOptions::default(),
)
.await?;
Ok(Database::from_raw(client, raw))
}
pub async fn list(
client_options: HeyoClientOptions,
) -> Result<Vec<DatabaseInfo>, HeyoError> {
let client = HeyoClient::new(client_options)?;
let env: DatabasesEnvelope = client
.request(Method::GET, "/sqlite-databases", None::<&()>, RequestOptions::default())
.await?;
Ok(env.databases)
}
pub async fn get(id: &str, client_options: HeyoClientOptions) -> Result<Self, HeyoError> {
let client = HeyoClient::new(client_options)?;
let path = format!("/sqlite-databases/{}", encode_path(id));
let raw: DatabaseInfo = client
.request(Method::GET, &path, None::<&()>, RequestOptions::default())
.await?;
Ok(Database::from_raw(client, raw))
}
pub async fn regions(client_options: HeyoClientOptions) -> Result<Vec<String>, HeyoError> {
let client = HeyoClient::new(client_options)?;
let env: RegionsEnvelope = client
.request(Method::GET, "/sqlite-regions", None::<&()>, RequestOptions::default())
.await?;
Ok(env.regions)
}
pub async fn info(&self) -> Result<DatabaseInfo, HeyoError> {
let path = format!("/sqlite-databases/{}", encode_path(&self.id));
self.client
.request(Method::GET, &path, None::<&()>, RequestOptions::default())
.await
}
pub async fn delete(&self) -> Result<(), HeyoError> {
let path = format!("/sqlite-databases/{}", encode_path(&self.id));
self.client
.request::<serde_json::Value>(Method::DELETE, &path, None::<&()>, RequestOptions::default())
.await?;
Ok(())
}
pub async fn exec(
&self,
sql: &str,
args: Vec<SqlValue>,
options: ExecOptions,
) -> Result<ExecResult, HeyoError> {
let batch = self
.batch(vec![SqlStatement::with_args(sql, args)], options)
.await?;
batch
.results
.into_iter()
.next()
.ok_or_else(|| HeyoError::api(0, "empty batch result"))
}
pub async fn batch(
&self,
statements: Vec<SqlStatement>,
options: ExecOptions,
) -> Result<BatchResult, HeyoError> {
#[derive(Serialize)]
struct Body<'a> {
statements: &'a [SqlStatement],
#[serde(skip_serializing_if = "Option::is_none")]
transaction: Option<SqlTransactionMode>,
#[serde(skip_serializing_if = "Option::is_none", rename = "max_rows")]
max_rows: Option<u32>,
}
let body = Body {
statements: &statements,
transaction: options.transaction,
max_rows: options.max_rows,
};
let path = format!("/sqlite-databases/{}/exec", encode_path(&self.id));
let raw: RawExecResponse = self
.client
.request(Method::POST, &path, Some(&body), RequestOptions::default())
.await?;
Ok(BatchResult {
results: raw.results.into_iter().map(ExecResult::from).collect(),
elapsed_ms: raw.elapsed_ms,
})
}
pub async fn connect_token(
&self,
options: ConnectionTokenOptions,
) -> Result<ConnectionToken, HeyoError> {
#[derive(Serialize)]
struct Body {
#[serde(skip_serializing_if = "Option::is_none", rename = "ttl_seconds")]
ttl_seconds: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
scopes: Option<Vec<ConnectionScope>>,
}
let body = Body {
ttl_seconds: options.ttl_seconds,
scopes: options.scopes,
};
let path = format!("/sqlite-databases/{}/connection", encode_path(&self.id));
let raw: RawConnectionToken = self
.client
.request(Method::POST, &path, Some(&body), RequestOptions::default())
.await?;
Ok(raw.into())
}
pub async fn connection_info(&self) -> Result<(String, String), HeyoError> {
let path = format!("/sqlite-databases/{}/connection-info", encode_path(&self.id));
let raw: ConnectionInfoRaw = self
.client
.request(Method::GET, &path, None::<&()>, RequestOptions::default())
.await?;
Ok((raw.database_id, raw.url))
}
pub async fn list_connections(&self) -> Result<Vec<ConnectionTokenInfo>, HeyoError> {
let path = format!("/sqlite-databases/{}/connection-tokens", encode_path(&self.id));
let env: TokensEnvelope = self
.client
.request(Method::GET, &path, None::<&()>, RequestOptions::default())
.await?;
Ok(env.tokens.into_iter().map(ConnectionTokenInfo::from).collect())
}
pub async fn revoke_connection(&self, token_id: &str) -> Result<(), HeyoError> {
let path = format!(
"/sqlite-databases/{}/connection-tokens/{}",
encode_path(&self.id),
encode_path(token_id)
);
self.client
.request::<serde_json::Value>(Method::DELETE, &path, None::<&()>, RequestOptions::default())
.await?;
Ok(())
}
pub async fn checkout(&self) -> Result<CheckoutResult, HeyoError> {
let path = format!("/sqlite-databases/{}/file", encode_path(&self.id));
let response = self
.client
.raw_request(Method::GET, &path, None::<&()>, RequestOptions::default())
.await?;
if !response.status().is_success() {
let status = response.status().as_u16();
let body = response.bytes().await.unwrap_or_default();
return Err(HeyoError::api(
status,
format!(
"checkout failed for {}: {}",
self.id,
String::from_utf8_lossy(&body)
),
));
}
let version = response
.headers()
.get("x-heyo-data-version")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<i64>().ok())
.ok_or_else(|| {
HeyoError::api(0, "checkout response missing X-Heyo-Data-Version header")
})?;
let bytes = response
.bytes()
.await
.map_err(|e| HeyoError::api(0, format!("read checkout body: {}", e)))?;
Ok(CheckoutResult {
database_id: self.id.clone(),
data_version: version,
bytes: bytes.to_vec(),
})
}
pub async fn checkin(
&self,
bytes: Vec<u8>,
options: CheckinOptions,
) -> Result<CheckinResult, HeyoError> {
if !options.force && options.expected_version.is_none() {
return Err(HeyoError::invalid(
"checkin() requires `expected_version` unless `force = true`",
));
}
let mut req_opts = RequestOptions::default();
if let Some(v) = options.expected_version {
req_opts
.query
.push(("expected_version".to_string(), v.to_string()));
}
if options.force {
req_opts.query.push(("force".to_string(), "true".to_string()));
}
let path = format!("/sqlite-databases/{}/file", encode_path(&self.id));
let response = self
.client
.put_bytes(&path, bytes, "application/gzip", req_opts)
.await?;
let status = response.status();
let body = response
.bytes()
.await
.map_err(|e| HeyoError::api(0, format!("read checkin body: {}", e)))?;
if status.as_u16() == 409 {
let mut current = -1_i64;
if let Ok(v) = serde_json::from_slice::<serde_json::Value>(&body) {
if let Some(n) = v.get("current_version").and_then(|x| x.as_i64()) {
current = n;
}
}
return Err(HeyoError::CheckinConflict {
expected: options.expected_version,
current,
});
}
if !status.is_success() {
return Err(HeyoError::api(
status.as_u16(),
format!("checkin failed: {}", String::from_utf8_lossy(&body)),
));
}
let resp: CheckinResponse = serde_json::from_slice(&body)
.map_err(|e| HeyoError::api(0, format!("parse checkin response: {}", e)))?;
Ok(CheckinResult {
database_id: resp.database_id,
data_version: resp.data_version,
s3_key: resp.s3_key,
forced: resp.forced,
})
}
}