sozu_client/
lib.rs

1//! # Sōzu client
2//!
3//! This library provides a client to interact with Sōzu.
4//! The client is able to do one-time request or send batches.
5
6use bb8::Pool;
7use sozu_command_lib::{
8    channel::ChannelError,
9    proto::command::{request::RequestType, Request, Response, ResponseStatus, WorkerRequest},
10};
11use tempdir::TempDir;
12use tokio::{
13    fs::File,
14    io::{AsyncWriteExt, BufWriter},
15    task::{spawn_blocking as blocking, JoinError},
16};
17use tracing::trace;
18
19use crate::channel::{ConnectionManager, ConnectionProperties};
20
21pub mod channel;
22pub mod config;
23pub mod socket;
24#[cfg(feature = "unpooled")]
25pub mod unpooled;
26
27// -----------------------------------------------------------------------------
28// Error
29
30#[derive(thiserror::Error, Debug)]
31pub enum Error {
32    #[error("failed to create connection pool over unix socket, {0}")]
33    CreatePool(channel::Error),
34    #[error("failed to execute blocking task, {0}")]
35    Join(JoinError),
36    #[error("failed to get connection to socket, {0}")]
37    GetConnection(bb8::RunError<channel::Error>),
38    #[error("failed to send request, {0}")]
39    Send(ChannelError),
40    #[error("failed to read response, {0}")]
41    Receive(ChannelError),
42    #[error("got an invalid status code, {0}")]
43    InvalidStatusCode(i32),
44    #[error("failed to execute request, got status '{0}', {1}")]
45    Failure(String, String, Response),
46    #[error("failed to create temporary directory, {0}")]
47    CreateTempDir(std::io::Error),
48    #[error("failed to create temporary file, {0}")]
49    CreateTempFile(std::io::Error),
50    #[error("failed to serialize worker request, {0}")]
51    Serialize(serde_json::Error),
52    #[error("failed to write worker request, {0}")]
53    Write(std::io::Error),
54    #[error("failed to flush worker request buffer, {0}")]
55    Flush(std::io::Error),
56}
57
58impl From<JoinError> for Error {
59    #[tracing::instrument]
60    fn from(err: JoinError) -> Self {
61        Self::Join(err)
62    }
63}
64
65impl Error {
66    #[tracing::instrument]
67    pub fn is_recoverable(&self) -> bool {
68        !matches!(self, Self::Send(_) | Self::Receive(_) | Self::CreatePool(_) | Self::GetConnection(_))
69    }
70}
71
72// -----------------------------------------------------------------------------
73// Sender
74
75#[async_trait::async_trait]
76pub trait Sender {
77    type Error;
78
79    async fn send(&self, request: RequestType) -> Result<Response, Self::Error>;
80
81    async fn send_all(&self, requests: &[RequestType]) -> Result<Response, Self::Error>;
82}
83
84// -----------------------------------------------------------------------------
85// Client
86
87#[derive(Clone, Debug)]
88pub struct Client {
89    pool: Pool<ConnectionManager>,
90}
91
92#[async_trait::async_trait]
93impl Sender for Client {
94    type Error = Error;
95
96    #[tracing::instrument(skip_all)]
97    async fn send(&self, request: RequestType) -> Result<Response, Self::Error> {
98        trace!("Retrieve a connection to Sōzu's socket");
99        let mut conn = self.pool.get().await.map_err(Error::GetConnection)?;
100
101        trace!("Send request to Sōzu");
102        conn.write_message(&Request {
103            request_type: Some(request),
104        })
105        .map_err(Error::Send)?;
106
107        loop {
108            trace!("Read request to Sōzu");
109            let response = conn.read_message().map_err(Error::Receive)?;
110
111            let status = ResponseStatus::try_from(response.status)
112                .map_err(|_| Error::InvalidStatusCode(response.status))?;
113
114            match status {
115                ResponseStatus::Processing => continue,
116                ResponseStatus::Failure => {
117                    return Err(Error::Failure(status.as_str_name().to_string(), response.message.to_string().to_lowercase(), response));
118                }
119                ResponseStatus::Ok => {
120                    return Ok(response);
121                }
122            }
123        }
124    }
125
126    #[tracing::instrument(skip_all)]
127    async fn send_all(&self, requests: &[RequestType]) -> Result<Response, Self::Error> {
128        // -------------------------------------------------------------------------
129        // Create temporary folder and writer to batch requests
130        let tmpdir =
131            blocking(|| TempDir::new(env!("CARGO_PKG_NAME")).map_err(Error::CreateTempDir))
132                .await??;
133
134        let path = tmpdir.path().join("requests.json");
135        let mut writer = BufWriter::new(File::create(&path).await.map_err(Error::CreateTempFile)?);
136
137        for (idx, request) in requests.iter().cloned().enumerate() {
138            let worker_request = WorkerRequest {
139                id: format!("{}-{idx}", env!("CARGO_PKG_NAME")).to_uppercase(),
140                content: Request::from(request),
141            };
142
143            let payload =
144                blocking(move || serde_json::to_string(&worker_request).map_err(Error::Serialize))
145                    .await??;
146
147            writer
148                .write_all(format!("{payload}\n\0").as_bytes())
149                .await
150                .map_err(Error::Write)?;
151        }
152
153        writer.flush().await.map_err(Error::Flush)?;
154
155        // -------------------------------------------------------------------------
156        // Send a LoadState request with the file that we have created.
157        self.send(RequestType::LoadState(path.to_string_lossy().to_string()))
158            .await
159    }
160}
161
162impl From<Pool<ConnectionManager>> for Client {
163    #[tracing::instrument(skip_all)]
164    fn from(pool: Pool<ConnectionManager>) -> Self {
165        Self { pool }
166    }
167}
168
169impl Client {
170    #[tracing::instrument]
171    pub async fn try_new(opts: ConnectionProperties) -> Result<Self, Error> {
172        let pool = Pool::builder()
173            .build(ConnectionManager::new(opts))
174            .await
175            .map_err(Error::CreatePool)?;
176
177        Ok(Self::from(pool))
178    }
179}