1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
//! An implementation of the [inertia.js] protocol for [axum].
//!
//! The basic idea is that any axum handler that accepts the `Inertia`
//! struct as a function parameter is an inertia endpoint. For
//! instance:
//!
//! ```rust
//! use axum_inertia::Inertia;
//! use axum::{Json, response::IntoResponse};
//! use serde_json::json;
//!
//! async fn my_handler_fn(i: Inertia) -> impl IntoResponse {
//!     i.render("Pages/MyPageComponent", json!({"myPageProps": "true"}))
//! }
//! ```
//!
//! This does the following:
//!
//! - If the incoming request is the initial page load (i.e., does not
//! have the `X-Inertia` header set to `true`), the
//! [render](Inertia::render) method responds with an html page, which
//! is configurable when setting up the initial Inertia state (see
//! [Getting started](#getting-started) below).
//!
//! - Otherwise, the handler responses with the standard inertia
//! "Page" object json, with the included component and page props
//! passed to [render](Inertia::render).
//!
//! - If the request has a mismatching asset version (again, this is
//! configurable), the handler responds with a `409 Conflict` to tell
//! the client to reload the page. The function body of the handler is
//! not executed in this case.
//!
//! # Getting started
//!
//! First, you'll need to provide your axum routes with
//! [InertiaConfig] state. This state boils down to two things: an
//! optional string representing the [asset version] and a function
//! that takes serialized props and returns an HTML string for the
//! initial page load.
//!
//! The [vite] module provides a convenient way to set up this state
//! with [axum::Router::with_state]. For instance, the following code
//! sets up a standard development server:
//!
//! ```rust
//! use axum_inertia::{vite, Inertia};
//! use axum::{Router, routing::get, response::IntoResponse};
//!
//! // Configuration for Inertia when using `vite dev`:
//! let inertia = vite::Development::default()
//!     .port(5173)
//!     .main("src/main.ts")
//!     .lang("en")
//!     .title("My inertia app")
//!     .into_config();
//! let app: Router = Router::new()
//!     .route("/", get(get_root))
//!     .with_state(inertia);
//!
//! # async fn get_root(_i: Inertia) -> impl IntoResponse { "foo" }
//! ```
//!
//! The [Inertia] struct is then available as an axum [Extractor] and
//! can be used in handlers like so:
//!
//! ```rust
//! use axum::response::IntoResponse;
//! # use axum_inertia::Inertia;
//! use serde_json::json;
//!
//! async fn get_root(i: Inertia) -> impl IntoResponse {
//!     i.render("Pages/Home", json!({ "posts": vec!["post one", "post two"] }))
//! }
//! ```
//!
//! The [Inertia::render] method takes care of building a response
//! conforming to the [inertia.js protocol]. It takes two parameters:
//! the name of the component to render, and the page props
//! (serializable to json).
//!
//! Using the extractor in a handler *requires* that you use
//! [axum::Router::with_state] to initialize Inertia in your
//! routes. In fact, it won't compile if you don't!
//!
//! # Using InertiaConfig as substate
//!
//! It's likely you'll want other pieces of state beyond
//! [InertiaConfig]. You'll just need to implement
//! [axum::extract::FromRef] for your state type for
//! [InertiaConfig]. For instance:
//!
//! ```rust
//! use axum_inertia::{vite, Inertia, InertiaConfig};
//! use axum::{Router, routing::get, extract::FromRef};
//! # use axum::response::IntoResponse;
//!
//! #[derive(Clone)]
//! struct AppState {
//!     inertia: InertiaConfig,
//!     name: String
//! }
//!
//! impl FromRef<AppState> for InertiaConfig {
//!     fn from_ref(app_state: &AppState) -> InertiaConfig {
//!         app_state.inertia.clone()
//!     }
//! }
//!
//! let inertia = vite::Development::default()
//!     .port(5173)
//!     .main("src/main.ts")
//!     .lang("en")
//!     .title("My inertia app")
//!     .into_config();
//! let app_state = AppState { inertia, name: "foo".to_string() };
//! let app: Router = Router::new()
//!     .route("/", get(get_root))
//!     .with_state(app_state);
//!
//! # async fn get_root(_i: Inertia) -> impl IntoResponse { "foo" }
//! ```
//!
//! # Configuring development and production
//!
//! See the [vite] module for more information.
//!
//! [Router::with_state]: https://docs.rs/axum/latest/axum/struct.Router.html#method.with_state
//! [asset version]: https://inertiajs.com/the-protocol#asset-versioning
//! [inertia.js]: https://inertiajs.com
//! [inertia.js protocol]: https://inertiajs.com/the-protocol
//! [axum]: https://crates.io/crates/axum
//! [Extractor]: https://docs.rs/axum/latest/axum/#extractors

