1use anyhow::anyhow;
2use std::fmt::Debug;
3use std::path::PathBuf;
4use std::sync::Arc;
5
6use crate::driver::VolumeDriver;
7use axum::body::Bytes;
8use axum::http::{HeaderMap, HeaderValue, Request, StatusCode};
9use axum::middleware::Next;
10use axum::response::{IntoResponse, Response};
11use axum::{middleware, routing::post, Router};
12use hyper::{Body, Server};
13use hyperlocal::UnixServerExt;
14use tokio::fs;
15use tracing::{debug, info};
16
17pub struct VolumeHandler<T: VolumeDriver> {
18 driver: Arc<T>,
19}
20
21impl<T: VolumeDriver> VolumeHandler<T> {
22 pub fn new(driver: T) -> Self {
23 Self {
24 driver: Arc::new(driver),
25 }
26 }
27
28 pub async fn run_tcp(&self, port: u16) -> Result<(), anyhow::Error> {
29 info!("Starting Volume handler on port: {}", port);
30 let app = self.build_router();
31
32 let addr = format!("0.0.0.0:{port}").parse()?;
33 axum::Server::bind(&addr)
34 .serve(app.into_make_service())
35 .await?;
36 Ok(())
37 }
38
39 pub async fn run_unix_socket(&self, socket_path: PathBuf) -> Result<(), anyhow::Error> {
40 info!("Starting Volume handler on unix socket: {:?}", socket_path);
41 if socket_path.exists() {
43 fs::remove_file(&socket_path).await?;
44 }
45 fs::create_dir_all(socket_path.parent().ok_or(anyhow!("no parent dir"))?).await?;
46
47 let app = self.build_router();
48 Server::bind_unix(socket_path)?
49 .serve(app.into_make_service())
50 .await?;
51
52 Ok(())
53 }
54
55 fn build_router(&self) -> Router {
56 Router::new()
57 .route("/Plugin.Activate", post(T::activate))
58 .route("/VolumeDriver.Create", post(T::create))
59 .route("/VolumeDriver.Remove", post(T::remove))
60 .route("/VolumeDriver.Mount", post(T::mount))
61 .route("/VolumeDriver.Unmount", post(T::unmount))
62 .route("/VolumeDriver.Get", post(T::get))
63 .route("/VolumeDriver.List", post(T::list))
64 .route("/VolumeDriver.Path", post(T::path))
65 .route("/VolumeDriver.Capabilities", post(T::capabilities))
66 .with_state(self.driver.clone())
67 .layer(middleware::from_fn(print_request_response))
68 }
69}
70
71async fn print_request_response(
72 req: Request<Body>,
73 next: Next<Body>,
74) -> Result<impl IntoResponse, (StatusCode, String)> {
75 let (mut parts, body) = req.into_parts();
76 let (mut headers, uri) = (parts.headers.clone(), parts.uri.clone());
77 debug!("handling request for: {:?}", uri);
78 let bytes = buffer_and_print("request", &headers, body).await?;
79 headers.insert("content-type", HeaderValue::from_static("application/json"));
80 parts.headers = headers;
81 let req = Request::from_parts(parts, Body::from(bytes));
82
83 let res = next.run(req).await;
84
85 let (parts, body) = res.into_parts();
86 let headers = &parts.headers;
87 let bytes = buffer_and_print("response", headers, body).await?;
88 let res = Response::from_parts(parts, Body::from(bytes));
89
90 Ok(res)
91}
92
93async fn buffer_and_print<B, T: Debug>(
94 direction: &str,
95 headers: &HeaderMap<T>,
96 body: B,
97) -> Result<Bytes, (StatusCode, String)>
98where
99 B: axum::body::HttpBody<Data = Bytes>,
100 B::Error: std::fmt::Display,
101{
102 let bytes = match hyper::body::to_bytes(body).await {
103 Ok(bytes) => bytes,
104 Err(err) => {
105 return Err((
106 StatusCode::BAD_REQUEST,
107 format!("failed to read {direction}, err: {err}"),
108 ));
109 }
110 };
111
112 if let Ok(body) = std::str::from_utf8(&bytes) {
113 debug!("{} headers = {:?}, body = {:?}", direction, headers, body);
114 }
115
116 Ok(bytes)
117}