Skip to main content

openauth_plugins/device_authorization/routes/
token.rs

1use std::sync::Arc;
2
3use http::{header, Method, StatusCode};
4use openauth_core::api::{
5    create_auth_endpoint, parse_request_body, ApiRequest, ApiResponse, AuthEndpointOptions,
6    BodyField, BodySchema, JsonSchemaType,
7};
8use openauth_core::context::{request_state, AuthContext};
9use openauth_core::db::{DbAdapter, DbRecord, DbValue};
10use openauth_core::error::OpenAuthError;
11use openauth_core::plugin::PluginEndpoint;
12use openauth_core::session::{CreateSessionInput, DbSessionStore};
13use openauth_core::user::DbUserStore;
14use serde::{Deserialize, Serialize};
15use time::{Duration, OffsetDateTime};
16
17use crate::device_authorization::errors::{oauth_error_response, OAuthDeviceError};
18use crate::device_authorization::options::DeviceAuthorizationOptions;
19use crate::device_authorization::store::{
20    DeviceAuthorizationStatus, DeviceCodeRecord, DeviceCodeStore,
21};
22
23const DEVICE_CODE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:device_code";
24
25#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
26pub struct DeviceTokenRequest {
27    pub grant_type: String,
28    pub device_code: String,
29    pub client_id: String,
30}
31
32#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
33pub struct DeviceTokenResponse {
34    pub access_token: String,
35    pub token_type: String,
36    pub expires_in: i64,
37    pub scope: String,
38}
39
40pub fn device_token(options: Arc<DeviceAuthorizationOptions>) -> PluginEndpoint {
41    create_auth_endpoint(
42        "/device/token",
43        Method::POST,
44        AuthEndpointOptions::new()
45            .operation_id("deviceToken")
46            .allowed_media_types(["application/json", "application/x-www-form-urlencoded"])
47            .openapi(super::openapi::device_token_operation())
48            .body_schema(BodySchema::object([
49                BodyField::new("grant_type", JsonSchemaType::String),
50                BodyField::new("device_code", JsonSchemaType::String),
51                BodyField::new("client_id", JsonSchemaType::String),
52            ])),
53        move |context, request| {
54            let options = Arc::clone(&options);
55            Box::pin(async move {
56                let body = parse_request_body::<DeviceTokenRequest>(&request)?;
57                if body.grant_type != DEVICE_CODE_GRANT_TYPE {
58                    return token_oauth_error_response(
59                        StatusCode::BAD_REQUEST,
60                        OAuthDeviceError::InvalidRequest,
61                        "Invalid grant type",
62                    );
63                }
64                if let Some(validate_client) = &options.validate_client {
65                    if !(validate_client)(body.client_id.clone()).await? {
66                        return token_oauth_error_response(
67                            StatusCode::BAD_REQUEST,
68                            OAuthDeviceError::InvalidGrant,
69                            "Invalid client ID",
70                        );
71                    }
72                }
73
74                let adapter = required_adapter(context)?;
75                let store = DeviceCodeStore::new(adapter.as_ref());
76                let Some(record) = store.find_by_device_code(&body.device_code).await? else {
77                    return token_oauth_error_response(
78                        StatusCode::BAD_REQUEST,
79                        OAuthDeviceError::InvalidGrant,
80                        "Invalid device code",
81                    );
82                };
83                if record
84                    .client_id
85                    .as_ref()
86                    .is_some_and(|client_id| client_id != &body.client_id)
87                {
88                    return token_oauth_error_response(
89                        StatusCode::BAD_REQUEST,
90                        OAuthDeviceError::InvalidGrant,
91                        "Client ID mismatch",
92                    );
93                }
94                if polling_too_fast(&record) {
95                    return token_oauth_error_response(
96                        StatusCode::BAD_REQUEST,
97                        OAuthDeviceError::SlowDown,
98                        "Polling too frequently",
99                    );
100                }
101                store.mark_polled(&record.id).await?;
102                if record.expires_at < OffsetDateTime::now_utc() {
103                    store.delete(&record.id).await?;
104                    return token_oauth_error_response(
105                        StatusCode::BAD_REQUEST,
106                        OAuthDeviceError::ExpiredToken,
107                        "Device code has expired",
108                    );
109                }
110
111                match record.status {
112                    DeviceAuthorizationStatus::Pending => token_oauth_error_response(
113                        StatusCode::BAD_REQUEST,
114                        OAuthDeviceError::AuthorizationPending,
115                        "Authorization pending",
116                    ),
117                    DeviceAuthorizationStatus::Denied => {
118                        store.delete(&record.id).await?;
119                        token_oauth_error_response(
120                            StatusCode::BAD_REQUEST,
121                            OAuthDeviceError::AccessDenied,
122                            "Access denied",
123                        )
124                    }
125                    DeviceAuthorizationStatus::Approved => {
126                        approved_response(context, adapter.as_ref(), &store, &record, &request)
127                            .await
128                    }
129                }
130            })
131        },
132    )
133}
134
135pub(super) fn required_adapter(context: &AuthContext) -> Result<Arc<dyn DbAdapter>, OpenAuthError> {
136    context.adapter().ok_or_else(|| {
137        OpenAuthError::Adapter("device authorization requires a database adapter".to_owned())
138    })
139}
140
141async fn approved_response(
142    context: &AuthContext,
143    adapter: &dyn DbAdapter,
144    store: &DeviceCodeStore<'_>,
145    record: &DeviceCodeRecord,
146    request: &ApiRequest,
147) -> Result<ApiResponse, OpenAuthError> {
148    let Some(user_id) = &record.user_id else {
149        return token_oauth_error_response(
150            StatusCode::INTERNAL_SERVER_ERROR,
151            OAuthDeviceError::ServerError,
152            "Invalid device code status",
153        );
154    };
155    let Some(user) = DbUserStore::new(adapter).find_user_by_id(user_id).await? else {
156        return token_oauth_error_response(
157            StatusCode::INTERNAL_SERVER_ERROR,
158            OAuthDeviceError::ServerError,
159            "User not found",
160        );
161    };
162    let expires_at =
163        OffsetDateTime::now_utc() + Duration::seconds(context.session_config.expires_in as i64);
164    let mut input = CreateSessionInput::new(user.id.clone(), expires_at)
165        .additional_fields(additional_session_create_values(context));
166    if let Some(ip_address) = request_ip(request) {
167        input = input.ip_address(ip_address);
168    }
169    if let Some(user_agent) = request_user_agent(request) {
170        input = input.user_agent(user_agent);
171    }
172    let session = match DbSessionStore::new(adapter).create_session(input).await {
173        Ok(session) => session,
174        Err(_) => {
175            return token_oauth_error_response(
176                StatusCode::INTERNAL_SERVER_ERROR,
177                OAuthDeviceError::ServerError,
178                "Failed to create session",
179            );
180        }
181    };
182    if request_state::has_request_state() {
183        request_state::set_current_new_session(session.clone(), user)?;
184    }
185    store.delete(&record.id).await?;
186
187    let expires_in = (session.expires_at - OffsetDateTime::now_utc())
188        .whole_seconds()
189        .max(0);
190    let mut response = super::json_response(
191        StatusCode::OK,
192        &DeviceTokenResponse {
193            access_token: session.token,
194            token_type: "Bearer".to_owned(),
195            expires_in,
196            scope: record.scope.clone().unwrap_or_default(),
197        },
198    )?;
199    add_token_cache_headers(&mut response);
200    Ok(response)
201}
202
203fn polling_too_fast(record: &DeviceCodeRecord) -> bool {
204    let (Some(last_polled_at), Some(interval)) = (record.last_polled_at, record.polling_interval)
205    else {
206        return false;
207    };
208    let elapsed = OffsetDateTime::now_utc() - last_polled_at;
209    elapsed.whole_milliseconds() < i128::from(interval)
210}
211
212fn token_oauth_error_response(
213    status: StatusCode,
214    error: OAuthDeviceError,
215    description: &str,
216) -> Result<ApiResponse, OpenAuthError> {
217    let mut response = oauth_error_response(status, error, description)?;
218    add_token_cache_headers(&mut response);
219    Ok(response)
220}
221
222fn add_token_cache_headers(response: &mut ApiResponse) {
223    response.headers_mut().insert(
224        header::CACHE_CONTROL,
225        http::HeaderValue::from_static("no-store"),
226    );
227    response
228        .headers_mut()
229        .insert(header::PRAGMA, http::HeaderValue::from_static("no-cache"));
230}
231
232fn additional_session_create_values(context: &AuthContext) -> DbRecord {
233    context
234        .options
235        .session
236        .additional_fields
237        .iter()
238        .map(|(name, field)| {
239            (
240                name.clone(),
241                field.default_value.clone().unwrap_or(DbValue::Null),
242            )
243        })
244        .collect()
245}
246
247fn request_user_agent(request: &ApiRequest) -> Option<String> {
248    request
249        .headers()
250        .get(header::USER_AGENT)
251        .and_then(|value| value.to_str().ok())
252        .map(str::to_owned)
253}
254
255fn request_ip(request: &ApiRequest) -> Option<String> {
256    request
257        .headers()
258        .get("x-forwarded-for")
259        .and_then(|value| value.to_str().ok())
260        .and_then(|value| value.split(',').next())
261        .map(str::trim)
262        .filter(|value| !value.is_empty())
263        .map(str::to_owned)
264        .or_else(|| {
265            request
266                .headers()
267                .get("x-real-ip")
268                .and_then(|value| value.to_str().ok())
269                .map(str::to_owned)
270        })
271}