Skip to main content

oauth2_test_server/handlers/
device.rs

1use axum::{
2    extract::{Form, State},
3    response::IntoResponse,
4    Json,
5};
6use chrono::{Duration, Utc};
7use rand::Rng;
8use serde::Deserialize;
9use serde_json::json;
10
11use crate::{
12    crypto::{generate_token_string, issue_jwt},
13    error::OauthError,
14    models::{DeviceAuthorization, DeviceCodeResponse, DeviceTokenRequest},
15    store::AppState,
16};
17
18#[derive(Deserialize, Debug)]
19pub struct DeviceCodeRequest {
20    pub client_id: String,
21    pub scope: Option<String>,
22}
23
24const DEVICE_CODE_CHARSET: &[u8] = b"BCDFGHJKLMNPQRSTUVWXYZ23456789";
25
26fn generate_user_code() -> String {
27    let mut rng = rand::thread_rng();
28    let code: String = (0..8)
29        .map(|_| {
30            let idx = rng.gen_range(0..DEVICE_CODE_CHARSET.len());
31            DEVICE_CODE_CHARSET[idx] as char
32        })
33        .collect();
34    code.chars()
35        .collect::<Vec<_>>()
36        .chunks(4)
37        .map(|chunk| chunk.iter().collect::<String>())
38        .collect::<Vec<_>>()
39        .join("-")
40}
41
42pub async fn device_code(
43    State(state): State<AppState>,
44    Form(form): Form<DeviceCodeRequest>,
45) -> Result<impl IntoResponse, OauthError> {
46    let client = state
47        .store
48        .get_client(&form.client_id)
49        .await
50        .ok_or(OauthError::InvalidClient)?;
51
52    let scope = form.scope.clone().unwrap_or_else(|| client.scope.clone());
53
54    if let Err(e) = state.config.validate_scope(&scope) {
55        return Err(OauthError::InvalidScope(e));
56    }
57
58    let device_code = generate_token_string();
59    let user_code = generate_user_code();
60    let expires_in = state.config.authorization_code_expires_in;
61    let interval = 5;
62
63    let device_auth = DeviceAuthorization {
64        device_code: device_code.clone(),
65        user_code: user_code.clone(),
66        client_id: form.client_id.clone(),
67        scope: scope.clone(),
68        expires_at: Utc::now() + Duration::seconds(expires_in as i64),
69        user_id: None,
70        approved: false,
71    };
72
73    state
74        .store
75        .insert_device_code(device_code.clone(), device_auth)
76        .await;
77
78    let verification_uri = format!("{}/device", state.issuer());
79    let verification_uri_complete = Some(format!("{}?user_code={}", verification_uri, user_code));
80
81    Ok(Json(DeviceCodeResponse {
82        device_code,
83        user_code,
84        verification_uri,
85        verification_uri_complete,
86        expires_in,
87        interval,
88    }))
89}
90
91pub async fn device_token(
92    State(state): State<AppState>,
93    Form(form): Form<DeviceTokenRequest>,
94) -> Result<impl IntoResponse, OauthError> {
95    if form.grant_type != "urn:ietf:params:oauth:grant-type:device_code" {
96        return Err(OauthError::UnsupportedGrantType);
97    }
98
99    let device_auth = state
100        .store
101        .get_device_code(&form.device_code)
102        .await
103        .ok_or(OauthError::InvalidGrant)?;
104
105    if device_auth.expires_at < Utc::now() {
106        return Err(OauthError::InvalidGrant);
107    }
108
109    if device_auth.client_id != form.client_id {
110        return Err(OauthError::InvalidClient);
111    }
112
113    if !device_auth.approved {
114        return Err(OauthError::AuthorizationPending);
115    }
116
117    let client = state
118        .store
119        .get_client(&form.client_id)
120        .await
121        .ok_or(OauthError::InvalidClient)?;
122
123    let user_id = device_auth
124        .user_id
125        .clone()
126        .unwrap_or_else(|| "device-user".to_string());
127
128    let jwt = issue_jwt(
129        state.issuer(),
130        &client.client_id,
131        &user_id,
132        &device_auth.scope,
133        state.config.access_token_expires_in as i64,
134        &state.keys,
135    )
136    .map_err(|_| OauthError::ServerError)?;
137
138    let refresh_token = generate_token_string();
139
140    let token = crate::models::Token {
141        access_token: jwt.clone(),
142        refresh_token: Some(refresh_token.clone()),
143        client_id: client.client_id.clone(),
144        scope: device_auth.scope.clone(),
145        expires_at: Utc::now() + Duration::seconds(state.config.access_token_expires_in as i64),
146        user_id: user_id.clone(),
147        revoked: false,
148    };
149
150    state.store.insert_token(jwt.clone(), token.clone()).await;
151    state
152        .store
153        .insert_refresh_token(refresh_token.clone(), token)
154        .await;
155
156    Ok(Json(json!({
157        "access_token": jwt,
158        "token_type": "Bearer",
159        "expires_in": state.config.access_token_expires_in,
160        "refresh_token": refresh_token,
161        "scope": device_auth.scope
162    })))
163}