1use std::sync::Arc;
7
8use axum::extract::{Path, State};
9use axum::http::StatusCode;
10use axum::routing::{get, post};
11use axum::{Extension, Json, Router};
12use base64::Engine as _;
13use base64::engine::general_purpose::STANDARD as BASE64;
14use serde::{Deserialize, Serialize};
15
16use crate::error::AppError;
17use crate::middleware::AuthContext;
18use crate::state::AppState;
19use zvault_core::policy::Capability;
20use zvault_core::transit::TransitEngine;
21
22pub fn router() -> Router<Arc<AppState>> {
34 Router::new()
35 .route("/keys", get(list_keys))
36 .route("/keys/{name}", get(key_info).post(create_key))
37 .route("/keys/{name}/rotate", post(rotate_key))
38 .route("/encrypt/{name}", post(encrypt))
39 .route("/decrypt/{name}", post(decrypt))
40 .route("/rewrap/{name}", post(rewrap))
41 .route("/datakey/{name}", post(generate_data_key))
42}
43
44#[derive(Debug, Deserialize)]
47pub struct EncryptRequest {
48 pub plaintext: String,
50}
51
52#[derive(Debug, Serialize)]
53pub struct EncryptResponse {
54 pub ciphertext: String,
55}
56
57#[derive(Debug, Deserialize)]
58pub struct DecryptRequest {
59 pub ciphertext: String,
61}
62
63#[derive(Debug, Serialize)]
64pub struct DecryptResponse {
65 pub plaintext: String,
67}
68
69#[derive(Debug, Deserialize)]
70pub struct RewrapRequest {
71 pub ciphertext: String,
73}
74
75#[derive(Debug, Serialize)]
76pub struct RewrapResponse {
77 pub ciphertext: String,
78}
79
80#[derive(Debug, Serialize)]
81pub struct DataKeyResponse {
82 pub plaintext: String,
84 pub ciphertext: String,
86}
87
88#[derive(Debug, Serialize)]
89pub struct KeyListResponse {
90 pub keys: Vec<String>,
91}
92
93#[derive(Debug, Serialize)]
94pub struct KeyInfoResponse {
95 pub name: String,
96 pub latest_version: u32,
97 pub min_decryption_version: u32,
98 pub supports_encryption: bool,
99 pub supports_decryption: bool,
100 pub version_count: u32,
101 pub created_at: String,
102}
103
104#[derive(Debug, Serialize)]
105pub struct RotateResponse {
106 pub new_version: u32,
107}
108
109async fn create_key(
113 State(state): State<Arc<AppState>>,
114 Extension(auth): Extension<AuthContext>,
115 Path(name): Path<String>,
116) -> Result<StatusCode, AppError> {
117 state
118 .policy_store
119 .check(&auth.policies, &format!("transit/keys/{name}"), &Capability::Create)
120 .await?;
121
122 let engine = get_transit_engine(&state).await?;
123 engine.create_key(&name).await?;
124
125 Ok(StatusCode::NO_CONTENT)
126}
127
128async fn rotate_key(
130 State(state): State<Arc<AppState>>,
131 Extension(auth): Extension<AuthContext>,
132 Path(name): Path<String>,
133) -> Result<Json<RotateResponse>, AppError> {
134 state
135 .policy_store
136 .check(&auth.policies, &format!("transit/keys/{name}"), &Capability::Update)
137 .await?;
138
139 let engine = get_transit_engine(&state).await?;
140 let new_version = engine.rotate_key(&name).await?;
141
142 Ok(Json(RotateResponse { new_version }))
143}
144
145async fn encrypt(
147 State(state): State<Arc<AppState>>,
148 Extension(auth): Extension<AuthContext>,
149 Path(name): Path<String>,
150 Json(body): Json<EncryptRequest>,
151) -> Result<Json<EncryptResponse>, AppError> {
152 state
153 .policy_store
154 .check(&auth.policies, &format!("transit/encrypt/{name}"), &Capability::Update)
155 .await?;
156
157 let plaintext_bytes = base64_decode(&body.plaintext)?;
158 let engine = get_transit_engine(&state).await?;
159 let ciphertext = engine.encrypt(&name, &plaintext_bytes).await?;
160
161 Ok(Json(EncryptResponse { ciphertext }))
162}
163
164async fn decrypt(
166 State(state): State<Arc<AppState>>,
167 Extension(auth): Extension<AuthContext>,
168 Path(name): Path<String>,
169 Json(body): Json<DecryptRequest>,
170) -> Result<Json<DecryptResponse>, AppError> {
171 state
172 .policy_store
173 .check(&auth.policies, &format!("transit/decrypt/{name}"), &Capability::Update)
174 .await?;
175
176 let engine = get_transit_engine(&state).await?;
177 let plaintext = engine.decrypt(&name, &body.ciphertext).await?;
178
179 let plaintext_b64 = BASE64.encode(&plaintext);
180
181 Ok(Json(DecryptResponse { plaintext: plaintext_b64 }))
182}
183
184async fn rewrap(
186 State(state): State<Arc<AppState>>,
187 Extension(auth): Extension<AuthContext>,
188 Path(name): Path<String>,
189 Json(body): Json<RewrapRequest>,
190) -> Result<Json<RewrapResponse>, AppError> {
191 state
192 .policy_store
193 .check(&auth.policies, &format!("transit/rewrap/{name}"), &Capability::Update)
194 .await?;
195
196 let engine = get_transit_engine(&state).await?;
197 let ciphertext = engine.rewrap(&name, &body.ciphertext).await?;
198
199 Ok(Json(RewrapResponse { ciphertext }))
200}
201
202async fn generate_data_key(
204 State(state): State<Arc<AppState>>,
205 Extension(auth): Extension<AuthContext>,
206 Path(name): Path<String>,
207) -> Result<Json<DataKeyResponse>, AppError> {
208 state
209 .policy_store
210 .check(&auth.policies, &format!("transit/datakey/{name}"), &Capability::Update)
211 .await?;
212
213 let engine = get_transit_engine(&state).await?;
214 let dk = engine.generate_data_key(&name).await?;
215
216 Ok(Json(DataKeyResponse {
217 plaintext: dk.plaintext,
218 ciphertext: dk.ciphertext,
219 }))
220}
221
222async fn list_keys(
224 State(state): State<Arc<AppState>>,
225 Extension(auth): Extension<AuthContext>,
226) -> Result<Json<KeyListResponse>, AppError> {
227 state
228 .policy_store
229 .check(&auth.policies, "transit/keys", &Capability::List)
230 .await?;
231
232 let engine = get_transit_engine(&state).await?;
233 let keys = engine.list_keys().await?;
234
235 Ok(Json(KeyListResponse { keys }))
236}
237
238async fn key_info(
240 State(state): State<Arc<AppState>>,
241 Extension(auth): Extension<AuthContext>,
242 Path(name): Path<String>,
243) -> Result<Json<KeyInfoResponse>, AppError> {
244 state
245 .policy_store
246 .check(&auth.policies, &format!("transit/keys/{name}"), &Capability::Read)
247 .await?;
248
249 let engine = get_transit_engine(&state).await?;
250 let info = engine.key_info(&name).await?;
251
252 Ok(Json(KeyInfoResponse {
253 name: info.name,
254 latest_version: info.latest_version,
255 min_decryption_version: info.min_decryption_version,
256 supports_encryption: info.supports_encryption,
257 supports_decryption: info.supports_decryption,
258 version_count: info.version_count,
259 created_at: info.created_at.to_rfc3339(),
260 }))
261}
262
263async fn get_transit_engine(state: &AppState) -> Result<Arc<TransitEngine>, AppError> {
267 state
268 .transit_engines
269 .read()
270 .await
271 .get("transit/")
272 .cloned()
273 .ok_or_else(|| AppError::NotFound("no transit engine mounted".to_owned()))
274}
275
276fn base64_decode(input: &str) -> Result<Vec<u8>, AppError> {
278 BASE64
279 .decode(input)
280 .map_err(|e| AppError::BadRequest(format!("invalid base64 input: {e}")))
281}