openauth_plugins/device_authorization/routes/
token.rs1use std::sync::Arc;
2
3use http::{header, Method, StatusCode};
4use openauth_core::api::{
5 create_auth_endpoint, parse_request_body, ApiRequest, ApiResponse, AuthEndpointOptions,
6 BodyField, BodySchema, JsonSchemaType,
7};
8use openauth_core::context::{request_state, AuthContext};
9use openauth_core::db::{DbAdapter, DbRecord, DbValue};
10use openauth_core::error::OpenAuthError;
11use openauth_core::plugin::PluginEndpoint;
12use openauth_core::session::{CreateSessionInput, DbSessionStore};
13use openauth_core::user::DbUserStore;
14use serde::{Deserialize, Serialize};
15use time::{Duration, OffsetDateTime};
16
17use crate::device_authorization::errors::{oauth_error_response, OAuthDeviceError};
18use crate::device_authorization::options::DeviceAuthorizationOptions;
19use crate::device_authorization::store::{
20 DeviceAuthorizationStatus, DeviceCodeRecord, DeviceCodeStore,
21};
22
23const DEVICE_CODE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:device_code";
24
25#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
26pub struct DeviceTokenRequest {
27 pub grant_type: String,
28 pub device_code: String,
29 pub client_id: String,
30}
31
32#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
33pub struct DeviceTokenResponse {
34 pub access_token: String,
35 pub token_type: String,
36 pub expires_in: i64,
37 pub scope: String,
38}
39
40pub fn device_token(options: Arc<DeviceAuthorizationOptions>) -> PluginEndpoint {
41 create_auth_endpoint(
42 "/device/token",
43 Method::POST,
44 AuthEndpointOptions::new()
45 .operation_id("deviceToken")
46 .allowed_media_types(["application/json", "application/x-www-form-urlencoded"])
47 .openapi(super::openapi::device_token_operation())
48 .body_schema(BodySchema::object([
49 BodyField::new("grant_type", JsonSchemaType::String),
50 BodyField::new("device_code", JsonSchemaType::String),
51 BodyField::new("client_id", JsonSchemaType::String),
52 ])),
53 move |context, request| {
54 let options = Arc::clone(&options);
55 Box::pin(async move {
56 let body = parse_request_body::<DeviceTokenRequest>(&request)?;
57 if body.grant_type != DEVICE_CODE_GRANT_TYPE {
58 return token_oauth_error_response(
59 StatusCode::BAD_REQUEST,
60 OAuthDeviceError::InvalidRequest,
61 "Invalid grant type",
62 );
63 }
64 if let Some(validate_client) = &options.validate_client {
65 if !(validate_client)(body.client_id.clone()).await? {
66 return token_oauth_error_response(
67 StatusCode::BAD_REQUEST,
68 OAuthDeviceError::InvalidGrant,
69 "Invalid client ID",
70 );
71 }
72 }
73
74 let adapter = required_adapter(context)?;
75 let store = DeviceCodeStore::new(adapter.as_ref());
76 let Some(record) = store.find_by_device_code(&body.device_code).await? else {
77 return token_oauth_error_response(
78 StatusCode::BAD_REQUEST,
79 OAuthDeviceError::InvalidGrant,
80 "Invalid device code",
81 );
82 };
83 if record
84 .client_id
85 .as_ref()
86 .is_some_and(|client_id| client_id != &body.client_id)
87 {
88 return token_oauth_error_response(
89 StatusCode::BAD_REQUEST,
90 OAuthDeviceError::InvalidGrant,
91 "Client ID mismatch",
92 );
93 }
94 if polling_too_fast(&record) {
95 return token_oauth_error_response(
96 StatusCode::BAD_REQUEST,
97 OAuthDeviceError::SlowDown,
98 "Polling too frequently",
99 );
100 }
101 store.mark_polled(&record.id).await?;
102 if record.expires_at < OffsetDateTime::now_utc() {
103 store.delete(&record.id).await?;
104 return token_oauth_error_response(
105 StatusCode::BAD_REQUEST,
106 OAuthDeviceError::ExpiredToken,
107 "Device code has expired",
108 );
109 }
110
111 match record.status {
112 DeviceAuthorizationStatus::Pending => token_oauth_error_response(
113 StatusCode::BAD_REQUEST,
114 OAuthDeviceError::AuthorizationPending,
115 "Authorization pending",
116 ),
117 DeviceAuthorizationStatus::Denied => {
118 store.delete(&record.id).await?;
119 token_oauth_error_response(
120 StatusCode::BAD_REQUEST,
121 OAuthDeviceError::AccessDenied,
122 "Access denied",
123 )
124 }
125 DeviceAuthorizationStatus::Approved => {
126 approved_response(context, adapter.as_ref(), &store, &record, &request)
127 .await
128 }
129 }
130 })
131 },
132 )
133}
134
135pub(super) fn required_adapter(context: &AuthContext) -> Result<Arc<dyn DbAdapter>, OpenAuthError> {
136 context.adapter().ok_or_else(|| {
137 OpenAuthError::Adapter("device authorization requires a database adapter".to_owned())
138 })
139}
140
141async fn approved_response(
142 context: &AuthContext,
143 adapter: &dyn DbAdapter,
144 store: &DeviceCodeStore<'_>,
145 record: &DeviceCodeRecord,
146 request: &ApiRequest,
147) -> Result<ApiResponse, OpenAuthError> {
148 let Some(user_id) = &record.user_id else {
149 return token_oauth_error_response(
150 StatusCode::INTERNAL_SERVER_ERROR,
151 OAuthDeviceError::ServerError,
152 "Invalid device code status",
153 );
154 };
155 let Some(user) = DbUserStore::new(adapter).find_user_by_id(user_id).await? else {
156 return token_oauth_error_response(
157 StatusCode::INTERNAL_SERVER_ERROR,
158 OAuthDeviceError::ServerError,
159 "User not found",
160 );
161 };
162 let expires_at =
163 OffsetDateTime::now_utc() + Duration::seconds(context.session_config.expires_in as i64);
164 let mut input = CreateSessionInput::new(user.id.clone(), expires_at)
165 .additional_fields(additional_session_create_values(context));
166 if let Some(ip_address) = request_ip(request) {
167 input = input.ip_address(ip_address);
168 }
169 if let Some(user_agent) = request_user_agent(request) {
170 input = input.user_agent(user_agent);
171 }
172 let session = match DbSessionStore::new(adapter).create_session(input).await {
173 Ok(session) => session,
174 Err(_) => {
175 return token_oauth_error_response(
176 StatusCode::INTERNAL_SERVER_ERROR,
177 OAuthDeviceError::ServerError,
178 "Failed to create session",
179 );
180 }
181 };
182 if request_state::has_request_state() {
183 request_state::set_current_new_session(session.clone(), user)?;
184 }
185 store.delete(&record.id).await?;
186
187 let expires_in = (session.expires_at - OffsetDateTime::now_utc())
188 .whole_seconds()
189 .max(0);
190 let mut response = super::json_response(
191 StatusCode::OK,
192 &DeviceTokenResponse {
193 access_token: session.token,
194 token_type: "Bearer".to_owned(),
195 expires_in,
196 scope: record.scope.clone().unwrap_or_default(),
197 },
198 )?;
199 add_token_cache_headers(&mut response);
200 Ok(response)
201}
202
203fn polling_too_fast(record: &DeviceCodeRecord) -> bool {
204 let (Some(last_polled_at), Some(interval)) = (record.last_polled_at, record.polling_interval)
205 else {
206 return false;
207 };
208 let elapsed = OffsetDateTime::now_utc() - last_polled_at;
209 elapsed.whole_milliseconds() < i128::from(interval)
210}
211
212fn token_oauth_error_response(
213 status: StatusCode,
214 error: OAuthDeviceError,
215 description: &str,
216) -> Result<ApiResponse, OpenAuthError> {
217 let mut response = oauth_error_response(status, error, description)?;
218 add_token_cache_headers(&mut response);
219 Ok(response)
220}
221
222fn add_token_cache_headers(response: &mut ApiResponse) {
223 response.headers_mut().insert(
224 header::CACHE_CONTROL,
225 http::HeaderValue::from_static("no-store"),
226 );
227 response
228 .headers_mut()
229 .insert(header::PRAGMA, http::HeaderValue::from_static("no-cache"));
230}
231
232fn additional_session_create_values(context: &AuthContext) -> DbRecord {
233 context
234 .options
235 .session
236 .additional_fields
237 .iter()
238 .map(|(name, field)| {
239 (
240 name.clone(),
241 field.default_value.clone().unwrap_or(DbValue::Null),
242 )
243 })
244 .collect()
245}
246
247fn request_user_agent(request: &ApiRequest) -> Option<String> {
248 request
249 .headers()
250 .get(header::USER_AGENT)
251 .and_then(|value| value.to_str().ok())
252 .map(str::to_owned)
253}
254
255fn request_ip(request: &ApiRequest) -> Option<String> {
256 request
257 .headers()
258 .get("x-forwarded-for")
259 .and_then(|value| value.to_str().ok())
260 .and_then(|value| value.split(',').next())
261 .map(str::trim)
262 .filter(|value| !value.is_empty())
263 .map(str::to_owned)
264 .or_else(|| {
265 request
266 .headers()
267 .get("x-real-ip")
268 .and_then(|value| value.to_str().ok())
269 .map(str::to_owned)
270 })
271}