1use std::{
2 error::Error as StdError, io, net::SocketAddr, ops::Deref, path::Path, task::Poll,
3 time::Duration,
4};
5
6use bytes::{Buf, Bytes, BytesMut};
7use corro_api_types::{sqlite::RusqliteConnManager, QueryEvent, RqliteResponse, Statement};
8use futures::{ready, Stream, StreamExt, TryStreamExt};
9use hyper::{client::HttpConnector, http::HeaderName, Body, StatusCode};
10use tokio_serde::{formats::SymmetricalJson, SymmetricallyFramed};
11use tokio_util::{
12 codec::{Decoder, FramedRead, LinesCodec, LinesCodecError},
13 io::StreamReader,
14};
15use tracing::warn;
16use uuid::Uuid;
17
18#[derive(Clone)]
19pub struct CorrosionApiClient {
20 api_addr: SocketAddr,
21 api_client: hyper::Client<HttpConnector, Body>,
22}
23
24impl CorrosionApiClient {
25 pub fn new(api_addr: SocketAddr) -> Self {
26 Self {
27 api_addr,
28 api_client: hyper::Client::builder().http2_only(true).build_http(),
29 }
30 }
31
32 pub async fn query(&self, statement: &Statement) -> Result<hyper::Body, Error> {
33 let req = hyper::Request::builder()
34 .method(hyper::Method::POST)
35 .uri(format!("http://{}/v1/queries", self.api_addr))
36 .header(hyper::header::CONTENT_TYPE, "application/json")
37 .header(hyper::header::ACCEPT, "application/json")
38 .body(Body::from(serde_json::to_vec(statement)?))?;
39
40 let res = self.api_client.request(req).await?;
41
42 if !res.status().is_success() {
43 return Err(Error::UnexpectedStatusCode(res.status()));
44 }
45
46 Ok(res.into_body())
47 }
48
49 pub async fn watch(
50 &self,
51 statement: &Statement,
52 ) -> Result<(Uuid, impl Stream<Item = io::Result<QueryEvent>>), Error> {
53 let req = hyper::Request::builder()
54 .method(hyper::Method::POST)
55 .uri(format!("http://{}/v1/watches", self.api_addr))
56 .header(hyper::header::CONTENT_TYPE, "application/json")
57 .header(hyper::header::ACCEPT, "application/json")
58 .body(Body::from(serde_json::to_vec(statement)?))?;
59
60 let res = self.api_client.request(req).await?;
61
62 if !res.status().is_success() {
63 return Err(Error::UnexpectedStatusCode(res.status()));
64 }
65
66 let id = res
68 .headers()
69 .get(HeaderName::from_static("corro-query-id"))
70 .and_then(|v| v.to_str().ok().and_then(|v| v.parse().ok()))
71 .ok_or(Error::ExpectedQueryId)?;
72
73 Ok((id, watch_stream(res.into_body())))
74 }
75
76 pub async fn watched_query(
77 &self,
78 id: Uuid,
79 ) -> Result<impl Stream<Item = io::Result<QueryEvent>>, Error> {
80 let req = hyper::Request::builder()
81 .method(hyper::Method::GET)
82 .uri(format!("http://{}/v1/watches/{}", self.api_addr, id))
83 .header(hyper::header::ACCEPT, "application/json")
84 .body(hyper::Body::empty())?;
85
86 let res = self.api_client.request(req).await?;
87
88 if !res.status().is_success() {
89 return Err(Error::UnexpectedStatusCode(res.status()));
90 }
91
92 Ok(watch_stream(res.into_body()))
93 }
94
95 pub async fn execute(&self, statements: &[Statement]) -> Result<RqliteResponse, Error> {
96 let req = hyper::Request::builder()
97 .method(hyper::Method::POST)
98 .uri(format!("http://{}/v1/transactions", self.api_addr))
99 .header(hyper::header::CONTENT_TYPE, "application/json")
100 .header(hyper::header::ACCEPT, "application/json")
101 .body(Body::from(serde_json::to_vec(statements)?))?;
102
103 let res = self.api_client.request(req).await?;
104
105 if !res.status().is_success() {
106 return Err(Error::UnexpectedStatusCode(res.status()));
107 }
108
109 let bytes = hyper::body::to_bytes(res.into_body()).await?;
110
111 Ok(serde_json::from_slice(&bytes)?)
112 }
113
114 pub async fn schema(&self, statements: &[Statement]) -> Result<RqliteResponse, Error> {
115 let req = hyper::Request::builder()
116 .method(hyper::Method::POST)
117 .uri(format!("http://{}/v1/migrations", self.api_addr))
118 .header(hyper::header::CONTENT_TYPE, "application/json")
119 .header(hyper::header::ACCEPT, "application/json")
120 .body(Body::from(serde_json::to_vec(statements)?))?;
121
122 let res = self.api_client.request(req).await?;
123
124 if !res.status().is_success() {
125 return Err(Error::UnexpectedStatusCode(res.status()));
126 }
127
128 let bytes = hyper::body::to_bytes(res.into_body()).await?;
129
130 Ok(serde_json::from_slice(&bytes)?)
131 }
132
133 pub async fn schema_from_paths<P: AsRef<Path>>(
134 &self,
135 schema_paths: &[P],
136 ) -> Result<RqliteResponse, Error> {
137 let mut statements = vec![];
138
139 for schema_path in schema_paths.iter() {
140 match tokio::fs::metadata(schema_path).await {
141 Ok(meta) => {
142 if meta.is_dir() {
143 match tokio::fs::read_dir(schema_path).await {
144 Ok(mut dir) => {
145 let mut entries = vec![];
146
147 while let Ok(Some(entry)) = dir.next_entry().await {
148 entries.push(entry);
149 }
150
151 let mut entries: Vec<_> = entries
152 .into_iter()
153 .filter_map(|entry| {
154 entry.path().extension().and_then(|ext| {
155 if ext == "sql" {
156 Some(entry)
157 } else {
158 None
159 }
160 })
161 })
162 .collect();
163
164 entries.sort_by_key(|entry| entry.path());
165
166 for entry in entries.iter() {
167 match tokio::fs::read_to_string(entry.path()).await {
168 Ok(s) => {
169 statements.push(Statement::Simple(s));
170 }
174 Err(e) => {
175 warn!(
176 "could not read schema file '{}', error: {e}",
177 entry.path().display()
178 );
179 }
180 }
181 }
182 }
183 Err(e) => {
184 warn!(
185 "could not read dir '{}', error: {e}",
186 schema_path.as_ref().display()
187 );
188 }
189 }
190 } else if meta.is_file() {
191 match tokio::fs::read_to_string(schema_path).await {
192 Ok(s) => {
193 statements.push(Statement::Simple(s));
194 }
196 Err(e) => {
197 warn!(
198 "could not read schema file '{}', error: {e}",
199 schema_path.as_ref().display()
200 );
201 }
202 }
203 }
204 }
205
206 Err(e) => {
207 warn!(
208 "could not read schema file meta '{}', error: {e}",
209 schema_path.as_ref().display()
210 );
211 }
212 }
213 }
214
215 self.schema(&statements).await
216 }
217}
218
219fn watch_stream(body: hyper::Body) -> impl Stream<Item = io::Result<QueryEvent>> {
220 let body = StreamReader::new(body.map_err(|e| {
221 if let Some(io_error) = e
222 .source()
223 .and_then(|source| source.downcast_ref::<io::Error>())
224 {
225 return io::Error::from(io_error.kind());
226 }
227 io::Error::new(io::ErrorKind::Other, e)
228 }));
229
230 let framed = FramedRead::new(body, LinesBytesCodec::new())
231 .map_err(|e| io::Error::new(io::ErrorKind::Other, e));
232
233 SymmetricallyFramed::new(framed, SymmetricalJson::<QueryEvent>::default())
234}
235
236struct LinesBytesCodec {
237 next_index: usize,
244
245 max_length: usize,
248
249 is_discarding: bool,
252}
253
254impl LinesBytesCodec {
255 pub fn new() -> LinesBytesCodec {
265 LinesBytesCodec {
266 next_index: 0,
267 max_length: usize::MAX,
268 is_discarding: false,
269 }
270 }
271
272 pub fn new_with_max_length(max_length: usize) -> Self {
291 LinesBytesCodec {
292 max_length,
293 ..LinesBytesCodec::new()
294 }
295 }
296}
297
298impl Decoder for LinesBytesCodec {
299 type Item = BytesMut;
300 type Error = LinesCodecError;
301
302 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, LinesCodecError> {
303 loop {
304 let read_to = std::cmp::min(self.max_length.saturating_add(1), buf.len());
307
308 let newline_offset = buf[self.next_index..read_to]
309 .iter()
310 .position(|b| *b == b'\n');
311
312 match (self.is_discarding, newline_offset) {
313 (true, Some(offset)) => {
314 buf.advance(offset + self.next_index + 1);
318 self.is_discarding = false;
319 self.next_index = 0;
320 }
321 (true, None) => {
322 buf.advance(read_to);
326 self.next_index = 0;
327 if buf.is_empty() {
328 return Ok(None);
329 }
330 }
331 (false, Some(offset)) => {
332 let newline_index = offset + self.next_index;
334 self.next_index = 0;
335 let mut line = buf.split_to(newline_index + 1);
336 line.truncate(line.len() - 1);
337 without_carriage_return(&mut line);
338 return Ok(Some(line));
339 }
340 (false, None) if buf.len() > self.max_length => {
341 self.is_discarding = true;
345 return Err(LinesCodecError::MaxLineLengthExceeded);
346 }
347 (false, None) => {
348 self.next_index = read_to;
351 return Ok(None);
352 }
353 }
354 }
355 }
356
357 fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, LinesCodecError> {
358 Ok(match self.decode(buf)? {
359 Some(frame) => Some(frame),
360 None => {
361 if buf.is_empty() || buf == &b"\r"[..] {
363 None
364 } else {
365 let mut line = buf.split_to(buf.len());
366 line.truncate(line.len() - 1);
367 without_carriage_return(&mut line);
368 self.next_index = 0;
369 Some(line)
370 }
371 }
372 })
373 }
374}
375
376fn without_carriage_return(s: &mut BytesMut) {
377 if let Some(&b'\r') = s.last() {
378 s.truncate(s.len() - 1);
379 }
380}
381
382#[derive(Clone)]
383pub struct CorrosionClient {
384 api_client: CorrosionApiClient,
385 pool: bb8::Pool<RusqliteConnManager>,
386}
387
388impl CorrosionClient {
389 pub fn new<P: AsRef<Path>>(api_addr: SocketAddr, db_path: P) -> Self {
390 Self {
391 api_client: CorrosionApiClient::new(api_addr),
392 pool: bb8::Pool::builder()
393 .max_size(5)
394 .max_lifetime(Some(Duration::from_secs(30)))
395 .build_unchecked(RusqliteConnManager::new(&db_path)),
396 }
397 }
398
399 pub fn pool(&self) -> &bb8::Pool<RusqliteConnManager> {
400 &self.pool
401 }
402}
403
404impl Deref for CorrosionClient {
405 type Target = CorrosionApiClient;
406
407 fn deref(&self) -> &Self::Target {
408 &self.api_client
409 }
410}
411
412#[derive(Debug, thiserror::Error)]
413pub enum Error {
414 #[error(transparent)]
415 Hyper(#[from] hyper::Error),
416 #[error(transparent)]
417 Http(#[from] hyper::http::Error),
418 #[error(transparent)]
419 Serde(#[from] serde_json::Error),
420
421 #[error("received unexpected response code: {0}")]
422 UnexpectedStatusCode(StatusCode),
423
424 #[error("could not retrieve watch id from headers")]
425 ExpectedQueryId,
426}