Skip to main content

openauth_plugins/device_authorization/
options.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use openauth_core::error::OpenAuthError;
6use time::Duration;
7
8pub type DeviceCodeGeneratorFuture = Pin<Box<dyn Future<Output = String> + Send>>;
9pub type AsyncDeviceCodeGenerator = Arc<dyn Fn() -> DeviceCodeGeneratorFuture + Send + Sync>;
10pub type ClientValidationFuture = Pin<Box<dyn Future<Output = Result<bool, OpenAuthError>> + Send>>;
11pub type ClientValidator = Arc<dyn Fn(String) -> ClientValidationFuture + Send + Sync>;
12pub type DeviceAuthRequestFuture = Pin<Box<dyn Future<Output = Result<(), OpenAuthError>> + Send>>;
13pub type DeviceAuthRequestHook =
14    Arc<dyn Fn(String, Option<String>) -> DeviceAuthRequestFuture + Send + Sync>;
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum DeviceAuthorizationOptionsError {
18    EmptyDeviceCodeLength,
19    EmptyUserCodeLength,
20    NonPositiveExpiresIn,
21    NonPositiveInterval,
22}
23
24impl std::fmt::Display for DeviceAuthorizationOptionsError {
25    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        let message = match self {
27            Self::EmptyDeviceCodeLength => "device code length must be greater than zero",
28            Self::EmptyUserCodeLength => "user code length must be greater than zero",
29            Self::NonPositiveExpiresIn => "expires_in must be positive",
30            Self::NonPositiveInterval => "interval must be positive",
31        };
32        formatter.write_str(message)
33    }
34}
35
36impl std::error::Error for DeviceAuthorizationOptionsError {}
37
38#[derive(Clone)]
39pub struct DeviceAuthorizationOptions {
40    pub expires_in: Duration,
41    pub interval: Duration,
42    pub device_code_length: usize,
43    pub user_code_length: usize,
44    pub generate_device_code: Option<AsyncDeviceCodeGenerator>,
45    pub generate_user_code: Option<AsyncDeviceCodeGenerator>,
46    pub validate_client: Option<ClientValidator>,
47    pub on_device_auth_request: Option<DeviceAuthRequestHook>,
48    pub verification_uri: String,
49    pub schema: DeviceAuthorizationSchemaOptions,
50}
51
52#[derive(Debug, Clone, Default, PartialEq, Eq)]
53pub struct DeviceAuthorizationSchemaOptions {
54    pub table_name: Option<String>,
55    pub fields: DeviceAuthorizationSchemaFields,
56}
57
58impl DeviceAuthorizationSchemaOptions {
59    pub fn new() -> Self {
60        Self::default()
61    }
62
63    #[must_use]
64    pub fn table_name(mut self, table_name: impl Into<String>) -> Self {
65        self.table_name = Some(table_name.into());
66        self
67    }
68
69    #[must_use]
70    pub fn field_name(
71        mut self,
72        logical_name: impl Into<String>,
73        physical_name: impl Into<String>,
74    ) -> Self {
75        self.fields.set(logical_name.into(), physical_name.into());
76        self
77    }
78}
79
80#[derive(Debug, Clone, Default, PartialEq, Eq)]
81pub struct DeviceAuthorizationSchemaFields {
82    pub id: Option<String>,
83    pub device_code: Option<String>,
84    pub user_code: Option<String>,
85    pub user_id: Option<String>,
86    pub expires_at: Option<String>,
87    pub status: Option<String>,
88    pub last_polled_at: Option<String>,
89    pub polling_interval: Option<String>,
90    pub client_id: Option<String>,
91    pub scope: Option<String>,
92    pub created_at: Option<String>,
93    pub updated_at: Option<String>,
94}
95
96impl DeviceAuthorizationSchemaFields {
97    fn set(&mut self, logical_name: String, physical_name: String) {
98        match logical_name.as_str() {
99            "id" => self.id = Some(physical_name),
100            "deviceCode" => self.device_code = Some(physical_name),
101            "userCode" => self.user_code = Some(physical_name),
102            "userId" => self.user_id = Some(physical_name),
103            "expiresAt" => self.expires_at = Some(physical_name),
104            "status" => self.status = Some(physical_name),
105            "lastPolledAt" => self.last_polled_at = Some(physical_name),
106            "pollingInterval" => self.polling_interval = Some(physical_name),
107            "clientId" => self.client_id = Some(physical_name),
108            "scope" => self.scope = Some(physical_name),
109            "createdAt" => self.created_at = Some(physical_name),
110            "updatedAt" => self.updated_at = Some(physical_name),
111            _ => {}
112        }
113    }
114}
115
116impl Default for DeviceAuthorizationOptions {
117    fn default() -> Self {
118        Self {
119            expires_in: Duration::minutes(30),
120            interval: Duration::seconds(5),
121            device_code_length: 40,
122            user_code_length: 8,
123            generate_device_code: None,
124            generate_user_code: None,
125            validate_client: None,
126            on_device_auth_request: None,
127            verification_uri: "/device".to_owned(),
128            schema: DeviceAuthorizationSchemaOptions::default(),
129        }
130    }
131}
132
133impl DeviceAuthorizationOptions {
134    pub fn new() -> Self {
135        Self::default()
136    }
137
138    pub fn validate(&self) -> Result<(), DeviceAuthorizationOptionsError> {
139        if self.device_code_length == 0 {
140            return Err(DeviceAuthorizationOptionsError::EmptyDeviceCodeLength);
141        }
142        if self.user_code_length == 0 {
143            return Err(DeviceAuthorizationOptionsError::EmptyUserCodeLength);
144        }
145        if self.expires_in <= Duration::ZERO {
146            return Err(DeviceAuthorizationOptionsError::NonPositiveExpiresIn);
147        }
148        if self.interval <= Duration::ZERO {
149            return Err(DeviceAuthorizationOptionsError::NonPositiveInterval);
150        }
151        Ok(())
152    }
153
154    #[must_use]
155    pub fn expires_in(mut self, expires_in: Duration) -> Self {
156        self.expires_in = expires_in;
157        self
158    }
159
160    #[must_use]
161    pub fn interval(mut self, interval: Duration) -> Self {
162        self.interval = interval;
163        self
164    }
165
166    #[must_use]
167    pub fn device_code_length(mut self, length: usize) -> Self {
168        self.device_code_length = length;
169        self
170    }
171
172    #[must_use]
173    pub fn user_code_length(mut self, length: usize) -> Self {
174        self.user_code_length = length;
175        self
176    }
177
178    #[must_use]
179    pub fn generate_device_code<F>(mut self, generator: F) -> Self
180    where
181        F: Fn() -> String + Send + Sync + 'static,
182    {
183        self.generate_device_code =
184            Some(Arc::new(move || Box::pin(std::future::ready(generator()))));
185        self
186    }
187
188    #[must_use]
189    pub fn generate_user_code<F>(mut self, generator: F) -> Self
190    where
191        F: Fn() -> String + Send + Sync + 'static,
192    {
193        self.generate_user_code = Some(Arc::new(move || Box::pin(std::future::ready(generator()))));
194        self
195    }
196
197    #[must_use]
198    pub fn generate_device_code_async<F, Fut>(mut self, generator: F) -> Self
199    where
200        F: Fn() -> Fut + Send + Sync + 'static,
201        Fut: Future<Output = String> + Send + 'static,
202    {
203        self.generate_device_code = Some(Arc::new(move || Box::pin(generator())));
204        self
205    }
206
207    #[must_use]
208    pub fn generate_user_code_async<F, Fut>(mut self, generator: F) -> Self
209    where
210        F: Fn() -> Fut + Send + Sync + 'static,
211        Fut: Future<Output = String> + Send + 'static,
212    {
213        self.generate_user_code = Some(Arc::new(move || Box::pin(generator())));
214        self
215    }
216
217    #[must_use]
218    pub fn validate_client<F, Fut>(mut self, validator: F) -> Self
219    where
220        F: Fn(String) -> Fut + Send + Sync + 'static,
221        Fut: Future<Output = Result<bool, OpenAuthError>> + Send + 'static,
222    {
223        self.validate_client = Some(Arc::new(move |client_id| Box::pin(validator(client_id))));
224        self
225    }
226
227    #[must_use]
228    pub fn on_device_auth_request<F, Fut>(mut self, hook: F) -> Self
229    where
230        F: Fn(String, Option<String>) -> Fut + Send + Sync + 'static,
231        Fut: Future<Output = Result<(), OpenAuthError>> + Send + 'static,
232    {
233        self.on_device_auth_request = Some(Arc::new(move |client_id, scope| {
234            Box::pin(hook(client_id, scope))
235        }));
236        self
237    }
238
239    #[must_use]
240    pub fn verification_uri(mut self, uri: impl Into<String>) -> Self {
241        self.verification_uri = uri.into();
242        self
243    }
244
245    #[must_use]
246    pub fn schema(mut self, schema: DeviceAuthorizationSchemaOptions) -> Self {
247        self.schema = schema;
248        self
249    }
250}