openauth_plugins/device_authorization/routes/
code.rs1use std::sync::Arc;
2
3use http::{header, Method, StatusCode};
4use openauth_core::api::{
5 create_auth_endpoint, parse_request_body, AuthEndpointOptions, BodyField, BodySchema,
6 JsonSchemaType,
7};
8use openauth_core::context::AuthContext;
9use openauth_core::crypto::random::generate_random_string;
10use openauth_core::db::DbAdapter;
11use openauth_core::error::OpenAuthError;
12use openauth_core::plugin::PluginEndpoint;
13use rand::rngs::OsRng;
14use rand::RngCore;
15use serde::{Deserialize, Serialize};
16use time::OffsetDateTime;
17use url::Url;
18
19use crate::device_authorization::errors::{oauth_error_response, OAuthDeviceError};
20use crate::device_authorization::options::{AsyncDeviceCodeGenerator, DeviceAuthorizationOptions};
21use crate::device_authorization::store::{CreateDeviceCodeInput, DeviceCodeStore};
22
23const DEFAULT_USER_CODE_CHARSET: &[u8] = b"ABCDEFGHJKLMNPQRSTUVWXYZ23456789";
24
25#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
26pub struct DeviceCodeRequest {
27 pub client_id: String,
28 pub scope: Option<String>,
29}
30
31#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
32pub struct DeviceCodeResponse {
33 pub device_code: String,
34 pub user_code: String,
35 pub verification_uri: String,
36 pub verification_uri_complete: String,
37 pub expires_in: i64,
38 pub interval: i64,
39}
40
41pub fn device_code(options: Arc<DeviceAuthorizationOptions>) -> PluginEndpoint {
42 create_auth_endpoint(
43 "/device/code",
44 Method::POST,
45 AuthEndpointOptions::new()
46 .operation_id("deviceCode")
47 .allowed_media_types(["application/json", "application/x-www-form-urlencoded"])
48 .openapi(super::openapi::device_code_operation())
49 .body_schema(BodySchema::object([
50 BodyField::new("client_id", JsonSchemaType::String),
51 BodyField::optional("scope", JsonSchemaType::String),
52 ])),
53 move |context, request| {
54 let options = Arc::clone(&options);
55 Box::pin(async move {
56 let body = parse_request_body::<DeviceCodeRequest>(&request)?;
57 if let Some(validate_client) = &options.validate_client {
58 if !(validate_client)(body.client_id.clone()).await? {
59 return oauth_error_response(
60 StatusCode::BAD_REQUEST,
61 OAuthDeviceError::InvalidClient,
62 "Invalid client ID",
63 );
64 }
65 }
66 if let Some(hook) = &options.on_device_auth_request {
67 (hook)(body.client_id.clone(), body.scope.clone()).await?;
68 }
69
70 let adapter = required_adapter(context)?;
71 let device_code = generate_code(
72 options.generate_device_code.as_ref(),
73 options.device_code_length,
74 default_device_code,
75 )
76 .await;
77 let user_code = generate_code(
78 options.generate_user_code.as_ref(),
79 options.user_code_length,
80 default_user_code,
81 )
82 .await;
83 let expires_in = options.expires_in.whole_seconds();
84 let interval = options.interval.whole_seconds();
85 let polling_interval = i64::try_from(options.interval.whole_milliseconds())
86 .map_err(|_| {
87 OpenAuthError::InvalidConfig(
88 "device authorization interval is too large".to_owned(),
89 )
90 })?;
91
92 DeviceCodeStore::new(adapter.as_ref())
93 .create(CreateDeviceCodeInput {
94 device_code: device_code.clone(),
95 user_code: super::clean_user_code(&user_code),
96 expires_at: OffsetDateTime::now_utc() + options.expires_in,
97 polling_interval,
98 client_id: body.client_id,
99 scope: body.scope,
100 })
101 .await?;
102
103 let (verification_uri, verification_uri_complete) = build_verification_uris(
104 &options.verification_uri,
105 &context.base_url,
106 &user_code,
107 )?;
108 let mut response = super::json_response(
109 StatusCode::OK,
110 &DeviceCodeResponse {
111 device_code,
112 user_code,
113 verification_uri,
114 verification_uri_complete,
115 expires_in,
116 interval,
117 },
118 )?;
119 response.headers_mut().insert(
120 header::CACHE_CONTROL,
121 http::HeaderValue::from_static("no-store"),
122 );
123 Ok(response)
124 })
125 },
126 )
127}
128
129fn required_adapter(context: &AuthContext) -> Result<Arc<dyn DbAdapter>, OpenAuthError> {
130 context.adapter().ok_or_else(|| {
131 OpenAuthError::Adapter("device authorization requires a database adapter".to_owned())
132 })
133}
134
135async fn generate_code(
136 generator: Option<&AsyncDeviceCodeGenerator>,
137 length: usize,
138 fallback: fn(usize) -> String,
139) -> String {
140 match generator {
141 Some(generator) => generator().await,
142 None => fallback(length),
143 }
144}
145
146fn default_device_code(length: usize) -> String {
147 generate_random_string(length)
148}
149
150fn default_user_code(length: usize) -> String {
151 let mut bytes = vec![0_u8; length];
152 OsRng.fill_bytes(&mut bytes);
153 bytes
154 .into_iter()
155 .map(|byte| {
156 let index = usize::from(byte) % DEFAULT_USER_CODE_CHARSET.len();
157 char::from(DEFAULT_USER_CODE_CHARSET[index])
158 })
159 .collect()
160}
161
162fn build_verification_uris(
163 verification_uri: &str,
164 base_url: &str,
165 user_code: &str,
166) -> Result<(String, String), OpenAuthError> {
167 let verification_url = Url::parse(verification_uri)
168 .or_else(|_| Url::parse(base_url).and_then(|base| base.join(verification_uri)))
169 .map_err(|error| OpenAuthError::InvalidConfig(error.to_string()))?;
170 let mut complete = verification_url.clone();
171 complete
172 .query_pairs_mut()
173 .append_pair("user_code", user_code);
174 Ok((verification_url.to_string(), complete.to_string()))
175}