corro_client/
lib.rs

1use std::{
2    error::Error as StdError, io, net::SocketAddr, ops::Deref, path::Path, task::Poll,
3    time::Duration,
4};
5
6use bytes::{Buf, Bytes, BytesMut};
7use corro_api_types::{sqlite::RusqliteConnManager, QueryEvent, RqliteResponse, Statement};
8use futures::{ready, Stream, StreamExt, TryStreamExt};
9use hyper::{client::HttpConnector, http::HeaderName, Body, StatusCode};
10use tokio_serde::{formats::SymmetricalJson, SymmetricallyFramed};
11use tokio_util::{
12    codec::{Decoder, FramedRead, LinesCodec, LinesCodecError},
13    io::StreamReader,
14};
15use tracing::warn;
16use uuid::Uuid;
17
18#[derive(Clone)]
19pub struct CorrosionApiClient {
20    api_addr: SocketAddr,
21    api_client: hyper::Client<HttpConnector, Body>,
22}
23
24impl CorrosionApiClient {
25    pub fn new(api_addr: SocketAddr) -> Self {
26        Self {
27            api_addr,
28            api_client: hyper::Client::builder().http2_only(true).build_http(),
29        }
30    }
31
32    pub async fn query(&self, statement: &Statement) -> Result<hyper::Body, Error> {
33        let req = hyper::Request::builder()
34            .method(hyper::Method::POST)
35            .uri(format!("http://{}/v1/queries", self.api_addr))
36            .header(hyper::header::CONTENT_TYPE, "application/json")
37            .header(hyper::header::ACCEPT, "application/json")
38            .body(Body::from(serde_json::to_vec(statement)?))?;
39
40        let res = self.api_client.request(req).await?;
41
42        if !res.status().is_success() {
43            return Err(Error::UnexpectedStatusCode(res.status()));
44        }
45
46        Ok(res.into_body())
47    }
48
49    pub async fn watch(
50        &self,
51        statement: &Statement,
52    ) -> Result<(Uuid, impl Stream<Item = io::Result<QueryEvent>>), Error> {
53        let req = hyper::Request::builder()
54            .method(hyper::Method::POST)
55            .uri(format!("http://{}/v1/watches", self.api_addr))
56            .header(hyper::header::CONTENT_TYPE, "application/json")
57            .header(hyper::header::ACCEPT, "application/json")
58            .body(Body::from(serde_json::to_vec(statement)?))?;
59
60        let res = self.api_client.request(req).await?;
61
62        if !res.status().is_success() {
63            return Err(Error::UnexpectedStatusCode(res.status()));
64        }
65
66        // TODO: make that header name a const in corro-types
67        let id = res
68            .headers()
69            .get(HeaderName::from_static("corro-query-id"))
70            .and_then(|v| v.to_str().ok().and_then(|v| v.parse().ok()))
71            .ok_or(Error::ExpectedQueryId)?;
72
73        Ok((id, watch_stream(res.into_body())))
74    }
75
76    pub async fn watched_query(
77        &self,
78        id: Uuid,
79    ) -> Result<impl Stream<Item = io::Result<QueryEvent>>, Error> {
80        let req = hyper::Request::builder()
81            .method(hyper::Method::GET)
82            .uri(format!("http://{}/v1/watches/{}", self.api_addr, id))
83            .header(hyper::header::ACCEPT, "application/json")
84            .body(hyper::Body::empty())?;
85
86        let res = self.api_client.request(req).await?;
87
88        if !res.status().is_success() {
89            return Err(Error::UnexpectedStatusCode(res.status()));
90        }
91
92        Ok(watch_stream(res.into_body()))
93    }
94
95    pub async fn execute(&self, statements: &[Statement]) -> Result<RqliteResponse, Error> {
96        let req = hyper::Request::builder()
97            .method(hyper::Method::POST)
98            .uri(format!("http://{}/v1/transactions", self.api_addr))
99            .header(hyper::header::CONTENT_TYPE, "application/json")
100            .header(hyper::header::ACCEPT, "application/json")
101            .body(Body::from(serde_json::to_vec(statements)?))?;
102
103        let res = self.api_client.request(req).await?;
104
105        if !res.status().is_success() {
106            return Err(Error::UnexpectedStatusCode(res.status()));
107        }
108
109        let bytes = hyper::body::to_bytes(res.into_body()).await?;
110
111        Ok(serde_json::from_slice(&bytes)?)
112    }
113
114    pub async fn schema(&self, statements: &[Statement]) -> Result<RqliteResponse, Error> {
115        let req = hyper::Request::builder()
116            .method(hyper::Method::POST)
117            .uri(format!("http://{}/v1/migrations", self.api_addr))
118            .header(hyper::header::CONTENT_TYPE, "application/json")
119            .header(hyper::header::ACCEPT, "application/json")
120            .body(Body::from(serde_json::to_vec(statements)?))?;
121
122        let res = self.api_client.request(req).await?;
123
124        if !res.status().is_success() {
125            return Err(Error::UnexpectedStatusCode(res.status()));
126        }
127
128        let bytes = hyper::body::to_bytes(res.into_body()).await?;
129
130        Ok(serde_json::from_slice(&bytes)?)
131    }
132
133    pub async fn schema_from_paths<P: AsRef<Path>>(
134        &self,
135        schema_paths: &[P],
136    ) -> Result<RqliteResponse, Error> {
137        let mut statements = vec![];
138
139        for schema_path in schema_paths.iter() {
140            match tokio::fs::metadata(schema_path).await {
141                Ok(meta) => {
142                    if meta.is_dir() {
143                        match tokio::fs::read_dir(schema_path).await {
144                            Ok(mut dir) => {
145                                let mut entries = vec![];
146
147                                while let Ok(Some(entry)) = dir.next_entry().await {
148                                    entries.push(entry);
149                                }
150
151                                let mut entries: Vec<_> = entries
152                                    .into_iter()
153                                    .filter_map(|entry| {
154                                        entry.path().extension().and_then(|ext| {
155                                            if ext == "sql" {
156                                                Some(entry)
157                                            } else {
158                                                None
159                                            }
160                                        })
161                                    })
162                                    .collect();
163
164                                entries.sort_by_key(|entry| entry.path());
165
166                                for entry in entries.iter() {
167                                    match tokio::fs::read_to_string(entry.path()).await {
168                                        Ok(s) => {
169                                            statements.push(Statement::Simple(s));
170                                            // pushed.push(
171                                            //     entry.path().to_string_lossy().to_string().into(),
172                                            // );
173                                        }
174                                        Err(e) => {
175                                            warn!(
176                                                "could not read schema file '{}', error: {e}",
177                                                entry.path().display()
178                                            );
179                                        }
180                                    }
181                                }
182                            }
183                            Err(e) => {
184                                warn!(
185                                    "could not read dir '{}', error: {e}",
186                                    schema_path.as_ref().display()
187                                );
188                            }
189                        }
190                    } else if meta.is_file() {
191                        match tokio::fs::read_to_string(schema_path).await {
192                            Ok(s) => {
193                                statements.push(Statement::Simple(s));
194                                // pushed.push(schema_path.clone());
195                            }
196                            Err(e) => {
197                                warn!(
198                                    "could not read schema file '{}', error: {e}",
199                                    schema_path.as_ref().display()
200                                );
201                            }
202                        }
203                    }
204                }
205
206                Err(e) => {
207                    warn!(
208                        "could not read schema file meta '{}', error: {e}",
209                        schema_path.as_ref().display()
210                    );
211                }
212            }
213        }
214
215        self.schema(&statements).await
216    }
217}
218
219fn watch_stream(body: hyper::Body) -> impl Stream<Item = io::Result<QueryEvent>> {
220    let body = StreamReader::new(body.map_err(|e| {
221        if let Some(io_error) = e
222            .source()
223            .and_then(|source| source.downcast_ref::<io::Error>())
224        {
225            return io::Error::from(io_error.kind());
226        }
227        io::Error::new(io::ErrorKind::Other, e)
228    }));
229
230    let framed = FramedRead::new(body, LinesBytesCodec::new())
231        .map_err(|e| io::Error::new(io::ErrorKind::Other, e));
232
233    SymmetricallyFramed::new(framed, SymmetricalJson::<QueryEvent>::default())
234}
235
236struct LinesBytesCodec {
237    // Stored index of the next index to examine for a `\n` character.
238    // This is used to optimize searching.
239    // For example, if `decode` was called with `abc`, it would hold `3`,
240    // because that is the next index to examine.
241    // The next time `decode` is called with `abcde\n`, the method will
242    // only look at `de\n` before returning.
243    next_index: usize,
244
245    /// The maximum length for a given line. If `usize::MAX`, lines will be
246    /// read until a `\n` character is reached.
247    max_length: usize,
248
249    /// Are we currently discarding the remainder of a line which was over
250    /// the length limit?
251    is_discarding: bool,
252}
253
254impl LinesBytesCodec {
255    /// Returns a `LinesBytesCodec` for splitting up data into lines.
256    ///
257    /// # Note
258    ///
259    /// The returned `LinesBytesCodec` will not have an upper bound on the length
260    /// of a buffered line. See the documentation for [`new_with_max_length`]
261    /// for information on why this could be a potential security risk.
262    ///
263    /// [`new_with_max_length`]: crate::codec::LinesBytesCodec::new_with_max_length()
264    pub fn new() -> LinesBytesCodec {
265        LinesBytesCodec {
266            next_index: 0,
267            max_length: usize::MAX,
268            is_discarding: false,
269        }
270    }
271
272    /// Returns a `LinesBytesCodec` with a maximum line length limit.
273    ///
274    /// If this is set, calls to `LinesBytesCodec::decode` will return a
275    /// [`LinesCodecError`] when a line exceeds the length limit. Subsequent calls
276    /// will discard up to `limit` bytes from that line until a newline
277    /// character is reached, returning `None` until the line over the limit
278    /// has been fully discarded. After that point, calls to `decode` will
279    /// function as normal.
280    ///
281    /// # Note
282    ///
283    /// Setting a length limit is highly recommended for any `LinesBytesCodec` which
284    /// will be exposed to untrusted input. Otherwise, the size of the buffer
285    /// that holds the line currently being read is unbounded. An attacker could
286    /// exploit this unbounded buffer by sending an unbounded amount of input
287    /// without any `\n` characters, causing unbounded memory consumption.
288    ///
289    /// [`LinesCodecError`]: crate::codec::LinesCodecError
290    pub fn new_with_max_length(max_length: usize) -> Self {
291        LinesBytesCodec {
292            max_length,
293            ..LinesBytesCodec::new()
294        }
295    }
296}
297
298impl Decoder for LinesBytesCodec {
299    type Item = BytesMut;
300    type Error = LinesCodecError;
301
302    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, LinesCodecError> {
303        loop {
304            // Determine how far into the buffer we'll search for a newline. If
305            // there's no max_length set, we'll read to the end of the buffer.
306            let read_to = std::cmp::min(self.max_length.saturating_add(1), buf.len());
307
308            let newline_offset = buf[self.next_index..read_to]
309                .iter()
310                .position(|b| *b == b'\n');
311
312            match (self.is_discarding, newline_offset) {
313                (true, Some(offset)) => {
314                    // If we found a newline, discard up to that offset and
315                    // then stop discarding. On the next iteration, we'll try
316                    // to read a line normally.
317                    buf.advance(offset + self.next_index + 1);
318                    self.is_discarding = false;
319                    self.next_index = 0;
320                }
321                (true, None) => {
322                    // Otherwise, we didn't find a newline, so we'll discard
323                    // everything we read. On the next iteration, we'll continue
324                    // discarding up to max_len bytes unless we find a newline.
325                    buf.advance(read_to);
326                    self.next_index = 0;
327                    if buf.is_empty() {
328                        return Ok(None);
329                    }
330                }
331                (false, Some(offset)) => {
332                    // Found a line!
333                    let newline_index = offset + self.next_index;
334                    self.next_index = 0;
335                    let mut line = buf.split_to(newline_index + 1);
336                    line.truncate(line.len() - 1);
337                    without_carriage_return(&mut line);
338                    return Ok(Some(line));
339                }
340                (false, None) if buf.len() > self.max_length => {
341                    // Reached the maximum length without finding a
342                    // newline, return an error and start discarding on the
343                    // next call.
344                    self.is_discarding = true;
345                    return Err(LinesCodecError::MaxLineLengthExceeded);
346                }
347                (false, None) => {
348                    // We didn't find a line or reach the length limit, so the next
349                    // call will resume searching at the current offset.
350                    self.next_index = read_to;
351                    return Ok(None);
352                }
353            }
354        }
355    }
356
357    fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, LinesCodecError> {
358        Ok(match self.decode(buf)? {
359            Some(frame) => Some(frame),
360            None => {
361                // No terminating newline - return remaining data, if any
362                if buf.is_empty() || buf == &b"\r"[..] {
363                    None
364                } else {
365                    let mut line = buf.split_to(buf.len());
366                    line.truncate(line.len() - 1);
367                    without_carriage_return(&mut line);
368                    self.next_index = 0;
369                    Some(line)
370                }
371            }
372        })
373    }
374}
375
376fn without_carriage_return(s: &mut BytesMut) {
377    if let Some(&b'\r') = s.last() {
378        s.truncate(s.len() - 1);
379    }
380}
381
382#[derive(Clone)]
383pub struct CorrosionClient {
384    api_client: CorrosionApiClient,
385    pool: bb8::Pool<RusqliteConnManager>,
386}
387
388impl CorrosionClient {
389    pub fn new<P: AsRef<Path>>(api_addr: SocketAddr, db_path: P) -> Self {
390        Self {
391            api_client: CorrosionApiClient::new(api_addr),
392            pool: bb8::Pool::builder()
393                .max_size(5)
394                .max_lifetime(Some(Duration::from_secs(30)))
395                .build_unchecked(RusqliteConnManager::new(&db_path)),
396        }
397    }
398
399    pub fn pool(&self) -> &bb8::Pool<RusqliteConnManager> {
400        &self.pool
401    }
402}
403
404impl Deref for CorrosionClient {
405    type Target = CorrosionApiClient;
406
407    fn deref(&self) -> &Self::Target {
408        &self.api_client
409    }
410}
411
412#[derive(Debug, thiserror::Error)]
413pub enum Error {
414    #[error(transparent)]
415    Hyper(#[from] hyper::Error),
416    #[error(transparent)]
417    Http(#[from] hyper::http::Error),
418    #[error(transparent)]
419    Serde(#[from] serde_json::Error),
420
421    #[error("received unexpected response code: {0}")]
422    UnexpectedStatusCode(StatusCode),
423
424    #[error("could not retrieve watch id from headers")]
425    ExpectedQueryId,
426}