1use fastapi_rust::core::{
61 App, Request, RequestContext, Response, ResponseBody, SecureCompare, StatusCode, TestClient,
62};
63use serde::Serialize;
64
65const DEMO_BEARER_VALUE: &str = "demo-bearer-value";
68
69#[derive(Debug, Serialize)]
71struct LoginResponse {
72 access_token: String,
73 token_type: &'static str,
74}
75
76#[derive(Debug, Serialize)]
78struct UserInfo {
79 username: String,
80 message: String,
81}
82
83fn public_handler(_ctx: &RequestContext, _req: &mut Request) -> std::future::Ready<Response> {
87 let body = serde_json::json!({
88 "message": "This is a public endpoint - no authentication required!"
89 });
90 std::future::ready(
91 Response::ok()
92 .header("content-type", b"application/json".to_vec())
93 .body(ResponseBody::Bytes(body.to_string().into_bytes())),
94 )
95}
96
97fn login_handler(_ctx: &RequestContext, req: &mut Request) -> std::future::Ready<Response> {
107 let is_json = req
112 .headers()
113 .get("content-type")
114 .is_some_and(|ct| ct.starts_with(b"application/json"));
115
116 if !is_json {
117 let error = serde_json::json!({
118 "detail": "Content-Type must be application/json"
119 });
120 return std::future::ready(
121 Response::with_status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
122 .header("content-type", b"application/json".to_vec())
123 .body(ResponseBody::Bytes(error.to_string().into_bytes())),
124 );
125 }
126
127 let response = LoginResponse {
135 access_token: DEMO_BEARER_VALUE.to_string(),
136 token_type: "bearer",
137 };
138
139 std::future::ready(
140 Response::ok()
141 .header("content-type", b"application/json".to_vec())
142 .body(ResponseBody::Bytes(json_bytes(&response))),
143 )
144}
145
146fn protected_handler(_ctx: &RequestContext, req: &mut Request) -> std::future::Ready<Response> {
155 let Some(auth_header) = req.headers().get("authorization") else {
157 let body = serde_json::json!({
159 "detail": "Not authenticated"
160 });
161 return std::future::ready(
162 Response::with_status(StatusCode::UNAUTHORIZED)
163 .header("www-authenticate", b"Bearer".to_vec())
164 .header("content-type", b"application/json".to_vec())
165 .body(ResponseBody::Bytes(body.to_string().into_bytes())),
166 );
167 };
168
169 let Ok(auth_str) = std::str::from_utf8(auth_header) else {
171 let body = serde_json::json!({
173 "detail": "Invalid authentication credentials"
174 });
175 return std::future::ready(
176 Response::with_status(StatusCode::UNAUTHORIZED)
177 .header("www-authenticate", b"Bearer".to_vec())
178 .header("content-type", b"application/json".to_vec())
179 .body(ResponseBody::Bytes(body.to_string().into_bytes())),
180 );
181 };
182
183 let Some(bearer_value) = auth_str
185 .strip_prefix("Bearer ")
186 .or_else(|| auth_str.strip_prefix("bearer "))
187 else {
188 let body = serde_json::json!({
190 "detail": "Invalid authentication credentials"
191 });
192 return std::future::ready(
193 Response::with_status(StatusCode::UNAUTHORIZED)
194 .header("www-authenticate", b"Bearer".to_vec())
195 .header("content-type", b"application/json".to_vec())
196 .body(ResponseBody::Bytes(body.to_string().into_bytes())),
197 );
198 };
199
200 let bearer_value = bearer_value.trim();
201 if bearer_value.is_empty() {
202 let body = serde_json::json!({
204 "detail": "Invalid authentication credentials"
205 });
206 return std::future::ready(
207 Response::with_status(StatusCode::UNAUTHORIZED)
208 .header("www-authenticate", b"Bearer".to_vec())
209 .header("content-type", b"application/json".to_vec())
210 .body(ResponseBody::Bytes(body.to_string().into_bytes())),
211 );
212 }
213
214 if !bearer_value.secure_eq(DEMO_BEARER_VALUE) {
216 let body = serde_json::json!({
218 "detail": "Invalid token"
219 });
220 return std::future::ready(
221 Response::with_status(StatusCode::FORBIDDEN)
222 .header("content-type", b"application/json".to_vec())
223 .body(ResponseBody::Bytes(body.to_string().into_bytes())),
224 );
225 }
226
227 let user_info = UserInfo {
229 username: "demo_user".to_string(),
230 message: "You have accessed a protected resource!".to_string(),
231 };
232
233 std::future::ready(
234 Response::ok()
235 .header("content-type", b"application/json".to_vec())
236 .body(ResponseBody::Bytes(json_bytes(&user_info))),
237 )
238}
239
240fn json_bytes<T: Serialize>(value: &T) -> Vec<u8> {
241 match serde_json::to_string(value) {
242 Ok(text) => text.into_bytes(),
243 Err(err) => format!(r#"{{"detail":"json serialize error: {err}"}}"#).into_bytes(),
244 }
245}
246
247fn check(condition: bool, message: &str) -> bool {
248 if condition {
249 true
250 } else {
251 eprintln!("Check failed: {message}");
252 false
253 }
254}
255
256#[allow(clippy::needless_pass_by_value)]
257fn check_eq<T: PartialEq + std::fmt::Debug>(left: T, right: T, message: &str) -> bool {
258 if left == right {
259 true
260 } else {
261 eprintln!("Check failed: {message}. left={left:?} right={right:?}");
262 false
263 }
264}
265
266#[allow(clippy::too_many_lines)]
267fn main() {
268 println!("fastapi_rust Authentication Example");
269 println!("====================================\n");
270
271 let app = App::builder()
273 .get("/public", public_handler)
275 .post("/login", login_handler)
277 .get("/protected", protected_handler)
279 .build();
280
281 println!("App created with {} route(s)\n", app.route_count());
282
283 let client = TestClient::new(app);
285
286 println!("1. Public endpoint - no auth required");
290 let response = client.get("/public").send();
291 println!(
292 " GET /public -> {} {}",
293 response.status().as_u16(),
294 response.status().canonical_reason()
295 );
296 if !check_eq(
297 response.status().as_u16(),
298 200,
299 "GET /public should return 200",
300 ) {
301 return;
302 }
303 if !check(
304 response.text().contains("public endpoint"),
305 "GET /public should include the public endpoint body",
306 ) {
307 return;
308 }
309
310 println!("\n2. Protected endpoint - without token");
314 let response = client.get("/protected").send();
315 println!(
316 " GET /protected -> {} {}",
317 response.status().as_u16(),
318 response.status().canonical_reason()
319 );
320 if !check_eq(
321 response.status().as_u16(),
322 401,
323 "Protected endpoint should return 401 without token",
324 ) {
325 return;
326 }
327
328 let has_www_auth = response
330 .headers()
331 .iter()
332 .any(|(name, value)| name == "www-authenticate" && value == b"Bearer");
333 if !check(
334 has_www_auth,
335 "401 response should include WWW-Authenticate: Bearer header",
336 ) {
337 return;
338 }
339
340 println!("\n3. Login endpoint - get a token");
344 let response = client
345 .post("/login")
346 .header("content-type", "application/json")
347 .body(r#"{"username":"test","password":"test123"}"#)
348 .send();
349 println!(
350 " POST /login -> {} {}",
351 response.status().as_u16(),
352 response.status().canonical_reason()
353 );
354 if !check_eq(
355 response.status().as_u16(),
356 200,
357 "POST /login should return 200",
358 ) {
359 return;
360 }
361
362 let body_text = response.text();
364 let body: serde_json::Value = match serde_json::from_str(body_text) {
365 Ok(body) => body,
366 Err(err) => {
367 eprintln!("Failed to parse login response JSON: {err}");
368 return;
369 }
370 };
371 let Some(bearer_value) = body.get("access_token").and_then(|value| value.as_str()) else {
372 eprintln!("Login response missing access_token");
373 return;
374 };
375 println!(" Bearer value: {bearer_value}");
376 if !check_eq(
377 bearer_value,
378 DEMO_BEARER_VALUE,
379 "Login should return the expected bearer value",
380 ) {
381 return;
382 }
383
384 println!("\n4. Protected endpoint - with valid token");
388 let response = client
389 .get("/protected")
390 .header("authorization", format!("Bearer {DEMO_BEARER_VALUE}"))
391 .send();
392 println!(
393 " GET /protected (Authorization: Bearer {}) -> {} {}",
394 DEMO_BEARER_VALUE,
395 response.status().as_u16(),
396 response.status().canonical_reason()
397 );
398 if !check_eq(
399 response.status().as_u16(),
400 200,
401 "Protected endpoint should return 200 with valid token",
402 ) {
403 return;
404 }
405 if !check(
406 response.text().contains("protected resource"),
407 "Protected endpoint should include protected resource body",
408 ) {
409 return;
410 }
411
412 println!("\n5. Protected endpoint - with invalid token");
416 let response = client
417 .get("/protected")
418 .header("authorization", "Bearer wrong_token")
419 .send();
420 println!(
421 " GET /protected (Authorization: Bearer wrong_token) -> {} {}",
422 response.status().as_u16(),
423 response.status().canonical_reason()
424 );
425 if !check_eq(
426 response.status().as_u16(),
427 403,
428 "Protected endpoint should return 403 with invalid token",
429 ) {
430 return;
431 }
432
433 println!("\n6. Protected endpoint - with wrong auth scheme");
437 let response = client
438 .get("/protected")
439 .header("authorization", "Basic dXNlcjpwYXNz")
440 .send();
441 println!(
442 " GET /protected (Authorization: Basic ...) -> {} {}",
443 response.status().as_u16(),
444 response.status().canonical_reason()
445 );
446 if !check_eq(
447 response.status().as_u16(),
448 401,
449 "Protected endpoint should return 401 with wrong auth scheme",
450 ) {
451 return;
452 }
453
454 println!("\n7. Login with wrong Content-Type");
458 let response = client
459 .post("/login")
460 .header("content-type", "text/plain")
461 .body("demo=true")
462 .send();
463 println!(
464 " POST /login (Content-Type: text/plain) -> {} {}",
465 response.status().as_u16(),
466 response.status().canonical_reason()
467 );
468 if !check_eq(
469 response.status().as_u16(),
470 415,
471 "Login should return 415 with wrong Content-Type",
472 ) {
473 return;
474 }
475
476 println!("\n8. Token case sensitivity (lowercase 'bearer')");
480 let response = client
481 .get("/protected")
482 .header("authorization", format!("bearer {DEMO_BEARER_VALUE}"))
483 .send();
484 println!(
485 " GET /protected (Authorization: bearer {}) -> {} {}",
486 DEMO_BEARER_VALUE,
487 response.status().as_u16(),
488 response.status().canonical_reason()
489 );
490 if !check_eq(
491 response.status().as_u16(),
492 200,
493 "Bearer scheme should be case-insensitive (lowercase accepted)",
494 ) {
495 return;
496 }
497
498 println!("\nAll authentication tests passed!");
499}