docker_volume/
handler.rs

1use anyhow::anyhow;
2use std::fmt::Debug;
3use std::path::PathBuf;
4use std::sync::Arc;
5
6use crate::driver::VolumeDriver;
7use axum::body::Bytes;
8use axum::http::{HeaderMap, HeaderValue, Request, StatusCode};
9use axum::middleware::Next;
10use axum::response::{IntoResponse, Response};
11use axum::{middleware, routing::post, Router};
12use hyper::{Body, Server};
13use hyperlocal::UnixServerExt;
14use tokio::fs;
15use tracing::{debug, info};
16
17pub struct VolumeHandler<T: VolumeDriver> {
18    driver: Arc<T>,
19}
20
21impl<T: VolumeDriver> VolumeHandler<T> {
22    pub fn new(driver: T) -> Self {
23        Self {
24            driver: Arc::new(driver),
25        }
26    }
27
28    pub async fn run_tcp(&self, port: u16) -> Result<(), anyhow::Error> {
29        info!("Starting Volume handler on port: {}", port);
30        let app = self.build_router();
31
32        let addr = format!("0.0.0.0:{port}").parse()?;
33        axum::Server::bind(&addr)
34            .serve(app.into_make_service())
35            .await?;
36        Ok(())
37    }
38
39    pub async fn run_unix_socket(&self, socket_path: PathBuf) -> Result<(), anyhow::Error> {
40        info!("Starting Volume handler on unix socket: {:?}", socket_path);
41        // setup socket file
42        if socket_path.exists() {
43            fs::remove_file(&socket_path).await?;
44        }
45        fs::create_dir_all(socket_path.parent().ok_or(anyhow!("no parent dir"))?).await?;
46
47        let app = self.build_router();
48        Server::bind_unix(socket_path)?
49            .serve(app.into_make_service())
50            .await?;
51
52        Ok(())
53    }
54
55    fn build_router(&self) -> Router {
56        Router::new()
57            .route("/Plugin.Activate", post(T::activate))
58            .route("/VolumeDriver.Create", post(T::create))
59            .route("/VolumeDriver.Remove", post(T::remove))
60            .route("/VolumeDriver.Mount", post(T::mount))
61            .route("/VolumeDriver.Unmount", post(T::unmount))
62            .route("/VolumeDriver.Get", post(T::get))
63            .route("/VolumeDriver.List", post(T::list))
64            .route("/VolumeDriver.Path", post(T::path))
65            .route("/VolumeDriver.Capabilities", post(T::capabilities))
66            .with_state(self.driver.clone())
67            .layer(middleware::from_fn(print_request_response))
68    }
69}
70
71async fn print_request_response(
72    req: Request<Body>,
73    next: Next<Body>,
74) -> Result<impl IntoResponse, (StatusCode, String)> {
75    let (mut parts, body) = req.into_parts();
76    let (mut headers, uri) = (parts.headers.clone(), parts.uri.clone());
77    debug!("handling request for: {:?}", uri);
78    let bytes = buffer_and_print("request", &headers, body).await?;
79    headers.insert("content-type", HeaderValue::from_static("application/json"));
80    parts.headers = headers;
81    let req = Request::from_parts(parts, Body::from(bytes));
82
83    let res = next.run(req).await;
84
85    let (parts, body) = res.into_parts();
86    let headers = &parts.headers;
87    let bytes = buffer_and_print("response", headers, body).await?;
88    let res = Response::from_parts(parts, Body::from(bytes));
89
90    Ok(res)
91}
92
93async fn buffer_and_print<B, T: Debug>(
94    direction: &str,
95    headers: &HeaderMap<T>,
96    body: B,
97) -> Result<Bytes, (StatusCode, String)>
98where
99    B: axum::body::HttpBody<Data = Bytes>,
100    B::Error: std::fmt::Display,
101{
102    let bytes = match hyper::body::to_bytes(body).await {
103        Ok(bytes) => bytes,
104        Err(err) => {
105            return Err((
106                StatusCode::BAD_REQUEST,
107                format!("failed to read {direction}, err: {err}"),
108            ));
109        }
110    };
111
112    if let Ok(body) = std::str::from_utf8(&bytes) {
113        debug!("{} headers = {:?}, body = {:?}", direction, headers, body);
114    }
115
116    Ok(bytes)
117}