1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
use async_trait::async_trait;
use base64::Engine;
use super::{parse_query_result, QueryResult, Statement};
#[derive(Clone, Debug)]
pub struct Client {
url: String,
auth: String,
}
impl Client {
pub fn new(
url: impl Into<String>,
username: impl Into<String>,
pass: impl Into<String>,
) -> Self {
let username = username.into();
let pass = pass.into();
let url = url.into();
let url = if !url.contains("://") {
"https://".to_owned() + &url
} else {
url
};
Self {
url,
auth: format!(
"Basic {}",
base64::engine::general_purpose::STANDARD.encode(format!("{username}:{pass}"))
),
}
}
pub fn from_url(url: &url::Url) -> anyhow::Result<Client> {
let username = url.username();
let password = url.password().unwrap_or_default();
let mut url = url.clone();
url.set_username("")
.map_err(|_| anyhow::anyhow!("Could not extract username from URL. Invalid URL?"))?;
url.set_password(None)
.map_err(|_| anyhow::anyhow!("Could not extract password from URL. Invalid URL?"))?;
Ok(Client::new(url.as_str(), username, password))
}
fn batch(
&self,
stmts: impl IntoIterator<Item = impl Into<Statement>>,
) -> anyhow::Result<Vec<QueryResult>> {
let mut body = "{\"statements\": [".to_string();
let mut stmts_count = 0;
for stmt in stmts {
body += &format!("{},", stmt.into());
stmts_count += 1;
}
if stmts_count > 0 {
body.pop();
}
body += "]}";
let req = http::Request::builder()
.uri(&self.url)
.header("Authorization", &self.auth)
.method("POST")
.body(Some(bytes::Bytes::copy_from_slice(body.as_bytes())))?;
let response = spin_sdk::outbound_http::send_request(req);
let resp: String =
std::str::from_utf8(&response?.into_body().unwrap_or_default())?.to_string();
let response_json: serde_json::Value = serde_json::from_str(&resp)?;
match response_json {
serde_json::Value::Array(results) => {
if results.len() != stmts_count {
Err(anyhow::anyhow!(
"Response array did not contain expected {stmts_count} results"
))
} else {
let mut query_results: Vec<QueryResult> = Vec::with_capacity(stmts_count);
for (idx, result) in results.into_iter().enumerate() {
query_results.push(parse_query_result(result, idx)?);
}
Ok(query_results)
}
}
e => Err(anyhow::anyhow!("Error: {} ({:?})", e, body)),
}
}
}
#[async_trait(?Send)]
impl super::DatabaseClient for Client {
async fn batch(
&self,
stmts: impl IntoIterator<Item = impl Into<Statement>>,
) -> anyhow::Result<Vec<QueryResult>> {
self.batch(stmts).map_err(|e| anyhow::anyhow!("{e}"))
}
}