use async_trait::async_trait;
use axum::extract::{FromRef, FromRequestParts};
pub use config::InertiaConfig;
use http::{request::Parts, HeaderMap, HeaderValue, StatusCode};
use page::Page;
use request::Request;
use response::Response;
use serde::Serialize;

pub mod config;
mod page;
mod request;
mod response;
pub mod vite;

#[derive(Clone)]
pub struct Inertia {
    request: Request,
    config: InertiaConfig,
}

#[async_trait]
impl<S> FromRequestParts<S> for Inertia
where
    S: Send + Sync,
    InertiaConfig: FromRef<S>,
{
    type Rejection = (StatusCode, HeaderMap<HeaderValue>);

    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
        let config = InertiaConfig::from_ref(state);
        let request = Request::from_request_parts(parts, state).await?;

        // Respond with a 409 conflict if X-Inertia-Version values
        // don't match for GET requests. See more at:
        // https://inertiajs.com/the-protocol#asset-versioning
        if parts.method == "GET"
            && request.is_xhr
            && config.version().is_some()
            && request.version != config.version()
        {
            let mut headers = HeaderMap::new();
            headers.insert("X-Inertia-Location", parts.uri.path().parse().unwrap());
            return Err((StatusCode::CONFLICT, headers));
        }

        Ok(Inertia::new(request, config))
    }
}

impl Inertia {
    fn new(request: Request, config: InertiaConfig) -> Inertia {
        Inertia { request, config }
    }

    /// Renders an Inertia response.
    pub fn render<S: Serialize>(self, component: &'static str, props: S) -> Response {
        let request = self.request;
        let url = request.url.clone();
        let page = Page {
            component,
            props: serde_json::to_value(props).expect("serialize"),
            url,
            version: self.config.version().clone(),
        };
        Response {
            page,
            request,
            config: self.config,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use axum::{self, response::IntoResponse, routing::get, Router};
    use reqwest::StatusCode;
    use serde_json::json;
    use tokio::net::TcpListener;

    #[tokio::test]
    async fn it_works() {
        async fn handler(i: Inertia) -> impl IntoResponse {
            i.render("foo!", json!({"bar": "baz"}))
        }

        let layout =
            Box::new(|props| format!(r#"<html><body><div id="app" data-page='{}'></div>"#, props));

        let config = InertiaConfig::new(Some("123".to_string()), layout);

        let app = Router::new()
            .route("/test", get(handler))
            .with_state(config);

        let listener = TcpListener::bind("127.0.0.1:0")
            .await
            .expect("Could not bind ephemeral socket");
        let addr = listener.local_addr().unwrap();

        tokio::spawn(async move {
            axum::serve(listener, app).await.expect("server error");
        });

        let res = reqwest::get(format!("http://{}/test", &addr))
            .await
            .unwrap();
        assert_eq!(res.status(), StatusCode::OK);
        assert_eq!(
            res.headers()
                .get("X-Inertia-Version")
                .map(|h| h.to_str().unwrap()),
            Some("123")
        );
    }

    #[tokio::test]
    async fn it_responds_with_conflict_on_version_mismatch() {
        async fn handler(i: Inertia) -> impl IntoResponse {
            i.render("foo!", json!({"bar": "baz"}))
        }

        let layout =
            Box::new(|props| format!(r#"<html><body><div id="app" data-page='{}'></div>"#, props));

        let inertia = InertiaConfig::new(Some("123".to_string()), layout);

        let app = Router::new()
            .route("/test", get(handler))
            .with_state(inertia);

        let listener = TcpListener::bind("127.0.0.1:0")
            .await
            .expect("Could not bind ephemeral socket");
        let addr = listener.local_addr().unwrap();

        tokio::spawn(async move {
            axum::serve(listener, app).await.expect("server error");
        });

        let client = reqwest::Client::new();

        let res = client
            .get(format!("http://{}/test", &addr))
            .header("X-Inertia", "true")
            .header("X-Inertia-Version", "456")
            .send()
            .await
            .unwrap();

        assert_eq!(res.status(), StatusCode::CONFLICT);
        assert_eq!(
            res.headers()
                .get("X-Inertia-Location")
                .map(|h| h.to_str().unwrap()),
            Some("/test")
        );
    }
}