1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
use async_trait::async_trait;
use base64::Engine;
use super::{QueryResult, Statement};
#[derive(Clone, Debug)]
pub struct Connection {
base_url: String,
url_for_queries: String,
auth: String,
}
impl Connection {
pub fn connect(
url: impl Into<String>,
username: impl Into<String>,
pass: impl Into<String>,
) -> Self {
let username = username.into();
let pass = pass.into();
let url = url.into();
let base_url = if !url.contains("://") {
"https://".to_owned() + &url
} else {
url
};
let url_for_queries = format!("{base_url}/queries");
Self {
base_url,
url_for_queries,
auth: format!(
"Basic {}",
base64::engine::general_purpose::STANDARD.encode(format!("{username}:{pass}"))
),
}
}
pub fn connect_from_url(url: &url::Url) -> anyhow::Result<Connection> {
let username = url.username();
let password = url.password().unwrap_or_default();
let mut url = url.clone();
url.set_username("")
.map_err(|_| anyhow::anyhow!("Could not extract username from URL. Invalid URL?"))?;
url.set_password(None)
.map_err(|_| anyhow::anyhow!("Could not extract password from URL. Invalid URL?"))?;
Ok(Connection::connect(url.as_str(), username, password))
}
pub fn connect_from_env() -> anyhow::Result<Connection> {
let url = std::env::var("LIBSQL_CLIENT_URL").map_err(|_| {
anyhow::anyhow!("LIBSQL_CLIENT_URL variable should point to your sqld database")
})?;
let user = match std::env::var("LIBSQL_CLIENT_USER") {
Ok(user) => user,
Err(_) => {
return Connection::connect_from_url(&url::Url::parse(&url)?);
}
};
let pass = std::env::var("LIBSQL_CLIENT_PASS").map_err(|_| {
anyhow::anyhow!("LIBSQL_CLIENT_PASS variable should be set to your sqld password")
})?;
Ok(Connection::connect(url, user, pass))
}
}
#[async_trait(?Send)]
impl super::Connection for Connection {
async fn batch(
&self,
stmts: impl IntoIterator<Item = impl Into<Statement>>,
) -> anyhow::Result<Vec<QueryResult>> {
let (body, stmts_count) = crate::connection::statements_to_string(stmts);
let client = reqwest::Client::new();
let response = match client
.post(&self.url_for_queries)
.body(body.clone())
.header("Authorization", &self.auth)
.send()
.await
{
Ok(resp) if resp.status() == reqwest::StatusCode::OK => resp,
_ => {
client
.post(&self.base_url)
.body(body)
.header("Authorization", &self.auth)
.send()
.await?
}
};
let resp: String = response.text().await?;
let response_json: serde_json::Value = serde_json::from_str(&resp)?;
crate::connection::json_to_query_result(response_json, stmts_count)
}
}