http_client_unix_domain_socket/
client.rs1#[cfg(feature = "json")]
2use crate::error::ErrorAndResponseJson;
3use crate::{Error, error::ErrorAndResponse};
4use axum_core::body::Body;
5use http_body_util::BodyExt;
6use hyper::{
7 Method, Request, StatusCode,
8 client::conn::http1::{self, SendRequest},
9};
10use hyper_util::rt::TokioIo;
11#[cfg(feature = "json")]
12use serde::{Serialize, de::DeserializeOwned};
13use std::path::PathBuf;
14use tokio::{net::UnixStream, task::JoinHandle};
15
16#[derive(Debug)]
18pub struct ClientUnix {
19 socket_path: PathBuf,
20 sender: SendRequest<Body>,
21 join_handle: JoinHandle<Error>,
22}
23
24impl ClientUnix {
25 pub async fn try_new(socket_path: &str) -> Result<Self, Error> {
36 let socket_path = PathBuf::from(socket_path);
37 ClientUnix::try_connect(socket_path).await
38 }
39
40 pub async fn try_reconnect(self) -> Result<Self, Error> {
62 let socket_path = self.socket_path.clone();
63 self.abort().await;
64 ClientUnix::try_connect(socket_path).await
65 }
66
67 pub async fn abort(self) -> Option<Error> {
71 self.join_handle.abort();
72 self.join_handle.await.ok()
73 }
74
75 async fn try_connect(socket_path: PathBuf) -> Result<Self, Error> {
76 let stream = TokioIo::new(
77 UnixStream::connect(socket_path.clone())
78 .await
79 .map_err(Error::SocketConnectionInitiation)?,
80 );
81
82 let (sender, connection) = http1::handshake(stream).await.map_err(Error::Handhsake)?;
83
84 let join_handle =
85 tokio::task::spawn(
86 async move { Error::SocketConnectionClosed(connection.await.err()) },
87 );
88
89 Ok(ClientUnix {
90 socket_path,
91 sender,
92 join_handle,
93 })
94 }
95
96 pub async fn send_request(
146 &mut self,
147 endpoint: &str,
148 method: Method,
149 headers: &[(&str, &str)],
150 body_request: Option<Body>,
151 ) -> Result<(StatusCode, Vec<u8>), ErrorAndResponse> {
152 let mut request_builder = Request::builder();
153 for header in headers {
154 request_builder = request_builder.header(header.0, header.1);
155 }
156 let request = request_builder
157 .method(method)
158 .uri(format!("http://unix.socket{}", endpoint))
159 .body(body_request.unwrap_or(Body::empty()))
160 .map_err(|e| ErrorAndResponse::InternalError(Error::RequestBuild(e)))?;
161
162 let response = self
163 .sender
164 .send_request(request)
165 .await
166 .map_err(|e| ErrorAndResponse::InternalError(Error::RequestSend(e)))?;
167
168 let status_code = response.status();
169 let body_response = response
170 .collect()
171 .await
172 .map_err(|e| ErrorAndResponse::InternalError(Error::ResponseCollect(e)))?
173 .to_bytes();
174
175 if !status_code.is_success() {
176 return Err(ErrorAndResponse::ResponseUnsuccessful(
177 status_code,
178 body_response.to_vec(),
179 ));
180 }
181 Ok((status_code, body_response.to_vec()))
182 }
183
184 #[cfg(feature = "json")]
236 pub async fn send_request_json<IN: Serialize, OUT: DeserializeOwned, ERR: DeserializeOwned>(
237 &mut self,
238 endpoint: &str,
239 method: Method,
240 headers: &[(&str, &str)],
241 body_request: Option<&IN>,
242 ) -> Result<(StatusCode, OUT), ErrorAndResponseJson<ERR>> {
243 let mut headers = headers.to_vec();
244 headers.push(("Content-Type", "application/json"));
245
246 let body_request = match body_request {
247 Some(body_request) => Body::from(
248 serde_json::to_vec(body_request)
249 .map_err(|e| ErrorAndResponseJson::InternalError(Error::RequestParsing(e)))?,
250 ),
251 None => Body::empty(),
252 };
253
254 match self
255 .send_request(endpoint, method, &headers, Some(body_request))
256 .await
257 {
258 Ok((status_code, response)) => Ok((
259 status_code,
260 serde_json::from_slice(&response)
261 .map_err(|e| ErrorAndResponseJson::InternalError(Error::ResponseParsing(e)))?,
262 )),
263 Err(ErrorAndResponse::InternalError(e)) => Err(ErrorAndResponseJson::InternalError(e)),
264 Err(ErrorAndResponse::ResponseUnsuccessful(status_code, response)) => {
265 Err(ErrorAndResponseJson::ResponseUnsuccessful(
266 status_code,
267 serde_json::from_slice(&response).map_err(|e| {
268 ErrorAndResponseJson::InternalError(Error::ResponseParsing(e))
269 })?,
270 ))
271 }
272 }
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use crate::test_helpers::{server::Server, util::*};
280 use hyper::Method;
281
282 #[tokio::test]
283 async fn simple_request() {
284 let (_, mut client) = make_client_server("simple_request").await;
285
286 let (status_code, response) = client
287 .send_request("/nolanv", Method::GET, &[], None)
288 .await
289 .expect("client.send_request");
290
291 assert_eq!(status_code, StatusCode::OK);
292 assert_eq!(response, "Hello nolanv".as_bytes())
293 }
294
295 #[tokio::test]
296 async fn simple_404_request() {
297 let (_, mut client) = make_client_server("simple_404_request").await;
298
299 let result = client
300 .send_request("/nolanv/nope", Method::GET, &[], None)
301 .await;
302
303 assert!(matches!(
304 result.err(),
305 Some(ErrorAndResponse::ResponseUnsuccessful(status_code, _))
306 if status_code == StatusCode::NOT_FOUND
307 ));
308 }
309
310 #[tokio::test]
311 async fn multiple_request() {
312 let (_, mut client) = make_client_server("multiple_request").await;
313
314 for i in 0..20 {
315 let (status_code, response) = client
316 .send_request(&format!("/nolanv{}", i), Method::GET, &[], None)
317 .await
318 .expect("client.send_request");
319
320 assert_eq!(status_code, StatusCode::OK);
321
322 assert_eq!(response, format!("Hello nolanv{}", i).as_bytes())
323 }
324 }
325
326 #[tokio::test]
327 async fn server_not_started() {
328 let socket_path = make_socket_path_test("client", "server_not_started");
329
330 let client = ClientUnix::try_new(&socket_path).await;
331 assert!(matches!(
332 client.err(),
333 Some(Error::SocketConnectionInitiation(_))
334 ));
335 }
336
337 #[tokio::test]
338 async fn server_stopped() {
339 let (server, mut client) = make_client_server("server_stopped").await;
340 server.abort().await;
341
342 let response_result = client.send_request("/nolanv", Method::GET, &[], None).await;
343 assert!(matches!(
344 response_result.err(),
345 Some(ErrorAndResponse::InternalError(Error::RequestSend(e)))
346 if e.is_canceled()
347 ));
348
349 let _ = Server::try_new(&make_socket_path_test("client", "server_stopped"))
350 .await
351 .expect("Server::try_new");
352 let mut http_client = client.try_reconnect().await.expect("client.try_reconnect");
353
354 let (status_code, response) = http_client
355 .send_request("/nolanv", Method::GET, &[], None)
356 .await
357 .expect("client.send_request");
358
359 assert_eq!(status_code, StatusCode::OK);
360 assert_eq!(response, "Hello nolanv".as_bytes())
361 }
362
363 #[tokio::test]
364 async fn server_rebooted() {
365 let (server, mut client) = make_client_server("server_rebooted").await;
366 server.abort().await;
367
368 let _ = Server::try_new(&make_socket_path_test("client", "server_rebooted"))
369 .await
370 .expect("Server::try_new");
371
372 let response_result = client.send_request("/nolanv", Method::GET, &[], None).await;
373 assert!(matches!(
374 response_result.err(),
375 Some(ErrorAndResponse::InternalError(Error::RequestSend(e)))
376 if e.is_canceled()
377 ));
378 let mut http_client = client.try_reconnect().await.expect("client.try_reconnect");
379
380 let (status_code, response) = http_client
381 .send_request("/nolanv", Method::GET, &[], None)
382 .await
383 .expect("client.send_request");
384
385 assert_eq!(status_code, StatusCode::OK);
386 assert_eq!(response, "Hello nolanv".as_bytes())
387 }
388}
389
390#[cfg(feature = "json")]
391#[cfg(test)]
392mod json_tests {
393 use hyper::{Method, StatusCode};
394 use serde::{Deserialize, Serialize};
395 use serde_json::{Value, json};
396
397 use crate::{error::ErrorAndResponseJson, test_helpers::util::make_client_server};
398
399 #[derive(Deserialize, Debug)]
400 struct ErrorJson {
401 msg: String,
402 }
403
404 #[tokio::test]
405 async fn simple_get_request() {
406 let (_, mut client) = make_client_server("simple_get_request").await;
407
408 let (status_code, response) = client
409 .send_request_json::<(), Value, Value>("/json/nolanv", Method::GET, &[], None)
410 .await
411 .expect("client.send_request_json");
412
413 assert_eq!(status_code, StatusCode::OK);
414 assert_eq!(response.get("hello"), Some(&json!("nolanv")))
415 }
416
417 #[tokio::test]
418 async fn simple_get_404_request() {
419 let (_, mut client) = make_client_server("simple_get_404_request").await;
420
421 let result = client
422 .send_request_json::<(), Value, ErrorJson>("/json/nolanv/nop", Method::GET, &[], None)
423 .await;
424
425 dbg!(&result);
426 assert!(matches!(
427 result.err(),
428 Some(ErrorAndResponseJson::ResponseUnsuccessful(status_code, body))
429 if status_code == StatusCode::NOT_FOUND && body.msg == "not found"
430 ));
431 }
432
433 #[tokio::test]
434 async fn simple_post_request() {
435 let (_, mut client) = make_client_server("simple_post_request").await;
436
437 #[derive(Serialize)]
438 struct NameJson {
439 name: String,
440 }
441
442 #[derive(Deserialize)]
443 struct HelloJson {
444 hello: String,
445 }
446
447 let (status_code, response) = client
448 .send_request_json::<NameJson, HelloJson, Value>(
449 "/json",
450 Method::POST,
451 &[],
452 Some(&NameJson {
453 name: "nolanv".into(),
454 }),
455 )
456 .await
457 .expect("client.send_request_json");
458
459 assert_eq!(status_code, StatusCode::OK);
460 assert_eq!(response.hello, "nolanv")
461 }
462
463 #[tokio::test]
464 async fn simple_post_bad_request() {
465 let (_, mut client) = make_client_server("simple_post_bad_request").await;
466
467 #[derive(Serialize)]
468 struct NameBadJson {
469 nom: String,
470 }
471
472 let result = client
473 .send_request_json::<NameBadJson, Value, ErrorJson>(
474 "/json",
475 Method::POST,
476 &[],
477 Some(&NameBadJson {
478 nom: "nolanv".into(),
479 }),
480 )
481 .await;
482
483 assert!(matches!(
484 result.err(),
485 Some(ErrorAndResponseJson::ResponseUnsuccessful(status_code, body))
486 if status_code == StatusCode::BAD_REQUEST && body.msg == "bad request"
487 ));
488 }
489}