Skip to main content

openauth_plugins/device_authorization/routes/
code.rs

1use std::sync::Arc;
2
3use http::{header, Method, StatusCode};
4use openauth_core::api::{
5    create_auth_endpoint, parse_request_body, AuthEndpointOptions, BodyField, BodySchema,
6    JsonSchemaType,
7};
8use openauth_core::context::AuthContext;
9use openauth_core::crypto::random::generate_random_string;
10use openauth_core::db::DbAdapter;
11use openauth_core::error::OpenAuthError;
12use openauth_core::plugin::PluginEndpoint;
13use rand::rngs::OsRng;
14use rand::RngCore;
15use serde::{Deserialize, Serialize};
16use time::OffsetDateTime;
17use url::Url;
18
19use crate::device_authorization::errors::{oauth_error_response, OAuthDeviceError};
20use crate::device_authorization::options::{AsyncDeviceCodeGenerator, DeviceAuthorizationOptions};
21use crate::device_authorization::store::{CreateDeviceCodeInput, DeviceCodeStore};
22
23const DEFAULT_USER_CODE_CHARSET: &[u8] = b"ABCDEFGHJKLMNPQRSTUVWXYZ23456789";
24
25#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
26pub struct DeviceCodeRequest {
27    pub client_id: String,
28    pub scope: Option<String>,
29}
30
31#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
32pub struct DeviceCodeResponse {
33    pub device_code: String,
34    pub user_code: String,
35    pub verification_uri: String,
36    pub verification_uri_complete: String,
37    pub expires_in: i64,
38    pub interval: i64,
39}
40
41pub fn device_code(options: Arc<DeviceAuthorizationOptions>) -> PluginEndpoint {
42    create_auth_endpoint(
43        "/device/code",
44        Method::POST,
45        AuthEndpointOptions::new()
46            .operation_id("deviceCode")
47            .allowed_media_types(["application/json", "application/x-www-form-urlencoded"])
48            .openapi(super::openapi::device_code_operation())
49            .body_schema(BodySchema::object([
50                BodyField::new("client_id", JsonSchemaType::String),
51                BodyField::optional("scope", JsonSchemaType::String),
52            ])),
53        move |context, request| {
54            let options = Arc::clone(&options);
55            Box::pin(async move {
56                let body = parse_request_body::<DeviceCodeRequest>(&request)?;
57                if let Some(validate_client) = &options.validate_client {
58                    if !(validate_client)(body.client_id.clone()).await? {
59                        return oauth_error_response(
60                            StatusCode::BAD_REQUEST,
61                            OAuthDeviceError::InvalidClient,
62                            "Invalid client ID",
63                        );
64                    }
65                }
66                if let Some(hook) = &options.on_device_auth_request {
67                    (hook)(body.client_id.clone(), body.scope.clone()).await?;
68                }
69
70                let adapter = required_adapter(context)?;
71                let device_code = generate_code(
72                    options.generate_device_code.as_ref(),
73                    options.device_code_length,
74                    default_device_code,
75                )
76                .await;
77                let user_code = generate_code(
78                    options.generate_user_code.as_ref(),
79                    options.user_code_length,
80                    default_user_code,
81                )
82                .await;
83                let expires_in = options.expires_in.whole_seconds();
84                let interval = options.interval.whole_seconds();
85                let polling_interval = i64::try_from(options.interval.whole_milliseconds())
86                    .map_err(|_| {
87                        OpenAuthError::InvalidConfig(
88                            "device authorization interval is too large".to_owned(),
89                        )
90                    })?;
91
92                DeviceCodeStore::new(adapter.as_ref())
93                    .create(CreateDeviceCodeInput {
94                        device_code: device_code.clone(),
95                        user_code: super::clean_user_code(&user_code),
96                        expires_at: OffsetDateTime::now_utc() + options.expires_in,
97                        polling_interval,
98                        client_id: body.client_id,
99                        scope: body.scope,
100                    })
101                    .await?;
102
103                let (verification_uri, verification_uri_complete) = build_verification_uris(
104                    &options.verification_uri,
105                    &context.base_url,
106                    &user_code,
107                )?;
108                let mut response = super::json_response(
109                    StatusCode::OK,
110                    &DeviceCodeResponse {
111                        device_code,
112                        user_code,
113                        verification_uri,
114                        verification_uri_complete,
115                        expires_in,
116                        interval,
117                    },
118                )?;
119                response.headers_mut().insert(
120                    header::CACHE_CONTROL,
121                    http::HeaderValue::from_static("no-store"),
122                );
123                Ok(response)
124            })
125        },
126    )
127}
128
129fn required_adapter(context: &AuthContext) -> Result<Arc<dyn DbAdapter>, OpenAuthError> {
130    context.adapter().ok_or_else(|| {
131        OpenAuthError::Adapter("device authorization requires a database adapter".to_owned())
132    })
133}
134
135async fn generate_code(
136    generator: Option<&AsyncDeviceCodeGenerator>,
137    length: usize,
138    fallback: fn(usize) -> String,
139) -> String {
140    match generator {
141        Some(generator) => generator().await,
142        None => fallback(length),
143    }
144}
145
146fn default_device_code(length: usize) -> String {
147    generate_random_string(length)
148}
149
150fn default_user_code(length: usize) -> String {
151    let mut bytes = vec![0_u8; length];
152    OsRng.fill_bytes(&mut bytes);
153    bytes
154        .into_iter()
155        .map(|byte| {
156            let index = usize::from(byte) % DEFAULT_USER_CODE_CHARSET.len();
157            char::from(DEFAULT_USER_CODE_CHARSET[index])
158        })
159        .collect()
160}
161
162fn build_verification_uris(
163    verification_uri: &str,
164    base_url: &str,
165    user_code: &str,
166) -> Result<(String, String), OpenAuthError> {
167    let verification_url = Url::parse(verification_uri)
168        .or_else(|_| Url::parse(base_url).and_then(|base| base.join(verification_uri)))
169        .map_err(|error| OpenAuthError::InvalidConfig(error.to_string()))?;
170    let mut complete = verification_url.clone();
171    complete
172        .query_pairs_mut()
173        .append_pair("user_code", user_code);
174    Ok((verification_url.to_string(), complete.to_string()))
175}